You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rerank_model.py 12KB

fix the tokens error that occurred when adding the xinference model (#1527) ### What problem does this PR solve? fix the tokens error that occurred when adding the xinference model #1522 root@pc-gpu-86-41:~# curl -X 'POST' 'http://127.0.0.1:9997/v1/rerank' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{ "model": "bge-reranker-v2-m3", "query": "A man is eating pasta.", "return_documents":"true", "return_len":"true", "documents": [ "A man is eating food.", "A man is eating a piece of bread.", "The girl is carrying a baby.", "A man is riding a horse.", "A woman is playing violin." ] }' {"id":"610a8724-3e96-11ef-81ce-08bfb886c012","results":[{"index":0,"relevance_score":0.999574601650238,"document":{"text":"A man is eating food."}},{"index":1,"relevance_score":0.07814773917198181,"document":{"text":"A man is eating a piece of bread."}},{"index":3,"relevance_score":0.000017700713215162978,"document":{"text":"A man is riding a horse."}},{"index":2,"relevance_score":0.0000163753629749408,"document":{"text":"The girl is carrying a baby."}},{"index":4,"relevance_score":0.00001631895975151565,"document":{"text":"A woman is playing violin."}}],"meta":{"api_version":null,"billed_units":null,"tokens":{"input_tokens":38,"output_tokens":38},"warnings":null}} ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [ ] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe):
1 年之前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. import requests
  19. from huggingface_hub import snapshot_download
  20. import os
  21. from abc import ABC
  22. import numpy as np
  23. from api.settings import LIGHTEN
  24. from api.utils.file_utils import get_home_cache_dir
  25. from rag.utils import num_tokens_from_string, truncate
  26. import json
  27. def sigmoid(x):
  28. return 1 / (1 + np.exp(-x))
  29. class Base(ABC):
  30. def __init__(self, key, model_name):
  31. pass
  32. def similarity(self, query: str, texts: list):
  33. raise NotImplementedError("Please implement encode method!")
  34. class DefaultRerank(Base):
  35. _model = None
  36. _model_lock = threading.Lock()
  37. def __init__(self, key, model_name, **kwargs):
  38. """
  39. If you have trouble downloading HuggingFace models, -_^ this might help!!
  40. For Linux:
  41. export HF_ENDPOINT=https://hf-mirror.com
  42. For Windows:
  43. Good luck
  44. ^_-
  45. """
  46. if not LIGHTEN and not DefaultRerank._model:
  47. import torch
  48. from FlagEmbedding import FlagReranker
  49. with DefaultRerank._model_lock:
  50. if not DefaultRerank._model:
  51. try:
  52. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  53. except Exception as e:
  54. model_dir = snapshot_download(repo_id= model_name,
  55. local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  56. local_dir_use_symlinks=False)
  57. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  58. self._model = DefaultRerank._model
  59. def similarity(self, query: str, texts: list):
  60. pairs = [(query,truncate(t, 2048)) for t in texts]
  61. token_count = 0
  62. for _, t in pairs:
  63. token_count += num_tokens_from_string(t)
  64. batch_size = 4096
  65. res = []
  66. for i in range(0, len(pairs), batch_size):
  67. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  68. scores = sigmoid(np.array(scores)).tolist()
  69. if isinstance(scores, float): res.append(scores)
  70. else: res.extend(scores)
  71. return np.array(res), token_count
  72. class JinaRerank(Base):
  73. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  74. base_url="https://api.jina.ai/v1/rerank"):
  75. self.base_url = "https://api.jina.ai/v1/rerank"
  76. self.headers = {
  77. "Content-Type": "application/json",
  78. "Authorization": f"Bearer {key}"
  79. }
  80. self.model_name = model_name
  81. def similarity(self, query: str, texts: list):
  82. texts = [truncate(t, 8196) for t in texts]
  83. data = {
  84. "model": self.model_name,
  85. "query": query,
  86. "documents": texts,
  87. "top_n": len(texts)
  88. }
  89. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  90. return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
  91. class YoudaoRerank(DefaultRerank):
  92. _model = None
  93. _model_lock = threading.Lock()
  94. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  95. if not LIGHTEN and not YoudaoRerank._model:
  96. from BCEmbedding import RerankerModel
  97. with YoudaoRerank._model_lock:
  98. if not YoudaoRerank._model:
  99. try:
  100. print("LOADING BCE...")
  101. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  102. get_home_cache_dir(),
  103. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  104. except Exception as e:
  105. YoudaoRerank._model = RerankerModel(
  106. model_name_or_path=model_name.replace(
  107. "maidalun1020", "InfiniFlow"))
  108. self._model = YoudaoRerank._model
  109. def similarity(self, query: str, texts: list):
  110. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  111. token_count = 0
  112. for _, t in pairs:
  113. token_count += num_tokens_from_string(t)
  114. batch_size = 32
  115. res = []
  116. for i in range(0, len(pairs), batch_size):
  117. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  118. scores = sigmoid(np.array(scores)).tolist()
  119. if isinstance(scores, float): res.append(scores)
  120. else: res.extend(scores)
  121. return np.array(res), token_count
  122. class XInferenceRerank(Base):
  123. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  124. if base_url.split("/")[-1] != "v1":
  125. base_url = os.path.join(base_url, "v1")
  126. self.model_name = model_name
  127. self.base_url = base_url
  128. self.headers = {
  129. "Content-Type": "application/json",
  130. "accept": "application/json"
  131. }
  132. def similarity(self, query: str, texts: list):
  133. if len(texts) == 0:
  134. return np.array([]), 0
  135. data = {
  136. "model": self.model_name,
  137. "query": query,
  138. "return_documents": "true",
  139. "return_len": "true",
  140. "documents": texts
  141. }
  142. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  143. return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
  144. class LocalAIRerank(Base):
  145. def __init__(self, key, model_name, base_url):
  146. pass
  147. def similarity(self, query: str, texts: list):
  148. raise NotImplementedError("The LocalAIRerank has not been implement")
  149. class NvidiaRerank(Base):
  150. def __init__(
  151. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  152. ):
  153. if not base_url:
  154. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  155. self.model_name = model_name
  156. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  157. self.base_url = os.path.join(
  158. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  159. )
  160. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  161. self.base_url = os.path.join(base_url, "reranking")
  162. self.model_name = "nv-rerank-qa-mistral-4b:1"
  163. self.headers = {
  164. "accept": "application/json",
  165. "Content-Type": "application/json",
  166. "Authorization": f"Bearer {key}",
  167. }
  168. def similarity(self, query: str, texts: list):
  169. token_count = num_tokens_from_string(query) + sum(
  170. [num_tokens_from_string(t) for t in texts]
  171. )
  172. data = {
  173. "model": self.model_name,
  174. "query": {"text": query},
  175. "passages": [{"text": text} for text in texts],
  176. "truncate": "END",
  177. "top_n": len(texts),
  178. }
  179. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  180. rank = np.array([d["logit"] for d in res["rankings"]])
  181. indexs = [d["index"] for d in res["rankings"]]
  182. return rank[indexs], token_count
  183. class LmStudioRerank(Base):
  184. def __init__(self, key, model_name, base_url):
  185. pass
  186. def similarity(self, query: str, texts: list):
  187. raise NotImplementedError("The LmStudioRerank has not been implement")
  188. class OpenAI_APIRerank(Base):
  189. def __init__(self, key, model_name, base_url):
  190. pass
  191. def similarity(self, query: str, texts: list):
  192. raise NotImplementedError("The api has not been implement")
  193. class CoHereRerank(Base):
  194. def __init__(self, key, model_name, base_url=None):
  195. from cohere import Client
  196. self.client = Client(api_key=key)
  197. self.model_name = model_name
  198. def similarity(self, query: str, texts: list):
  199. token_count = num_tokens_from_string(query) + sum(
  200. [num_tokens_from_string(t) for t in texts]
  201. )
  202. res = self.client.rerank(
  203. model=self.model_name,
  204. query=query,
  205. documents=texts,
  206. top_n=len(texts),
  207. return_documents=False,
  208. )
  209. rank = np.array([d.relevance_score for d in res.results])
  210. indexs = [d.index for d in res.results]
  211. return rank[indexs], token_count
  212. class TogetherAIRerank(Base):
  213. def __init__(self, key, model_name, base_url):
  214. pass
  215. def similarity(self, query: str, texts: list):
  216. raise NotImplementedError("The api has not been implement")
  217. class SILICONFLOWRerank(Base):
  218. def __init__(
  219. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  220. ):
  221. if not base_url:
  222. base_url = "https://api.siliconflow.cn/v1/rerank"
  223. self.model_name = model_name
  224. self.base_url = base_url
  225. self.headers = {
  226. "accept": "application/json",
  227. "content-type": "application/json",
  228. "authorization": f"Bearer {key}",
  229. }
  230. def similarity(self, query: str, texts: list):
  231. payload = {
  232. "model": self.model_name,
  233. "query": query,
  234. "documents": texts,
  235. "top_n": len(texts),
  236. "return_documents": False,
  237. "max_chunks_per_doc": 1024,
  238. "overlap_tokens": 80,
  239. }
  240. response = requests.post(
  241. self.base_url, json=payload, headers=self.headers
  242. ).json()
  243. rank = np.array([d["relevance_score"] for d in response["results"]])
  244. indexs = [d["index"] for d in response["results"]]
  245. return (
  246. rank[indexs],
  247. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  248. )
  249. class BaiduYiyanRerank(Base):
  250. def __init__(self, key, model_name, base_url=None):
  251. from qianfan.resources import Reranker
  252. key = json.loads(key)
  253. ak = key.get("yiyan_ak", "")
  254. sk = key.get("yiyan_sk", "")
  255. self.client = Reranker(ak=ak, sk=sk)
  256. self.model_name = model_name
  257. def similarity(self, query: str, texts: list):
  258. res = self.client.do(
  259. model=self.model_name,
  260. query=query,
  261. documents=texts,
  262. top_n=len(texts),
  263. ).body
  264. rank = np.array([d["relevance_score"] for d in res["results"]])
  265. indexs = [d["index"] for d in res["results"]]
  266. return rank[indexs], res["usage"]["total_tokens"]
  267. class VoyageRerank(Base):
  268. def __init__(self, key, model_name, base_url=None):
  269. import voyageai
  270. self.client = voyageai.Client(api_key=key)
  271. self.model_name = model_name
  272. def similarity(self, query: str, texts: list):
  273. res = self.client.rerank(
  274. query=query, documents=texts, model=self.model_name, top_k=len(texts)
  275. )
  276. rank = np.array([r.relevance_score for r in res.results])
  277. indexs = [r.index for r in res.results]
  278. return rank[indexs], res.total_tokens