Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

rerank_model.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import re
  19. import threading
  20. from abc import ABC
  21. from collections.abc import Iterable
  22. from urllib.parse import urljoin
  23. import httpx
  24. import numpy as np
  25. import requests
  26. from huggingface_hub import snapshot_download
  27. from yarl import URL
  28. from api import settings
  29. from api.utils.file_utils import get_home_cache_dir
  30. from api.utils.log_utils import log_exception
  31. from rag.utils import num_tokens_from_string, truncate
  32. class Base(ABC):
  33. def __init__(self, key, model_name, **kwargs):
  34. """
  35. Abstract base class constructor.
  36. Parameters are not stored; initialization is left to subclasses.
  37. """
  38. pass
  39. def similarity(self, query: str, texts: list):
  40. raise NotImplementedError("Please implement encode method!")
  41. def total_token_count(self, resp):
  42. if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
  43. try:
  44. return resp.usage.total_tokens
  45. except Exception:
  46. pass
  47. if 'usage' in resp and 'total_tokens' in resp['usage']:
  48. try:
  49. return resp["usage"]["total_tokens"]
  50. except Exception:
  51. pass
  52. return 0
  53. class DefaultRerank(Base):
  54. _FACTORY_NAME = "BAAI"
  55. _model = None
  56. _model_lock = threading.Lock()
  57. def __init__(self, key, model_name, **kwargs):
  58. """
  59. If you have trouble downloading HuggingFace models, -_^ this might help!!
  60. For Linux:
  61. export HF_ENDPOINT=https://hf-mirror.com
  62. For Windows:
  63. Good luck
  64. ^_-
  65. """
  66. if not settings.LIGHTEN and not DefaultRerank._model:
  67. import torch
  68. from FlagEmbedding import FlagReranker
  69. with DefaultRerank._model_lock:
  70. if not DefaultRerank._model:
  71. try:
  72. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  73. except Exception:
  74. model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
  75. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  76. self._model = DefaultRerank._model
  77. self._dynamic_batch_size = 8
  78. self._min_batch_size = 1
  79. def torch_empty_cache(self):
  80. try:
  81. import torch
  82. torch.cuda.empty_cache()
  83. except Exception as e:
  84. log_exception(e)
  85. def _process_batch(self, pairs, max_batch_size=None):
  86. """template method for subclass call"""
  87. old_dynamic_batch_size = self._dynamic_batch_size
  88. if max_batch_size is not None:
  89. self._dynamic_batch_size = max_batch_size
  90. res = np.array(len(pairs), dtype=float)
  91. i = 0
  92. while i < len(pairs):
  93. cur_i = i
  94. current_batch = self._dynamic_batch_size
  95. max_retries = 5
  96. retry_count = 0
  97. while retry_count < max_retries:
  98. try:
  99. # call subclass implemented batch processing calculation
  100. batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
  101. res[i : i + current_batch] = batch_scores
  102. i += current_batch
  103. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  104. break
  105. except RuntimeError as e:
  106. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  107. current_batch = max(current_batch // 2, self._min_batch_size)
  108. self.torch_empty_cache()
  109. i = cur_i # reset i to the start of the current batch
  110. retry_count += 1
  111. else:
  112. raise
  113. if retry_count >= max_retries:
  114. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  115. self.torch_empty_cache()
  116. self._dynamic_batch_size = old_dynamic_batch_size
  117. return np.array(res)
  118. def _compute_batch_scores(self, batch_pairs, max_length=None):
  119. if max_length is None:
  120. scores = self._model.compute_score(batch_pairs, normalize=True)
  121. else:
  122. scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
  123. if not isinstance(scores, Iterable):
  124. scores = [scores]
  125. return scores
  126. def similarity(self, query: str, texts: list):
  127. pairs = [(query, truncate(t, 2048)) for t in texts]
  128. token_count = 0
  129. for _, t in pairs:
  130. token_count += num_tokens_from_string(t)
  131. batch_size = 4096
  132. res = self._process_batch(pairs, max_batch_size=batch_size)
  133. return np.array(res), token_count
  134. class JinaRerank(Base):
  135. _FACTORY_NAME = "Jina"
  136. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
  137. self.base_url = "https://api.jina.ai/v1/rerank"
  138. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  139. self.model_name = model_name
  140. def similarity(self, query: str, texts: list):
  141. texts = [truncate(t, 8196) for t in texts]
  142. data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
  143. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  144. rank = np.zeros(len(texts), dtype=float)
  145. try:
  146. for d in res["results"]:
  147. rank[d["index"]] = d["relevance_score"]
  148. except Exception as _e:
  149. log_exception(_e, res)
  150. return rank, self.total_token_count(res)
  151. class YoudaoRerank(DefaultRerank):
  152. _FACTORY_NAME = "Youdao"
  153. _model = None
  154. _model_lock = threading.Lock()
  155. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  156. if not settings.LIGHTEN and not YoudaoRerank._model:
  157. from BCEmbedding import RerankerModel
  158. with YoudaoRerank._model_lock:
  159. if not YoudaoRerank._model:
  160. try:
  161. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  162. except Exception:
  163. YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
  164. self._model = YoudaoRerank._model
  165. self._dynamic_batch_size = 8
  166. self._min_batch_size = 1
  167. def similarity(self, query: str, texts: list):
  168. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  169. token_count = 0
  170. for _, t in pairs:
  171. token_count += num_tokens_from_string(t)
  172. batch_size = 8
  173. res = self._process_batch(pairs, max_batch_size=batch_size)
  174. return np.array(res), token_count
  175. class XInferenceRerank(Base):
  176. _FACTORY_NAME = "Xinference"
  177. def __init__(self, key="x", model_name="", base_url=""):
  178. if base_url.find("/v1") == -1:
  179. base_url = urljoin(base_url, "/v1/rerank")
  180. if base_url.find("/rerank") == -1:
  181. base_url = urljoin(base_url, "/v1/rerank")
  182. self.model_name = model_name
  183. self.base_url = base_url
  184. self.headers = {"Content-Type": "application/json", "accept": "application/json"}
  185. if key and key != "x":
  186. self.headers["Authorization"] = f"Bearer {key}"
  187. def similarity(self, query: str, texts: list):
  188. if len(texts) == 0:
  189. return np.array([]), 0
  190. pairs = [(query, truncate(t, 4096)) for t in texts]
  191. token_count = 0
  192. for _, t in pairs:
  193. token_count += num_tokens_from_string(t)
  194. data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
  195. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  196. rank = np.zeros(len(texts), dtype=float)
  197. try:
  198. for d in res["results"]:
  199. rank[d["index"]] = d["relevance_score"]
  200. except Exception as _e:
  201. log_exception(_e, res)
  202. return rank, token_count
  203. class LocalAIRerank(Base):
  204. _FACTORY_NAME = "LocalAI"
  205. def __init__(self, key, model_name, base_url):
  206. if base_url.find("/rerank") == -1:
  207. self.base_url = urljoin(base_url, "/rerank")
  208. else:
  209. self.base_url = base_url
  210. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  211. self.model_name = model_name.split("___")[0]
  212. def similarity(self, query: str, texts: list):
  213. # noway to config Ragflow , use fix setting
  214. texts = [truncate(t, 500) for t in texts]
  215. data = {
  216. "model": self.model_name,
  217. "query": query,
  218. "documents": texts,
  219. "top_n": len(texts),
  220. }
  221. token_count = 0
  222. for t in texts:
  223. token_count += num_tokens_from_string(t)
  224. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  225. rank = np.zeros(len(texts), dtype=float)
  226. try:
  227. for d in res["results"]:
  228. rank[d["index"]] = d["relevance_score"]
  229. except Exception as _e:
  230. log_exception(_e, res)
  231. # Normalize the rank values to the range 0 to 1
  232. min_rank = np.min(rank)
  233. max_rank = np.max(rank)
  234. # Avoid division by zero if all ranks are identical
  235. if not np.isclose(min_rank, max_rank, atol=1e-3):
  236. rank = (rank - min_rank) / (max_rank - min_rank)
  237. else:
  238. rank = np.zeros_like(rank)
  239. return rank, token_count
  240. class NvidiaRerank(Base):
  241. _FACTORY_NAME = "NVIDIA"
  242. def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
  243. if not base_url:
  244. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  245. self.model_name = model_name
  246. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  247. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
  248. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  249. self.base_url = urljoin(base_url, "reranking")
  250. self.model_name = "nv-rerank-qa-mistral-4b:1"
  251. self.headers = {
  252. "accept": "application/json",
  253. "Content-Type": "application/json",
  254. "Authorization": f"Bearer {key}",
  255. }
  256. def similarity(self, query: str, texts: list):
  257. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  258. data = {
  259. "model": self.model_name,
  260. "query": {"text": query},
  261. "passages": [{"text": text} for text in texts],
  262. "truncate": "END",
  263. "top_n": len(texts),
  264. }
  265. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  266. rank = np.zeros(len(texts), dtype=float)
  267. try:
  268. for d in res["rankings"]:
  269. rank[d["index"]] = d["logit"]
  270. except Exception as _e:
  271. log_exception(_e, res)
  272. return rank, token_count
  273. class LmStudioRerank(Base):
  274. _FACTORY_NAME = "LM-Studio"
  275. def __init__(self, key, model_name, base_url, **kwargs):
  276. pass
  277. def similarity(self, query: str, texts: list):
  278. raise NotImplementedError("The LmStudioRerank has not been implement")
  279. class OpenAI_APIRerank(Base):
  280. _FACTORY_NAME = "OpenAI-API-Compatible"
  281. def __init__(self, key, model_name, base_url):
  282. if base_url.find("/rerank") == -1:
  283. self.base_url = urljoin(base_url, "/rerank")
  284. else:
  285. self.base_url = base_url
  286. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  287. self.model_name = model_name.split("___")[0]
  288. def similarity(self, query: str, texts: list):
  289. # noway to config Ragflow , use fix setting
  290. texts = [truncate(t, 500) for t in texts]
  291. data = {
  292. "model": self.model_name,
  293. "query": query,
  294. "documents": texts,
  295. "top_n": len(texts),
  296. }
  297. token_count = 0
  298. for t in texts:
  299. token_count += num_tokens_from_string(t)
  300. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  301. rank = np.zeros(len(texts), dtype=float)
  302. try:
  303. for d in res["results"]:
  304. rank[d["index"]] = d["relevance_score"]
  305. except Exception as _e:
  306. log_exception(_e, res)
  307. # Normalize the rank values to the range 0 to 1
  308. min_rank = np.min(rank)
  309. max_rank = np.max(rank)
  310. # Avoid division by zero if all ranks are identical
  311. if np.isclose(min_rank, max_rank, atol=1e-3):
  312. rank = (rank - min_rank) / (max_rank - min_rank)
  313. else:
  314. rank = np.zeros_like(rank)
  315. return rank, token_count
  316. class CoHereRerank(Base):
  317. _FACTORY_NAME = ["Cohere", "VLLM"]
  318. def __init__(self, key, model_name, base_url=None):
  319. from cohere import Client
  320. self.client = Client(api_key=key, base_url=base_url)
  321. self.model_name = model_name.split("___")[0]
  322. def similarity(self, query: str, texts: list):
  323. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  324. res = self.client.rerank(
  325. model=self.model_name,
  326. query=query,
  327. documents=texts,
  328. top_n=len(texts),
  329. return_documents=False,
  330. )
  331. rank = np.zeros(len(texts), dtype=float)
  332. try:
  333. for d in res.results:
  334. rank[d.index] = d.relevance_score
  335. except Exception as _e:
  336. log_exception(_e, res)
  337. return rank, token_count
  338. class TogetherAIRerank(Base):
  339. _FACTORY_NAME = "TogetherAI"
  340. def __init__(self, key, model_name, base_url, **kwargs):
  341. pass
  342. def similarity(self, query: str, texts: list):
  343. raise NotImplementedError("The api has not been implement")
  344. class SILICONFLOWRerank(Base):
  345. _FACTORY_NAME = "SILICONFLOW"
  346. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
  347. if not base_url:
  348. base_url = "https://api.siliconflow.cn/v1/rerank"
  349. self.model_name = model_name
  350. self.base_url = base_url
  351. self.headers = {
  352. "accept": "application/json",
  353. "content-type": "application/json",
  354. "authorization": f"Bearer {key}",
  355. }
  356. def similarity(self, query: str, texts: list):
  357. payload = {
  358. "model": self.model_name,
  359. "query": query,
  360. "documents": texts,
  361. "top_n": len(texts),
  362. "return_documents": False,
  363. "max_chunks_per_doc": 1024,
  364. "overlap_tokens": 80,
  365. }
  366. response = requests.post(self.base_url, json=payload, headers=self.headers).json()
  367. rank = np.zeros(len(texts), dtype=float)
  368. try:
  369. for d in response["results"]:
  370. rank[d["index"]] = d["relevance_score"]
  371. except Exception as _e:
  372. log_exception(_e, response)
  373. return (
  374. rank,
  375. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  376. )
  377. class BaiduYiyanRerank(Base):
  378. _FACTORY_NAME = "BaiduYiyan"
  379. def __init__(self, key, model_name, base_url=None):
  380. from qianfan.resources import Reranker
  381. key = json.loads(key)
  382. ak = key.get("yiyan_ak", "")
  383. sk = key.get("yiyan_sk", "")
  384. self.client = Reranker(ak=ak, sk=sk)
  385. self.model_name = model_name
  386. def similarity(self, query: str, texts: list):
  387. res = self.client.do(
  388. model=self.model_name,
  389. query=query,
  390. documents=texts,
  391. top_n=len(texts),
  392. ).body
  393. rank = np.zeros(len(texts), dtype=float)
  394. try:
  395. for d in res["results"]:
  396. rank[d["index"]] = d["relevance_score"]
  397. except Exception as _e:
  398. log_exception(_e, res)
  399. return rank, self.total_token_count(res)
  400. class VoyageRerank(Base):
  401. _FACTORY_NAME = "Voyage AI"
  402. def __init__(self, key, model_name, base_url=None):
  403. import voyageai
  404. self.client = voyageai.Client(api_key=key)
  405. self.model_name = model_name
  406. def similarity(self, query: str, texts: list):
  407. if not texts:
  408. return np.array([]), 0
  409. rank = np.zeros(len(texts), dtype=float)
  410. res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
  411. try:
  412. for r in res.results:
  413. rank[r.index] = r.relevance_score
  414. except Exception as _e:
  415. log_exception(_e, res)
  416. return rank, res.total_tokens
  417. class QWenRerank(Base):
  418. _FACTORY_NAME = "Tongyi-Qianwen"
  419. def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
  420. import dashscope
  421. self.api_key = key
  422. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  423. def similarity(self, query: str, texts: list):
  424. from http import HTTPStatus
  425. import dashscope
  426. resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
  427. rank = np.zeros(len(texts), dtype=float)
  428. if resp.status_code == HTTPStatus.OK:
  429. try:
  430. for r in resp.output.results:
  431. rank[r.index] = r.relevance_score
  432. except Exception as _e:
  433. log_exception(_e, resp)
  434. return rank, resp.usage.total_tokens
  435. else:
  436. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  437. class HuggingfaceRerank(DefaultRerank):
  438. _FACTORY_NAME = "HuggingFace"
  439. @staticmethod
  440. def post(query: str, texts: list, url="127.0.0.1"):
  441. exc = None
  442. scores = [0 for _ in range(len(texts))]
  443. batch_size = 8
  444. for i in range(0, len(texts), batch_size):
  445. try:
  446. res = requests.post(
  447. f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
  448. )
  449. for o in res.json():
  450. scores[o["index"] + i] = o["score"]
  451. except Exception as e:
  452. exc = e
  453. if exc:
  454. raise exc
  455. return np.array(scores)
  456. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  457. self.model_name = model_name.split("___")[0]
  458. self.base_url = base_url
  459. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  460. if not texts:
  461. return np.array([]), 0
  462. token_count = 0
  463. for t in texts:
  464. token_count += num_tokens_from_string(t)
  465. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  466. class GPUStackRerank(Base):
  467. _FACTORY_NAME = "GPUStack"
  468. def __init__(self, key, model_name, base_url):
  469. if not base_url:
  470. raise ValueError("url cannot be None")
  471. self.model_name = model_name
  472. self.base_url = str(URL(base_url) / "v1" / "rerank")
  473. self.headers = {
  474. "accept": "application/json",
  475. "content-type": "application/json",
  476. "authorization": f"Bearer {key}",
  477. }
  478. def similarity(self, query: str, texts: list):
  479. payload = {
  480. "model": self.model_name,
  481. "query": query,
  482. "documents": texts,
  483. "top_n": len(texts),
  484. }
  485. try:
  486. response = requests.post(self.base_url, json=payload, headers=self.headers)
  487. response.raise_for_status()
  488. response_json = response.json()
  489. rank = np.zeros(len(texts), dtype=float)
  490. token_count = 0
  491. for t in texts:
  492. token_count += num_tokens_from_string(t)
  493. try:
  494. for result in response_json["results"]:
  495. rank[result["index"]] = result["relevance_score"]
  496. except Exception as _e:
  497. log_exception(_e, response)
  498. return (
  499. rank,
  500. token_count,
  501. )
  502. except httpx.HTTPStatusError as e:
  503. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  504. class NovitaRerank(JinaRerank):
  505. _FACTORY_NAME = "NovitaAI"
  506. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  507. if not base_url:
  508. base_url = "https://api.novita.ai/v3/openai/rerank"
  509. super().__init__(key, model_name, base_url)
  510. class GiteeRerank(JinaRerank):
  511. _FACTORY_NAME = "GiteeAI"
  512. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
  513. if not base_url:
  514. base_url = "https://ai.gitee.com/v1/rerank"
  515. super().__init__(key, model_name, base_url)
  516. class Ai302Rerank(Base):
  517. _FACTORY_NAME = "302.AI"
  518. def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
  519. if not base_url:
  520. base_url = "https://api.302.ai/v1/rerank"
  521. super().__init__(key, model_name, base_url)