Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

rerank_model.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from urllib.parse import urljoin
  19. import requests
  20. from huggingface_hub import snapshot_download
  21. import os
  22. from abc import ABC
  23. import numpy as np
  24. from api import settings
  25. from api.utils.file_utils import get_home_cache_dir
  26. from rag.utils import num_tokens_from_string, truncate
  27. import json
  28. def sigmoid(x):
  29. return 1 / (1 + np.exp(-x))
  30. class Base(ABC):
  31. def __init__(self, key, model_name):
  32. pass
  33. def similarity(self, query: str, texts: list):
  34. raise NotImplementedError("Please implement encode method!")
  35. class DefaultRerank(Base):
  36. _model = None
  37. _model_lock = threading.Lock()
  38. def __init__(self, key, model_name, **kwargs):
  39. """
  40. If you have trouble downloading HuggingFace models, -_^ this might help!!
  41. For Linux:
  42. export HF_ENDPOINT=https://hf-mirror.com
  43. For Windows:
  44. Good luck
  45. ^_-
  46. """
  47. if not settings.LIGHTEN and not DefaultRerank._model:
  48. import torch
  49. from FlagEmbedding import FlagReranker
  50. with DefaultRerank._model_lock:
  51. if not DefaultRerank._model:
  52. try:
  53. DefaultRerank._model = FlagReranker(
  54. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  55. use_fp16=torch.cuda.is_available())
  56. except Exception:
  57. model_dir = snapshot_download(repo_id=model_name,
  58. local_dir=os.path.join(get_home_cache_dir(),
  59. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  60. local_dir_use_symlinks=False)
  61. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  62. self._model = DefaultRerank._model
  63. def similarity(self, query: str, texts: list):
  64. pairs = [(query, truncate(t, 2048)) for t in texts]
  65. token_count = 0
  66. for _, t in pairs:
  67. token_count += num_tokens_from_string(t)
  68. batch_size = 4096
  69. res = []
  70. for i in range(0, len(pairs), batch_size):
  71. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  72. scores = sigmoid(np.array(scores)).tolist()
  73. if isinstance(scores, float):
  74. res.append(scores)
  75. else:
  76. res.extend(scores)
  77. return np.array(res), token_count
  78. class JinaRerank(Base):
  79. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  80. base_url="https://api.jina.ai/v1/rerank"):
  81. self.base_url = "https://api.jina.ai/v1/rerank"
  82. self.headers = {
  83. "Content-Type": "application/json",
  84. "Authorization": f"Bearer {key}"
  85. }
  86. self.model_name = model_name
  87. def similarity(self, query: str, texts: list):
  88. texts = [truncate(t, 8196) for t in texts]
  89. data = {
  90. "model": self.model_name,
  91. "query": query,
  92. "documents": texts,
  93. "top_n": len(texts)
  94. }
  95. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  96. rank = np.zeros(len(texts), dtype=float)
  97. for d in res["results"]:
  98. rank[d["index"]] = d["relevance_score"]
  99. return rank, res["usage"]["total_tokens"]
  100. class YoudaoRerank(DefaultRerank):
  101. _model = None
  102. _model_lock = threading.Lock()
  103. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  104. if not settings.LIGHTEN and not YoudaoRerank._model:
  105. from BCEmbedding import RerankerModel
  106. with YoudaoRerank._model_lock:
  107. if not YoudaoRerank._model:
  108. try:
  109. logging.info("LOADING BCE...")
  110. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  111. get_home_cache_dir(),
  112. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  113. except Exception:
  114. YoudaoRerank._model = RerankerModel(
  115. model_name_or_path=model_name.replace(
  116. "maidalun1020", "InfiniFlow"))
  117. self._model = YoudaoRerank._model
  118. def similarity(self, query: str, texts: list):
  119. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  120. token_count = 0
  121. for _, t in pairs:
  122. token_count += num_tokens_from_string(t)
  123. batch_size = 8
  124. res = []
  125. for i in range(0, len(pairs), batch_size):
  126. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  127. scores = sigmoid(np.array(scores)).tolist()
  128. if isinstance(scores, float):
  129. res.append(scores)
  130. else:
  131. res.extend(scores)
  132. return np.array(res), token_count
  133. class XInferenceRerank(Base):
  134. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  135. if base_url.find("/v1") == -1:
  136. base_url = urljoin(base_url, "/v1/rerank")
  137. if base_url.find("/rerank") == -1:
  138. base_url = urljoin(base_url, "/v1/rerank")
  139. self.model_name = model_name
  140. self.base_url = base_url
  141. self.headers = {
  142. "Content-Type": "application/json",
  143. "accept": "application/json",
  144. "Authorization": f"Bearer {key}"
  145. }
  146. def similarity(self, query: str, texts: list):
  147. if len(texts) == 0:
  148. return np.array([]), 0
  149. data = {
  150. "model": self.model_name,
  151. "query": query,
  152. "return_documents": "true",
  153. "return_len": "true",
  154. "documents": texts
  155. }
  156. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  157. rank = np.zeros(len(texts), dtype=float)
  158. for d in res["results"]:
  159. rank[d["index"]] = d["relevance_score"]
  160. return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
  161. class LocalAIRerank(Base):
  162. def __init__(self, key, model_name, base_url):
  163. if base_url.find("/rerank") == -1:
  164. self.base_url = urljoin(base_url, "/rerank")
  165. else:
  166. self.base_url = base_url
  167. self.headers = {
  168. "Content-Type": "application/json",
  169. "Authorization": f"Bearer {key}"
  170. }
  171. self.model_name = model_name.replace("___LocalAI","")
  172. def similarity(self, query: str, texts: list):
  173. # noway to config Ragflow , use fix setting
  174. texts = [truncate(t, 500) for t in texts]
  175. data = {
  176. "model": self.model_name,
  177. "query": query,
  178. "documents": texts,
  179. "top_n": len(texts),
  180. }
  181. token_count = 0
  182. for t in texts:
  183. token_count += num_tokens_from_string(t)
  184. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  185. rank = np.zeros(len(texts), dtype=float)
  186. if 'results' not in res:
  187. raise ValueError("response not contains results\n" + str(res))
  188. for d in res["results"]:
  189. rank[d["index"]] = d["relevance_score"]
  190. # Normalize the rank values to the range 0 to 1
  191. min_rank = np.min(rank)
  192. max_rank = np.max(rank)
  193. # Avoid division by zero if all ranks are identical
  194. if max_rank - min_rank != 0:
  195. rank = (rank - min_rank) / (max_rank - min_rank)
  196. else:
  197. rank = np.zeros_like(rank)
  198. return rank, token_count
  199. class NvidiaRerank(Base):
  200. def __init__(
  201. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  202. ):
  203. if not base_url:
  204. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  205. self.model_name = model_name
  206. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  207. self.base_url = os.path.join(
  208. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  209. )
  210. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  211. self.base_url = os.path.join(base_url, "reranking")
  212. self.model_name = "nv-rerank-qa-mistral-4b:1"
  213. self.headers = {
  214. "accept": "application/json",
  215. "Content-Type": "application/json",
  216. "Authorization": f"Bearer {key}",
  217. }
  218. def similarity(self, query: str, texts: list):
  219. token_count = num_tokens_from_string(query) + sum(
  220. [num_tokens_from_string(t) for t in texts]
  221. )
  222. data = {
  223. "model": self.model_name,
  224. "query": {"text": query},
  225. "passages": [{"text": text} for text in texts],
  226. "truncate": "END",
  227. "top_n": len(texts),
  228. }
  229. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  230. rank = np.zeros(len(texts), dtype=float)
  231. for d in res["rankings"]:
  232. rank[d["index"]] = d["logit"]
  233. return rank, token_count
  234. class LmStudioRerank(Base):
  235. def __init__(self, key, model_name, base_url):
  236. pass
  237. def similarity(self, query: str, texts: list):
  238. raise NotImplementedError("The LmStudioRerank has not been implement")
  239. class OpenAI_APIRerank(Base):
  240. def __init__(self, key, model_name, base_url):
  241. if base_url.find("/rerank") == -1:
  242. self.base_url = urljoin(base_url, "/rerank")
  243. else:
  244. self.base_url = base_url
  245. self.headers = {
  246. "Content-Type": "application/json",
  247. "Authorization": f"Bearer {key}"
  248. }
  249. self.model_name = model_name
  250. def similarity(self, query: str, texts: list):
  251. # noway to config Ragflow , use fix setting
  252. texts = [truncate(t, 500) for t in texts]
  253. data = {
  254. "model": self.model_name,
  255. "query": query,
  256. "documents": texts,
  257. "top_n": len(texts),
  258. }
  259. token_count = 0
  260. for t in texts:
  261. token_count += num_tokens_from_string(t)
  262. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  263. rank = np.zeros(len(texts), dtype=float)
  264. if 'results' not in res:
  265. raise ValueError("response not contains results\n" + str(res))
  266. for d in res["results"]:
  267. rank[d["index"]] = d["relevance_score"]
  268. # Normalize the rank values to the range 0 to 1
  269. min_rank = np.min(rank)
  270. max_rank = np.max(rank)
  271. # Avoid division by zero if all ranks are identical
  272. if max_rank - min_rank != 0:
  273. rank = (rank - min_rank) / (max_rank - min_rank)
  274. else:
  275. rank = np.zeros_like(rank)
  276. return rank, token_count
  277. class CoHereRerank(Base):
  278. def __init__(self, key, model_name, base_url=None):
  279. from cohere import Client
  280. self.client = Client(api_key=key)
  281. self.model_name = model_name
  282. def similarity(self, query: str, texts: list):
  283. token_count = num_tokens_from_string(query) + sum(
  284. [num_tokens_from_string(t) for t in texts]
  285. )
  286. res = self.client.rerank(
  287. model=self.model_name,
  288. query=query,
  289. documents=texts,
  290. top_n=len(texts),
  291. return_documents=False,
  292. )
  293. rank = np.zeros(len(texts), dtype=float)
  294. for d in res.results:
  295. rank[d.index] = d.relevance_score
  296. return rank, token_count
  297. class TogetherAIRerank(Base):
  298. def __init__(self, key, model_name, base_url):
  299. pass
  300. def similarity(self, query: str, texts: list):
  301. raise NotImplementedError("The api has not been implement")
  302. class SILICONFLOWRerank(Base):
  303. def __init__(
  304. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  305. ):
  306. if not base_url:
  307. base_url = "https://api.siliconflow.cn/v1/rerank"
  308. self.model_name = model_name
  309. self.base_url = base_url
  310. self.headers = {
  311. "accept": "application/json",
  312. "content-type": "application/json",
  313. "authorization": f"Bearer {key}",
  314. }
  315. def similarity(self, query: str, texts: list):
  316. payload = {
  317. "model": self.model_name,
  318. "query": query,
  319. "documents": texts,
  320. "top_n": len(texts),
  321. "return_documents": False,
  322. "max_chunks_per_doc": 1024,
  323. "overlap_tokens": 80,
  324. }
  325. response = requests.post(
  326. self.base_url, json=payload, headers=self.headers
  327. ).json()
  328. rank = np.zeros(len(texts), dtype=float)
  329. if "results" not in response:
  330. return rank, 0
  331. for d in response["results"]:
  332. rank[d["index"]] = d["relevance_score"]
  333. return (
  334. rank,
  335. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  336. )
  337. class BaiduYiyanRerank(Base):
  338. def __init__(self, key, model_name, base_url=None):
  339. from qianfan.resources import Reranker
  340. key = json.loads(key)
  341. ak = key.get("yiyan_ak", "")
  342. sk = key.get("yiyan_sk", "")
  343. self.client = Reranker(ak=ak, sk=sk)
  344. self.model_name = model_name
  345. def similarity(self, query: str, texts: list):
  346. res = self.client.do(
  347. model=self.model_name,
  348. query=query,
  349. documents=texts,
  350. top_n=len(texts),
  351. ).body
  352. rank = np.zeros(len(texts), dtype=float)
  353. for d in res["results"]:
  354. rank[d["index"]] = d["relevance_score"]
  355. return rank, res["usage"]["total_tokens"]
  356. class VoyageRerank(Base):
  357. def __init__(self, key, model_name, base_url=None):
  358. import voyageai
  359. self.client = voyageai.Client(api_key=key)
  360. self.model_name = model_name
  361. def similarity(self, query: str, texts: list):
  362. res = self.client.rerank(
  363. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  364. )
  365. rank = np.zeros(len(texts), dtype=float)
  366. for r in res.results:
  367. rank[r.index] = r.relevance_score
  368. return rank, res.total_tokens
  369. class QWenRerank(Base):
  370. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  371. import dashscope
  372. self.api_key = key
  373. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  374. def similarity(self, query: str, texts: list):
  375. import dashscope
  376. from http import HTTPStatus
  377. resp = dashscope.TextReRank.call(
  378. api_key=self.api_key,
  379. model=self.model_name,
  380. query=query,
  381. documents=texts,
  382. top_n=len(texts),
  383. return_documents=False
  384. )
  385. rank = np.zeros(len(texts), dtype=float)
  386. if resp.status_code == HTTPStatus.OK:
  387. for r in resp.output.results:
  388. rank[r.index] = r.relevance_score
  389. return rank, resp.usage.total_tokens
  390. else:
  391. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")