您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rerank_model.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import re
  19. import threading
  20. from abc import ABC
  21. from collections.abc import Iterable
  22. from urllib.parse import urljoin
  23. import httpx
  24. import numpy as np
  25. import requests
  26. from huggingface_hub import snapshot_download
  27. from yarl import URL
  28. from api import settings
  29. from api.utils.file_utils import get_home_cache_dir
  30. from api.utils.log_utils import log_exception
  31. from rag.utils import num_tokens_from_string, truncate
  32. def sigmoid(x):
  33. return 1 / (1 + np.exp(-x))
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def similarity(self, query: str, texts: list):
  38. raise NotImplementedError("Please implement encode method!")
  39. def total_token_count(self, resp):
  40. try:
  41. return resp.usage.total_tokens
  42. except Exception:
  43. pass
  44. try:
  45. return resp["usage"]["total_tokens"]
  46. except Exception:
  47. pass
  48. return 0
  49. class DefaultRerank(Base):
  50. _FACTORY_NAME = "BAAI"
  51. _model = None
  52. _model_lock = threading.Lock()
  53. def __init__(self, key, model_name, **kwargs):
  54. """
  55. If you have trouble downloading HuggingFace models, -_^ this might help!!
  56. For Linux:
  57. export HF_ENDPOINT=https://hf-mirror.com
  58. For Windows:
  59. Good luck
  60. ^_-
  61. """
  62. if not settings.LIGHTEN and not DefaultRerank._model:
  63. import torch
  64. from FlagEmbedding import FlagReranker
  65. with DefaultRerank._model_lock:
  66. if not DefaultRerank._model:
  67. try:
  68. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  69. except Exception:
  70. model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
  71. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  72. self._model = DefaultRerank._model
  73. self._dynamic_batch_size = 8
  74. self._min_batch_size = 1
  75. def torch_empty_cache(self):
  76. try:
  77. import torch
  78. torch.cuda.empty_cache()
  79. except Exception as e:
  80. print(f"Error emptying cache: {e}")
  81. def _process_batch(self, pairs, max_batch_size=None):
  82. """template method for subclass call"""
  83. old_dynamic_batch_size = self._dynamic_batch_size
  84. if max_batch_size is not None:
  85. self._dynamic_batch_size = max_batch_size
  86. res = np.array([], dtype=float)
  87. i = 0
  88. while i < len(pairs):
  89. cur_i = i
  90. current_batch = self._dynamic_batch_size
  91. max_retries = 5
  92. retry_count = 0
  93. while retry_count < max_retries:
  94. try:
  95. # call subclass implemented batch processing calculation
  96. batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
  97. res = np.append(res, batch_scores)
  98. i += current_batch
  99. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  100. break
  101. except RuntimeError as e:
  102. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  103. current_batch = max(current_batch // 2, self._min_batch_size)
  104. self.torch_empty_cache()
  105. i = cur_i # reset i to the start of the current batch
  106. retry_count += 1
  107. else:
  108. raise
  109. if retry_count >= max_retries:
  110. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  111. self.torch_empty_cache()
  112. self._dynamic_batch_size = old_dynamic_batch_size
  113. return np.array(res)
  114. def _compute_batch_scores(self, batch_pairs, max_length=None):
  115. if max_length is None:
  116. scores = self._model.compute_score(batch_pairs)
  117. else:
  118. scores = self._model.compute_score(batch_pairs, max_length=max_length)
  119. scores = sigmoid(np.array(scores))
  120. if not isinstance(scores, Iterable):
  121. scores = [scores]
  122. return scores
  123. def similarity(self, query: str, texts: list):
  124. pairs = [(query, truncate(t, 2048)) for t in texts]
  125. token_count = 0
  126. for _, t in pairs:
  127. token_count += num_tokens_from_string(t)
  128. batch_size = 4096
  129. res = self._process_batch(pairs, max_batch_size=batch_size)
  130. return np.array(res), token_count
  131. class JinaRerank(Base):
  132. _FACTORY_NAME = "Jina"
  133. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
  134. self.base_url = "https://api.jina.ai/v1/rerank"
  135. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  136. self.model_name = model_name
  137. def similarity(self, query: str, texts: list):
  138. texts = [truncate(t, 8196) for t in texts]
  139. data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
  140. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  141. rank = np.zeros(len(texts), dtype=float)
  142. try:
  143. for d in res["results"]:
  144. rank[d["index"]] = d["relevance_score"]
  145. except Exception as _e:
  146. log_exception(_e, res)
  147. return rank, self.total_token_count(res)
  148. class YoudaoRerank(DefaultRerank):
  149. _FACTORY_NAME = "Youdao"
  150. _model = None
  151. _model_lock = threading.Lock()
  152. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  153. if not settings.LIGHTEN and not YoudaoRerank._model:
  154. from BCEmbedding import RerankerModel
  155. with YoudaoRerank._model_lock:
  156. if not YoudaoRerank._model:
  157. try:
  158. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  159. except Exception:
  160. YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
  161. self._model = YoudaoRerank._model
  162. self._dynamic_batch_size = 8
  163. self._min_batch_size = 1
  164. def similarity(self, query: str, texts: list):
  165. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  166. token_count = 0
  167. for _, t in pairs:
  168. token_count += num_tokens_from_string(t)
  169. batch_size = 8
  170. res = self._process_batch(pairs, max_batch_size=batch_size)
  171. return np.array(res), token_count
  172. class XInferenceRerank(Base):
  173. _FACTORY_NAME = "Xinference"
  174. def __init__(self, key="x", model_name="", base_url=""):
  175. if base_url.find("/v1") == -1:
  176. base_url = urljoin(base_url, "/v1/rerank")
  177. if base_url.find("/rerank") == -1:
  178. base_url = urljoin(base_url, "/v1/rerank")
  179. self.model_name = model_name
  180. self.base_url = base_url
  181. self.headers = {"Content-Type": "application/json", "accept": "application/json"}
  182. if key and key != "x":
  183. self.headers["Authorization"] = f"Bearer {key}"
  184. def similarity(self, query: str, texts: list):
  185. if len(texts) == 0:
  186. return np.array([]), 0
  187. pairs = [(query, truncate(t, 4096)) for t in texts]
  188. token_count = 0
  189. for _, t in pairs:
  190. token_count += num_tokens_from_string(t)
  191. data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
  192. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  193. rank = np.zeros(len(texts), dtype=float)
  194. try:
  195. for d in res["results"]:
  196. rank[d["index"]] = d["relevance_score"]
  197. except Exception as _e:
  198. log_exception(_e, res)
  199. return rank, token_count
  200. class LocalAIRerank(Base):
  201. _FACTORY_NAME = "LocalAI"
  202. def __init__(self, key, model_name, base_url):
  203. if base_url.find("/rerank") == -1:
  204. self.base_url = urljoin(base_url, "/rerank")
  205. else:
  206. self.base_url = base_url
  207. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  208. self.model_name = model_name.split("___")[0]
  209. def similarity(self, query: str, texts: list):
  210. # noway to config Ragflow , use fix setting
  211. texts = [truncate(t, 500) for t in texts]
  212. data = {
  213. "model": self.model_name,
  214. "query": query,
  215. "documents": texts,
  216. "top_n": len(texts),
  217. }
  218. token_count = 0
  219. for t in texts:
  220. token_count += num_tokens_from_string(t)
  221. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  222. rank = np.zeros(len(texts), dtype=float)
  223. try:
  224. for d in res["results"]:
  225. rank[d["index"]] = d["relevance_score"]
  226. except Exception as _e:
  227. log_exception(_e, res)
  228. # Normalize the rank values to the range 0 to 1
  229. min_rank = np.min(rank)
  230. max_rank = np.max(rank)
  231. # Avoid division by zero if all ranks are identical
  232. if max_rank - min_rank != 0:
  233. rank = (rank - min_rank) / (max_rank - min_rank)
  234. else:
  235. rank = np.zeros_like(rank)
  236. return rank, token_count
  237. class NvidiaRerank(Base):
  238. _FACTORY_NAME = "NVIDIA"
  239. def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
  240. if not base_url:
  241. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  242. self.model_name = model_name
  243. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  244. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
  245. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  246. self.base_url = urljoin(base_url, "reranking")
  247. self.model_name = "nv-rerank-qa-mistral-4b:1"
  248. self.headers = {
  249. "accept": "application/json",
  250. "Content-Type": "application/json",
  251. "Authorization": f"Bearer {key}",
  252. }
  253. def similarity(self, query: str, texts: list):
  254. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  255. data = {
  256. "model": self.model_name,
  257. "query": {"text": query},
  258. "passages": [{"text": text} for text in texts],
  259. "truncate": "END",
  260. "top_n": len(texts),
  261. }
  262. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  263. rank = np.zeros(len(texts), dtype=float)
  264. try:
  265. for d in res["rankings"]:
  266. rank[d["index"]] = d["logit"]
  267. except Exception as _e:
  268. log_exception(_e, res)
  269. return rank, token_count
  270. class LmStudioRerank(Base):
  271. _FACTORY_NAME = "LM-Studio"
  272. def __init__(self, key, model_name, base_url):
  273. pass
  274. def similarity(self, query: str, texts: list):
  275. raise NotImplementedError("The LmStudioRerank has not been implement")
  276. class OpenAI_APIRerank(Base):
  277. _FACTORY_NAME = "OpenAI-API-Compatible"
  278. def __init__(self, key, model_name, base_url):
  279. if base_url.find("/rerank") == -1:
  280. self.base_url = urljoin(base_url, "/rerank")
  281. else:
  282. self.base_url = base_url
  283. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  284. self.model_name = model_name.split("___")[0]
  285. def similarity(self, query: str, texts: list):
  286. # noway to config Ragflow , use fix setting
  287. texts = [truncate(t, 500) for t in texts]
  288. data = {
  289. "model": self.model_name,
  290. "query": query,
  291. "documents": texts,
  292. "top_n": len(texts),
  293. }
  294. token_count = 0
  295. for t in texts:
  296. token_count += num_tokens_from_string(t)
  297. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  298. rank = np.zeros(len(texts), dtype=float)
  299. try:
  300. for d in res["results"]:
  301. rank[d["index"]] = d["relevance_score"]
  302. except Exception as _e:
  303. log_exception(_e, res)
  304. # Normalize the rank values to the range 0 to 1
  305. min_rank = np.min(rank)
  306. max_rank = np.max(rank)
  307. # Avoid division by zero if all ranks are identical
  308. if max_rank - min_rank != 0:
  309. rank = (rank - min_rank) / (max_rank - min_rank)
  310. else:
  311. rank = np.zeros_like(rank)
  312. return rank, token_count
  313. class CoHereRerank(Base):
  314. _FACTORY_NAME = ["Cohere", "VLLM"]
  315. def __init__(self, key, model_name, base_url=None):
  316. from cohere import Client
  317. self.client = Client(api_key=key, base_url=base_url)
  318. self.model_name = model_name.split("___")[0]
  319. def similarity(self, query: str, texts: list):
  320. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  321. res = self.client.rerank(
  322. model=self.model_name,
  323. query=query,
  324. documents=texts,
  325. top_n=len(texts),
  326. return_documents=False,
  327. )
  328. rank = np.zeros(len(texts), dtype=float)
  329. try:
  330. for d in res.results:
  331. rank[d.index] = d.relevance_score
  332. except Exception as _e:
  333. log_exception(_e, res)
  334. return rank, token_count
  335. class TogetherAIRerank(Base):
  336. _FACTORY_NAME = "TogetherAI"
  337. def __init__(self, key, model_name, base_url):
  338. pass
  339. def similarity(self, query: str, texts: list):
  340. raise NotImplementedError("The api has not been implement")
  341. class SILICONFLOWRerank(Base):
  342. _FACTORY_NAME = "SILICONFLOW"
  343. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
  344. if not base_url:
  345. base_url = "https://api.siliconflow.cn/v1/rerank"
  346. self.model_name = model_name
  347. self.base_url = base_url
  348. self.headers = {
  349. "accept": "application/json",
  350. "content-type": "application/json",
  351. "authorization": f"Bearer {key}",
  352. }
  353. def similarity(self, query: str, texts: list):
  354. payload = {
  355. "model": self.model_name,
  356. "query": query,
  357. "documents": texts,
  358. "top_n": len(texts),
  359. "return_documents": False,
  360. "max_chunks_per_doc": 1024,
  361. "overlap_tokens": 80,
  362. }
  363. response = requests.post(self.base_url, json=payload, headers=self.headers).json()
  364. rank = np.zeros(len(texts), dtype=float)
  365. try:
  366. for d in response["results"]:
  367. rank[d["index"]] = d["relevance_score"]
  368. except Exception as _e:
  369. log_exception(_e, response)
  370. return (
  371. rank,
  372. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  373. )
  374. class BaiduYiyanRerank(Base):
  375. _FACTORY_NAME = "BaiduYiyan"
  376. def __init__(self, key, model_name, base_url=None):
  377. from qianfan.resources import Reranker
  378. key = json.loads(key)
  379. ak = key.get("yiyan_ak", "")
  380. sk = key.get("yiyan_sk", "")
  381. self.client = Reranker(ak=ak, sk=sk)
  382. self.model_name = model_name
  383. def similarity(self, query: str, texts: list):
  384. res = self.client.do(
  385. model=self.model_name,
  386. query=query,
  387. documents=texts,
  388. top_n=len(texts),
  389. ).body
  390. rank = np.zeros(len(texts), dtype=float)
  391. try:
  392. for d in res["results"]:
  393. rank[d["index"]] = d["relevance_score"]
  394. except Exception as _e:
  395. log_exception(_e, res)
  396. return rank, self.total_token_count(res)
  397. class VoyageRerank(Base):
  398. _FACTORY_NAME = "Voyage AI"
  399. def __init__(self, key, model_name, base_url=None):
  400. import voyageai
  401. self.client = voyageai.Client(api_key=key)
  402. self.model_name = model_name
  403. def similarity(self, query: str, texts: list):
  404. rank = np.zeros(len(texts), dtype=float)
  405. if not texts:
  406. return rank, 0
  407. res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
  408. try:
  409. for r in res.results:
  410. rank[r.index] = r.relevance_score
  411. except Exception as _e:
  412. log_exception(_e, res)
  413. return rank, res.total_tokens
  414. class QWenRerank(Base):
  415. _FACTORY_NAME = "Tongyi-Qianwen"
  416. def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
  417. import dashscope
  418. self.api_key = key
  419. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  420. def similarity(self, query: str, texts: list):
  421. from http import HTTPStatus
  422. import dashscope
  423. resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
  424. rank = np.zeros(len(texts), dtype=float)
  425. if resp.status_code == HTTPStatus.OK:
  426. try:
  427. for r in resp.output.results:
  428. rank[r.index] = r.relevance_score
  429. except Exception as _e:
  430. log_exception(_e, resp)
  431. return rank, resp.usage.total_tokens
  432. else:
  433. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  434. class HuggingfaceRerank(DefaultRerank):
  435. _FACTORY_NAME = "HuggingFace"
  436. @staticmethod
  437. def post(query: str, texts: list, url="127.0.0.1"):
  438. exc = None
  439. scores = [0 for _ in range(len(texts))]
  440. batch_size = 8
  441. for i in range(0, len(texts), batch_size):
  442. try:
  443. res = requests.post(
  444. f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
  445. )
  446. for o in res.json():
  447. scores[o["index"] + i] = o["score"]
  448. except Exception as e:
  449. exc = e
  450. if exc:
  451. raise exc
  452. return np.array(scores)
  453. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  454. self.model_name = model_name.split("___")[0]
  455. self.base_url = base_url
  456. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  457. if not texts:
  458. return np.array([]), 0
  459. token_count = 0
  460. for t in texts:
  461. token_count += num_tokens_from_string(t)
  462. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  463. class GPUStackRerank(Base):
  464. _FACTORY_NAME = "GPUStack"
  465. def __init__(self, key, model_name, base_url):
  466. if not base_url:
  467. raise ValueError("url cannot be None")
  468. self.model_name = model_name
  469. self.base_url = str(URL(base_url) / "v1" / "rerank")
  470. self.headers = {
  471. "accept": "application/json",
  472. "content-type": "application/json",
  473. "authorization": f"Bearer {key}",
  474. }
  475. def similarity(self, query: str, texts: list):
  476. payload = {
  477. "model": self.model_name,
  478. "query": query,
  479. "documents": texts,
  480. "top_n": len(texts),
  481. }
  482. try:
  483. response = requests.post(self.base_url, json=payload, headers=self.headers)
  484. response.raise_for_status()
  485. response_json = response.json()
  486. rank = np.zeros(len(texts), dtype=float)
  487. token_count = 0
  488. for t in texts:
  489. token_count += num_tokens_from_string(t)
  490. try:
  491. for result in response_json["results"]:
  492. rank[result["index"]] = result["relevance_score"]
  493. except Exception as _e:
  494. log_exception(_e, response)
  495. return (
  496. rank,
  497. token_count,
  498. )
  499. except httpx.HTTPStatusError as e:
  500. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  501. class NovitaRerank(JinaRerank):
  502. _FACTORY_NAME = "NovitaAI"
  503. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  504. if not base_url:
  505. base_url = "https://api.novita.ai/v3/openai/rerank"
  506. super().__init__(key, model_name, base_url)
  507. class GiteeRerank(JinaRerank):
  508. _FACTORY_NAME = "GiteeAI"
  509. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
  510. if not base_url:
  511. base_url = "https://ai.gitee.com/v1/rerank"
  512. super().__init__(key, model_name, base_url)