您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rerank_model.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. import requests
  19. import torch
  20. from FlagEmbedding import FlagReranker
  21. from huggingface_hub import snapshot_download
  22. import os
  23. from abc import ABC
  24. import numpy as np
  25. from api.utils.file_utils import get_home_cache_dir
  26. from rag.utils import num_tokens_from_string, truncate
  27. import json
  28. def sigmoid(x):
  29. return 1 / (1 + np.exp(-x))
  30. class Base(ABC):
  31. def __init__(self, key, model_name):
  32. pass
  33. def similarity(self, query: str, texts: list):
  34. raise NotImplementedError("Please implement encode method!")
  35. class DefaultRerank(Base):
  36. _model = None
  37. _model_lock = threading.Lock()
  38. def __init__(self, key, model_name, **kwargs):
  39. """
  40. If you have trouble downloading HuggingFace models, -_^ this might help!!
  41. For Linux:
  42. export HF_ENDPOINT=https://hf-mirror.com
  43. For Windows:
  44. Good luck
  45. ^_-
  46. """
  47. if not DefaultRerank._model:
  48. with DefaultRerank._model_lock:
  49. if not DefaultRerank._model:
  50. try:
  51. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  52. except Exception as e:
  53. model_dir = snapshot_download(repo_id= model_name,
  54. local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  55. local_dir_use_symlinks=False)
  56. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  57. self._model = DefaultRerank._model
  58. def similarity(self, query: str, texts: list):
  59. pairs = [(query,truncate(t, 2048)) for t in texts]
  60. token_count = 0
  61. for _, t in pairs:
  62. token_count += num_tokens_from_string(t)
  63. batch_size = 4096
  64. res = []
  65. for i in range(0, len(pairs), batch_size):
  66. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  67. scores = sigmoid(np.array(scores)).tolist()
  68. if isinstance(scores, float): res.append(scores)
  69. else: res.extend(scores)
  70. return np.array(res), token_count
  71. class JinaRerank(Base):
  72. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  73. base_url="https://api.jina.ai/v1/rerank"):
  74. self.base_url = "https://api.jina.ai/v1/rerank"
  75. self.headers = {
  76. "Content-Type": "application/json",
  77. "Authorization": f"Bearer {key}"
  78. }
  79. self.model_name = model_name
  80. def similarity(self, query: str, texts: list):
  81. texts = [truncate(t, 8196) for t in texts]
  82. data = {
  83. "model": self.model_name,
  84. "query": query,
  85. "documents": texts,
  86. "top_n": len(texts)
  87. }
  88. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  89. return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
  90. class YoudaoRerank(DefaultRerank):
  91. _model = None
  92. _model_lock = threading.Lock()
  93. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  94. from BCEmbedding import RerankerModel
  95. if not YoudaoRerank._model:
  96. with YoudaoRerank._model_lock:
  97. if not YoudaoRerank._model:
  98. try:
  99. print("LOADING BCE...")
  100. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  101. get_home_cache_dir(),
  102. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  103. except Exception as e:
  104. YoudaoRerank._model = RerankerModel(
  105. model_name_or_path=model_name.replace(
  106. "maidalun1020", "InfiniFlow"))
  107. self._model = YoudaoRerank._model
  108. def similarity(self, query: str, texts: list):
  109. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  110. token_count = 0
  111. for _, t in pairs:
  112. token_count += num_tokens_from_string(t)
  113. batch_size = 32
  114. res = []
  115. for i in range(0, len(pairs), batch_size):
  116. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  117. scores = sigmoid(np.array(scores)).tolist()
  118. if isinstance(scores, float): res.append(scores)
  119. else: res.extend(scores)
  120. return np.array(res), token_count
  121. class XInferenceRerank(Base):
  122. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  123. if base_url.split("/")[-1] != "v1":
  124. base_url = os.path.join(base_url, "v1")
  125. self.model_name = model_name
  126. self.base_url = base_url
  127. self.headers = {
  128. "Content-Type": "application/json",
  129. "accept": "application/json"
  130. }
  131. def similarity(self, query: str, texts: list):
  132. if len(texts) == 0:
  133. return np.array([]), 0
  134. data = {
  135. "model": self.model_name,
  136. "query": query,
  137. "return_documents": "true",
  138. "return_len": "true",
  139. "documents": texts
  140. }
  141. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  142. return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
  143. class LocalAIRerank(Base):
  144. def __init__(self, key, model_name, base_url):
  145. pass
  146. def similarity(self, query: str, texts: list):
  147. raise NotImplementedError("The LocalAIRerank has not been implement")
  148. class NvidiaRerank(Base):
  149. def __init__(
  150. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  151. ):
  152. if not base_url:
  153. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  154. self.model_name = model_name
  155. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  156. self.base_url = os.path.join(
  157. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  158. )
  159. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  160. self.base_url = os.path.join(base_url, "reranking")
  161. self.model_name = "nv-rerank-qa-mistral-4b:1"
  162. self.headers = {
  163. "accept": "application/json",
  164. "Content-Type": "application/json",
  165. "Authorization": f"Bearer {key}",
  166. }
  167. def similarity(self, query: str, texts: list):
  168. token_count = num_tokens_from_string(query) + sum(
  169. [num_tokens_from_string(t) for t in texts]
  170. )
  171. data = {
  172. "model": self.model_name,
  173. "query": {"text": query},
  174. "passages": [{"text": text} for text in texts],
  175. "truncate": "END",
  176. "top_n": len(texts),
  177. }
  178. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  179. rank = np.array([d["logit"] for d in res["rankings"]])
  180. indexs = [d["index"] for d in res["rankings"]]
  181. return rank[indexs], token_count
  182. class LmStudioRerank(Base):
  183. def __init__(self, key, model_name, base_url):
  184. pass
  185. def similarity(self, query: str, texts: list):
  186. raise NotImplementedError("The LmStudioRerank has not been implement")
  187. class OpenAI_APIRerank(Base):
  188. def __init__(self, key, model_name, base_url):
  189. pass
  190. def similarity(self, query: str, texts: list):
  191. raise NotImplementedError("The api has not been implement")
  192. class CoHereRerank(Base):
  193. def __init__(self, key, model_name, base_url=None):
  194. from cohere import Client
  195. self.client = Client(api_key=key)
  196. self.model_name = model_name
  197. def similarity(self, query: str, texts: list):
  198. token_count = num_tokens_from_string(query) + sum(
  199. [num_tokens_from_string(t) for t in texts]
  200. )
  201. res = self.client.rerank(
  202. model=self.model_name,
  203. query=query,
  204. documents=texts,
  205. top_n=len(texts),
  206. return_documents=False,
  207. )
  208. rank = np.array([d.relevance_score for d in res.results])
  209. indexs = [d.index for d in res.results]
  210. return rank[indexs], token_count
  211. class TogetherAIRerank(Base):
  212. def __init__(self, key, model_name, base_url):
  213. pass
  214. def similarity(self, query: str, texts: list):
  215. raise NotImplementedError("The api has not been implement")
  216. class SILICONFLOWRerank(Base):
  217. def __init__(
  218. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  219. ):
  220. if not base_url:
  221. base_url = "https://api.siliconflow.cn/v1/rerank"
  222. self.model_name = model_name
  223. self.base_url = base_url
  224. self.headers = {
  225. "accept": "application/json",
  226. "content-type": "application/json",
  227. "authorization": f"Bearer {key}",
  228. }
  229. def similarity(self, query: str, texts: list):
  230. payload = {
  231. "model": self.model_name,
  232. "query": query,
  233. "documents": texts,
  234. "top_n": len(texts),
  235. "return_documents": False,
  236. "max_chunks_per_doc": 1024,
  237. "overlap_tokens": 80,
  238. }
  239. response = requests.post(
  240. self.base_url, json=payload, headers=self.headers
  241. ).json()
  242. rank = np.array([d["relevance_score"] for d in response["results"]])
  243. indexs = [d["index"] for d in response["results"]]
  244. return (
  245. rank[indexs],
  246. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  247. )
  248. class BaiduYiyanRerank(Base):
  249. def __init__(self, key, model_name, base_url=None):
  250. from qianfan.resources import Reranker
  251. key = json.loads(key)
  252. ak = key.get("yiyan_ak", "")
  253. sk = key.get("yiyan_sk", "")
  254. self.client = Reranker(ak=ak, sk=sk)
  255. self.model_name = model_name
  256. def similarity(self, query: str, texts: list):
  257. res = self.client.do(
  258. model=self.model_name,
  259. query=query,
  260. documents=texts,
  261. top_n=len(texts),
  262. ).body
  263. rank = np.array([d["relevance_score"] for d in res["results"]])
  264. indexs = [d["index"] for d in res["results"]]
  265. return rank[indexs], res["usage"]["total_tokens"]
  266. class VoyageRerank(Base):
  267. def __init__(self, key, model_name, base_url=None):
  268. import voyageai
  269. self.client = voyageai.Client(api_key=key)
  270. self.model_name = model_name
  271. def similarity(self, query: str, texts: list):
  272. res = self.client.rerank(
  273. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  274. )
  275. rank = np.array([r.relevance_score for r in res.results])
  276. indexs = [r.index for r in res.results]
  277. return rank[indexs], res.total_tokens