Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

rerank_model.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from collections.abc import Iterable
  19. from urllib.parse import urljoin
  20. import requests
  21. import httpx
  22. from huggingface_hub import snapshot_download
  23. import os
  24. from abc import ABC
  25. import numpy as np
  26. from yarl import URL
  27. from api import settings
  28. from api.utils.file_utils import get_home_cache_dir
  29. from rag.utils import num_tokens_from_string, truncate
  30. import json
  31. def sigmoid(x):
  32. return 1 / (1 + np.exp(-x))
  33. class Base(ABC):
  34. def __init__(self, key, model_name):
  35. pass
  36. def similarity(self, query: str, texts: list):
  37. raise NotImplementedError("Please implement encode method!")
  38. def total_token_count(self, resp):
  39. try:
  40. return resp.usage.total_tokens
  41. except Exception:
  42. pass
  43. try:
  44. return resp["usage"]["total_tokens"]
  45. except Exception:
  46. pass
  47. return 0
  48. class DefaultRerank(Base):
  49. _model = None
  50. _model_lock = threading.Lock()
  51. def __init__(self, key, model_name, **kwargs):
  52. """
  53. If you have trouble downloading HuggingFace models, -_^ this might help!!
  54. For Linux:
  55. export HF_ENDPOINT=https://hf-mirror.com
  56. For Windows:
  57. Good luck
  58. ^_-
  59. """
  60. if not settings.LIGHTEN and not DefaultRerank._model:
  61. import torch
  62. from FlagEmbedding import FlagReranker
  63. with DefaultRerank._model_lock:
  64. if not DefaultRerank._model:
  65. try:
  66. DefaultRerank._model = FlagReranker(
  67. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  68. use_fp16=torch.cuda.is_available())
  69. except Exception:
  70. model_dir = snapshot_download(repo_id=model_name,
  71. local_dir=os.path.join(get_home_cache_dir(),
  72. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  73. local_dir_use_symlinks=False)
  74. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  75. self._model = DefaultRerank._model
  76. self._dynamic_batch_size = 8
  77. self._min_batch_size = 1
  78. def torch_empty_cache(self):
  79. try:
  80. import torch
  81. torch.cuda.empty_cache()
  82. except Exception as e:
  83. print(f"Error emptying cache: {e}")
  84. def _process_batch(self, pairs, max_batch_size=None):
  85. """template method for subclass call"""
  86. old_dynamic_batch_size = self._dynamic_batch_size
  87. if max_batch_size is not None:
  88. self._dynamic_batch_size = max_batch_size
  89. res = []
  90. i = 0
  91. while i < len(pairs):
  92. current_batch = self._dynamic_batch_size
  93. max_retries = 5
  94. retry_count = 0
  95. while retry_count < max_retries:
  96. try:
  97. # call subclass implemented batch processing calculation
  98. batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
  99. res.extend(batch_scores)
  100. i += current_batch
  101. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  102. break
  103. except RuntimeError as e:
  104. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  105. current_batch = max(current_batch // 2, self._min_batch_size)
  106. self.torch_empty_cache()
  107. retry_count += 1
  108. else:
  109. raise
  110. if retry_count >= max_retries:
  111. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  112. self.torch_empty_cache()
  113. self._dynamic_batch_size = old_dynamic_batch_size
  114. return np.array(res)
  115. def _compute_batch_scores(self, batch_pairs, max_length=None):
  116. if max_length is None:
  117. scores = self._model.compute_score(batch_pairs)
  118. else:
  119. scores = self._model.compute_score(batch_pairs, max_length=max_length)
  120. scores = sigmoid(np.array(scores)).tolist()
  121. if not isinstance(scores, Iterable):
  122. scores = [scores]
  123. return scores
  124. def similarity(self, query: str, texts: list):
  125. pairs = [(query, truncate(t, 2048)) for t in texts]
  126. token_count = 0
  127. for _, t in pairs:
  128. token_count += num_tokens_from_string(t)
  129. batch_size = 4096
  130. res = self._process_batch(pairs, max_batch_size=batch_size)
  131. return np.array(res), token_count
  132. class JinaRerank(Base):
  133. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
  134. base_url="https://api.jina.ai/v1/rerank"):
  135. self.base_url = "https://api.jina.ai/v1/rerank"
  136. self.headers = {
  137. "Content-Type": "application/json",
  138. "Authorization": f"Bearer {key}"
  139. }
  140. self.model_name = model_name
  141. def similarity(self, query: str, texts: list):
  142. texts = [truncate(t, 8196) for t in texts]
  143. data = {
  144. "model": self.model_name,
  145. "query": query,
  146. "documents": texts,
  147. "top_n": len(texts)
  148. }
  149. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  150. rank = np.zeros(len(texts), dtype=float)
  151. for d in res["results"]:
  152. rank[d["index"]] = d["relevance_score"]
  153. return rank, self.total_token_count(res)
  154. class YoudaoRerank(DefaultRerank):
  155. _model = None
  156. _model_lock = threading.Lock()
  157. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  158. if not settings.LIGHTEN and not YoudaoRerank._model:
  159. from BCEmbedding import RerankerModel
  160. with YoudaoRerank._model_lock:
  161. if not YoudaoRerank._model:
  162. try:
  163. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  164. get_home_cache_dir(),
  165. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  166. except Exception:
  167. YoudaoRerank._model = RerankerModel(
  168. model_name_or_path=model_name.replace(
  169. "maidalun1020", "InfiniFlow"))
  170. self._model = YoudaoRerank._model
  171. self._dynamic_batch_size = 8
  172. self._min_batch_size = 1
  173. def similarity(self, query: str, texts: list):
  174. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  175. token_count = 0
  176. for _, t in pairs:
  177. token_count += num_tokens_from_string(t)
  178. batch_size = 8
  179. res = self._process_batch(pairs, max_batch_size=batch_size)
  180. return np.array(res), token_count
  181. class XInferenceRerank(Base):
  182. def __init__(self, key="x", model_name="", base_url=""):
  183. if base_url.find("/v1") == -1:
  184. base_url = urljoin(base_url, "/v1/rerank")
  185. if base_url.find("/rerank") == -1:
  186. base_url = urljoin(base_url, "/v1/rerank")
  187. self.model_name = model_name
  188. self.base_url = base_url
  189. self.headers = {
  190. "Content-Type": "application/json",
  191. "accept": "application/json"
  192. }
  193. if key and key != "x":
  194. self.headers["Authorization"] = f"Bearer {key}"
  195. def similarity(self, query: str, texts: list):
  196. if len(texts) == 0:
  197. return np.array([]), 0
  198. pairs = [(query, truncate(t, 4096)) for t in texts]
  199. token_count = 0
  200. for _, t in pairs:
  201. token_count += num_tokens_from_string(t)
  202. data = {
  203. "model": self.model_name,
  204. "query": query,
  205. "return_documents": "true",
  206. "return_len": "true",
  207. "documents": texts
  208. }
  209. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  210. rank = np.zeros(len(texts), dtype=float)
  211. for d in res["results"]:
  212. rank[d["index"]] = d["relevance_score"]
  213. return rank, token_count
  214. class LocalAIRerank(Base):
  215. def __init__(self, key, model_name, base_url):
  216. if base_url.find("/rerank") == -1:
  217. self.base_url = urljoin(base_url, "/rerank")
  218. else:
  219. self.base_url = base_url
  220. self.headers = {
  221. "Content-Type": "application/json",
  222. "Authorization": f"Bearer {key}"
  223. }
  224. self.model_name = model_name.split("___")[0]
  225. def similarity(self, query: str, texts: list):
  226. # noway to config Ragflow , use fix setting
  227. texts = [truncate(t, 500) for t in texts]
  228. data = {
  229. "model": self.model_name,
  230. "query": query,
  231. "documents": texts,
  232. "top_n": len(texts),
  233. }
  234. token_count = 0
  235. for t in texts:
  236. token_count += num_tokens_from_string(t)
  237. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  238. rank = np.zeros(len(texts), dtype=float)
  239. if 'results' not in res:
  240. raise ValueError("response not contains results\n" + str(res))
  241. for d in res["results"]:
  242. rank[d["index"]] = d["relevance_score"]
  243. # Normalize the rank values to the range 0 to 1
  244. min_rank = np.min(rank)
  245. max_rank = np.max(rank)
  246. # Avoid division by zero if all ranks are identical
  247. if max_rank - min_rank != 0:
  248. rank = (rank - min_rank) / (max_rank - min_rank)
  249. else:
  250. rank = np.zeros_like(rank)
  251. return rank, token_count
  252. class NvidiaRerank(Base):
  253. def __init__(
  254. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  255. ):
  256. if not base_url:
  257. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  258. self.model_name = model_name
  259. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  260. self.base_url = os.path.join(
  261. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  262. )
  263. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  264. self.base_url = os.path.join(base_url, "reranking")
  265. self.model_name = "nv-rerank-qa-mistral-4b:1"
  266. self.headers = {
  267. "accept": "application/json",
  268. "Content-Type": "application/json",
  269. "Authorization": f"Bearer {key}",
  270. }
  271. def similarity(self, query: str, texts: list):
  272. token_count = num_tokens_from_string(query) + sum(
  273. [num_tokens_from_string(t) for t in texts]
  274. )
  275. data = {
  276. "model": self.model_name,
  277. "query": {"text": query},
  278. "passages": [{"text": text} for text in texts],
  279. "truncate": "END",
  280. "top_n": len(texts),
  281. }
  282. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  283. rank = np.zeros(len(texts), dtype=float)
  284. for d in res["rankings"]:
  285. rank[d["index"]] = d["logit"]
  286. return rank, token_count
  287. class LmStudioRerank(Base):
  288. def __init__(self, key, model_name, base_url):
  289. pass
  290. def similarity(self, query: str, texts: list):
  291. raise NotImplementedError("The LmStudioRerank has not been implement")
  292. class OpenAI_APIRerank(Base):
  293. def __init__(self, key, model_name, base_url):
  294. if base_url.find("/rerank") == -1:
  295. self.base_url = urljoin(base_url, "/rerank")
  296. else:
  297. self.base_url = base_url
  298. self.headers = {
  299. "Content-Type": "application/json",
  300. "Authorization": f"Bearer {key}"
  301. }
  302. self.model_name = model_name.split("___")[0]
  303. def similarity(self, query: str, texts: list):
  304. # noway to config Ragflow , use fix setting
  305. texts = [truncate(t, 500) for t in texts]
  306. data = {
  307. "model": self.model_name,
  308. "query": query,
  309. "documents": texts,
  310. "top_n": len(texts),
  311. }
  312. token_count = 0
  313. for t in texts:
  314. token_count += num_tokens_from_string(t)
  315. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  316. rank = np.zeros(len(texts), dtype=float)
  317. if 'results' not in res:
  318. raise ValueError("response not contains results\n" + str(res))
  319. for d in res["results"]:
  320. rank[d["index"]] = d["relevance_score"]
  321. # Normalize the rank values to the range 0 to 1
  322. min_rank = np.min(rank)
  323. max_rank = np.max(rank)
  324. # Avoid division by zero if all ranks are identical
  325. if max_rank - min_rank != 0:
  326. rank = (rank - min_rank) / (max_rank - min_rank)
  327. else:
  328. rank = np.zeros_like(rank)
  329. return rank, token_count
  330. class CoHereRerank(Base):
  331. def __init__(self, key, model_name, base_url=None):
  332. from cohere import Client
  333. self.client = Client(api_key=key, base_url=base_url)
  334. self.model_name = model_name.split("___")[0]
  335. def similarity(self, query: str, texts: list):
  336. token_count = num_tokens_from_string(query) + sum(
  337. [num_tokens_from_string(t) for t in texts]
  338. )
  339. res = self.client.rerank(
  340. model=self.model_name,
  341. query=query,
  342. documents=texts,
  343. top_n=len(texts),
  344. return_documents=False,
  345. )
  346. rank = np.zeros(len(texts), dtype=float)
  347. for d in res.results:
  348. rank[d.index] = d.relevance_score
  349. return rank, token_count
  350. class TogetherAIRerank(Base):
  351. def __init__(self, key, model_name, base_url):
  352. pass
  353. def similarity(self, query: str, texts: list):
  354. raise NotImplementedError("The api has not been implement")
  355. class SILICONFLOWRerank(Base):
  356. def __init__(
  357. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  358. ):
  359. if not base_url:
  360. base_url = "https://api.siliconflow.cn/v1/rerank"
  361. self.model_name = model_name
  362. self.base_url = base_url
  363. self.headers = {
  364. "accept": "application/json",
  365. "content-type": "application/json",
  366. "authorization": f"Bearer {key}",
  367. }
  368. def similarity(self, query: str, texts: list):
  369. payload = {
  370. "model": self.model_name,
  371. "query": query,
  372. "documents": texts,
  373. "top_n": len(texts),
  374. "return_documents": False,
  375. "max_chunks_per_doc": 1024,
  376. "overlap_tokens": 80,
  377. }
  378. response = requests.post(
  379. self.base_url, json=payload, headers=self.headers
  380. ).json()
  381. rank = np.zeros(len(texts), dtype=float)
  382. if "results" not in response:
  383. return rank, 0
  384. for d in response["results"]:
  385. rank[d["index"]] = d["relevance_score"]
  386. return (
  387. rank,
  388. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  389. )
  390. class BaiduYiyanRerank(Base):
  391. def __init__(self, key, model_name, base_url=None):
  392. from qianfan.resources import Reranker
  393. key = json.loads(key)
  394. ak = key.get("yiyan_ak", "")
  395. sk = key.get("yiyan_sk", "")
  396. self.client = Reranker(ak=ak, sk=sk)
  397. self.model_name = model_name
  398. def similarity(self, query: str, texts: list):
  399. res = self.client.do(
  400. model=self.model_name,
  401. query=query,
  402. documents=texts,
  403. top_n=len(texts),
  404. ).body
  405. rank = np.zeros(len(texts), dtype=float)
  406. for d in res["results"]:
  407. rank[d["index"]] = d["relevance_score"]
  408. return rank, self.total_token_count(res)
  409. class VoyageRerank(Base):
  410. def __init__(self, key, model_name, base_url=None):
  411. import voyageai
  412. self.client = voyageai.Client(api_key=key)
  413. self.model_name = model_name
  414. def similarity(self, query: str, texts: list):
  415. rank = np.zeros(len(texts), dtype=float)
  416. if not texts:
  417. return rank, 0
  418. res = self.client.rerank(
  419. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  420. )
  421. for r in res.results:
  422. rank[r.index] = r.relevance_score
  423. return rank, res.total_tokens
  424. class QWenRerank(Base):
  425. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  426. import dashscope
  427. self.api_key = key
  428. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  429. def similarity(self, query: str, texts: list):
  430. import dashscope
  431. from http import HTTPStatus
  432. resp = dashscope.TextReRank.call(
  433. api_key=self.api_key,
  434. model=self.model_name,
  435. query=query,
  436. documents=texts,
  437. top_n=len(texts),
  438. return_documents=False
  439. )
  440. rank = np.zeros(len(texts), dtype=float)
  441. if resp.status_code == HTTPStatus.OK:
  442. for r in resp.output.results:
  443. rank[r.index] = r.relevance_score
  444. return rank, resp.usage.total_tokens
  445. else:
  446. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  447. class HuggingfaceRerank(DefaultRerank):
  448. @staticmethod
  449. def post(query: str, texts: list, url="127.0.0.1"):
  450. exc = None
  451. scores = [0 for _ in range(len(texts))]
  452. batch_size = 8
  453. for i in range(0, len(texts), batch_size):
  454. try:
  455. res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
  456. json={"query": query, "texts": texts[i: i + batch_size],
  457. "raw_scores": False, "truncate": True})
  458. for o in res.json():
  459. scores[o["index"] + i] = o["score"]
  460. except Exception as e:
  461. exc = e
  462. if exc:
  463. raise exc
  464. return np.array(scores)
  465. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  466. self.model_name = model_name.split("___")[0]
  467. self.base_url = base_url
  468. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  469. if not texts:
  470. return np.array([]), 0
  471. token_count = 0
  472. for t in texts:
  473. token_count += num_tokens_from_string(t)
  474. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  475. class GPUStackRerank(Base):
  476. def __init__(
  477. self, key, model_name, base_url
  478. ):
  479. if not base_url:
  480. raise ValueError("url cannot be None")
  481. self.model_name = model_name
  482. self.base_url = str(URL(base_url) / "v1" / "rerank")
  483. self.headers = {
  484. "accept": "application/json",
  485. "content-type": "application/json",
  486. "authorization": f"Bearer {key}",
  487. }
  488. def similarity(self, query: str, texts: list):
  489. payload = {
  490. "model": self.model_name,
  491. "query": query,
  492. "documents": texts,
  493. "top_n": len(texts),
  494. }
  495. try:
  496. response = requests.post(
  497. self.base_url, json=payload, headers=self.headers
  498. )
  499. response.raise_for_status()
  500. response_json = response.json()
  501. rank = np.zeros(len(texts), dtype=float)
  502. if "results" not in response_json:
  503. return rank, 0
  504. token_count = 0
  505. for t in texts:
  506. token_count += num_tokens_from_string(t)
  507. for result in response_json["results"]:
  508. rank[result["index"]] = result["relevance_score"]
  509. return (
  510. rank,
  511. token_count,
  512. )
  513. except httpx.HTTPStatusError as e:
  514. raise ValueError(
  515. f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")