Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

rerank_model.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from urllib.parse import urljoin
  19. import requests
  20. from huggingface_hub import snapshot_download
  21. import os
  22. from abc import ABC
  23. import numpy as np
  24. from api.settings import LIGHTEN
  25. from api.utils.file_utils import get_home_cache_dir
  26. from rag.utils import num_tokens_from_string, truncate
  27. import json
  28. from api.utils.log_utils import logger
  29. def sigmoid(x):
  30. return 1 / (1 + np.exp(-x))
  31. class Base(ABC):
  32. def __init__(self, key, model_name):
  33. pass
  34. def similarity(self, query: str, texts: list):
  35. raise NotImplementedError("Please implement encode method!")
  36. class DefaultRerank(Base):
  37. _model = None
  38. _model_lock = threading.Lock()
  39. def __init__(self, key, model_name, **kwargs):
  40. """
  41. If you have trouble downloading HuggingFace models, -_^ this might help!!
  42. For Linux:
  43. export HF_ENDPOINT=https://hf-mirror.com
  44. For Windows:
  45. Good luck
  46. ^_-
  47. """
  48. if not LIGHTEN and not DefaultRerank._model:
  49. import torch
  50. from FlagEmbedding import FlagReranker
  51. with DefaultRerank._model_lock:
  52. if not DefaultRerank._model:
  53. try:
  54. DefaultRerank._model = FlagReranker(
  55. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  56. use_fp16=torch.cuda.is_available())
  57. except Exception:
  58. model_dir = snapshot_download(repo_id=model_name,
  59. local_dir=os.path.join(get_home_cache_dir(),
  60. re.sub(r"^[a-zA-Z]+/", "", model_name)),
  61. local_dir_use_symlinks=False)
  62. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  63. self._model = DefaultRerank._model
  64. def similarity(self, query: str, texts: list):
  65. pairs = [(query, truncate(t, 2048)) for t in texts]
  66. token_count = 0
  67. for _, t in pairs:
  68. token_count += num_tokens_from_string(t)
  69. batch_size = 4096
  70. res = []
  71. for i in range(0, len(pairs), batch_size):
  72. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  73. scores = sigmoid(np.array(scores)).tolist()
  74. if isinstance(scores, float):
  75. res.append(scores)
  76. else:
  77. res.extend(scores)
  78. return np.array(res), token_count
  79. class JinaRerank(Base):
  80. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  81. base_url="https://api.jina.ai/v1/rerank"):
  82. self.base_url = "https://api.jina.ai/v1/rerank"
  83. self.headers = {
  84. "Content-Type": "application/json",
  85. "Authorization": f"Bearer {key}"
  86. }
  87. self.model_name = model_name
  88. def similarity(self, query: str, texts: list):
  89. texts = [truncate(t, 8196) for t in texts]
  90. data = {
  91. "model": self.model_name,
  92. "query": query,
  93. "documents": texts,
  94. "top_n": len(texts)
  95. }
  96. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  97. rank = np.zeros(len(texts), dtype=float)
  98. for d in res["results"]:
  99. rank[d["index"]] = d["relevance_score"]
  100. return rank, res["usage"]["total_tokens"]
  101. class YoudaoRerank(DefaultRerank):
  102. _model = None
  103. _model_lock = threading.Lock()
  104. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  105. if not LIGHTEN and not YoudaoRerank._model:
  106. from BCEmbedding import RerankerModel
  107. with YoudaoRerank._model_lock:
  108. if not YoudaoRerank._model:
  109. try:
  110. logger.info("LOADING BCE...")
  111. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  112. get_home_cache_dir(),
  113. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  114. except Exception:
  115. YoudaoRerank._model = RerankerModel(
  116. model_name_or_path=model_name.replace(
  117. "maidalun1020", "InfiniFlow"))
  118. self._model = YoudaoRerank._model
  119. def similarity(self, query: str, texts: list):
  120. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  121. token_count = 0
  122. for _, t in pairs:
  123. token_count += num_tokens_from_string(t)
  124. batch_size = 8
  125. res = []
  126. for i in range(0, len(pairs), batch_size):
  127. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  128. scores = sigmoid(np.array(scores)).tolist()
  129. if isinstance(scores, float):
  130. res.append(scores)
  131. else:
  132. res.extend(scores)
  133. return np.array(res), token_count
  134. class XInferenceRerank(Base):
  135. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  136. if base_url.find("/v1") == -1:
  137. base_url = urljoin(base_url, "/v1/rerank")
  138. self.model_name = model_name
  139. self.base_url = base_url
  140. self.headers = {
  141. "Content-Type": "application/json",
  142. "accept": "application/json",
  143. "Authorization": f"Bearer {key}"
  144. }
  145. def similarity(self, query: str, texts: list):
  146. if len(texts) == 0:
  147. return np.array([]), 0
  148. data = {
  149. "model": self.model_name,
  150. "query": query,
  151. "return_documents": "true",
  152. "return_len": "true",
  153. "documents": texts
  154. }
  155. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  156. rank = np.zeros(len(texts), dtype=float)
  157. for d in res["results"]:
  158. rank[d["index"]] = d["relevance_score"]
  159. return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
  160. class LocalAIRerank(Base):
  161. def __init__(self, key, model_name, base_url):
  162. pass
  163. def similarity(self, query: str, texts: list):
  164. raise NotImplementedError("The LocalAIRerank has not been implement")
  165. class NvidiaRerank(Base):
  166. def __init__(
  167. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  168. ):
  169. if not base_url:
  170. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  171. self.model_name = model_name
  172. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  173. self.base_url = os.path.join(
  174. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  175. )
  176. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  177. self.base_url = os.path.join(base_url, "reranking")
  178. self.model_name = "nv-rerank-qa-mistral-4b:1"
  179. self.headers = {
  180. "accept": "application/json",
  181. "Content-Type": "application/json",
  182. "Authorization": f"Bearer {key}",
  183. }
  184. def similarity(self, query: str, texts: list):
  185. token_count = num_tokens_from_string(query) + sum(
  186. [num_tokens_from_string(t) for t in texts]
  187. )
  188. data = {
  189. "model": self.model_name,
  190. "query": {"text": query},
  191. "passages": [{"text": text} for text in texts],
  192. "truncate": "END",
  193. "top_n": len(texts),
  194. }
  195. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  196. rank = np.zeros(len(texts), dtype=float)
  197. for d in res["rankings"]:
  198. rank[d["index"]] = d["logit"]
  199. return rank, token_count
  200. class LmStudioRerank(Base):
  201. def __init__(self, key, model_name, base_url):
  202. pass
  203. def similarity(self, query: str, texts: list):
  204. raise NotImplementedError("The LmStudioRerank has not been implement")
  205. class OpenAI_APIRerank(Base):
  206. def __init__(self, key, model_name, base_url):
  207. if base_url.find("/rerank") == -1:
  208. self.base_url = urljoin(base_url, "/rerank")
  209. else:
  210. self.base_url = base_url
  211. self.headers = {
  212. "Content-Type": "application/json",
  213. "Authorization": f"Bearer {key}"
  214. }
  215. self.model_name = model_name
  216. def similarity(self, query: str, texts: list):
  217. # noway to config Ragflow , use fix setting
  218. texts = [truncate(t, 500) for t in texts]
  219. data = {
  220. "model": self.model_name,
  221. "query": query,
  222. "documents": texts,
  223. "top_n": len(texts),
  224. }
  225. token_count = 0
  226. for t in texts:
  227. token_count += num_tokens_from_string(t)
  228. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  229. rank = np.zeros(len(texts), dtype=float)
  230. if 'results' not in res:
  231. raise ValueError("response not contains results\n" + str(res))
  232. for d in res["results"]:
  233. rank[d["index"]] = d["relevance_score"]
  234. # Normalize the rank values to the range 0 to 1
  235. min_rank = np.min(rank)
  236. max_rank = np.max(rank)
  237. # Avoid division by zero if all ranks are identical
  238. if max_rank - min_rank != 0:
  239. rank = (rank - min_rank) / (max_rank - min_rank)
  240. else:
  241. rank = np.zeros_like(rank)
  242. return rank, token_count
  243. class CoHereRerank(Base):
  244. def __init__(self, key, model_name, base_url=None):
  245. from cohere import Client
  246. self.client = Client(api_key=key)
  247. self.model_name = model_name
  248. def similarity(self, query: str, texts: list):
  249. token_count = num_tokens_from_string(query) + sum(
  250. [num_tokens_from_string(t) for t in texts]
  251. )
  252. res = self.client.rerank(
  253. model=self.model_name,
  254. query=query,
  255. documents=texts,
  256. top_n=len(texts),
  257. return_documents=False,
  258. )
  259. rank = np.zeros(len(texts), dtype=float)
  260. for d in res.results:
  261. rank[d.index] = d.relevance_score
  262. return rank, token_count
  263. class TogetherAIRerank(Base):
  264. def __init__(self, key, model_name, base_url):
  265. pass
  266. def similarity(self, query: str, texts: list):
  267. raise NotImplementedError("The api has not been implement")
  268. class SILICONFLOWRerank(Base):
  269. def __init__(
  270. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  271. ):
  272. if not base_url:
  273. base_url = "https://api.siliconflow.cn/v1/rerank"
  274. self.model_name = model_name
  275. self.base_url = base_url
  276. self.headers = {
  277. "accept": "application/json",
  278. "content-type": "application/json",
  279. "authorization": f"Bearer {key}",
  280. }
  281. def similarity(self, query: str, texts: list):
  282. payload = {
  283. "model": self.model_name,
  284. "query": query,
  285. "documents": texts,
  286. "top_n": len(texts),
  287. "return_documents": False,
  288. "max_chunks_per_doc": 1024,
  289. "overlap_tokens": 80,
  290. }
  291. response = requests.post(
  292. self.base_url, json=payload, headers=self.headers
  293. ).json()
  294. rank = np.zeros(len(texts), dtype=float)
  295. if "results" not in response:
  296. return rank, 0
  297. for d in response["results"]:
  298. rank[d["index"]] = d["relevance_score"]
  299. return (
  300. rank,
  301. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  302. )
  303. class BaiduYiyanRerank(Base):
  304. def __init__(self, key, model_name, base_url=None):
  305. from qianfan.resources import Reranker
  306. key = json.loads(key)
  307. ak = key.get("yiyan_ak", "")
  308. sk = key.get("yiyan_sk", "")
  309. self.client = Reranker(ak=ak, sk=sk)
  310. self.model_name = model_name
  311. def similarity(self, query: str, texts: list):
  312. res = self.client.do(
  313. model=self.model_name,
  314. query=query,
  315. documents=texts,
  316. top_n=len(texts),
  317. ).body
  318. rank = np.zeros(len(texts), dtype=float)
  319. for d in res["results"]:
  320. rank[d["index"]] = d["relevance_score"]
  321. return rank, res["usage"]["total_tokens"]
  322. class VoyageRerank(Base):
  323. def __init__(self, key, model_name, base_url=None):
  324. import voyageai
  325. self.client = voyageai.Client(api_key=key)
  326. self.model_name = model_name
  327. def similarity(self, query: str, texts: list):
  328. res = self.client.rerank(
  329. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  330. )
  331. rank = np.zeros(len(texts), dtype=float)
  332. for r in res.results:
  333. rank[r.index] = r.relevance_score
  334. return rank, res.total_tokens
  335. class QWenRerank(Base):
  336. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  337. import dashscope
  338. self.api_key = key
  339. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  340. def similarity(self, query: str, texts: list):
  341. import dashscope
  342. from http import HTTPStatus
  343. resp = dashscope.TextReRank.call(
  344. api_key=self.api_key,
  345. model=self.model_name,
  346. query=query,
  347. documents=texts,
  348. top_n=len(texts),
  349. return_documents=False
  350. )
  351. rank = np.zeros(len(texts), dtype=float)
  352. if resp.status_code == HTTPStatus.OK:
  353. for r in resp.output.results:
  354. rank[r.index] = r.relevance_score
  355. return rank, resp.usage.total_tokens
  356. return rank, 0