You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. import requests
  18. import torch
  19. from FlagEmbedding import FlagReranker
  20. from huggingface_hub import snapshot_download
  21. import os
  22. from abc import ABC
  23. import numpy as np
  24. from api.utils.file_utils import get_home_cache_dir
  25. from rag.utils import num_tokens_from_string, truncate
  26. class Base(ABC):
  27. def __init__(self, key, model_name):
  28. pass
  29. def similarity(self, query: str, texts: list):
  30. raise NotImplementedError("Please implement encode method!")
  31. class DefaultRerank(Base):
  32. _model = None
  33. def __init__(self, key, model_name, **kwargs):
  34. """
  35. If you have trouble downloading HuggingFace models, -_^ this might help!!
  36. For Linux:
  37. export HF_ENDPOINT=https://hf-mirror.com
  38. For Windows:
  39. Good luck
  40. ^_-
  41. """
  42. if not DefaultRerank._model:
  43. try:
  44. self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
  45. use_fp16=torch.cuda.is_available())
  46. except Exception as e:
  47. self._model = snapshot_download(repo_id=model_name,
  48. local_dir=os.path.join(get_home_cache_dir(),
  49. re.sub(r"^[a-zA-Z]+/", "", model_name)),
  50. local_dir_use_symlinks=False)
  51. self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name),
  52. use_fp16=torch.cuda.is_available())
  53. def similarity(self, query: str, texts: list):
  54. pairs = [(query,truncate(t, 2048)) for t in texts]
  55. token_count = 0
  56. for _, t in pairs:
  57. token_count += num_tokens_from_string(t)
  58. batch_size = 32
  59. res = []
  60. for i in range(0, len(pairs), batch_size):
  61. scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
  62. res.extend(scores)
  63. return np.array(res), token_count
  64. class JinaRerank(Base):
  65. def __init__(self, key, model_name="jina-reranker-v1-base-en",
  66. base_url="https://api.jina.ai/v1/rerank"):
  67. self.base_url = "https://api.jina.ai/v1/rerank"
  68. self.headers = {
  69. "Content-Type": "application/json",
  70. "Authorization": f"Bearer {key}"
  71. }
  72. self.model_name = model_name
  73. def similarity(self, query: str, texts: list):
  74. texts = [truncate(t, 8196) for t in texts]
  75. data = {
  76. "model": self.model_name,
  77. "query": query,
  78. "documents": texts,
  79. "top_n": len(texts)
  80. }
  81. res = requests.post(self.base_url, headers=self.headers, json=data)
  82. return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"]
  83. class YoudaoRerank(DefaultRerank):
  84. _model = None
  85. def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
  86. from BCEmbedding import RerankerModel
  87. if not YoudaoRerank._model:
  88. try:
  89. print("LOADING BCE...")
  90. YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
  91. get_home_cache_dir(),
  92. re.sub(r"^[a-zA-Z]+/", "", model_name)))
  93. except Exception as e:
  94. YoudaoRerank._model = RerankerModel(
  95. model_name_or_path=model_name.replace(
  96. "maidalun1020", "InfiniFlow"))