Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

rerank_model.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import re
  19. import threading
  20. from abc import ABC
  21. from collections.abc import Iterable
  22. from urllib.parse import urljoin
  23. import httpx
  24. import numpy as np
  25. import requests
  26. from huggingface_hub import snapshot_download
  27. from yarl import URL
  28. from api import settings
  29. from api.utils.file_utils import get_home_cache_dir
  30. from api.utils.log_utils import log_exception
  31. from rag.utils import num_tokens_from_string, truncate
  32. def sigmoid(x):
  33. return 1 / (1 + np.exp(-x))
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def similarity(self, query: str, texts: list):
  38. raise NotImplementedError("Please implement encode method!")
  39. def total_token_count(self, resp):
  40. try:
  41. return resp.usage.total_tokens
  42. except Exception:
  43. pass
  44. try:
  45. return resp["usage"]["total_tokens"]
  46. except Exception:
  47. pass
  48. return 0
  49. class DefaultRerank(Base):
  50. _FACTORY_NAME = "BAAI"
  51. _model = None
  52. _model_lock = threading.Lock()
  53. def __init__(self, key, model_name, **kwargs):
  54. """
  55. If you have trouble downloading HuggingFace models, -_^ this might help!!
  56. For Linux:
  57. export HF_ENDPOINT=https://hf-mirror.com
  58. For Windows:
  59. Good luck
  60. ^_-
  61. """
  62. if not settings.LIGHTEN and not DefaultRerank._model:
  63. import torch
  64. from FlagEmbedding import FlagReranker
  65. with DefaultRerank._model_lock:
  66. if not DefaultRerank._model:
  67. try:
  68. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  69. except Exception:
  70. model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
  71. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  72. self._model = DefaultRerank._model
  73. self._dynamic_batch_size = 8
  74. self._min_batch_size = 1
  75. def torch_empty_cache(self):
  76. try:
  77. import torch
  78. torch.cuda.empty_cache()
  79. except Exception as e:
  80. print(f"Error emptying cache: {e}")
  81. def _process_batch(self, pairs, max_batch_size=None):
  82. """template method for subclass call"""
  83. old_dynamic_batch_size = self._dynamic_batch_size
  84. if max_batch_size is not None:
  85. self._dynamic_batch_size = max_batch_size
  86. res = []
  87. i = 0
  88. while i < len(pairs):
  89. current_batch = self._dynamic_batch_size
  90. max_retries = 5
  91. retry_count = 0
  92. while retry_count < max_retries:
  93. try:
  94. # call subclass implemented batch processing calculation
  95. batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
  96. res.extend(batch_scores)
  97. i += current_batch
  98. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  99. break
  100. except RuntimeError as e:
  101. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  102. current_batch = max(current_batch // 2, self._min_batch_size)
  103. self.torch_empty_cache()
  104. retry_count += 1
  105. else:
  106. raise
  107. if retry_count >= max_retries:
  108. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  109. self.torch_empty_cache()
  110. self._dynamic_batch_size = old_dynamic_batch_size
  111. return np.array(res)
  112. def _compute_batch_scores(self, batch_pairs, max_length=None):
  113. if max_length is None:
  114. scores = self._model.compute_score(batch_pairs)
  115. else:
  116. scores = self._model.compute_score(batch_pairs, max_length=max_length)
  117. scores = sigmoid(np.array(scores)).tolist()
  118. if not isinstance(scores, Iterable):
  119. scores = [scores]
  120. return scores
  121. def similarity(self, query: str, texts: list):
  122. pairs = [(query, truncate(t, 2048)) for t in texts]
  123. token_count = 0
  124. for _, t in pairs:
  125. token_count += num_tokens_from_string(t)
  126. batch_size = 4096
  127. res = self._process_batch(pairs, max_batch_size=batch_size)
  128. return np.array(res), token_count
  129. class JinaRerank(Base):
  130. _FACTORY_NAME = "Jina"
  131. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
  132. self.base_url = "https://api.jina.ai/v1/rerank"
  133. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  134. self.model_name = model_name
  135. def similarity(self, query: str, texts: list):
  136. texts = [truncate(t, 8196) for t in texts]
  137. data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
  138. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  139. rank = np.zeros(len(texts), dtype=float)
  140. try:
  141. for d in res["results"]:
  142. rank[d["index"]] = d["relevance_score"]
  143. except Exception as _e:
  144. log_exception(_e, res)
  145. return rank, self.total_token_count(res)
  146. class YoudaoRerank(DefaultRerank):
  147. _FACTORY_NAME = "Youdao"
  148. _model = None
  149. _model_lock = threading.Lock()
  150. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  151. if not settings.LIGHTEN and not YoudaoRerank._model:
  152. from BCEmbedding import RerankerModel
  153. with YoudaoRerank._model_lock:
  154. if not YoudaoRerank._model:
  155. try:
  156. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  157. except Exception:
  158. YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
  159. self._model = YoudaoRerank._model
  160. self._dynamic_batch_size = 8
  161. self._min_batch_size = 1
  162. def similarity(self, query: str, texts: list):
  163. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  164. token_count = 0
  165. for _, t in pairs:
  166. token_count += num_tokens_from_string(t)
  167. batch_size = 8
  168. res = self._process_batch(pairs, max_batch_size=batch_size)
  169. return np.array(res), token_count
  170. class XInferenceRerank(Base):
  171. _FACTORY_NAME = "Xinference"
  172. def __init__(self, key="x", model_name="", base_url=""):
  173. if base_url.find("/v1") == -1:
  174. base_url = urljoin(base_url, "/v1/rerank")
  175. if base_url.find("/rerank") == -1:
  176. base_url = urljoin(base_url, "/v1/rerank")
  177. self.model_name = model_name
  178. self.base_url = base_url
  179. self.headers = {"Content-Type": "application/json", "accept": "application/json"}
  180. if key and key != "x":
  181. self.headers["Authorization"] = f"Bearer {key}"
  182. def similarity(self, query: str, texts: list):
  183. if len(texts) == 0:
  184. return np.array([]), 0
  185. pairs = [(query, truncate(t, 4096)) for t in texts]
  186. token_count = 0
  187. for _, t in pairs:
  188. token_count += num_tokens_from_string(t)
  189. data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
  190. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  191. rank = np.zeros(len(texts), dtype=float)
  192. try:
  193. for d in res["results"]:
  194. rank[d["index"]] = d["relevance_score"]
  195. except Exception as _e:
  196. log_exception(_e, res)
  197. return rank, token_count
  198. class LocalAIRerank(Base):
  199. _FACTORY_NAME = "LocalAI"
  200. def __init__(self, key, model_name, base_url):
  201. if base_url.find("/rerank") == -1:
  202. self.base_url = urljoin(base_url, "/rerank")
  203. else:
  204. self.base_url = base_url
  205. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  206. self.model_name = model_name.split("___")[0]
  207. def similarity(self, query: str, texts: list):
  208. # noway to config Ragflow , use fix setting
  209. texts = [truncate(t, 500) for t in texts]
  210. data = {
  211. "model": self.model_name,
  212. "query": query,
  213. "documents": texts,
  214. "top_n": len(texts),
  215. }
  216. token_count = 0
  217. for t in texts:
  218. token_count += num_tokens_from_string(t)
  219. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  220. rank = np.zeros(len(texts), dtype=float)
  221. try:
  222. for d in res["results"]:
  223. rank[d["index"]] = d["relevance_score"]
  224. except Exception as _e:
  225. log_exception(_e, res)
  226. # Normalize the rank values to the range 0 to 1
  227. min_rank = np.min(rank)
  228. max_rank = np.max(rank)
  229. # Avoid division by zero if all ranks are identical
  230. if max_rank - min_rank != 0:
  231. rank = (rank - min_rank) / (max_rank - min_rank)
  232. else:
  233. rank = np.zeros_like(rank)
  234. return rank, token_count
  235. class NvidiaRerank(Base):
  236. _FACTORY_NAME = "NVIDIA"
  237. def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
  238. if not base_url:
  239. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  240. self.model_name = model_name
  241. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  242. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
  243. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  244. self.base_url = urljoin(base_url, "reranking")
  245. self.model_name = "nv-rerank-qa-mistral-4b:1"
  246. self.headers = {
  247. "accept": "application/json",
  248. "Content-Type": "application/json",
  249. "Authorization": f"Bearer {key}",
  250. }
  251. def similarity(self, query: str, texts: list):
  252. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  253. data = {
  254. "model": self.model_name,
  255. "query": {"text": query},
  256. "passages": [{"text": text} for text in texts],
  257. "truncate": "END",
  258. "top_n": len(texts),
  259. }
  260. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  261. rank = np.zeros(len(texts), dtype=float)
  262. try:
  263. for d in res["rankings"]:
  264. rank[d["index"]] = d["logit"]
  265. except Exception as _e:
  266. log_exception(_e, res)
  267. return rank, token_count
  268. class LmStudioRerank(Base):
  269. _FACTORY_NAME = "LM-Studio"
  270. def __init__(self, key, model_name, base_url):
  271. pass
  272. def similarity(self, query: str, texts: list):
  273. raise NotImplementedError("The LmStudioRerank has not been implement")
  274. class OpenAI_APIRerank(Base):
  275. _FACTORY_NAME = "OpenAI-API-Compatible"
  276. def __init__(self, key, model_name, base_url):
  277. if base_url.find("/rerank") == -1:
  278. self.base_url = urljoin(base_url, "/rerank")
  279. else:
  280. self.base_url = base_url
  281. self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
  282. self.model_name = model_name.split("___")[0]
  283. def similarity(self, query: str, texts: list):
  284. # noway to config Ragflow , use fix setting
  285. texts = [truncate(t, 500) for t in texts]
  286. data = {
  287. "model": self.model_name,
  288. "query": query,
  289. "documents": texts,
  290. "top_n": len(texts),
  291. }
  292. token_count = 0
  293. for t in texts:
  294. token_count += num_tokens_from_string(t)
  295. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  296. rank = np.zeros(len(texts), dtype=float)
  297. try:
  298. for d in res["results"]:
  299. rank[d["index"]] = d["relevance_score"]
  300. except Exception as _e:
  301. log_exception(_e, res)
  302. # Normalize the rank values to the range 0 to 1
  303. min_rank = np.min(rank)
  304. max_rank = np.max(rank)
  305. # Avoid division by zero if all ranks are identical
  306. if max_rank - min_rank != 0:
  307. rank = (rank - min_rank) / (max_rank - min_rank)
  308. else:
  309. rank = np.zeros_like(rank)
  310. return rank, token_count
  311. class CoHereRerank(Base):
  312. _FACTORY_NAME = ["Cohere", "VLLM"]
  313. def __init__(self, key, model_name, base_url=None):
  314. from cohere import Client
  315. self.client = Client(api_key=key, base_url=base_url)
  316. self.model_name = model_name.split("___")[0]
  317. def similarity(self, query: str, texts: list):
  318. token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
  319. res = self.client.rerank(
  320. model=self.model_name,
  321. query=query,
  322. documents=texts,
  323. top_n=len(texts),
  324. return_documents=False,
  325. )
  326. rank = np.zeros(len(texts), dtype=float)
  327. try:
  328. for d in res.results:
  329. rank[d.index] = d.relevance_score
  330. except Exception as _e:
  331. log_exception(_e, res)
  332. return rank, token_count
  333. class TogetherAIRerank(Base):
  334. _FACTORY_NAME = "TogetherAI"
  335. def __init__(self, key, model_name, base_url):
  336. pass
  337. def similarity(self, query: str, texts: list):
  338. raise NotImplementedError("The api has not been implement")
  339. class SILICONFLOWRerank(Base):
  340. _FACTORY_NAME = "SILICONFLOW"
  341. def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
  342. if not base_url:
  343. base_url = "https://api.siliconflow.cn/v1/rerank"
  344. self.model_name = model_name
  345. self.base_url = base_url
  346. self.headers = {
  347. "accept": "application/json",
  348. "content-type": "application/json",
  349. "authorization": f"Bearer {key}",
  350. }
  351. def similarity(self, query: str, texts: list):
  352. payload = {
  353. "model": self.model_name,
  354. "query": query,
  355. "documents": texts,
  356. "top_n": len(texts),
  357. "return_documents": False,
  358. "max_chunks_per_doc": 1024,
  359. "overlap_tokens": 80,
  360. }
  361. response = requests.post(self.base_url, json=payload, headers=self.headers).json()
  362. rank = np.zeros(len(texts), dtype=float)
  363. try:
  364. for d in response["results"]:
  365. rank[d["index"]] = d["relevance_score"]
  366. except Exception as _e:
  367. log_exception(_e, response)
  368. return (
  369. rank,
  370. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  371. )
  372. class BaiduYiyanRerank(Base):
  373. _FACTORY_NAME = "BaiduYiyan"
  374. def __init__(self, key, model_name, base_url=None):
  375. from qianfan.resources import Reranker
  376. key = json.loads(key)
  377. ak = key.get("yiyan_ak", "")
  378. sk = key.get("yiyan_sk", "")
  379. self.client = Reranker(ak=ak, sk=sk)
  380. self.model_name = model_name
  381. def similarity(self, query: str, texts: list):
  382. res = self.client.do(
  383. model=self.model_name,
  384. query=query,
  385. documents=texts,
  386. top_n=len(texts),
  387. ).body
  388. rank = np.zeros(len(texts), dtype=float)
  389. try:
  390. for d in res["results"]:
  391. rank[d["index"]] = d["relevance_score"]
  392. except Exception as _e:
  393. log_exception(_e, res)
  394. return rank, self.total_token_count(res)
  395. class VoyageRerank(Base):
  396. _FACTORY_NAME = "Voyage AI"
  397. def __init__(self, key, model_name, base_url=None):
  398. import voyageai
  399. self.client = voyageai.Client(api_key=key)
  400. self.model_name = model_name
  401. def similarity(self, query: str, texts: list):
  402. rank = np.zeros(len(texts), dtype=float)
  403. if not texts:
  404. return rank, 0
  405. res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
  406. try:
  407. for r in res.results:
  408. rank[r.index] = r.relevance_score
  409. except Exception as _e:
  410. log_exception(_e, res)
  411. return rank, res.total_tokens
  412. class QWenRerank(Base):
  413. _FACTORY_NAME = "Tongyi-Qianwen"
  414. def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
  415. import dashscope
  416. self.api_key = key
  417. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  418. def similarity(self, query: str, texts: list):
  419. from http import HTTPStatus
  420. import dashscope
  421. resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
  422. rank = np.zeros(len(texts), dtype=float)
  423. if resp.status_code == HTTPStatus.OK:
  424. try:
  425. for r in resp.output.results:
  426. rank[r.index] = r.relevance_score
  427. except Exception as _e:
  428. log_exception(_e, resp)
  429. return rank, resp.usage.total_tokens
  430. else:
  431. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  432. class HuggingfaceRerank(DefaultRerank):
  433. _FACTORY_NAME = "HuggingFace"
  434. @staticmethod
  435. def post(query: str, texts: list, url="127.0.0.1"):
  436. exc = None
  437. scores = [0 for _ in range(len(texts))]
  438. batch_size = 8
  439. for i in range(0, len(texts), batch_size):
  440. try:
  441. res = requests.post(
  442. f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
  443. )
  444. for o in res.json():
  445. scores[o["index"] + i] = o["score"]
  446. except Exception as e:
  447. exc = e
  448. if exc:
  449. raise exc
  450. return np.array(scores)
  451. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  452. self.model_name = model_name.split("___")[0]
  453. self.base_url = base_url
  454. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  455. if not texts:
  456. return np.array([]), 0
  457. token_count = 0
  458. for t in texts:
  459. token_count += num_tokens_from_string(t)
  460. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  461. class GPUStackRerank(Base):
  462. _FACTORY_NAME = "GPUStack"
  463. def __init__(self, key, model_name, base_url):
  464. if not base_url:
  465. raise ValueError("url cannot be None")
  466. self.model_name = model_name
  467. self.base_url = str(URL(base_url) / "v1" / "rerank")
  468. self.headers = {
  469. "accept": "application/json",
  470. "content-type": "application/json",
  471. "authorization": f"Bearer {key}",
  472. }
  473. def similarity(self, query: str, texts: list):
  474. payload = {
  475. "model": self.model_name,
  476. "query": query,
  477. "documents": texts,
  478. "top_n": len(texts),
  479. }
  480. try:
  481. response = requests.post(self.base_url, json=payload, headers=self.headers)
  482. response.raise_for_status()
  483. response_json = response.json()
  484. rank = np.zeros(len(texts), dtype=float)
  485. token_count = 0
  486. for t in texts:
  487. token_count += num_tokens_from_string(t)
  488. try:
  489. for result in response_json["results"]:
  490. rank[result["index"]] = result["relevance_score"]
  491. except Exception as _e:
  492. log_exception(_e, response)
  493. return (
  494. rank,
  495. token_count,
  496. )
  497. except httpx.HTTPStatusError as e:
  498. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  499. class NovitaRerank(JinaRerank):
  500. _FACTORY_NAME = "NovitaAI"
  501. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  502. if not base_url:
  503. base_url = "https://api.novita.ai/v3/openai/rerank"
  504. super().__init__(key, model_name, base_url)
  505. class GiteeRerank(JinaRerank):
  506. _FACTORY_NAME = "GiteeAI"
  507. def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
  508. if not base_url:
  509. base_url = "https://ai.gitee.com/v1/rerank"
  510. super().__init__(key, model_name, base_url)