Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

rerank_model.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. from urllib.parse import urljoin
  19. import requests
  20. import httpx
  21. from huggingface_hub import snapshot_download
  22. import os
  23. from abc import ABC
  24. import numpy as np
  25. from yarl import URL
  26. from api import settings
  27. from api.utils.file_utils import get_home_cache_dir
  28. from rag.utils import num_tokens_from_string, truncate
  29. import json
  30. def sigmoid(x):
  31. return 1 / (1 + np.exp(-x))
  32. class Base(ABC):
  33. def __init__(self, key, model_name):
  34. pass
  35. def similarity(self, query: str, texts: list):
  36. raise NotImplementedError("Please implement encode method!")
  37. def total_token_count(self, resp):
  38. try:
  39. return resp.usage.total_tokens
  40. except Exception:
  41. pass
  42. try:
  43. return resp["usage"]["total_tokens"]
  44. except Exception:
  45. pass
  46. return 0
  47. class DefaultRerank(Base):
  48. _model = None
  49. _model_lock = threading.Lock()
  50. def __init__(self, key, model_name, **kwargs):
  51. """
  52. If you have trouble downloading HuggingFace models, -_^ this might help!!
  53. For Linux:
  54. export HF_ENDPOINT=https://hf-mirror.com
  55. For Windows:
  56. Good luck
  57. ^_-
  58. """
  59. if not settings.LIGHTEN and not DefaultRerank._model:
  60. import torch
  61. from FlagEmbedding import FlagReranker
  62. with DefaultRerank._model_lock:
  63. if not DefaultRerank._model:
  64. try:
  65. DefaultRerank._model = FlagReranker(
  66. os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  67. use_fp16=torch.cuda.is_available())
  68. except Exception:
  69. model_dir = snapshot_download(repo_id=model_name,
  70. local_dir=os.path.join(get_home_cache_dir(),
  71. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  72. local_dir_use_symlinks=False)
  73. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  74. self._model = DefaultRerank._model
  75. def similarity(self, query: str, texts: list):
  76. pairs = [(query, truncate(t, 2048)) for t in texts]
  77. token_count = 0
  78. for _, t in pairs:
  79. token_count += num_tokens_from_string(t)
  80. batch_size = 4096
  81. res = []
  82. for i in range(0, len(pairs), batch_size):
  83. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  84. scores = sigmoid(np.array(scores)).tolist()
  85. if isinstance(scores, float):
  86. res.append(scores)
  87. else:
  88. res.extend(scores)
  89. return np.array(res), token_count
  90. class JinaRerank(Base):
  91. def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
  92. base_url="https://api.jina.ai/v1/rerank"):
  93. self.base_url = "https://api.jina.ai/v1/rerank"
  94. self.headers = {
  95. "Content-Type": "application/json",
  96. "Authorization": f"Bearer {key}"
  97. }
  98. self.model_name = model_name
  99. def similarity(self, query: str, texts: list):
  100. texts = [truncate(t, 8196) for t in texts]
  101. data = {
  102. "model": self.model_name,
  103. "query": query,
  104. "documents": texts,
  105. "top_n": len(texts)
  106. }
  107. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  108. rank = np.zeros(len(texts), dtype=float)
  109. for d in res["results"]:
  110. rank[d["index"]] = d["relevance_score"]
  111. return rank, self.total_token_count(res)
  112. class YoudaoRerank(DefaultRerank):
  113. _model = None
  114. _model_lock = threading.Lock()
  115. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  116. if not settings.LIGHTEN and not YoudaoRerank._model:
  117. from BCEmbedding import RerankerModel
  118. with YoudaoRerank._model_lock:
  119. if not YoudaoRerank._model:
  120. try:
  121. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  122. get_home_cache_dir(),
  123. re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
  124. except Exception:
  125. YoudaoRerank._model = RerankerModel(
  126. model_name_or_path=model_name.replace(
  127. "maidalun1020", "InfiniFlow"))
  128. self._model = YoudaoRerank._model
  129. def similarity(self, query: str, texts: list):
  130. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  131. token_count = 0
  132. for _, t in pairs:
  133. token_count += num_tokens_from_string(t)
  134. batch_size = 8
  135. res = []
  136. for i in range(0, len(pairs), batch_size):
  137. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  138. scores = sigmoid(np.array(scores)).tolist()
  139. if isinstance(scores, float):
  140. res.append(scores)
  141. else:
  142. res.extend(scores)
  143. return np.array(res), token_count
  144. class XInferenceRerank(Base):
  145. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  146. if base_url.find("/v1") == -1:
  147. base_url = urljoin(base_url, "/v1/rerank")
  148. if base_url.find("/rerank") == -1:
  149. base_url = urljoin(base_url, "/v1/rerank")
  150. self.model_name = model_name
  151. self.base_url = base_url
  152. self.headers = {
  153. "Content-Type": "application/json",
  154. "accept": "application/json",
  155. "Authorization": f"Bearer {key}"
  156. }
  157. def similarity(self, query: str, texts: list):
  158. if len(texts) == 0:
  159. return np.array([]), 0
  160. pairs = [(query, truncate(t, 4096)) for t in texts]
  161. token_count = 0
  162. for _, t in pairs:
  163. token_count += num_tokens_from_string(t)
  164. data = {
  165. "model": self.model_name,
  166. "query": query,
  167. "return_documents": "true",
  168. "return_len": "true",
  169. "documents": texts
  170. }
  171. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  172. rank = np.zeros(len(texts), dtype=float)
  173. for d in res["results"]:
  174. rank[d["index"]] = d["relevance_score"]
  175. return rank, token_count
  176. class LocalAIRerank(Base):
  177. def __init__(self, key, model_name, base_url):
  178. if base_url.find("/rerank") == -1:
  179. self.base_url = urljoin(base_url, "/rerank")
  180. else:
  181. self.base_url = base_url
  182. self.headers = {
  183. "Content-Type": "application/json",
  184. "Authorization": f"Bearer {key}"
  185. }
  186. self.model_name = model_name.split("___")[0]
  187. def similarity(self, query: str, texts: list):
  188. # noway to config Ragflow , use fix setting
  189. texts = [truncate(t, 500) for t in texts]
  190. data = {
  191. "model": self.model_name,
  192. "query": query,
  193. "documents": texts,
  194. "top_n": len(texts),
  195. }
  196. token_count = 0
  197. for t in texts:
  198. token_count += num_tokens_from_string(t)
  199. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  200. rank = np.zeros(len(texts), dtype=float)
  201. if 'results' not in res:
  202. raise ValueError("response not contains results\n" + str(res))
  203. for d in res["results"]:
  204. rank[d["index"]] = d["relevance_score"]
  205. # Normalize the rank values to the range 0 to 1
  206. min_rank = np.min(rank)
  207. max_rank = np.max(rank)
  208. # Avoid division by zero if all ranks are identical
  209. if max_rank - min_rank != 0:
  210. rank = (rank - min_rank) / (max_rank - min_rank)
  211. else:
  212. rank = np.zeros_like(rank)
  213. return rank, token_count
  214. class NvidiaRerank(Base):
  215. def __init__(
  216. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  217. ):
  218. if not base_url:
  219. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  220. self.model_name = model_name
  221. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  222. self.base_url = os.path.join(
  223. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  224. )
  225. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  226. self.base_url = os.path.join(base_url, "reranking")
  227. self.model_name = "nv-rerank-qa-mistral-4b:1"
  228. self.headers = {
  229. "accept": "application/json",
  230. "Content-Type": "application/json",
  231. "Authorization": f"Bearer {key}",
  232. }
  233. def similarity(self, query: str, texts: list):
  234. token_count = num_tokens_from_string(query) + sum(
  235. [num_tokens_from_string(t) for t in texts]
  236. )
  237. data = {
  238. "model": self.model_name,
  239. "query": {"text": query},
  240. "passages": [{"text": text} for text in texts],
  241. "truncate": "END",
  242. "top_n": len(texts),
  243. }
  244. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  245. rank = np.zeros(len(texts), dtype=float)
  246. for d in res["rankings"]:
  247. rank[d["index"]] = d["logit"]
  248. return rank, token_count
  249. class LmStudioRerank(Base):
  250. def __init__(self, key, model_name, base_url):
  251. pass
  252. def similarity(self, query: str, texts: list):
  253. raise NotImplementedError("The LmStudioRerank has not been implement")
  254. class OpenAI_APIRerank(Base):
  255. def __init__(self, key, model_name, base_url):
  256. if base_url.find("/rerank") == -1:
  257. self.base_url = urljoin(base_url, "/rerank")
  258. else:
  259. self.base_url = base_url
  260. self.headers = {
  261. "Content-Type": "application/json",
  262. "Authorization": f"Bearer {key}"
  263. }
  264. self.model_name = model_name.split("___")[0]
  265. def similarity(self, query: str, texts: list):
  266. # noway to config Ragflow , use fix setting
  267. texts = [truncate(t, 500) for t in texts]
  268. data = {
  269. "model": self.model_name,
  270. "query": query,
  271. "documents": texts,
  272. "top_n": len(texts),
  273. }
  274. token_count = 0
  275. for t in texts:
  276. token_count += num_tokens_from_string(t)
  277. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  278. rank = np.zeros(len(texts), dtype=float)
  279. if 'results' not in res:
  280. raise ValueError("response not contains results\n" + str(res))
  281. for d in res["results"]:
  282. rank[d["index"]] = d["relevance_score"]
  283. # Normalize the rank values to the range 0 to 1
  284. min_rank = np.min(rank)
  285. max_rank = np.max(rank)
  286. # Avoid division by zero if all ranks are identical
  287. if max_rank - min_rank != 0:
  288. rank = (rank - min_rank) / (max_rank - min_rank)
  289. else:
  290. rank = np.zeros_like(rank)
  291. return rank, token_count
  292. class CoHereRerank(Base):
  293. def __init__(self, key, model_name, base_url=None):
  294. from cohere import Client
  295. self.client = Client(api_key=key)
  296. self.model_name = model_name
  297. def similarity(self, query: str, texts: list):
  298. token_count = num_tokens_from_string(query) + sum(
  299. [num_tokens_from_string(t) for t in texts]
  300. )
  301. res = self.client.rerank(
  302. model=self.model_name,
  303. query=query,
  304. documents=texts,
  305. top_n=len(texts),
  306. return_documents=False,
  307. )
  308. rank = np.zeros(len(texts), dtype=float)
  309. for d in res.results:
  310. rank[d.index] = d.relevance_score
  311. return rank, token_count
  312. class TogetherAIRerank(Base):
  313. def __init__(self, key, model_name, base_url):
  314. pass
  315. def similarity(self, query: str, texts: list):
  316. raise NotImplementedError("The api has not been implement")
  317. class SILICONFLOWRerank(Base):
  318. def __init__(
  319. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  320. ):
  321. if not base_url:
  322. base_url = "https://api.siliconflow.cn/v1/rerank"
  323. self.model_name = model_name
  324. self.base_url = base_url
  325. self.headers = {
  326. "accept": "application/json",
  327. "content-type": "application/json",
  328. "authorization": f"Bearer {key}",
  329. }
  330. def similarity(self, query: str, texts: list):
  331. payload = {
  332. "model": self.model_name,
  333. "query": query,
  334. "documents": texts,
  335. "top_n": len(texts),
  336. "return_documents": False,
  337. "max_chunks_per_doc": 1024,
  338. "overlap_tokens": 80,
  339. }
  340. response = requests.post(
  341. self.base_url, json=payload, headers=self.headers
  342. ).json()
  343. rank = np.zeros(len(texts), dtype=float)
  344. if "results" not in response:
  345. return rank, 0
  346. for d in response["results"]:
  347. rank[d["index"]] = d["relevance_score"]
  348. return (
  349. rank,
  350. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  351. )
  352. class BaiduYiyanRerank(Base):
  353. def __init__(self, key, model_name, base_url=None):
  354. from qianfan.resources import Reranker
  355. key = json.loads(key)
  356. ak = key.get("yiyan_ak", "")
  357. sk = key.get("yiyan_sk", "")
  358. self.client = Reranker(ak=ak, sk=sk)
  359. self.model_name = model_name
  360. def similarity(self, query: str, texts: list):
  361. res = self.client.do(
  362. model=self.model_name,
  363. query=query,
  364. documents=texts,
  365. top_n=len(texts),
  366. ).body
  367. rank = np.zeros(len(texts), dtype=float)
  368. for d in res["results"]:
  369. rank[d["index"]] = d["relevance_score"]
  370. return rank, self.total_token_count(res)
  371. class VoyageRerank(Base):
  372. def __init__(self, key, model_name, base_url=None):
  373. import voyageai
  374. self.client = voyageai.Client(api_key=key)
  375. self.model_name = model_name
  376. def similarity(self, query: str, texts: list):
  377. rank = np.zeros(len(texts), dtype=float)
  378. if not texts:
  379. return rank, 0
  380. res = self.client.rerank(
  381. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  382. )
  383. for r in res.results:
  384. rank[r.index] = r.relevance_score
  385. return rank, res.total_tokens
  386. class QWenRerank(Base):
  387. def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
  388. import dashscope
  389. self.api_key = key
  390. self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
  391. def similarity(self, query: str, texts: list):
  392. import dashscope
  393. from http import HTTPStatus
  394. resp = dashscope.TextReRank.call(
  395. api_key=self.api_key,
  396. model=self.model_name,
  397. query=query,
  398. documents=texts,
  399. top_n=len(texts),
  400. return_documents=False
  401. )
  402. rank = np.zeros(len(texts), dtype=float)
  403. if resp.status_code == HTTPStatus.OK:
  404. for r in resp.output.results:
  405. rank[r.index] = r.relevance_score
  406. return rank, resp.usage.total_tokens
  407. else:
  408. raise ValueError(f"Error calling QWenRerank model {self.model_name}: {resp.status_code} - {resp.text}")
  409. class GPUStackRerank(Base):
  410. def __init__(
  411. self, key, model_name, base_url
  412. ):
  413. if not base_url:
  414. raise ValueError("url cannot be None")
  415. self.model_name = model_name
  416. self.base_url = str(URL(base_url)/ "v1" / "rerank")
  417. self.headers = {
  418. "accept": "application/json",
  419. "content-type": "application/json",
  420. "authorization": f"Bearer {key}",
  421. }
  422. def similarity(self, query: str, texts: list):
  423. payload = {
  424. "model": self.model_name,
  425. "query": query,
  426. "documents": texts,
  427. "top_n": len(texts),
  428. }
  429. try:
  430. response = requests.post(
  431. self.base_url, json=payload, headers=self.headers
  432. )
  433. response.raise_for_status()
  434. response_json = response.json()
  435. rank = np.zeros(len(texts), dtype=float)
  436. if "results" not in response_json:
  437. return rank, 0
  438. token_count = 0
  439. for t in texts:
  440. token_count += num_tokens_from_string(t)
  441. for result in response_json["results"]:
  442. rank[result["index"]] = result["relevance_score"]
  443. return (
  444. rank,
  445. token_count,
  446. )
  447. except httpx.HTTPStatusError as e:
  448. raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")