您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rerank_model.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import re
  19. import threading
  20. from abc import ABC
  21. from collections.abc import Iterable
  22. from urllib.parse import urljoin
  23. import httpx
  24. import numpy as np
  25. import requests
  26. from huggingface_hub import snapshot_download
  27. from yarl import URL
  28. from api import settings
  29. from api.utils.file_utils import get_home_cache_dir
  30. from api.utils.log_utils import log_exception
  31. from rag.utils import num_tokens_from_string, truncate
  32. class Base(ABC):
  33. def __init__(self, key, model_name, **kwargs):
  34. """
  35. Abstract base class constructor.
  36. Parameters are not stored; initialization is left to subclasses.
  37. """
  38. pass
  39. def similarity(self, query: str, texts: list):
  40. raise NotImplementedError("Please implement encode method!")
  41. def total_token_count(self, resp):
  42. try:
  43. return resp.usage.total_tokens
  44. except Exception:
  45. pass
  46. try:
  47. return resp["usage"]["total_tokens"]
  48. except Exception:
  49. pass
  50. return 0
  51. class DefaultRerank(Base):
  52. _FACTORY_NAME = "BAAI"
  53. _model = None
  54. _model_lock = threading.Lock()
  55. def __init__(self, key, model_name, **kwargs):
  56. """
  57. If you have trouble downloading HuggingFace models, -_^ this might help!!
  58. For Linux:
  59. export HF_ENDPOINT=https://hf-mirror.com
  60. For Windows:
  61. Good luck
  62. ^_-
  63. """
  64. if not settings.LIGHTEN and not DefaultRerank._model:
  65. import torch
  66. from FlagEmbedding import FlagReranker
  67. with DefaultRerank._model_lock:
  68. if not DefaultRerank._model:
  69. try:
  70. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  71. except Exception:
  72. model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
  73. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  74. self._model = DefaultRerank._model
  75. self._dynamic_batch_size = 8
  76. self._min_batch_size = 1
  77. def torch_empty_cache(self):
  78. try:
  79. import torch
  80. torch.cuda.empty_cache()
  81. except Exception as e:
  82. log_exception(e)
  83. def _process_batch(self, pairs, max_batch_size=None):
  84. """template method for subclass call"""
  85. old_dynamic_batch_size = self._dynamic_batch_size
  86. if max_batch_size is not None:
  87. self._dynamic_batch_size = max_batch_size
  88. res = np.array([], dtype=float)
  89. i = 0
  90. while i < len(pairs):
  91. cur_i = i
  92. current_batch = self._dynamic_batch_size
  93. max_retries = 5
  94. retry_count = 0
  95. while retry_count < max_retries:
  96. try:
  97. # call subclass implemented batch processing calculation
  98. batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
  99. res = np.append(res, batch_scores)
  100. i += current_batch
  101. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  102. break
  103. except RuntimeError as e:
  104. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  105. current_batch = max(current_batch // 2, self._min_batch_size)
  106. self.torch_empty_cache()
  107. i = cur_i # reset i to the start of the current batch
  108. retry_count += 1
  109. else:
  110. raise
  111. if retry_count >= max_retries:
  112. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  113. self.torch_empty_cache()
  114. self._dynamic_batch_size = old_dynamic_batch_size
  115. return np.array(res)
  116. def _compute_batch_scores(self, batch_pairs, max_length=None):
  117. if max_length is None:
  118. scores = self._model.compute_score(batch_pairs, normalize=True)
  119. else:
  120. scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
  121. if not isinstance(scores, Iterable):
  122. scores = [scores]
  123. return scores
  124. def similarity(self, query: str, texts: list):
  125. pairs = [(query, truncate(t, 2048)) for t in texts]
  126. token_count = 0
  127. for _, t in pairs:
  128. token_count += num_tokens_from_string(t)
  129. batch_size = 4096
  130. res = self._process_batch(pairs, max_batch_size=batch_size)
  131. return np.array(res), token_count
  132. class JinaRerank(Base):
  133. _FACTORY_NAME = "Jina"
  134. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
  135. self.base_url = "https://api.jina.ai/v1/rerank"
  136. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  137. self.model_name = model_name
  138. def similarity(self, query: str, texts: list):
  139. texts = [truncate(t, 8196) for t in texts]
  140. data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
  141. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  142. rank = np.zeros(len(texts), dtype=float)
  143. try:
  144. for d in res["results"]:
  145. rank[d["index"]] = d["relevance_score"]
  146. except Exception as _e:
  147. log_exception(_e, res)
  148. return rank, self.total_token_count(res)
  149. class YoudaoRerank(DefaultRerank):
  150. _FACTORY_NAME = "Youdao"
  151. _model = None
  152. _model_lock = threading.Lock()
  153. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  154. if not settings.LIGHTEN and not YoudaoRerank._model:
  155. from BCEmbedding import RerankerModel
  156. with YoudaoRerank._model_lock:
  157. if not YoudaoRerank._model:
  158. try:
  159. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  160. except Exception:
  161. YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
  162. self._model = YoudaoRerank._model
  163. self._dynamic_batch_size = 8
  164. self._min_batch_size = 1
  165. def similarity(self, query: str, texts: list):
  166. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  167. token_count = 0
  168. for _, t in pairs:
  169. token_count += num_tokens_from_string(t)
  170. batch_size = 8
  171. res = self._process_batch(pairs, max_batch_size=batch_size)
  172. return np.array(res), token_count
  173. class XInferenceRerank(Base):
  174. _FACTORY_NAME = "Xinference"
  175. def __init__(self, key="x", model_name="", base_url=""):
  176. if base_url.find("/v1") == -1:
  177. base_url = urljoin(base_url, "/v1/rerank")
  178. if base_url.find("/rerank") == -1:
  179. base_url = urljoin(base_url, "/v1/rerank")
  180. self.model_name = model_name
  181. self.base_url = base_url
  182. self.headers = {"Content-Type": "application/json", "accept": "application/json"}
  183. if key and key != "x":
  184. self.headers["Authorization"] = f"Bearer {key}"
  185. def similarity(self, query: str, texts: list):
  186. if len(texts) == 0:
  187. return np.array([]), 0
  188. pairs = [(query, truncate(t, 4096)) for t in texts]
  189. token_count = 0
  190. for _, t in pairs:
  191. token_count += num_tokens_from_string(t)
  192. data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
  193. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  194. rank = np.zeros(len(texts), dtype=float)
  195. try:
  196. for d in res["results"]:
  197. rank[d["index"]] = d["relevance_score"]
  198. except Exception as _e:
  199. log_exception(_e, res)
  200. return rank, token_count
  201. class LocalAIRerank(Base):
  202. _FACTORY_NAME = "LocalAI"
  203. def __init__(self, key, model_name, base_url):
  204. if base_url.find("/rerank") == -1:
  205. self.base_url = urljoin(base_url, "/rerank")
  206. else:
  207. self.base_url = base_url
  208. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  209. self.model_name = model_name.split("___")[0]
  210. def similarity(self, query: str, texts: list):
  211. # noway to config Ragflow , use fix setting
  212. texts = [truncate(t, 500) for t in texts]
  213. data = {
  214. "model": self.model_name,
  215. "query": query,
  216. "documents": texts,
  217. "top_n": len(texts),
  218. }
  219. token_count = 0
  220. for t in texts:
  221. token_count += num_tokens_from_string(t)
  222. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  223. rank = np.zeros(len(texts), dtype=float)
  224. try:
  225. for d in res["results"]:
  226. rank[d["index"]] = d["relevance_score"]
  227. except Exception as _e:
  228. log_exception(_e, res)
  229. # Normalize the rank values to the range 0 to 1
  230. min_rank = np.min(rank)
  231. max_rank = np.max(rank)
  232. # Avoid division by zero if all ranks are identical
  233. if max_rank - min_rank != 0:
  234. rank = (rank - min_rank) / (max_rank - min_rank)
  235. else:
  236. rank = np.zeros_like(rank)
  237. return rank, token_count
  238. class NvidiaRerank(Base):
  239. _FACTORY_NAME = "NVIDIA"
  240. def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
  241. if not base_url:
  242. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  243. self.model_name = model_name
  244. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  245. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
  246. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  247. self.base_url = urljoin(base_url, "reranking")
  248. self.model_name = "nv-rerank-qa-mistral-4b:1"
  249. self.headers = {
  250. "accept": "application/json",
  251. "Content-Type": "application/json",
  252. "Authorization": f"Bearer {key}",
  253. }
  254. def similarity(self, query: str, texts: list):
  255. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  256. data = {
  257. "model": self.model_name,
  258. "query": {"text": query},
  259. "passages": [{"text": text} for text in texts],
  260. "truncate": "END",
  261. "top_n": len(texts),
  262. }
  263. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  264. rank = np.zeros(len(texts), dtype=float)
  265. try:
  266. for d in res["rankings"]:
  267. rank[d["index"]] = d["logit"]
  268. except Exception as _e:
  269. log_exception(_e, res)
  270. return rank, token_count
  271. class LmStudioRerank(Base):
  272. _FACTORY_NAME = "LM-Studio"
  273. def __init__(self, key, model_name, base_url, **kwargs):
  274. pass
  275. def similarity(self, query: str, texts: list):
  276. raise NotImplementedError("The LmStudioRerank has not been implement")
  277. class OpenAI_APIRerank(Base):
  278. _FACTORY_NAME = "OpenAI-API-Compatible"
  279. def __init__(self, key, model_name, base_url):
  280. if base_url.find("/rerank") == -1:
  281. self.base_url = urljoin(base_url, "/rerank")
  282. else:
  283. self.base_url = base_url
  284. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  285. self.model_name = model_name.split("___")[0]
  286. def similarity(self, query: str, texts: list):
  287. # noway to config Ragflow , use fix setting
  288. texts = [truncate(t, 500) for t in texts]
  289. data = {
  290. "model": self.model_name,
  291. "query": query,
  292. "documents": texts,
  293. "top_n": len(texts),
  294. }
  295. token_count = 0
  296. for t in texts:
  297. token_count += num_tokens_from_string(t)
  298. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  299. rank = np.zeros(len(texts), dtype=float)
  300. try:
  301. for d in res["results"]:
  302. rank[d["index"]] = d["relevance_score"]
  303. except Exception as _e:
  304. log_exception(_e, res)
  305. # Normalize the rank values to the range 0 to 1
  306. min_rank = np.min(rank)
  307. max_rank = np.max(rank)
  308. # Avoid division by zero if all ranks are identical
  309. if np.isclose(min_rank, max_rank, atol=1e-3):
  310. rank = (rank - min_rank) / (max_rank - min_rank)
  311. else:
  312. rank = np.zeros_like(rank)
  313. return rank, token_count
  314. class CoHereRerank(Base):
  315. _FACTORY_NAME = ["Cohere", "VLLM"]
  316. def __init__(self, key, model_name, base_url=None):
  317. from cohere import Client
  318. self.client = Client(api_key=key, base_url=base_url)
  319. self.model_name = model_name.split("___")[0]
  320. def similarity(self, query: str, texts: list):
  321. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  322. res = self.client.rerank(
  323. model=self.model_name,
  324. query=query,
  325. documents=texts,
  326. top_n=len(texts),
  327. return_documents=False,
  328. )
  329. rank = np.zeros(len(texts), dtype=float)
  330. try:
  331. for d in res.results:
  332. rank[d.index] = d.relevance_score
  333. except Exception as _e:
  334. log_exception(_e, res)
  335. return rank, token_count
  336. class TogetherAIRerank(Base):
  337. _FACTORY_NAME = "TogetherAI"
  338. def __init__(self, key, model_name, base_url, **kwargs):
  339. pass
  340. def similarity(self, query: str, texts: list):
  341. raise NotImplementedError("The api has not been implement")
  342. class SILICONFLOWRerank(Base):
  343. _FACTORY_NAME = "SILICONFLOW"
  344. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
  345. if not base_url:
  346. base_url = "https://api.siliconflow.cn/v1/rerank"
  347. self.model_name = model_name
  348. self.base_url = base_url
  349. self.headers = {
  350. "accept": "application/json",
  351. "content-type": "application/json",
  352. "authorization": f"Bearer {key}",
  353. }
  354. def similarity(self, query: str, texts: list):
  355. payload = {
  356. "model": self.model_name,
  357. "query": query,
  358. "documents": texts,
  359. "top_n": len(texts),
  360. "return_documents": False,
  361. "max_chunks_per_doc": 1024,
  362. "overlap_tokens": 80,
  363. }
  364. response = requests.post(self.base_url, json=payload, headers=self.headers).json()
  365. rank = np.zeros(len(texts), dtype=float)
  366. try:
  367. for d in response["results"]:
  368. rank[d["index"]] = d["relevance_score"]
  369. except Exception as _e:
  370. log_exception(_e, response)
  371. return (
  372. rank,
  373. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  374. )
  375. class BaiduYiyanRerank(Base):
  376. _FACTORY_NAME = "BaiduYiyan"
  377. def __init__(self, key, model_name, base_url=None):
  378. from qianfan.resources import Reranker
  379. key = json.loads(key)
  380. ak = key.get("yiyan_ak", "")
  381. sk = key.get("yiyan_sk", "")
  382. self.client = Reranker(ak=ak, sk=sk)
  383. self.model_name = model_name
  384. def similarity(self, query: str, texts: list):
  385. res = self.client.do(
  386. model=self.model_name,
  387. query=query,
  388. documents=texts,
  389. top_n=len(texts),
  390. ).body
  391. rank = np.zeros(len(texts), dtype=float)
  392. try:
  393. for d in res["results"]:
  394. rank[d["index"]] = d["relevance_score"]
  395. except Exception as _e:
  396. log_exception(_e, res)
  397. return rank, self.total_token_count(res)
  398. class VoyageRerank(Base):
  399. _FACTORY_NAME = "Voyage AI"
  400. def __init__(self, key, model_name, base_url=None):
  401. import voyageai
  402. self.client = voyageai.Client(api_key=key)
  403. self.model_name = model_name
  404. def similarity(self, query: str, texts: list):
  405. rank = np.zeros(len(texts), dtype=float)
  406. if not texts:
  407. return rank, 0
  408. res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
  409. try:
  410. for r in res.results:
  411. rank[r.index] = r.relevance_score
  412. except Exception as _e:
  413. log_exception(_e, res)
  414. return rank, res.total_tokens
  415. class QWenRerank(Base):
  416. _FACTORY_NAME = "Tongyi-Qianwen"
  417. def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
  418. import dashscope
  419. self.api_key = key
  420. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  421. def similarity(self, query: str, texts: list):
  422. from http import HTTPStatus
  423. import dashscope
  424. resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
  425. rank = np.zeros(len(texts), dtype=float)
  426. if resp.status_code == HTTPStatus.OK:
  427. try:
  428. for r in resp.output.results:
  429. rank[r.index] = r.relevance_score
  430. except Exception as _e:
  431. log_exception(_e, resp)
  432. return rank, resp.usage.total_tokens
  433. else:
  434. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  435. class HuggingfaceRerank(DefaultRerank):
  436. _FACTORY_NAME = "HuggingFace"
  437. @staticmethod
  438. def post(query: str, texts: list, url="127.0.0.1"):
  439. exc = None
  440. scores = [0 for _ in range(len(texts))]
  441. batch_size = 8
  442. for i in range(0, len(texts), batch_size):
  443. try:
  444. res = requests.post(
  445. f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
  446. )
  447. for o in res.json():
  448. scores[o["index"] + i] = o["score"]
  449. except Exception as e:
  450. exc = e
  451. if exc:
  452. raise exc
  453. return np.array(scores)
  454. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  455. self.model_name = model_name.split("___")[0]
  456. self.base_url = base_url
  457. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  458. if not texts:
  459. return np.array([]), 0
  460. token_count = 0
  461. for t in texts:
  462. token_count += num_tokens_from_string(t)
  463. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  464. class GPUStackRerank(Base):
  465. _FACTORY_NAME = "GPUStack"
  466. def __init__(self, key, model_name, base_url):
  467. if not base_url:
  468. raise ValueError("url cannot be None")
  469. self.model_name = model_name
  470. self.base_url = str(URL(base_url) / "v1" / "rerank")
  471. self.headers = {
  472. "accept": "application/json",
  473. "content-type": "application/json",
  474. "authorization": f"Bearer {key}",
  475. }
  476. def similarity(self, query: str, texts: list):
  477. payload = {
  478. "model": self.model_name,
  479. "query": query,
  480. "documents": texts,
  481. "top_n": len(texts),
  482. }
  483. try:
  484. response = requests.post(self.base_url, json=payload, headers=self.headers)
  485. response.raise_for_status()
  486. response_json = response.json()
  487. rank = np.zeros(len(texts), dtype=float)
  488. token_count = 0
  489. for t in texts:
  490. token_count += num_tokens_from_string(t)
  491. try:
  492. for result in response_json["results"]:
  493. rank[result["index"]] = result["relevance_score"]
  494. except Exception as _e:
  495. log_exception(_e, response)
  496. return (
  497. rank,
  498. token_count,
  499. )
  500. except httpx.HTTPStatusError as e:
  501. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  502. class NovitaRerank(JinaRerank):
  503. _FACTORY_NAME = "NovitaAI"
  504. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  505. if not base_url:
  506. base_url = "https://api.novita.ai/v3/openai/rerank"
  507. super().__init__(key, model_name, base_url)
  508. class GiteeRerank(JinaRerank):
  509. _FACTORY_NAME = "GiteeAI"
  510. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
  511. if not base_url:
  512. base_url = "https://ai.gitee.com/v1/rerank"
  513. super().__init__(key, model_name, base_url)
  514. class Ai302Rerank(Base):
  515. _FACTORY_NAME = "302.AI"
  516. def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"):
  517. if not base_url:
  518. base_url = "https://api.302.ai/v1/rerank"
  519. super().__init__(key, model_name, base_url)