您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rerank_model.py 17KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from urllib.parse import urljoin
  19. import requests
  20. import httpx
  21. from huggingface_hub import snapshot_download
  22. import os
  23. from abc import ABC
  24. import numpy as np
  25. from yarl import URL
  26. from api import settings
  27. from api.utils.file_utils import get_home_cache_dir
  28. from rag.utils import num_tokens_from_string, truncate
  29. import json
  30. def sigmoid(x):
  31. return 1 / (1 + np.exp(-x))
  32. class Base(ABC):
  33. def __init__(self, key, model_name):
  34. pass
  35. def similarity(self, query: str, texts: list):
  36. raise NotImplementedError("Please implement encode method!")
  37. class DefaultRerank(Base):
  38. _model = None
  39. _model_lock = threading.Lock()
  40. def __init__(self, key, model_name, **kwargs):
  41. """
  42. If you have trouble downloading HuggingFace models, -_^ this might help!!
  43. For Linux:
  44. export HF_ENDPOINT=https://hf-mirror.com
  45. For Windows:
  46. Good luck
  47. ^_-
  48. """
  49. if not settings.LIGHTEN and not DefaultRerank._model:
  50. import torch
  51. from FlagEmbedding import FlagReranker
  52. with DefaultRerank._model_lock:
  53. if not DefaultRerank._model:
  54. try:
  55. DefaultRerank._model = FlagReranker(
  56. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  57. use_fp16=torch.cuda.is_available())
  58. except Exception:
  59. model_dir = snapshot_download(repo_id=model_name,
  60. local_dir=os.path.join(get_home_cache_dir(),
  61. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  62. local_dir_use_symlinks=False)
  63. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  64. self._model = DefaultRerank._model
  65. def similarity(self, query: str, texts: list):
  66. pairs = [(query, truncate(t, 2048)) for t in texts]
  67. token_count = 0
  68. for _, t in pairs:
  69. token_count += num_tokens_from_string(t)
  70. batch_size = 4096
  71. res = []
  72. for i in range(0, len(pairs), batch_size):
  73. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  74. scores = sigmoid(np.array(scores)).tolist()
  75. if isinstance(scores, float):
  76. res.append(scores)
  77. else:
  78. res.extend(scores)
  79. return np.array(res), token_count
  80. class JinaRerank(Base):
  81. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
  82. base_url="https://api.jina.ai/v1/rerank"):
  83. self.base_url = "https://api.jina.ai/v1/rerank"
  84. self.headers = {
  85. "Content-Type": "application/json",
  86. "Authorization": f"Bearer {key}"
  87. }
  88. self.model_name = model_name
  89. def similarity(self, query: str, texts: list):
  90. texts = [truncate(t, 8196) for t in texts]
  91. data = {
  92. "model": self.model_name,
  93. "query": query,
  94. "documents": texts,
  95. "top_n": len(texts)
  96. }
  97. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  98. rank = np.zeros(len(texts), dtype=float)
  99. for d in res["results"]:
  100. rank[d["index"]] = d["relevance_score"]
  101. return rank, res["usage"]["total_tokens"]
  102. class YoudaoRerank(DefaultRerank):
  103. _model = None
  104. _model_lock = threading.Lock()
  105. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  106. if not settings.LIGHTEN and not YoudaoRerank._model:
  107. from BCEmbedding import RerankerModel
  108. with YoudaoRerank._model_lock:
  109. if not YoudaoRerank._model:
  110. try:
  111. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  112. get_home_cache_dir(),
  113. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  114. except Exception:
  115. YoudaoRerank._model = RerankerModel(
  116. model_name_or_path=model_name.replace(
  117. "maidalun1020", "InfiniFlow"))
  118. self._model = YoudaoRerank._model
  119. def similarity(self, query: str, texts: list):
  120. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  121. token_count = 0
  122. for _, t in pairs:
  123. token_count += num_tokens_from_string(t)
  124. batch_size = 8
  125. res = []
  126. for i in range(0, len(pairs), batch_size):
  127. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  128. scores = sigmoid(np.array(scores)).tolist()
  129. if isinstance(scores, float):
  130. res.append(scores)
  131. else:
  132. res.extend(scores)
  133. return np.array(res), token_count
  134. class XInferenceRerank(Base):
  135. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  136. if base_url.find("/v1") == -1:
  137. base_url = urljoin(base_url, "/v1/rerank")
  138. if base_url.find("/rerank") == -1:
  139. base_url = urljoin(base_url, "/v1/rerank")
  140. self.model_name = model_name
  141. self.base_url = base_url
  142. self.headers = {
  143. "Content-Type": "application/json",
  144. "accept": "application/json",
  145. "Authorization": f"Bearer {key}"
  146. }
  147. def similarity(self, query: str, texts: list):
  148. if len(texts) == 0:
  149. return np.array([]), 0
  150. data = {
  151. "model": self.model_name,
  152. "query": query,
  153. "return_documents": "true",
  154. "return_len": "true",
  155. "documents": texts
  156. }
  157. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  158. rank = np.zeros(len(texts), dtype=float)
  159. for d in res["results"]:
  160. rank[d["index"]] = d["relevance_score"]
  161. return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
  162. class LocalAIRerank(Base):
  163. def __init__(self, key, model_name, base_url):
  164. if base_url.find("/rerank") == -1:
  165. self.base_url = urljoin(base_url, "/rerank")
  166. else:
  167. self.base_url = base_url
  168. self.headers = {
  169. "Content-Type": "application/json",
  170. "Authorization": f"Bearer {key}"
  171. }
  172. self.model_name = model_name.split("___")[0]
  173. def similarity(self, query: str, texts: list):
  174. # noway to config Ragflow , use fix setting
  175. texts = [truncate(t, 500) for t in texts]
  176. data = {
  177. "model": self.model_name,
  178. "query": query,
  179. "documents": texts,
  180. "top_n": len(texts),
  181. }
  182. token_count = 0
  183. for t in texts:
  184. token_count += num_tokens_from_string(t)
  185. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  186. rank = np.zeros(len(texts), dtype=float)
  187. if 'results' not in res:
  188. raise ValueError("response not contains results\n" + str(res))
  189. for d in res["results"]:
  190. rank[d["index"]] = d["relevance_score"]
  191. # Normalize the rank values to the range 0 to 1
  192. min_rank = np.min(rank)
  193. max_rank = np.max(rank)
  194. # Avoid division by zero if all ranks are identical
  195. if max_rank - min_rank != 0:
  196. rank = (rank - min_rank) / (max_rank - min_rank)
  197. else:
  198. rank = np.zeros_like(rank)
  199. return rank, token_count
  200. class NvidiaRerank(Base):
  201. def __init__(
  202. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  203. ):
  204. if not base_url:
  205. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  206. self.model_name = model_name
  207. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  208. self.base_url = os.path.join(
  209. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  210. )
  211. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  212. self.base_url = os.path.join(base_url, "reranking")
  213. self.model_name = "nv-rerank-qa-mistral-4b:1"
  214. self.headers = {
  215. "accept": "application/json",
  216. "Content-Type": "application/json",
  217. "Authorization": f"Bearer {key}",
  218. }
  219. def similarity(self, query: str, texts: list):
  220. token_count = num_tokens_from_string(query) + sum(
  221. [num_tokens_from_string(t) for t in texts]
  222. )
  223. data = {
  224. "model": self.model_name,
  225. "query": {"text": query},
  226. "passages": [{"text": text} for text in texts],
  227. "truncate": "END",
  228. "top_n": len(texts),
  229. }
  230. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  231. rank = np.zeros(len(texts), dtype=float)
  232. for d in res["rankings"]:
  233. rank[d["index"]] = d["logit"]
  234. return rank, token_count
  235. class LmStudioRerank(Base):
  236. def __init__(self, key, model_name, base_url):
  237. pass
  238. def similarity(self, query: str, texts: list):
  239. raise NotImplementedError("The LmStudioRerank has not been implement")
  240. class OpenAI_APIRerank(Base):
  241. def __init__(self, key, model_name, base_url):
  242. if base_url.find("/rerank") == -1:
  243. self.base_url = urljoin(base_url, "/rerank")
  244. else:
  245. self.base_url = base_url
  246. self.headers = {
  247. "Content-Type": "application/json",
  248. "Authorization": f"Bearer {key}"
  249. }
  250. self.model_name = model_name.split("___")[0]
  251. def similarity(self, query: str, texts: list):
  252. # noway to config Ragflow , use fix setting
  253. texts = [truncate(t, 500) for t in texts]
  254. data = {
  255. "model": self.model_name,
  256. "query": query,
  257. "documents": texts,
  258. "top_n": len(texts),
  259. }
  260. token_count = 0
  261. for t in texts:
  262. token_count += num_tokens_from_string(t)
  263. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  264. rank = np.zeros(len(texts), dtype=float)
  265. if 'results' not in res:
  266. raise ValueError("response not contains results\n" + str(res))
  267. for d in res["results"]:
  268. rank[d["index"]] = d["relevance_score"]
  269. # Normalize the rank values to the range 0 to 1
  270. min_rank = np.min(rank)
  271. max_rank = np.max(rank)
  272. # Avoid division by zero if all ranks are identical
  273. if max_rank - min_rank != 0:
  274. rank = (rank - min_rank) / (max_rank - min_rank)
  275. else:
  276. rank = np.zeros_like(rank)
  277. return rank, token_count
  278. class CoHereRerank(Base):
  279. def __init__(self, key, model_name, base_url=None):
  280. from cohere import Client
  281. self.client = Client(api_key=key)
  282. self.model_name = model_name
  283. def similarity(self, query: str, texts: list):
  284. token_count = num_tokens_from_string(query) + sum(
  285. [num_tokens_from_string(t) for t in texts]
  286. )
  287. res = self.client.rerank(
  288. model=self.model_name,
  289. query=query,
  290. documents=texts,
  291. top_n=len(texts),
  292. return_documents=False,
  293. )
  294. rank = np.zeros(len(texts), dtype=float)
  295. for d in res.results:
  296. rank[d.index] = d.relevance_score
  297. return rank, token_count
  298. class TogetherAIRerank(Base):
  299. def __init__(self, key, model_name, base_url):
  300. pass
  301. def similarity(self, query: str, texts: list):
  302. raise NotImplementedError("The api has not been implement")
  303. class SILICONFLOWRerank(Base):
  304. def __init__(
  305. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  306. ):
  307. if not base_url:
  308. base_url = "https://api.siliconflow.cn/v1/rerank"
  309. self.model_name = model_name
  310. self.base_url = base_url
  311. self.headers = {
  312. "accept": "application/json",
  313. "content-type": "application/json",
  314. "authorization": f"Bearer {key}",
  315. }
  316. def similarity(self, query: str, texts: list):
  317. payload = {
  318. "model": self.model_name,
  319. "query": query,
  320. "documents": texts,
  321. "top_n": len(texts),
  322. "return_documents": False,
  323. "max_chunks_per_doc": 1024,
  324. "overlap_tokens": 80,
  325. }
  326. response = requests.post(
  327. self.base_url, json=payload, headers=self.headers
  328. ).json()
  329. rank = np.zeros(len(texts), dtype=float)
  330. if "results" not in response:
  331. return rank, 0
  332. for d in response["results"]:
  333. rank[d["index"]] = d["relevance_score"]
  334. return (
  335. rank,
  336. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  337. )
  338. class BaiduYiyanRerank(Base):
  339. def __init__(self, key, model_name, base_url=None):
  340. from qianfan.resources import Reranker
  341. key = json.loads(key)
  342. ak = key.get("yiyan_ak", "")
  343. sk = key.get("yiyan_sk", "")
  344. self.client = Reranker(ak=ak, sk=sk)
  345. self.model_name = model_name
  346. def similarity(self, query: str, texts: list):
  347. res = self.client.do(
  348. model=self.model_name,
  349. query=query,
  350. documents=texts,
  351. top_n=len(texts),
  352. ).body
  353. rank = np.zeros(len(texts), dtype=float)
  354. for d in res["results"]:
  355. rank[d["index"]] = d["relevance_score"]
  356. return rank, res["usage"]["total_tokens"]
  357. class VoyageRerank(Base):
  358. def __init__(self, key, model_name, base_url=None):
  359. import voyageai
  360. self.client = voyageai.Client(api_key=key)
  361. self.model_name = model_name
  362. def similarity(self, query: str, texts: list):
  363. rank = np.zeros(len(texts), dtype=float)
  364. if not texts:
  365. return rank, 0
  366. res = self.client.rerank(
  367. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  368. )
  369. for r in res.results:
  370. rank[r.index] = r.relevance_score
  371. return rank, res.total_tokens
  372. class QWenRerank(Base):
  373. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  374. import dashscope
  375. self.api_key = key
  376. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  377. def similarity(self, query: str, texts: list):
  378. import dashscope
  379. from http import HTTPStatus
  380. resp = dashscope.TextReRank.call(
  381. api_key=self.api_key,
  382. model=self.model_name,
  383. query=query,
  384. documents=texts,
  385. top_n=len(texts),
  386. return_documents=False
  387. )
  388. rank = np.zeros(len(texts), dtype=float)
  389. if resp.status_code == HTTPStatus.OK:
  390. for r in resp.output.results:
  391. rank[r.index] = r.relevance_score
  392. return rank, resp.usage.total_tokens
  393. else:
  394. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  395. class GPUStackRerank(Base):
  396. def __init__(
  397. self, key, model_name, base_url
  398. ):
  399. if not base_url:
  400. raise ValueError("url cannot be None")
  401. self.model_name = model_name
  402. self.base_url = str(URL(base_url)/ "v1" / "rerank")
  403. self.headers = {
  404. "accept": "application/json",
  405. "content-type": "application/json",
  406. "authorization": f"Bearer {key}",
  407. }
  408. def similarity(self, query: str, texts: list):
  409. payload = {
  410. "model": self.model_name,
  411. "query": query,
  412. "documents": texts,
  413. "top_n": len(texts),
  414. }
  415. try:
  416. response = requests.post(
  417. self.base_url, json=payload, headers=self.headers
  418. )
  419. response.raise_for_status()
  420. response_json = response.json()
  421. rank = np.zeros(len(texts), dtype=float)
  422. if "results" not in response_json:
  423. return rank, 0
  424. token_count = 0
  425. for t in texts:
  426. token_count += num_tokens_from_string(t)
  427. for result in response_json["results"]:
  428. rank[result["index"]] = result["relevance_score"]
  429. return (
  430. rank,
  431. token_count,
  432. )
  433. except httpx.HTTPStatusError as e:
  434. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")