Python利用Faiss库实现ANN近邻搜索的方法详解

作者:蚂蚁学Python 时间:2021-11-17 19:35:02 

Embedding的近邻搜索是当前图推荐系统非常重要的一种召回方式,通过item2vec、矩阵分解、双塔DNN等方式都能够产出训练好的user embedding、item embedding,对于embedding的使用非常的灵活:

  • 输入user embedding,近邻搜索item embedding,可以给user推荐感兴趣的items

  • 输入user embedding,近邻搜搜user embedding,可以给user推荐感兴趣的user

  • 输入item embedding,近邻搜索item embedding,可以给item推荐相关的items

然而有一个工程问题,一旦user embedding、item embedding数据量达到一定的程度,对他们的近邻搜索将会变得非常慢,如果离线阶段提前搜索好在高速缓存比如redis存储好结果当然没问题,但是这种方式很不实时,如果能在线阶段上线几十MS的搜索当然效果最好。

Faiss是Facebook AI团队开源的针对聚类和相似性搜索库,为稠密向量提供高效相似度搜索和聚类,支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库。

接下来通过jupyter notebook的代码,给大家演示下使用faiss的简单流程,内容包括:

  • 读取训练好的Embedding数据

  • 构建faiss索引,将待搜索的Embedding添加进去

  • 取得目标Embedding,实现搜索得到ID列表

  • 根据ID获取电影标题,返回结果

对于已经训练好的Embedding怎样实现高速近邻搜索是一个工程问题,facebook的faiss库可以构建多种embedding索引实现目标embedding的高速近邻搜索,能够满足在线使用的需要

安装命令:


conda install -c pytorch faiss-cpu

提前总结下faiss使用经验:

1. 为了支持自己的ID,可以用faiss.IndexIDMap包裹faiss.IndexFlatL2即可

2. embedding数据都需要转换成np.float32,包括索引中的embedding以及待搜索的embedding

3. ids需要转换成int64类型

1. 准备数据


import pandas as pd
import numpy as np

df = pd.read_csv("./datas/movielens_sparkals_item_embedding.csv")
df.head()


idfeatures
010[0.25866490602493286, 0.3560594320297241, 0.15…
120[0.12449632585048676, -0.29282501339912415, -0…
230[0.9557555317878723, 0.6764761805534363, 0.114…
340[0.3184879720211029, 0.6365472078323364, 0.596…
450[0.45523127913475037, 0.34402626752853394, -0….

构建ids


ids = df["id"].values.astype(np.int64)
type(ids), ids.shape
(numpy.ndarray, (3706,))
ids.dtype
dtype('int64')
ids_size = ids.shape[0]
ids_size
3706

构建datas


import json
import numpy as np
datas = []
for x in df["features"]:
datas.append(json.loads(x))
datas = np.array(datas).astype(np.float32)
datas.dtype
dtype('float32')
datas.shape
(3706, 10)
datas[0]
array([ 0.2586649 , 0.35605943, 0.15589039, -0.7067125 , -0.07414215,
-0.62500805, -0.0573845 , 0.4533663 , 0.26074877, -0.60799956],
dtype=float32)
# 维度
dimension = datas.shape[1]
dimension
10

2. 建立索引


import faiss
index = faiss.IndexFlatL2(dimension)
index2 = faiss.IndexIDMap(index)
ids.dtype
dtype('int64')
index2.add_with_ids(datas, ids)
index.ntotal
3706

4. 搜索近邻ID列表


df_user = pd.read_csv("./datas/movielens_sparkals_user_embedding.csv")
df_user.head()
id features


idfeatures
010[0.5974288582801819, 0.17486965656280518, 0.04…
120[1.3099910020828247, 0.5037978291511536, 0.260…
230[-1.1886241436004639, -0.13511677086353302, 0….
340[1.0809299945831299, 1.0048035383224487, 0.986…
450[0.42388680577278137, 0.5294889807701111, -0.6…


user_embedding = np.array(json.loads(df_user[df_user["id"] == 10]["features"].iloc[0]))
user_embedding = np.expand_dims(user_embedding, axis=0).astype(np.float32)
user_embedding
array([[ 0.59742886, 0.17486966, 0.04345559, -1.3193961 , 0.5313592 ,
-0.6052168 , -0.19088413, 1.5307966 , 0.09310367, -2.7573566 ]],
dtype=float32)
user_embedding.shape
(1, 10)
user_embedding.dtype
dtype('float32')
topk = 30
D, I = index.search(user_embedding, topk) # actual search
I.shape
(1, 30)
I
array([[3380, 2900, 1953, 121, 3285, 999, 617, 747, 2351, 601, 2347,
42, 2383, 538, 1774, 980, 2165, 3049, 2664, 367, 3289, 2866,
2452, 547, 1072, 2055, 3660, 3343, 3390, 3590]])

5. 根据电影ID取出电影信息


target_ids = pd.Series(I[0], name="MovieID")
target_ids.head()
0 3380
1 2900
2 1953
3 121
4 3285
Name: MovieID, dtype: int64
df_movie = pd.read_csv("./datas/ml-1m/movies.dat",
 sep="::", header=None, engine="python",
 names = "MovieID::Title::Genres".split("::"))
df_movie.head()


MovieIDTitleGenres
01Toy Story (1995)Animation|Children's|Comedy
12Jumanji (1995)Adventure|Children's|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama
45Father of the Bride Part II (1995)Comedy


df_result = pd.merge(target_ids, df_movie)
df_result.head()


MovieIDTitleGenres
03380Railroaded! (1947)Film-Noir
12900Monkey Shines (1988)Horror|Sci-Fi
21953French Connection, The (1971)Action|Crime|Drama|Thriller
3121Boys of St. Vincent, The (1993)Drama
43285Beach, The (2000)Adventure|Drama

来源:http://www.crazyant.net/2646.html

标签:python,faiss,近邻
0
投稿

猜你喜欢

  • PHP面向对象程序设计之类与反射API详解

    2023-11-19 12:44:12
  • Django+Celery实现定时任务的示例

    2023-07-27 19:48:18
  • PHP实现获取第一个中文首字母并进行排序的方法

    2023-10-30 12:29:08
  • SQL语句实现删除ACCESS重复记录的两种方法

    2024-01-24 21:43:25
  • js实现动态增加文件域表单功能

    2024-04-19 09:50:33
  • 下拉列表两级连动的新方法(一)

    2009-06-04 18:18:00
  • 浅谈pytorch卷积核大小的设置对全连接神经元的影响

    2022-08-19 04:49:55
  • Python使用树状图实现可视化聚类详解

    2021-07-08 13:41:25
  • Python中常用操作字符串的函数与方法总结

    2023-07-25 12:09:23
  • Python中Unittest框架的具体使用

    2023-02-20 11:41:09
  • golang cache带索引超时缓存库实战示例

    2023-07-24 04:43:11
  • python中注释用法简单示例

    2022-10-24 05:04:09
  • python根据开头和结尾字符串获取中间字符串的方法

    2021-01-02 01:44:28
  • python dlib人脸识别代码实例

    2021-04-05 12:57:33
  • SQLServer 2008 新增T-SQL 简写语法

    2024-01-28 23:49:25
  • 一条sql 语句搞定数据库分页

    2009-03-21 18:32:00
  • python logging类库使用例子

    2023-10-31 11:17:11
  • python使用xlrd模块读取xlsx文件中的ip方法

    2022-12-26 13:42:00
  • MySQL中Replace语句用法实例详解

    2024-01-15 03:26:28
  • 一个向上滚动代码

    2010-02-10 12:29:00
  • asp之家 网络编程 m.aspxhome.com