Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

rerank_model.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import re
  19. import threading
  20. from abc import ABC
  21. from collections.abc import Iterable
  22. from urllib.parse import urljoin
  23. import httpx
  24. import numpy as np
  25. import requests
  26. from huggingface_hub import snapshot_download
  27. from yarl import URL
  28. from api import settings
  29. from api.utils.file_utils import get_home_cache_dir
  30. from api.utils.log_utils import log_exception
  31. from rag.utils import num_tokens_from_string, truncate
  32. class Base(ABC):
  33. def __init__(self, key, model_name):
  34. pass
  35. def similarity(self, query: str, texts: list):
  36. raise NotImplementedError("Please implement encode method!")
  37. def total_token_count(self, resp):
  38. try:
  39. return resp.usage.total_tokens
  40. except Exception:
  41. pass
  42. try:
  43. return resp["usage"]["total_tokens"]
  44. except Exception:
  45. pass
  46. return 0
  47. class DefaultRerank(Base):
  48. _FACTORY_NAME = "BAAI"
  49. _model = None
  50. _model_lock = threading.Lock()
  51. def __init__(self, key, model_name, **kwargs):
  52. """
  53. If you have trouble downloading HuggingFace models, -_^ this might help!!
  54. For Linux:
  55. export HF_ENDPOINT=https://hf-mirror.com
  56. For Windows:
  57. Good luck
  58. ^_-
  59. """
  60. if not settings.LIGHTEN and not DefaultRerank._model:
  61. import torch
  62. from FlagEmbedding import FlagReranker
  63. with DefaultRerank._model_lock:
  64. if not DefaultRerank._model:
  65. try:
  66. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  67. except Exception:
  68. model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
  69. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  70. self._model = DefaultRerank._model
  71. self._dynamic_batch_size = 8
  72. self._min_batch_size = 1
  73. def torch_empty_cache(self):
  74. try:
  75. import torch
  76. torch.cuda.empty_cache()
  77. except Exception as e:
  78. print(f"Error emptying cache: {e}")
  79. def _process_batch(self, pairs, max_batch_size=None):
  80. """template method for subclass call"""
  81. old_dynamic_batch_size = self._dynamic_batch_size
  82. if max_batch_size is not None:
  83. self._dynamic_batch_size = max_batch_size
  84. res = np.array([], dtype=float)
  85. i = 0
  86. while i < len(pairs):
  87. cur_i = i
  88. current_batch = self._dynamic_batch_size
  89. max_retries = 5
  90. retry_count = 0
  91. while retry_count < max_retries:
  92. try:
  93. # call subclass implemented batch processing calculation
  94. batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
  95. res = np.append(res, batch_scores)
  96. i += current_batch
  97. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  98. break
  99. except RuntimeError as e:
  100. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  101. current_batch = max(current_batch // 2, self._min_batch_size)
  102. self.torch_empty_cache()
  103. i = cur_i # reset i to the start of the current batch
  104. retry_count += 1
  105. else:
  106. raise
  107. if retry_count >= max_retries:
  108. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  109. self.torch_empty_cache()
  110. self._dynamic_batch_size = old_dynamic_batch_size
  111. return np.array(res)
  112. def _compute_batch_scores(self, batch_pairs, max_length=None):
  113. if max_length is None:
  114. scores = self._model.compute_score(batch_pairs, normalize=True)
  115. else:
  116. scores = self._model.compute_score(batch_pairs, max_length=max_length, normalize=True)
  117. if not isinstance(scores, Iterable):
  118. scores = [scores]
  119. return scores
  120. def similarity(self, query: str, texts: list):
  121. pairs = [(query, truncate(t, 2048)) for t in texts]
  122. token_count = 0
  123. for _, t in pairs:
  124. token_count += num_tokens_from_string(t)
  125. batch_size = 4096
  126. res = self._process_batch(pairs, max_batch_size=batch_size)
  127. return np.array(res), token_count
  128. class JinaRerank(Base):
  129. _FACTORY_NAME = "Jina"
  130. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
  131. self.base_url = "https://api.jina.ai/v1/rerank"
  132. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  133. self.model_name = model_name
  134. def similarity(self, query: str, texts: list):
  135. texts = [truncate(t, 8196) for t in texts]
  136. data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
  137. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  138. rank = np.zeros(len(texts), dtype=float)
  139. try:
  140. for d in res["results"]:
  141. rank[d["index"]] = d["relevance_score"]
  142. except Exception as _e:
  143. log_exception(_e, res)
  144. return rank, self.total_token_count(res)
  145. class YoudaoRerank(DefaultRerank):
  146. _FACTORY_NAME = "Youdao"
  147. _model = None
  148. _model_lock = threading.Lock()
  149. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  150. if not settings.LIGHTEN and not YoudaoRerank._model:
  151. from BCEmbedding import RerankerModel
  152. with YoudaoRerank._model_lock:
  153. if not YoudaoRerank._model:
  154. try:
  155. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  156. except Exception:
  157. YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
  158. self._model = YoudaoRerank._model
  159. self._dynamic_batch_size = 8
  160. self._min_batch_size = 1
  161. def similarity(self, query: str, texts: list):
  162. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  163. token_count = 0
  164. for _, t in pairs:
  165. token_count += num_tokens_from_string(t)
  166. batch_size = 8
  167. res = self._process_batch(pairs, max_batch_size=batch_size)
  168. return np.array(res), token_count
  169. class XInferenceRerank(Base):
  170. _FACTORY_NAME = "Xinference"
  171. def __init__(self, key="x", model_name="", base_url=""):
  172. if base_url.find("/v1") == -1:
  173. base_url = urljoin(base_url, "/v1/rerank")
  174. if base_url.find("/rerank") == -1:
  175. base_url = urljoin(base_url, "/v1/rerank")
  176. self.model_name = model_name
  177. self.base_url = base_url
  178. self.headers = {"Content-Type": "application/json", "accept": "application/json"}
  179. if key and key != "x":
  180. self.headers["Authorization"] = f"Bearer {key}"
  181. def similarity(self, query: str, texts: list):
  182. if len(texts) == 0:
  183. return np.array([]), 0
  184. pairs = [(query, truncate(t, 4096)) for t in texts]
  185. token_count = 0
  186. for _, t in pairs:
  187. token_count += num_tokens_from_string(t)
  188. data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
  189. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  190. rank = np.zeros(len(texts), dtype=float)
  191. try:
  192. for d in res["results"]:
  193. rank[d["index"]] = d["relevance_score"]
  194. except Exception as _e:
  195. log_exception(_e, res)
  196. return rank, token_count
  197. class LocalAIRerank(Base):
  198. _FACTORY_NAME = "LocalAI"
  199. def __init__(self, key, model_name, base_url):
  200. if base_url.find("/rerank") == -1:
  201. self.base_url = urljoin(base_url, "/rerank")
  202. else:
  203. self.base_url = base_url
  204. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  205. self.model_name = model_name.split("___")[0]
  206. def similarity(self, query: str, texts: list):
  207. # noway to config Ragflow , use fix setting
  208. texts = [truncate(t, 500) for t in texts]
  209. data = {
  210. "model": self.model_name,
  211. "query": query,
  212. "documents": texts,
  213. "top_n": len(texts),
  214. }
  215. token_count = 0
  216. for t in texts:
  217. token_count += num_tokens_from_string(t)
  218. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  219. rank = np.zeros(len(texts), dtype=float)
  220. try:
  221. for d in res["results"]:
  222. rank[d["index"]] = d["relevance_score"]
  223. except Exception as _e:
  224. log_exception(_e, res)
  225. # Normalize the rank values to the range 0 to 1
  226. min_rank = np.min(rank)
  227. max_rank = np.max(rank)
  228. # Avoid division by zero if all ranks are identical
  229. if max_rank - min_rank != 0:
  230. rank = (rank - min_rank) / (max_rank - min_rank)
  231. else:
  232. rank = np.zeros_like(rank)
  233. return rank, token_count
  234. class NvidiaRerank(Base):
  235. _FACTORY_NAME = "NVIDIA"
  236. def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
  237. if not base_url:
  238. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  239. self.model_name = model_name
  240. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  241. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
  242. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  243. self.base_url = urljoin(base_url, "reranking")
  244. self.model_name = "nv-rerank-qa-mistral-4b:1"
  245. self.headers = {
  246. "accept": "application/json",
  247. "Content-Type": "application/json",
  248. "Authorization": f"Bearer {key}",
  249. }
  250. def similarity(self, query: str, texts: list):
  251. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  252. data = {
  253. "model": self.model_name,
  254. "query": {"text": query},
  255. "passages": [{"text": text} for text in texts],
  256. "truncate": "END",
  257. "top_n": len(texts),
  258. }
  259. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  260. rank = np.zeros(len(texts), dtype=float)
  261. try:
  262. for d in res["rankings"]:
  263. rank[d["index"]] = d["logit"]
  264. except Exception as _e:
  265. log_exception(_e, res)
  266. return rank, token_count
  267. class LmStudioRerank(Base):
  268. _FACTORY_NAME = "LM-Studio"
  269. def __init__(self, key, model_name, base_url):
  270. pass
  271. def similarity(self, query: str, texts: list):
  272. raise NotImplementedError("The LmStudioRerank has not been implement")
  273. class OpenAI_APIRerank(Base):
  274. _FACTORY_NAME = "OpenAI-API-Compatible"
  275. def __init__(self, key, model_name, base_url):
  276. if base_url.find("/rerank") == -1:
  277. self.base_url = urljoin(base_url, "/rerank")
  278. else:
  279. self.base_url = base_url
  280. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  281. self.model_name = model_name.split("___")[0]
  282. def similarity(self, query: str, texts: list):
  283. # noway to config Ragflow , use fix setting
  284. texts = [truncate(t, 500) for t in texts]
  285. data = {
  286. "model": self.model_name,
  287. "query": query,
  288. "documents": texts,
  289. "top_n": len(texts),
  290. }
  291. token_count = 0
  292. for t in texts:
  293. token_count += num_tokens_from_string(t)
  294. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  295. rank = np.zeros(len(texts), dtype=float)
  296. try:
  297. for d in res["results"]:
  298. rank[d["index"]] = d["relevance_score"]
  299. except Exception as _e:
  300. log_exception(_e, res)
  301. # Normalize the rank values to the range 0 to 1
  302. min_rank = np.min(rank)
  303. max_rank = np.max(rank)
  304. # Avoid division by zero if all ranks are identical
  305. if max_rank - min_rank != 0:
  306. rank = (rank - min_rank) / (max_rank - min_rank)
  307. else:
  308. rank = np.zeros_like(rank)
  309. return rank, token_count
  310. class CoHereRerank(Base):
  311. _FACTORY_NAME = ["Cohere", "VLLM"]
  312. def __init__(self, key, model_name, base_url=None):
  313. from cohere import Client
  314. self.client = Client(api_key=key, base_url=base_url)
  315. self.model_name = model_name.split("___")[0]
  316. def similarity(self, query: str, texts: list):
  317. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  318. res = self.client.rerank(
  319. model=self.model_name,
  320. query=query,
  321. documents=texts,
  322. top_n=len(texts),
  323. return_documents=False,
  324. )
  325. rank = np.zeros(len(texts), dtype=float)
  326. try:
  327. for d in res.results:
  328. rank[d.index] = d.relevance_score
  329. except Exception as _e:
  330. log_exception(_e, res)
  331. return rank, token_count
  332. class TogetherAIRerank(Base):
  333. _FACTORY_NAME = "TogetherAI"
  334. def __init__(self, key, model_name, base_url):
  335. pass
  336. def similarity(self, query: str, texts: list):
  337. raise NotImplementedError("The api has not been implement")
  338. class SILICONFLOWRerank(Base):
  339. _FACTORY_NAME = "SILICONFLOW"
  340. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
  341. if not base_url:
  342. base_url = "https://api.siliconflow.cn/v1/rerank"
  343. self.model_name = model_name
  344. self.base_url = base_url
  345. self.headers = {
  346. "accept": "application/json",
  347. "content-type": "application/json",
  348. "authorization": f"Bearer {key}",
  349. }
  350. def similarity(self, query: str, texts: list):
  351. payload = {
  352. "model": self.model_name,
  353. "query": query,
  354. "documents": texts,
  355. "top_n": len(texts),
  356. "return_documents": False,
  357. "max_chunks_per_doc": 1024,
  358. "overlap_tokens": 80,
  359. }
  360. response = requests.post(self.base_url, json=payload, headers=self.headers).json()
  361. rank = np.zeros(len(texts), dtype=float)
  362. try:
  363. for d in response["results"]:
  364. rank[d["index"]] = d["relevance_score"]
  365. except Exception as _e:
  366. log_exception(_e, response)
  367. return (
  368. rank,
  369. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  370. )
  371. class BaiduYiyanRerank(Base):
  372. _FACTORY_NAME = "BaiduYiyan"
  373. def __init__(self, key, model_name, base_url=None):
  374. from qianfan.resources import Reranker
  375. key = json.loads(key)
  376. ak = key.get("yiyan_ak", "")
  377. sk = key.get("yiyan_sk", "")
  378. self.client = Reranker(ak=ak, sk=sk)
  379. self.model_name = model_name
  380. def similarity(self, query: str, texts: list):
  381. res = self.client.do(
  382. model=self.model_name,
  383. query=query,
  384. documents=texts,
  385. top_n=len(texts),
  386. ).body
  387. rank = np.zeros(len(texts), dtype=float)
  388. try:
  389. for d in res["results"]:
  390. rank[d["index"]] = d["relevance_score"]
  391. except Exception as _e:
  392. log_exception(_e, res)
  393. return rank, self.total_token_count(res)
  394. class VoyageRerank(Base):
  395. _FACTORY_NAME = "Voyage AI"
  396. def __init__(self, key, model_name, base_url=None):
  397. import voyageai
  398. self.client = voyageai.Client(api_key=key)
  399. self.model_name = model_name
  400. def similarity(self, query: str, texts: list):
  401. rank = np.zeros(len(texts), dtype=float)
  402. if not texts:
  403. return rank, 0
  404. res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
  405. try:
  406. for r in res.results:
  407. rank[r.index] = r.relevance_score
  408. except Exception as _e:
  409. log_exception(_e, res)
  410. return rank, res.total_tokens
  411. class QWenRerank(Base):
  412. _FACTORY_NAME = "Tongyi-Qianwen"
  413. def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
  414. import dashscope
  415. self.api_key = key
  416. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  417. def similarity(self, query: str, texts: list):
  418. from http import HTTPStatus
  419. import dashscope
  420. resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
  421. rank = np.zeros(len(texts), dtype=float)
  422. if resp.status_code == HTTPStatus.OK:
  423. try:
  424. for r in resp.output.results:
  425. rank[r.index] = r.relevance_score
  426. except Exception as _e:
  427. log_exception(_e, resp)
  428. return rank, resp.usage.total_tokens
  429. else:
  430. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  431. class HuggingfaceRerank(DefaultRerank):
  432. _FACTORY_NAME = "HuggingFace"
  433. @staticmethod
  434. def post(query: str, texts: list, url="127.0.0.1"):
  435. exc = None
  436. scores = [0 for _ in range(len(texts))]
  437. batch_size = 8
  438. for i in range(0, len(texts), batch_size):
  439. try:
  440. res = requests.post(
  441. f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
  442. )
  443. for o in res.json():
  444. scores[o["index"] + i] = o["score"]
  445. except Exception as e:
  446. exc = e
  447. if exc:
  448. raise exc
  449. return np.array(scores)
  450. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  451. self.model_name = model_name.split("___")[0]
  452. self.base_url = base_url
  453. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  454. if not texts:
  455. return np.array([]), 0
  456. token_count = 0
  457. for t in texts:
  458. token_count += num_tokens_from_string(t)
  459. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  460. class GPUStackRerank(Base):
  461. _FACTORY_NAME = "GPUStack"
  462. def __init__(self, key, model_name, base_url):
  463. if not base_url:
  464. raise ValueError("url cannot be None")
  465. self.model_name = model_name
  466. self.base_url = str(URL(base_url) / "v1" / "rerank")
  467. self.headers = {
  468. "accept": "application/json",
  469. "content-type": "application/json",
  470. "authorization": f"Bearer {key}",
  471. }
  472. def similarity(self, query: str, texts: list):
  473. payload = {
  474. "model": self.model_name,
  475. "query": query,
  476. "documents": texts,
  477. "top_n": len(texts),
  478. }
  479. try:
  480. response = requests.post(self.base_url, json=payload, headers=self.headers)
  481. response.raise_for_status()
  482. response_json = response.json()
  483. rank = np.zeros(len(texts), dtype=float)
  484. token_count = 0
  485. for t in texts:
  486. token_count += num_tokens_from_string(t)
  487. try:
  488. for result in response_json["results"]:
  489. rank[result["index"]] = result["relevance_score"]
  490. except Exception as _e:
  491. log_exception(_e, response)
  492. return (
  493. rank,
  494. token_count,
  495. )
  496. except httpx.HTTPStatusError as e:
  497. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  498. class NovitaRerank(JinaRerank):
  499. _FACTORY_NAME = "NovitaAI"
  500. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  501. if not base_url:
  502. base_url = "https://api.novita.ai/v3/openai/rerank"
  503. super().__init__(key, model_name, base_url)
  504. class GiteeRerank(JinaRerank):
  505. _FACTORY_NAME = "GiteeAI"
  506. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
  507. if not base_url:
  508. base_url = "https://ai.gitee.com/v1/rerank"
  509. super().__init__(key, model_name, base_url)