Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

rerank_model.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. import requests
  19. from huggingface_hub import snapshot_download
  20. import os
  21. from abc import ABC
  22. import numpy as np
  23. from api.settings import LIGHTEN
  24. from api.utils.file_utils import get_home_cache_dir
  25. from rag.utils import num_tokens_from_string, truncate
  26. import json
  27. def sigmoid(x):
  28. return 1 / (1 + np.exp(-x))
  29. class Base(ABC):
  30. def __init__(self, key, model_name):
  31. pass
  32. def similarity(self, query: str, texts: list):
  33. raise NotImplementedError("Please implement encode method!")
  34. class DefaultRerank(Base):
  35. _model = None
  36. _model_lock = threading.Lock()
  37. def __init__(self, key, model_name, **kwargs):
  38. """
  39. If you have trouble downloading HuggingFace models, -_^ this might help!!
  40. For Linux:
  41. export HF_ENDPOINT=https://hf-mirror.com
  42. For Windows:
  43. Good luck
  44. ^_-
  45. """
  46. if not LIGHTEN and not DefaultRerank._model:
  47. import torch
  48. from FlagEmbedding import FlagReranker
  49. with DefaultRerank._model_lock:
  50. if not DefaultRerank._model:
  51. try:
  52. DefaultRerank._model = FlagReranker(
  53. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  54. use_fp16=torch.cuda.is_available())
  55. except Exception as e:
  56. model_dir = snapshot_download(repo_id=model_name,
  57. local_dir=os.path.join(get_home_cache_dir(),
  58. re.sub(r"^[a-zA-Z]+/", "", model_name)),
  59. local_dir_use_symlinks=False)
  60. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  61. self._model = DefaultRerank._model
  62. def similarity(self, query: str, texts: list):
  63. pairs = [(query, truncate(t, 2048)) for t in texts]
  64. token_count = 0
  65. for _, t in pairs:
  66. token_count += num_tokens_from_string(t)
  67. batch_size = 4096
  68. res = []
  69. for i in range(0, len(pairs), batch_size):
  70. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  71. scores = sigmoid(np.array(scores)).tolist()
  72. if isinstance(scores, float):
  73. res.append(scores)
  74. else:
  75. res.extend(scores)
  76. return np.array(res), token_count
  77. class JinaRerank(Base):
  78. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  79. base_url="https://api.jina.ai/v1/rerank"):
  80. self.base_url = "https://api.jina.ai/v1/rerank"
  81. self.headers = {
  82. "Content-Type": "application/json",
  83. "Authorization": f"Bearer {key}"
  84. }
  85. self.model_name = model_name
  86. def similarity(self, query: str, texts: list):
  87. texts = [truncate(t, 8196) for t in texts]
  88. data = {
  89. "model": self.model_name,
  90. "query": query,
  91. "documents": texts,
  92. "top_n": len(texts)
  93. }
  94. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  95. rank = np.zeros(len(texts), dtype=float)
  96. for d in res["results"]:
  97. rank[d["index"]] = d["relevance_score"]
  98. return rank, res["usage"]["total_tokens"]
  99. class YoudaoRerank(DefaultRerank):
  100. _model = None
  101. _model_lock = threading.Lock()
  102. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  103. if not LIGHTEN and not YoudaoRerank._model:
  104. from BCEmbedding import RerankerModel
  105. with YoudaoRerank._model_lock:
  106. if not YoudaoRerank._model:
  107. try:
  108. print("LOADING BCE...")
  109. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  110. get_home_cache_dir(),
  111. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  112. except Exception as e:
  113. YoudaoRerank._model = RerankerModel(
  114. model_name_or_path=model_name.replace(
  115. "maidalun1020", "InfiniFlow"))
  116. self._model = YoudaoRerank._model
  117. def similarity(self, query: str, texts: list):
  118. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  119. token_count = 0
  120. for _, t in pairs:
  121. token_count += num_tokens_from_string(t)
  122. batch_size = 32
  123. res = []
  124. for i in range(0, len(pairs), batch_size):
  125. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  126. scores = sigmoid(np.array(scores)).tolist()
  127. if isinstance(scores, float):
  128. res.append(scores)
  129. else:
  130. res.extend(scores)
  131. return np.array(res), token_count
  132. class XInferenceRerank(Base):
  133. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  134. if base_url.split("/")[-1] != "v1":
  135. base_url = os.path.join(base_url, "v1")
  136. self.model_name = model_name
  137. self.base_url = base_url
  138. self.headers = {
  139. "Content-Type": "application/json",
  140. "accept": "application/json"
  141. }
  142. def similarity(self, query: str, texts: list):
  143. if len(texts) == 0:
  144. return np.array([]), 0
  145. data = {
  146. "model": self.model_name,
  147. "query": query,
  148. "return_documents": "true",
  149. "return_len": "true",
  150. "documents": texts
  151. }
  152. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  153. rank = np.zeros(len(texts), dtype=float)
  154. for d in res["results"]:
  155. rank[d["index"]] = d["relevance_score"]
  156. return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
  157. class LocalAIRerank(Base):
  158. def __init__(self, key, model_name, base_url):
  159. pass
  160. def similarity(self, query: str, texts: list):
  161. raise NotImplementedError("The LocalAIRerank has not been implement")
  162. class NvidiaRerank(Base):
  163. def __init__(
  164. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  165. ):
  166. if not base_url:
  167. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  168. self.model_name = model_name
  169. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  170. self.base_url = os.path.join(
  171. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  172. )
  173. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  174. self.base_url = os.path.join(base_url, "reranking")
  175. self.model_name = "nv-rerank-qa-mistral-4b:1"
  176. self.headers = {
  177. "accept": "application/json",
  178. "Content-Type": "application/json",
  179. "Authorization": f"Bearer {key}",
  180. }
  181. def similarity(self, query: str, texts: list):
  182. token_count = num_tokens_from_string(query) + sum(
  183. [num_tokens_from_string(t) for t in texts]
  184. )
  185. data = {
  186. "model": self.model_name,
  187. "query": {"text": query},
  188. "passages": [{"text": text} for text in texts],
  189. "truncate": "END",
  190. "top_n": len(texts),
  191. }
  192. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  193. rank = np.zeros(len(texts), dtype=float)
  194. for d in res["rankings"]:
  195. rank[d["index"]] = d["logit"]
  196. return rank, token_count
  197. class LmStudioRerank(Base):
  198. def __init__(self, key, model_name, base_url):
  199. pass
  200. def similarity(self, query: str, texts: list):
  201. raise NotImplementedError("The LmStudioRerank has not been implement")
  202. class OpenAI_APIRerank(Base):
  203. def __init__(self, key, model_name, base_url):
  204. pass
  205. def similarity(self, query: str, texts: list):
  206. raise NotImplementedError("The api has not been implement")
  207. class CoHereRerank(Base):
  208. def __init__(self, key, model_name, base_url=None):
  209. from cohere import Client
  210. self.client = Client(api_key=key)
  211. self.model_name = model_name
  212. def similarity(self, query: str, texts: list):
  213. token_count = num_tokens_from_string(query) + sum(
  214. [num_tokens_from_string(t) for t in texts]
  215. )
  216. res = self.client.rerank(
  217. model=self.model_name,
  218. query=query,
  219. documents=texts,
  220. top_n=len(texts),
  221. return_documents=False,
  222. )
  223. rank = np.zeros(len(texts), dtype=float)
  224. for d in res.results:
  225. rank[d.index] = d.relevance_score
  226. return rank, token_count
  227. class TogetherAIRerank(Base):
  228. def __init__(self, key, model_name, base_url):
  229. pass
  230. def similarity(self, query: str, texts: list):
  231. raise NotImplementedError("The api has not been implement")
  232. class SILICONFLOWRerank(Base):
  233. def __init__(
  234. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  235. ):
  236. if not base_url:
  237. base_url = "https://api.siliconflow.cn/v1/rerank"
  238. self.model_name = model_name
  239. self.base_url = base_url
  240. self.headers = {
  241. "accept": "application/json",
  242. "content-type": "application/json",
  243. "authorization": f"Bearer {key}",
  244. }
  245. def similarity(self, query: str, texts: list):
  246. payload = {
  247. "model": self.model_name,
  248. "query": query,
  249. "documents": texts,
  250. "top_n": len(texts),
  251. "return_documents": False,
  252. "max_chunks_per_doc": 1024,
  253. "overlap_tokens": 80,
  254. }
  255. response = requests.post(
  256. self.base_url, json=payload, headers=self.headers
  257. ).json()
  258. rank = np.zeros(len(texts), dtype=float)
  259. for d in response["results"]:
  260. rank[d["index"]] = d["relevance_score"]
  261. return (
  262. rank,
  263. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  264. )
  265. class BaiduYiyanRerank(Base):
  266. def __init__(self, key, model_name, base_url=None):
  267. from qianfan.resources import Reranker
  268. key = json.loads(key)
  269. ak = key.get("yiyan_ak", "")
  270. sk = key.get("yiyan_sk", "")
  271. self.client = Reranker(ak=ak, sk=sk)
  272. self.model_name = model_name
  273. def similarity(self, query: str, texts: list):
  274. res = self.client.do(
  275. model=self.model_name,
  276. query=query,
  277. documents=texts,
  278. top_n=len(texts),
  279. ).body
  280. rank = np.zeros(len(texts), dtype=float)
  281. for d in res["results"]:
  282. rank[d["index"]] = d["relevance_score"]
  283. return rank, res["usage"]["total_tokens"]
  284. class VoyageRerank(Base):
  285. def __init__(self, key, model_name, base_url=None):
  286. import voyageai
  287. self.client = voyageai.Client(api_key=key)
  288. self.model_name = model_name
  289. def similarity(self, query: str, texts: list):
  290. res = self.client.rerank(
  291. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  292. )
  293. rank = np.zeros(len(texts), dtype=float)
  294. for r in res.results:
  295. rank[r.index] = r.relevance_score
  296. return rank, res.total_tokens