Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

rerank_model.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from collections.abc import Iterable
  19. from urllib.parse import urljoin
  20. import requests
  21. import httpx
  22. from huggingface_hub import snapshot_download
  23. import os
  24. from abc import ABC
  25. import numpy as np
  26. from yarl import URL
  27. from api import settings
  28. from api.utils.file_utils import get_home_cache_dir
  29. from api.utils.log_utils import log_exception
  30. from rag.utils import num_tokens_from_string, truncate
  31. import json
  32. def sigmoid(x):
  33. return 1 / (1 + np.exp(-x))
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def similarity(self, query: str, texts: list):
  38. raise NotImplementedError("Please implement encode method!")
  39. def total_token_count(self, resp):
  40. try:
  41. return resp.usage.total_tokens
  42. except Exception:
  43. pass
  44. try:
  45. return resp["usage"]["total_tokens"]
  46. except Exception:
  47. pass
  48. return 0
  49. class DefaultRerank(Base):
  50. _model = None
  51. _model_lock = threading.Lock()
  52. def __init__(self, key, model_name, **kwargs):
  53. """
  54. If you have trouble downloading HuggingFace models, -_^ this might help!!
  55. For Linux:
  56. export HF_ENDPOINT=https://hf-mirror.com
  57. For Windows:
  58. Good luck
  59. ^_-
  60. """
  61. if not settings.LIGHTEN and not DefaultRerank._model:
  62. import torch
  63. from FlagEmbedding import FlagReranker
  64. with DefaultRerank._model_lock:
  65. if not DefaultRerank._model:
  66. try:
  67. DefaultRerank._model = FlagReranker(
  68. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  69. use_fp16=torch.cuda.is_available())
  70. except Exception:
  71. model_dir = snapshot_download(repo_id=model_name,
  72. local_dir=os.path.join(get_home_cache_dir(),
  73. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  74. local_dir_use_symlinks=False)
  75. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  76. self._model = DefaultRerank._model
  77. self._dynamic_batch_size = 8
  78. self._min_batch_size = 1
  79. def torch_empty_cache(self):
  80. try:
  81. import torch
  82. torch.cuda.empty_cache()
  83. except Exception as e:
  84. print(f"Error emptying cache: {e}")
  85. def _process_batch(self, pairs, max_batch_size=None):
  86. """template method for subclass call"""
  87. old_dynamic_batch_size = self._dynamic_batch_size
  88. if max_batch_size is not None:
  89. self._dynamic_batch_size = max_batch_size
  90. res = []
  91. i = 0
  92. while i < len(pairs):
  93. current_batch = self._dynamic_batch_size
  94. max_retries = 5
  95. retry_count = 0
  96. while retry_count < max_retries:
  97. try:
  98. # call subclass implemented batch processing calculation
  99. batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
  100. res.extend(batch_scores)
  101. i += current_batch
  102. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  103. break
  104. except RuntimeError as e:
  105. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  106. current_batch = max(current_batch // 2, self._min_batch_size)
  107. self.torch_empty_cache()
  108. retry_count += 1
  109. else:
  110. raise
  111. if retry_count >= max_retries:
  112. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  113. self.torch_empty_cache()
  114. self._dynamic_batch_size = old_dynamic_batch_size
  115. return np.array(res)
  116. def _compute_batch_scores(self, batch_pairs, max_length=None):
  117. if max_length is None:
  118. scores = self._model.compute_score(batch_pairs)
  119. else:
  120. scores = self._model.compute_score(batch_pairs, max_length=max_length)
  121. scores = sigmoid(np.array(scores)).tolist()
  122. if not isinstance(scores, Iterable):
  123. scores = [scores]
  124. return scores
  125. def similarity(self, query: str, texts: list):
  126. pairs = [(query, truncate(t, 2048)) for t in texts]
  127. token_count = 0
  128. for _, t in pairs:
  129. token_count += num_tokens_from_string(t)
  130. batch_size = 4096
  131. res = self._process_batch(pairs, max_batch_size=batch_size)
  132. return np.array(res), token_count
  133. class JinaRerank(Base):
  134. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
  135. base_url="https://api.jina.ai/v1/rerank"):
  136. self.base_url = "https://api.jina.ai/v1/rerank"
  137. self.headers = {
  138. "Content-Type": "application/json",
  139. "Authorization": f"Bearer {key}"
  140. }
  141. self.model_name = model_name
  142. def similarity(self, query: str, texts: list):
  143. texts = [truncate(t, 8196) for t in texts]
  144. data = {
  145. "model": self.model_name,
  146. "query": query,
  147. "documents": texts,
  148. "top_n": len(texts)
  149. }
  150. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  151. rank = np.zeros(len(texts), dtype=float)
  152. try:
  153. for d in res["results"]:
  154. rank[d["index"]] = d["relevance_score"]
  155. except Exception as _e:
  156. log_exception(_e, res)
  157. return rank, self.total_token_count(res)
  158. class YoudaoRerank(DefaultRerank):
  159. _model = None
  160. _model_lock = threading.Lock()
  161. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  162. if not settings.LIGHTEN and not YoudaoRerank._model:
  163. from BCEmbedding import RerankerModel
  164. with YoudaoRerank._model_lock:
  165. if not YoudaoRerank._model:
  166. try:
  167. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  168. get_home_cache_dir(),
  169. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  170. except Exception:
  171. YoudaoRerank._model = RerankerModel(
  172. model_name_or_path=model_name.replace(
  173. "maidalun1020", "InfiniFlow"))
  174. self._model = YoudaoRerank._model
  175. self._dynamic_batch_size = 8
  176. self._min_batch_size = 1
  177. def similarity(self, query: str, texts: list):
  178. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  179. token_count = 0
  180. for _, t in pairs:
  181. token_count += num_tokens_from_string(t)
  182. batch_size = 8
  183. res = self._process_batch(pairs, max_batch_size=batch_size)
  184. return np.array(res), token_count
  185. class XInferenceRerank(Base):
  186. def __init__(self, key="x", model_name="", base_url=""):
  187. if base_url.find("/v1") == -1:
  188. base_url = urljoin(base_url, "/v1/rerank")
  189. if base_url.find("/rerank") == -1:
  190. base_url = urljoin(base_url, "/v1/rerank")
  191. self.model_name = model_name
  192. self.base_url = base_url
  193. self.headers = {
  194. "Content-Type": "application/json",
  195. "accept": "application/json"
  196. }
  197. if key and key != "x":
  198. self.headers["Authorization"] = f"Bearer {key}"
  199. def similarity(self, query: str, texts: list):
  200. if len(texts) == 0:
  201. return np.array([]), 0
  202. pairs = [(query, truncate(t, 4096)) for t in texts]
  203. token_count = 0
  204. for _, t in pairs:
  205. token_count += num_tokens_from_string(t)
  206. data = {
  207. "model": self.model_name,
  208. "query": query,
  209. "return_documents": "true",
  210. "return_len": "true",
  211. "documents": texts
  212. }
  213. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  214. rank = np.zeros(len(texts), dtype=float)
  215. try:
  216. for d in res["results"]:
  217. rank[d["index"]] = d["relevance_score"]
  218. except Exception as _e:
  219. log_exception(_e, res)
  220. return rank, token_count
  221. class LocalAIRerank(Base):
  222. def __init__(self, key, model_name, base_url):
  223. if base_url.find("/rerank") == -1:
  224. self.base_url = urljoin(base_url, "/rerank")
  225. else:
  226. self.base_url = base_url
  227. self.headers = {
  228. "Content-Type": "application/json",
  229. "Authorization": f"Bearer {key}"
  230. }
  231. self.model_name = model_name.split("___")[0]
  232. def similarity(self, query: str, texts: list):
  233. # noway to config Ragflow , use fix setting
  234. texts = [truncate(t, 500) for t in texts]
  235. data = {
  236. "model": self.model_name,
  237. "query": query,
  238. "documents": texts,
  239. "top_n": len(texts),
  240. }
  241. token_count = 0
  242. for t in texts:
  243. token_count += num_tokens_from_string(t)
  244. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  245. rank = np.zeros(len(texts), dtype=float)
  246. try:
  247. for d in res["results"]:
  248. rank[d["index"]] = d["relevance_score"]
  249. except Exception as _e:
  250. log_exception(_e, res)
  251. # Normalize the rank values to the range 0 to 1
  252. min_rank = np.min(rank)
  253. max_rank = np.max(rank)
  254. # Avoid division by zero if all ranks are identical
  255. if max_rank - min_rank != 0:
  256. rank = (rank - min_rank) / (max_rank - min_rank)
  257. else:
  258. rank = np.zeros_like(rank)
  259. return rank, token_count
  260. class NvidiaRerank(Base):
  261. def __init__(
  262. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  263. ):
  264. if not base_url:
  265. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  266. self.model_name = model_name
  267. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  268. self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
  269. )
  270. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  271. self.base_url = urljoin(base_url, "reranking")
  272. self.model_name = "nv-rerank-qa-mistral-4b:1"
  273. self.headers = {
  274. "accept": "application/json",
  275. "Content-Type": "application/json",
  276. "Authorization": f"Bearer {key}",
  277. }
  278. def similarity(self, query: str, texts: list):
  279. token_count = num_tokens_from_string(query) + sum(
  280. [num_tokens_from_string(t) for t in texts]
  281. )
  282. data = {
  283. "model": self.model_name,
  284. "query": {"text": query},
  285. "passages": [{"text": text} for text in texts],
  286. "truncate": "END",
  287. "top_n": len(texts),
  288. }
  289. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  290. rank = np.zeros(len(texts), dtype=float)
  291. try:
  292. for d in res["rankings"]:
  293. rank[d["index"]] = d["logit"]
  294. except Exception as _e:
  295. log_exception(_e, res)
  296. return rank, token_count
  297. class LmStudioRerank(Base):
  298. def __init__(self, key, model_name, base_url):
  299. pass
  300. def similarity(self, query: str, texts: list):
  301. raise NotImplementedError("The LmStudioRerank has not been implement")
  302. class OpenAI_APIRerank(Base):
  303. def __init__(self, key, model_name, base_url):
  304. if base_url.find("/rerank") == -1:
  305. self.base_url = urljoin(base_url, "/rerank")
  306. else:
  307. self.base_url = base_url
  308. self.headers = {
  309. "Content-Type": "application/json",
  310. "Authorization": f"Bearer {key}"
  311. }
  312. self.model_name = model_name.split("___")[0]
  313. def similarity(self, query: str, texts: list):
  314. # noway to config Ragflow , use fix setting
  315. texts = [truncate(t, 500) for t in texts]
  316. data = {
  317. "model": self.model_name,
  318. "query": query,
  319. "documents": texts,
  320. "top_n": len(texts),
  321. }
  322. token_count = 0
  323. for t in texts:
  324. token_count += num_tokens_from_string(t)
  325. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  326. rank = np.zeros(len(texts), dtype=float)
  327. try:
  328. for d in res["results"]:
  329. rank[d["index"]] = d["relevance_score"]
  330. except Exception as _e:
  331. log_exception(_e, res)
  332. # Normalize the rank values to the range 0 to 1
  333. min_rank = np.min(rank)
  334. max_rank = np.max(rank)
  335. # Avoid division by zero if all ranks are identical
  336. if max_rank - min_rank != 0:
  337. rank = (rank - min_rank) / (max_rank - min_rank)
  338. else:
  339. rank = np.zeros_like(rank)
  340. return rank, token_count
  341. class CoHereRerank(Base):
  342. def __init__(self, key, model_name, base_url=None):
  343. from cohere import Client
  344. self.client = Client(api_key=key, base_url=base_url)
  345. self.model_name = model_name.split("___")[0]
  346. def similarity(self, query: str, texts: list):
  347. token_count = num_tokens_from_string(query) + sum(
  348. [num_tokens_from_string(t) for t in texts]
  349. )
  350. res = self.client.rerank(
  351. model=self.model_name,
  352. query=query,
  353. documents=texts,
  354. top_n=len(texts),
  355. return_documents=False,
  356. )
  357. rank = np.zeros(len(texts), dtype=float)
  358. try:
  359. for d in res.results:
  360. rank[d.index] = d.relevance_score
  361. except Exception as _e:
  362. log_exception(_e, res)
  363. return rank, token_count
  364. class TogetherAIRerank(Base):
  365. def __init__(self, key, model_name, base_url):
  366. pass
  367. def similarity(self, query: str, texts: list):
  368. raise NotImplementedError("The api has not been implement")
  369. class SILICONFLOWRerank(Base):
  370. def __init__(
  371. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  372. ):
  373. if not base_url:
  374. base_url = "https://api.siliconflow.cn/v1/rerank"
  375. self.model_name = model_name
  376. self.base_url = base_url
  377. self.headers = {
  378. "accept": "application/json",
  379. "content-type": "application/json",
  380. "authorization": f"Bearer {key}",
  381. }
  382. def similarity(self, query: str, texts: list):
  383. payload = {
  384. "model": self.model_name,
  385. "query": query,
  386. "documents": texts,
  387. "top_n": len(texts),
  388. "return_documents": False,
  389. "max_chunks_per_doc": 1024,
  390. "overlap_tokens": 80,
  391. }
  392. response = requests.post(
  393. self.base_url, json=payload, headers=self.headers
  394. ).json()
  395. rank = np.zeros(len(texts), dtype=float)
  396. try:
  397. for d in response["results"]:
  398. rank[d["index"]] = d["relevance_score"]
  399. except Exception as _e:
  400. log_exception(_e, response)
  401. return (
  402. rank,
  403. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  404. )
  405. class BaiduYiyanRerank(Base):
  406. def __init__(self, key, model_name, base_url=None):
  407. from qianfan.resources import Reranker
  408. key = json.loads(key)
  409. ak = key.get("yiyan_ak", "")
  410. sk = key.get("yiyan_sk", "")
  411. self.client = Reranker(ak=ak, sk=sk)
  412. self.model_name = model_name
  413. def similarity(self, query: str, texts: list):
  414. res = self.client.do(
  415. model=self.model_name,
  416. query=query,
  417. documents=texts,
  418. top_n=len(texts),
  419. ).body
  420. rank = np.zeros(len(texts), dtype=float)
  421. try:
  422. for d in res["results"]:
  423. rank[d["index"]] = d["relevance_score"]
  424. except Exception as _e:
  425. log_exception(_e, res)
  426. return rank, self.total_token_count(res)
  427. class VoyageRerank(Base):
  428. def __init__(self, key, model_name, base_url=None):
  429. import voyageai
  430. self.client = voyageai.Client(api_key=key)
  431. self.model_name = model_name
  432. def similarity(self, query: str, texts: list):
  433. rank = np.zeros(len(texts), dtype=float)
  434. if not texts:
  435. return rank, 0
  436. res = self.client.rerank(
  437. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  438. )
  439. try:
  440. for r in res.results:
  441. rank[r.index] = r.relevance_score
  442. except Exception as _e:
  443. log_exception(_e, res)
  444. return rank, res.total_tokens
  445. class QWenRerank(Base):
  446. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  447. import dashscope
  448. self.api_key = key
  449. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  450. def similarity(self, query: str, texts: list):
  451. import dashscope
  452. from http import HTTPStatus
  453. resp = dashscope.TextReRank.call(
  454. api_key=self.api_key,
  455. model=self.model_name,
  456. query=query,
  457. documents=texts,
  458. top_n=len(texts),
  459. return_documents=False
  460. )
  461. rank = np.zeros(len(texts), dtype=float)
  462. if resp.status_code == HTTPStatus.OK:
  463. try:
  464. for r in resp.output.results:
  465. rank[r.index] = r.relevance_score
  466. except Exception as _e:
  467. log_exception(_e, resp)
  468. return rank, resp.usage.total_tokens
  469. else:
  470. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  471. class HuggingfaceRerank(DefaultRerank):
  472. @staticmethod
  473. def post(query: str, texts: list, url="127.0.0.1"):
  474. exc = None
  475. scores = [0 for _ in range(len(texts))]
  476. batch_size = 8
  477. for i in range(0, len(texts), batch_size):
  478. try:
  479. res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
  480. json={"query": query, "texts": texts[i: i + batch_size],
  481. "raw_scores": False, "truncate": True})
  482. for o in res.json():
  483. scores[o["index"] + i] = o["score"]
  484. except Exception as e:
  485. exc = e
  486. if exc:
  487. raise exc
  488. return np.array(scores)
  489. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  490. self.model_name = model_name.split("___")[0]
  491. self.base_url = base_url
  492. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  493. if not texts:
  494. return np.array([]), 0
  495. token_count = 0
  496. for t in texts:
  497. token_count += num_tokens_from_string(t)
  498. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  499. class GPUStackRerank(Base):
  500. def __init__(
  501. self, key, model_name, base_url
  502. ):
  503. if not base_url:
  504. raise ValueError("url cannot be None")
  505. self.model_name = model_name
  506. self.base_url = str(URL(base_url) / "v1" / "rerank")
  507. self.headers = {
  508. "accept": "application/json",
  509. "content-type": "application/json",
  510. "authorization": f"Bearer {key}",
  511. }
  512. def similarity(self, query: str, texts: list):
  513. payload = {
  514. "model": self.model_name,
  515. "query": query,
  516. "documents": texts,
  517. "top_n": len(texts),
  518. }
  519. try:
  520. response = requests.post(
  521. self.base_url, json=payload, headers=self.headers
  522. )
  523. response.raise_for_status()
  524. response_json = response.json()
  525. rank = np.zeros(len(texts), dtype=float)
  526. token_count = 0
  527. for t in texts:
  528. token_count += num_tokens_from_string(t)
  529. try:
  530. for result in response_json["results"]:
  531. rank[result["index"]] = result["relevance_score"]
  532. except Exception as _e:
  533. log_exception(_e, response)
  534. return (
  535. rank,
  536. token_count,
  537. )
  538. except httpx.HTTPStatusError as e:
  539. raise ValueError(
  540. f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
  541. class NovitaRerank(JinaRerank):
  542. def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
  543. super().__init__(key, model_name, base_url)