You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding_model.py 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from typing import Optional
  17. from huggingface_hub import snapshot_download
  18. from zhipuai import ZhipuAI
  19. import os
  20. from abc import ABC
  21. from ollama import Client
  22. import dashscope
  23. from openai import OpenAI
  24. from FlagEmbedding import FlagModel
  25. import torch
  26. import numpy as np
  27. from api.utils.file_utils import get_project_base_directory
  28. from rag.utils import num_tokens_from_string
  29. try:
  30. flag_model = FlagModel(os.path.join(
  31. get_project_base_directory(),
  32. "rag/res/bge-large-zh-v1.5"),
  33. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  34. use_fp16=torch.cuda.is_available())
  35. except Exception as e:
  36. model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
  37. local_dir=os.path.join(get_project_base_directory(), "rag/res/bge-large-zh-v1.5"),
  38. local_dir_use_symlinks=False)
  39. flag_model = FlagModel(model_dir,
  40. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  41. use_fp16=torch.cuda.is_available())
  42. class Base(ABC):
  43. def __init__(self, key, model_name):
  44. pass
  45. def encode(self, texts: list, batch_size=32):
  46. raise NotImplementedError("Please implement encode method!")
  47. def encode_queries(self, text: str):
  48. raise NotImplementedError("Please implement encode method!")
  49. class HuEmbedding(Base):
  50. def __init__(self, *args, **kwargs):
  51. """
  52. If you have trouble downloading HuggingFace models, -_^ this might help!!
  53. For Linux:
  54. export HF_ENDPOINT=https://hf-mirror.com
  55. For Windows:
  56. Good luck
  57. ^_-
  58. """
  59. self.model = flag_model
  60. def encode(self, texts: list, batch_size=32):
  61. texts = [t[:2000] for t in texts]
  62. token_count = 0
  63. for t in texts:
  64. token_count += num_tokens_from_string(t)
  65. res = []
  66. for i in range(0, len(texts), batch_size):
  67. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  68. return np.array(res), token_count
  69. def encode_queries(self, text: str):
  70. token_count = num_tokens_from_string(text)
  71. return self.model.encode_queries([text]).tolist()[0], token_count
  72. class OpenAIEmbed(Base):
  73. def __init__(self, key, model_name="text-embedding-ada-002",
  74. base_url="https://api.openai.com/v1"):
  75. if not base_url:
  76. base_url = "https://api.openai.com/v1"
  77. self.client = OpenAI(api_key=key, base_url=base_url)
  78. self.model_name = model_name
  79. def encode(self, texts: list, batch_size=32):
  80. res = self.client.embeddings.create(input=texts,
  81. model=self.model_name)
  82. return np.array([d.embedding for d in res.data]
  83. ), res.usage.total_tokens
  84. def encode_queries(self, text):
  85. res = self.client.embeddings.create(input=[text],
  86. model=self.model_name)
  87. return np.array(res.data[0].embedding), res.usage.total_tokens
  88. class QWenEmbed(Base):
  89. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  90. dashscope.api_key = key
  91. self.model_name = model_name
  92. def encode(self, texts: list, batch_size=10):
  93. import dashscope
  94. res = []
  95. token_count = 0
  96. texts = [txt[:2048] for txt in texts]
  97. for i in range(0, len(texts), batch_size):
  98. resp = dashscope.TextEmbedding.call(
  99. model=self.model_name,
  100. input=texts[i:i + batch_size],
  101. text_type="document"
  102. )
  103. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  104. for e in resp["output"]["embeddings"]:
  105. embds[e["text_index"]] = e["embedding"]
  106. res.extend(embds)
  107. token_count += resp["usage"]["total_tokens"]
  108. return np.array(res), token_count
  109. def encode_queries(self, text):
  110. resp = dashscope.TextEmbedding.call(
  111. model=self.model_name,
  112. input=text[:2048],
  113. text_type="query"
  114. )
  115. return np.array(resp["output"]["embeddings"][0]
  116. ["embedding"]), resp["usage"]["total_tokens"]
  117. class ZhipuEmbed(Base):
  118. def __init__(self, key, model_name="embedding-2", **kwargs):
  119. self.client = ZhipuAI(api_key=key)
  120. self.model_name = model_name
  121. def encode(self, texts: list, batch_size=32):
  122. arr = []
  123. tks_num = 0
  124. for txt in texts:
  125. res = self.client.embeddings.create(input=txt,
  126. model=self.model_name)
  127. arr.append(res.data[0].embedding)
  128. tks_num += res.usage.total_tokens
  129. return np.array(arr), tks_num
  130. def encode_queries(self, text):
  131. res = self.client.embeddings.create(input=text,
  132. model=self.model_name)
  133. return np.array(res.data[0].embedding), res.usage.total_tokens
  134. class OllamaEmbed(Base):
  135. def __init__(self, key, model_name, **kwargs):
  136. self.client = Client(host=kwargs["base_url"])
  137. self.model_name = model_name
  138. def encode(self, texts: list, batch_size=32):
  139. arr = []
  140. tks_num = 0
  141. for txt in texts:
  142. res = self.client.embeddings(prompt=txt,
  143. model=self.model_name)
  144. arr.append(res["embedding"])
  145. tks_num += 128
  146. return np.array(arr), tks_num
  147. def encode_queries(self, text):
  148. res = self.client.embeddings(prompt=text,
  149. model=self.model_name)
  150. return np.array(res["embedding"]), 128
  151. class FastEmbed(Base):
  152. def __init__(
  153. self,
  154. key: Optional[str] = None,
  155. model_name: str = "BAAI/bge-small-en-v1.5",
  156. cache_dir: Optional[str] = None,
  157. threads: Optional[int] = None,
  158. **kwargs,
  159. ):
  160. from fastembed import TextEmbedding
  161. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  162. def encode(self, texts: list, batch_size=32):
  163. # Using the internal tokenizer to encode the texts and get the total
  164. # number of tokens
  165. encodings = self._model.model.tokenizer.encode_batch(texts)
  166. total_tokens = sum(len(e) for e in encodings)
  167. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
  168. return np.array(embeddings), total_tokens
  169. def encode_queries(self, text: str):
  170. # Using the internal tokenizer to encode the texts and get the total
  171. # number of tokens
  172. encoding = self._model.model.tokenizer.encode(text)
  173. embedding = next(self._model.query_embed(text)).tolist()
  174. return np.array(embedding), len(encoding.ids)
  175. class XinferenceEmbed(Base):
  176. def __init__(self, key, model_name="", base_url=""):
  177. self.client = OpenAI(api_key="xxx", base_url=base_url)
  178. self.model_name = model_name
  179. def encode(self, texts: list, batch_size=32):
  180. res = self.client.embeddings.create(input=texts,
  181. model=self.model_name)
  182. return np.array([d.embedding for d in res.data]
  183. ), res.usage.total_tokens
  184. def encode_queries(self, text):
  185. res = self.client.embeddings.create(input=[text],
  186. model=self.model_name)
  187. return np.array(res.data[0].embedding), res.usage.total_tokens
  188. class YoudaoEmbed(Base):
  189. _client = None
  190. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  191. from BCEmbedding import EmbeddingModel as qanthing
  192. if not YoudaoEmbed._client:
  193. try:
  194. print("LOADING BCE...")
  195. YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
  196. get_project_base_directory(),
  197. "rag/res/bce-embedding-base_v1"))
  198. except Exception as e:
  199. YoudaoEmbed._client = qanthing(
  200. model_name_or_path=model_name.replace(
  201. "maidalun1020", "InfiniFlow"))
  202. def encode(self, texts: list, batch_size=10):
  203. res = []
  204. token_count = 0
  205. for t in texts:
  206. token_count += num_tokens_from_string(t)
  207. for i in range(0, len(texts), batch_size):
  208. embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
  209. res.extend(embds)
  210. return np.array(res), token_count
  211. def encode_queries(self, text):
  212. embds = YoudaoEmbed._client.encode([text])
  213. return np.array(embds[0]), num_tokens_from_string(text)