選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

rerank_model.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from urllib.parse import urljoin
  19. import requests
  20. import httpx
  21. from huggingface_hub import snapshot_download
  22. import os
  23. from abc import ABC
  24. import numpy as np
  25. from yarl import URL
  26. from api import settings
  27. from api.utils.file_utils import get_home_cache_dir
  28. from rag.utils import num_tokens_from_string, truncate
  29. import json
  30. def sigmoid(x):
  31. return 1 / (1 + np.exp(-x))
  32. class Base(ABC):
  33. def __init__(self, key, model_name):
  34. pass
  35. def similarity(self, query: str, texts: list):
  36. raise NotImplementedError("Please implement encode method!")
  37. def total_token_count(self, resp):
  38. try:
  39. return resp.usage.total_tokens
  40. except Exception:
  41. pass
  42. try:
  43. return resp["usage"]["total_tokens"]
  44. except Exception:
  45. pass
  46. return 0
  47. class DefaultRerank(Base):
  48. _model = None
  49. _model_lock = threading.Lock()
  50. def __init__(self, key, model_name, **kwargs):
  51. """
  52. If you have trouble downloading HuggingFace models, -_^ this might help!!
  53. For Linux:
  54. export HF_ENDPOINT=https://hf-mirror.com
  55. For Windows:
  56. Good luck
  57. ^_-
  58. """
  59. if not settings.LIGHTEN and not DefaultRerank._model:
  60. import torch
  61. from FlagEmbedding import FlagReranker
  62. with DefaultRerank._model_lock:
  63. if not DefaultRerank._model:
  64. try:
  65. DefaultRerank._model = FlagReranker(
  66. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  67. use_fp16=torch.cuda.is_available())
  68. except Exception:
  69. model_dir = snapshot_download(repo_id=model_name,
  70. local_dir=os.path.join(get_home_cache_dir(),
  71. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  72. local_dir_use_symlinks=False)
  73. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  74. self._model = DefaultRerank._model
  75. self._dynamic_batch_size = 8
  76. self._min_batch_size = 1
  77. def torch_empty_cache(self):
  78. try:
  79. import torch
  80. torch.cuda.empty_cache()
  81. except Exception as e:
  82. print(f"Error emptying cache: {e}")
  83. def _process_batch(self, pairs, max_batch_size=None):
  84. """template method for subclass call"""
  85. old_dynamic_batch_size = self._dynamic_batch_size
  86. if max_batch_size is not None:
  87. self._dynamic_batch_size = max_batch_size
  88. res = []
  89. i = 0
  90. while i < len(pairs):
  91. current_batch = self._dynamic_batch_size
  92. max_retries = 5
  93. retry_count = 0
  94. while retry_count < max_retries:
  95. try:
  96. # call subclass implemented batch processing calculation
  97. batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
  98. res.extend(batch_scores)
  99. i += current_batch
  100. self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
  101. break
  102. except RuntimeError as e:
  103. if "CUDA out of memory" in str(e) and current_batch > self._min_batch_size:
  104. current_batch = max(current_batch // 2, self._min_batch_size)
  105. self.torch_empty_cache()
  106. retry_count += 1
  107. else:
  108. raise
  109. if retry_count >= max_retries:
  110. raise RuntimeError("max retry times, still cannot process batch, please check your GPU memory")
  111. self.torch_empty_cache()
  112. self._dynamic_batch_size = old_dynamic_batch_size
  113. return np.array(res)
  114. def _compute_batch_scores(self, batch_pairs, max_length=None):
  115. if max_length is None:
  116. scores = self._model.compute_score(batch_pairs)
  117. else:
  118. scores = self._model.compute_score(batch_pairs, max_length=max_length)
  119. scores = sigmoid(np.array(scores)).tolist()
  120. return scores
  121. def similarity(self, query: str, texts: list):
  122. pairs = [(query, truncate(t, 2048)) for t in texts]
  123. token_count = 0
  124. for _, t in pairs:
  125. token_count += num_tokens_from_string(t)
  126. batch_size = 4096
  127. res = self._process_batch(pairs, max_batch_size=batch_size)
  128. return np.array(res), token_count
  129. class JinaRerank(Base):
  130. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
  131. base_url="https://api.jina.ai/v1/rerank"):
  132. self.base_url = "https://api.jina.ai/v1/rerank"
  133. self.headers = {
  134. "Content-Type": "application/json",
  135. "Authorization": f"Bearer {key}"
  136. }
  137. self.model_name = model_name
  138. def similarity(self, query: str, texts: list):
  139. texts = [truncate(t, 8196) for t in texts]
  140. data = {
  141. "model": self.model_name,
  142. "query": query,
  143. "documents": texts,
  144. "top_n": len(texts)
  145. }
  146. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  147. rank = np.zeros(len(texts), dtype=float)
  148. for d in res["results"]:
  149. rank[d["index"]] = d["relevance_score"]
  150. return rank, self.total_token_count(res)
  151. class YoudaoRerank(DefaultRerank):
  152. _model = None
  153. _model_lock = threading.Lock()
  154. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  155. if not settings.LIGHTEN and not YoudaoRerank._model:
  156. from BCEmbedding import RerankerModel
  157. with YoudaoRerank._model_lock:
  158. if not YoudaoRerank._model:
  159. try:
  160. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  161. get_home_cache_dir(),
  162. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  163. except Exception:
  164. YoudaoRerank._model = RerankerModel(
  165. model_name_or_path=model_name.replace(
  166. "maidalun1020", "InfiniFlow"))
  167. self._model = YoudaoRerank._model
  168. def similarity(self, query: str, texts: list):
  169. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  170. token_count = 0
  171. for _, t in pairs:
  172. token_count += num_tokens_from_string(t)
  173. batch_size = 8
  174. res = self._process_batch(pairs, max_batch_size=batch_size)
  175. return np.array(res), token_count
  176. class XInferenceRerank(Base):
  177. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  178. if base_url.find("/v1") == -1:
  179. base_url = urljoin(base_url, "/v1/rerank")
  180. if base_url.find("/rerank") == -1:
  181. base_url = urljoin(base_url, "/v1/rerank")
  182. self.model_name = model_name
  183. self.base_url = base_url
  184. self.headers = {
  185. "Content-Type": "application/json",
  186. "accept": "application/json",
  187. "Authorization": f"Bearer {key}"
  188. }
  189. def similarity(self, query: str, texts: list):
  190. if len(texts) == 0:
  191. return np.array([]), 0
  192. pairs = [(query, truncate(t, 4096)) for t in texts]
  193. token_count = 0
  194. for _, t in pairs:
  195. token_count += num_tokens_from_string(t)
  196. data = {
  197. "model": self.model_name,
  198. "query": query,
  199. "return_documents": "true",
  200. "return_len": "true",
  201. "documents": texts
  202. }
  203. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  204. rank = np.zeros(len(texts), dtype=float)
  205. for d in res["results"]:
  206. rank[d["index"]] = d["relevance_score"]
  207. return rank, token_count
  208. class LocalAIRerank(Base):
  209. def __init__(self, key, model_name, base_url):
  210. if base_url.find("/rerank") == -1:
  211. self.base_url = urljoin(base_url, "/rerank")
  212. else:
  213. self.base_url = base_url
  214. self.headers = {
  215. "Content-Type": "application/json",
  216. "Authorization": f"Bearer {key}"
  217. }
  218. self.model_name = model_name.split("___")[0]
  219. def similarity(self, query: str, texts: list):
  220. # noway to config Ragflow , use fix setting
  221. texts = [truncate(t, 500) for t in texts]
  222. data = {
  223. "model": self.model_name,
  224. "query": query,
  225. "documents": texts,
  226. "top_n": len(texts),
  227. }
  228. token_count = 0
  229. for t in texts:
  230. token_count += num_tokens_from_string(t)
  231. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  232. rank = np.zeros(len(texts), dtype=float)
  233. if 'results' not in res:
  234. raise ValueError("response not contains results\n" + str(res))
  235. for d in res["results"]:
  236. rank[d["index"]] = d["relevance_score"]
  237. # Normalize the rank values to the range 0 to 1
  238. min_rank = np.min(rank)
  239. max_rank = np.max(rank)
  240. # Avoid division by zero if all ranks are identical
  241. if max_rank - min_rank != 0:
  242. rank = (rank - min_rank) / (max_rank - min_rank)
  243. else:
  244. rank = np.zeros_like(rank)
  245. return rank, token_count
  246. class NvidiaRerank(Base):
  247. def __init__(
  248. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  249. ):
  250. if not base_url:
  251. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  252. self.model_name = model_name
  253. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  254. self.base_url = os.path.join(
  255. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  256. )
  257. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  258. self.base_url = os.path.join(base_url, "reranking")
  259. self.model_name = "nv-rerank-qa-mistral-4b:1"
  260. self.headers = {
  261. "accept": "application/json",
  262. "Content-Type": "application/json",
  263. "Authorization": f"Bearer {key}",
  264. }
  265. def similarity(self, query: str, texts: list):
  266. token_count = num_tokens_from_string(query) + sum(
  267. [num_tokens_from_string(t) for t in texts]
  268. )
  269. data = {
  270. "model": self.model_name,
  271. "query": {"text": query},
  272. "passages": [{"text": text} for text in texts],
  273. "truncate": "END",
  274. "top_n": len(texts),
  275. }
  276. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  277. rank = np.zeros(len(texts), dtype=float)
  278. for d in res["rankings"]:
  279. rank[d["index"]] = d["logit"]
  280. return rank, token_count
  281. class LmStudioRerank(Base):
  282. def __init__(self, key, model_name, base_url):
  283. pass
  284. def similarity(self, query: str, texts: list):
  285. raise NotImplementedError("The LmStudioRerank has not been implement")
  286. class OpenAI_APIRerank(Base):
  287. def __init__(self, key, model_name, base_url):
  288. if base_url.find("/rerank") == -1:
  289. self.base_url = urljoin(base_url, "/rerank")
  290. else:
  291. self.base_url = base_url
  292. self.headers = {
  293. "Content-Type": "application/json",
  294. "Authorization": f"Bearer {key}"
  295. }
  296. self.model_name = model_name.split("___")[0]
  297. def similarity(self, query: str, texts: list):
  298. # noway to config Ragflow , use fix setting
  299. texts = [truncate(t, 500) for t in texts]
  300. data = {
  301. "model": self.model_name,
  302. "query": query,
  303. "documents": texts,
  304. "top_n": len(texts),
  305. }
  306. token_count = 0
  307. for t in texts:
  308. token_count += num_tokens_from_string(t)
  309. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  310. rank = np.zeros(len(texts), dtype=float)
  311. if 'results' not in res:
  312. raise ValueError("response not contains results\n" + str(res))
  313. for d in res["results"]:
  314. rank[d["index"]] = d["relevance_score"]
  315. # Normalize the rank values to the range 0 to 1
  316. min_rank = np.min(rank)
  317. max_rank = np.max(rank)
  318. # Avoid division by zero if all ranks are identical
  319. if max_rank - min_rank != 0:
  320. rank = (rank - min_rank) / (max_rank - min_rank)
  321. else:
  322. rank = np.zeros_like(rank)
  323. return rank, token_count
  324. class CoHereRerank(Base):
  325. def __init__(self, key, model_name, base_url=None):
  326. from cohere import Client
  327. self.client = Client(api_key=key)
  328. self.model_name = model_name
  329. def similarity(self, query: str, texts: list):
  330. token_count = num_tokens_from_string(query) + sum(
  331. [num_tokens_from_string(t) for t in texts]
  332. )
  333. res = self.client.rerank(
  334. model=self.model_name,
  335. query=query,
  336. documents=texts,
  337. top_n=len(texts),
  338. return_documents=False,
  339. )
  340. rank = np.zeros(len(texts), dtype=float)
  341. for d in res.results:
  342. rank[d.index] = d.relevance_score
  343. return rank, token_count
  344. class TogetherAIRerank(Base):
  345. def __init__(self, key, model_name, base_url):
  346. pass
  347. def similarity(self, query: str, texts: list):
  348. raise NotImplementedError("The api has not been implement")
  349. class SILICONFLOWRerank(Base):
  350. def __init__(
  351. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  352. ):
  353. if not base_url:
  354. base_url = "https://api.siliconflow.cn/v1/rerank"
  355. self.model_name = model_name
  356. self.base_url = base_url
  357. self.headers = {
  358. "accept": "application/json",
  359. "content-type": "application/json",
  360. "authorization": f"Bearer {key}",
  361. }
  362. def similarity(self, query: str, texts: list):
  363. payload = {
  364. "model": self.model_name,
  365. "query": query,
  366. "documents": texts,
  367. "top_n": len(texts),
  368. "return_documents": False,
  369. "max_chunks_per_doc": 1024,
  370. "overlap_tokens": 80,
  371. }
  372. response = requests.post(
  373. self.base_url, json=payload, headers=self.headers
  374. ).json()
  375. rank = np.zeros(len(texts), dtype=float)
  376. if "results" not in response:
  377. return rank, 0
  378. for d in response["results"]:
  379. rank[d["index"]] = d["relevance_score"]
  380. return (
  381. rank,
  382. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  383. )
  384. class BaiduYiyanRerank(Base):
  385. def __init__(self, key, model_name, base_url=None):
  386. from qianfan.resources import Reranker
  387. key = json.loads(key)
  388. ak = key.get("yiyan_ak", "")
  389. sk = key.get("yiyan_sk", "")
  390. self.client = Reranker(ak=ak, sk=sk)
  391. self.model_name = model_name
  392. def similarity(self, query: str, texts: list):
  393. res = self.client.do(
  394. model=self.model_name,
  395. query=query,
  396. documents=texts,
  397. top_n=len(texts),
  398. ).body
  399. rank = np.zeros(len(texts), dtype=float)
  400. for d in res["results"]:
  401. rank[d["index"]] = d["relevance_score"]
  402. return rank, self.total_token_count(res)
  403. class VoyageRerank(Base):
  404. def __init__(self, key, model_name, base_url=None):
  405. import voyageai
  406. self.client = voyageai.Client(api_key=key)
  407. self.model_name = model_name
  408. def similarity(self, query: str, texts: list):
  409. rank = np.zeros(len(texts), dtype=float)
  410. if not texts:
  411. return rank, 0
  412. res = self.client.rerank(
  413. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  414. )
  415. for r in res.results:
  416. rank[r.index] = r.relevance_score
  417. return rank, res.total_tokens
  418. class QWenRerank(Base):
  419. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  420. import dashscope
  421. self.api_key = key
  422. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  423. def similarity(self, query: str, texts: list):
  424. import dashscope
  425. from http import HTTPStatus
  426. resp = dashscope.TextReRank.call(
  427. api_key=self.api_key,
  428. model=self.model_name,
  429. query=query,
  430. documents=texts,
  431. top_n=len(texts),
  432. return_documents=False
  433. )
  434. rank = np.zeros(len(texts), dtype=float)
  435. if resp.status_code == HTTPStatus.OK:
  436. for r in resp.output.results:
  437. rank[r.index] = r.relevance_score
  438. return rank, resp.usage.total_tokens
  439. else:
  440. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  441. class HuggingfaceRerank(DefaultRerank):
  442. @staticmethod
  443. def post(query: str, texts: list, url="127.0.0.1"):
  444. exc = None
  445. scores = [0 for _ in range(len(texts))]
  446. batch_size = 8
  447. for i in range(0, len(texts), batch_size):
  448. try:
  449. res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
  450. json={"query": query, "texts": texts[i: i + batch_size],
  451. "raw_scores": False, "truncate": True})
  452. for o in res.json():
  453. scores[o["index"] + i] = o["score"]
  454. except Exception as e:
  455. exc = e
  456. if exc:
  457. raise exc
  458. return np.array(scores)
  459. def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://127.0.0.1"):
  460. self.model_name = model_name
  461. self.base_url = base_url
  462. def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]:
  463. if not texts:
  464. return np.array([]), 0
  465. token_count = 0
  466. for t in texts:
  467. token_count += num_tokens_from_string(t)
  468. return HuggingfaceRerank.post(query, texts, self.base_url), token_count
  469. class GPUStackRerank(Base):
  470. def __init__(
  471. self, key, model_name, base_url
  472. ):
  473. if not base_url:
  474. raise ValueError("url cannot be None")
  475. self.model_name = model_name
  476. self.base_url = str(URL(base_url) / "v1" / "rerank")
  477. self.headers = {
  478. "accept": "application/json",
  479. "content-type": "application/json",
  480. "authorization": f"Bearer {key}",
  481. }
  482. def similarity(self, query: str, texts: list):
  483. payload = {
  484. "model": self.model_name,
  485. "query": query,
  486. "documents": texts,
  487. "top_n": len(texts),
  488. }
  489. try:
  490. response = requests.post(
  491. self.base_url, json=payload, headers=self.headers
  492. )
  493. response.raise_for_status()
  494. response_json = response.json()
  495. rank = np.zeros(len(texts), dtype=float)
  496. if "results" not in response_json:
  497. return rank, 0
  498. token_count = 0
  499. for t in texts:
  500. token_count += num_tokens_from_string(t)
  501. for result in response_json["results"]:
  502. rank[result["index"]] = result["relevance_score"]
  503. return (
  504. rank,
  505. token_count,
  506. )
  507. except httpx.HTTPStatusError as e:
  508. raise ValueError(
  509. f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")