Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

rerank_model.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import threading
  18. import requests
  19. import torch
  20. from FlagEmbedding import FlagReranker
  21. from huggingface_hub import snapshot_download
  22. import os
  23. from abc import ABC
  24. import numpy as np
  25. from api.utils.file_utils import get_home_cache_dir
  26. from rag.utils import num_tokens_from_string, truncate
  27. def sigmoid(x):
  28. return 1 / (1 + np.exp(-x))
  29. class Base(ABC):
  30. def __init__(self, key, model_name):
  31. pass
  32. def similarity(self, query: str, texts: list):
  33. raise NotImplementedError("Please implement encode method!")
  34. class DefaultRerank(Base):
  35. _model = None
  36. _model_lock = threading.Lock()
  37. def __init__(self, key, model_name, **kwargs):
  38. """
  39. If you have trouble downloading HuggingFace models, -_^ this might help!!
  40. For Linux:
  41. export HF_ENDPOINT=https://hf-mirror.com
  42. For Windows:
  43. Good luck
  44. ^_-
  45. """
  46. if not DefaultRerank._model:
  47. with DefaultRerank._model_lock:
  48. if not DefaultRerank._model:
  49. try:
  50. DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available())
  51. except Exception as e:
  52. model_dir = snapshot_download(repo_id= model_name,
  53. local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  54. local_dir_use_symlinks=False)
  55. DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
  56. self._model = DefaultRerank._model
  57. def similarity(self, query: str, texts: list):
  58. pairs = [(query,truncate(t, 2048)) for t in texts]
  59. token_count = 0
  60. for _, t in pairs:
  61. token_count += num_tokens_from_string(t)
  62. batch_size = 4096
  63. res = []
  64. for i in range(0, len(pairs), batch_size):
  65. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  66. scores = sigmoid(np.array(scores)).tolist()
  67. if isinstance(scores, float): res.append(scores)
  68. else: res.extend(scores)
  69. return np.array(res), token_count
  70. class JinaRerank(Base):
  71. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  72. base_url="https://api.jina.ai/v1/rerank"):
  73. self.base_url = "https://api.jina.ai/v1/rerank"
  74. self.headers = {
  75. "Content-Type": "application/json",
  76. "Authorization": f"Bearer {key}"
  77. }
  78. self.model_name = model_name
  79. def similarity(self, query: str, texts: list):
  80. texts = [truncate(t, 8196) for t in texts]
  81. data = {
  82. "model": self.model_name,
  83. "query": query,
  84. "documents": texts,
  85. "top_n": len(texts)
  86. }
  87. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  88. return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
  89. class YoudaoRerank(DefaultRerank):
  90. _model = None
  91. _model_lock = threading.Lock()
  92. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  93. from BCEmbedding import RerankerModel
  94. if not YoudaoRerank._model:
  95. with YoudaoRerank._model_lock:
  96. if not YoudaoRerank._model:
  97. try:
  98. print("LOADING BCE...")
  99. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  100. get_home_cache_dir(),
  101. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  102. except Exception as e:
  103. YoudaoRerank._model = RerankerModel(
  104. model_name_or_path=model_name.replace(
  105. "maidalun1020", "InfiniFlow"))
  106. self._model = YoudaoRerank._model
  107. def similarity(self, query: str, texts: list):
  108. pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
  109. token_count = 0
  110. for _, t in pairs:
  111. token_count += num_tokens_from_string(t)
  112. batch_size = 32
  113. res = []
  114. for i in range(0, len(pairs), batch_size):
  115. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
  116. scores = sigmoid(np.array(scores)).tolist()
  117. if isinstance(scores, float): res.append(scores)
  118. else: res.extend(scores)
  119. return np.array(res), token_count
  120. class XInferenceRerank(Base):
  121. def __init__(self, key="xxxxxxx", model_name="", base_url=""):
  122. self.model_name = model_name
  123. self.base_url = base_url
  124. self.headers = {
  125. "Content-Type": "application/json",
  126. "accept": "application/json"
  127. }
  128. def similarity(self, query: str, texts: list):
  129. if len(texts) == 0:
  130. return np.array([]), 0
  131. data = {
  132. "model": self.model_name,
  133. "query": query,
  134. "return_documents": "true",
  135. "return_len": "true",
  136. "documents": texts
  137. }
  138. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  139. return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"]
  140. class LocalAIRerank(Base):
  141. def __init__(self, key, model_name, base_url):
  142. pass
  143. def similarity(self, query: str, texts: list):
  144. raise NotImplementedError("The LocalAIRerank has not been implement")
  145. class NvidiaRerank(Base):
  146. def __init__(
  147. self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  148. ):
  149. if not base_url:
  150. base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
  151. self.model_name = model_name
  152. if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
  153. self.base_url = os.path.join(
  154. base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
  155. )
  156. if self.model_name == "nvidia/rerank-qa-mistral-4b":
  157. self.base_url = os.path.join(base_url, "reranking")
  158. self.model_name = "nv-rerank-qa-mistral-4b:1"
  159. self.headers = {
  160. "accept": "application/json",
  161. "Content-Type": "application/json",
  162. "Authorization": f"Bearer {key}",
  163. }
  164. def similarity(self, query: str, texts: list):
  165. token_count = num_tokens_from_string(query) + sum(
  166. [num_tokens_from_string(t) for t in texts]
  167. )
  168. data = {
  169. "model": self.model_name,
  170. "query": {"text": query},
  171. "passages": [{"text": text} for text in texts],
  172. "truncate": "END",
  173. "top_n": len(texts),
  174. }
  175. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  176. rank = np.array([d["logit"] for d in res["rankings"]])
  177. indexs = [d["index"] for d in res["rankings"]]
  178. return rank[indexs], token_count
  179. class LmStudioRerank(Base):
  180. def __init__(self, key, model_name, base_url):
  181. pass
  182. def similarity(self, query: str, texts: list):
  183. raise NotImplementedError("The LmStudioRerank has not been implement")
  184. class OpenAI_APIRerank(Base):
  185. def __init__(self, key, model_name, base_url):
  186. pass
  187. def similarity(self, query: str, texts: list):
  188. raise NotImplementedError("The api has not been implement")
  189. class CoHereRerank(Base):
  190. def __init__(self, key, model_name, base_url=None):
  191. from cohere import Client
  192. self.client = Client(api_key=key)
  193. self.model_name = model_name
  194. def similarity(self, query: str, texts: list):
  195. token_count = num_tokens_from_string(query) + sum(
  196. [num_tokens_from_string(t) for t in texts]
  197. )
  198. res = self.client.rerank(
  199. model=self.model_name,
  200. query=query,
  201. documents=texts,
  202. top_n=len(texts),
  203. return_documents=False,
  204. )
  205. rank = np.array([d.relevance_score for d in res.results])
  206. indexs = [d.index for d in res.results]
  207. return rank[indexs], token_count
  208. class TogetherAIRerank(Base):
  209. def __init__(self, key, model_name, base_url):
  210. pass
  211. def similarity(self, query: str, texts: list):
  212. raise NotImplementedError("The api has not been implement")
  213. class SILICONFLOWRerank(Base):
  214. def __init__(
  215. self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
  216. ):
  217. if not base_url:
  218. base_url = "https://api.siliconflow.cn/v1/rerank"
  219. self.model_name = model_name
  220. self.base_url = base_url
  221. self.headers = {
  222. "accept": "application/json",
  223. "content-type": "application/json",
  224. "authorization": f"Bearer {key}",
  225. }
  226. def similarity(self, query: str, texts: list):
  227. payload = {
  228. "model": self.model_name,
  229. "query": query,
  230. "documents": texts,
  231. "top_n": len(texts),
  232. "return_documents": False,
  233. "max_chunks_per_doc": 1024,
  234. "overlap_tokens": 80,
  235. }
  236. response = requests.post(
  237. self.base_url, json=payload, headers=self.headers
  238. ).json()
  239. rank = np.array([d["relevance_score"] for d in response["results"]])
  240. indexs = [d["index"] for d in response["results"]]
  241. return (
  242. rank[indexs],
  243. response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
  244. )