Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

embedding_model.py 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from typing import Optional
  17. from huggingface_hub import snapshot_download
  18. from zhipuai import ZhipuAI
  19. import os
  20. from abc import ABC
  21. from ollama import Client
  22. import dashscope
  23. from openai import OpenAI
  24. from FlagEmbedding import FlagModel
  25. import torch
  26. import numpy as np
  27. from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
  28. from rag.utils import num_tokens_from_string, truncate
  29. try:
  30. flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  31. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  32. use_fp16=torch.cuda.is_available())
  33. except Exception as e:
  34. model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
  35. local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  36. local_dir_use_symlinks=False)
  37. flag_model = FlagModel(model_dir,
  38. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  39. use_fp16=torch.cuda.is_available())
  40. class Base(ABC):
  41. def __init__(self, key, model_name):
  42. pass
  43. def encode(self, texts: list, batch_size=32):
  44. raise NotImplementedError("Please implement encode method!")
  45. def encode_queries(self, text: str):
  46. raise NotImplementedError("Please implement encode method!")
  47. class DefaultEmbedding(Base):
  48. def __init__(self, *args, **kwargs):
  49. """
  50. If you have trouble downloading HuggingFace models, -_^ this might help!!
  51. For Linux:
  52. export HF_ENDPOINT=https://hf-mirror.com
  53. For Windows:
  54. Good luck
  55. ^_-
  56. """
  57. self.model = flag_model
  58. def encode(self, texts: list, batch_size=32):
  59. texts = [truncate(t, 2048) for t in texts]
  60. token_count = 0
  61. for t in texts:
  62. token_count += num_tokens_from_string(t)
  63. res = []
  64. for i in range(0, len(texts), batch_size):
  65. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  66. return np.array(res), token_count
  67. def encode_queries(self, text: str):
  68. token_count = num_tokens_from_string(text)
  69. return self.model.encode_queries([text]).tolist()[0], token_count
  70. class OpenAIEmbed(Base):
  71. def __init__(self, key, model_name="text-embedding-ada-002",
  72. base_url="https://api.openai.com/v1"):
  73. if not base_url:
  74. base_url = "https://api.openai.com/v1"
  75. self.client = OpenAI(api_key=key, base_url=base_url)
  76. self.model_name = model_name
  77. def encode(self, texts: list, batch_size=32):
  78. texts = [truncate(t, 8196) for t in texts]
  79. res = self.client.embeddings.create(input=texts,
  80. model=self.model_name)
  81. return np.array([d.embedding for d in res.data]
  82. ), res.usage.total_tokens
  83. def encode_queries(self, text):
  84. res = self.client.embeddings.create(input=[truncate(text, 8196)],
  85. model=self.model_name)
  86. return np.array(res.data[0].embedding), res.usage.total_tokens
  87. class QWenEmbed(Base):
  88. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  89. dashscope.api_key = key
  90. self.model_name = model_name
  91. def encode(self, texts: list, batch_size=10):
  92. import dashscope
  93. res = []
  94. token_count = 0
  95. texts = [truncate(t, 2048) for t in texts]
  96. for i in range(0, len(texts), batch_size):
  97. resp = dashscope.TextEmbedding.call(
  98. model=self.model_name,
  99. input=texts[i:i + batch_size],
  100. text_type="document"
  101. )
  102. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  103. for e in resp["output"]["embeddings"]:
  104. embds[e["text_index"]] = e["embedding"]
  105. res.extend(embds)
  106. token_count += resp["usage"]["total_tokens"]
  107. return np.array(res), token_count
  108. def encode_queries(self, text):
  109. resp = dashscope.TextEmbedding.call(
  110. model=self.model_name,
  111. input=text[:2048],
  112. text_type="query"
  113. )
  114. return np.array(resp["output"]["embeddings"][0]
  115. ["embedding"]), resp["usage"]["total_tokens"]
  116. class ZhipuEmbed(Base):
  117. def __init__(self, key, model_name="embedding-2", **kwargs):
  118. self.client = ZhipuAI(api_key=key)
  119. self.model_name = model_name
  120. def encode(self, texts: list, batch_size=32):
  121. arr = []
  122. tks_num = 0
  123. for txt in texts:
  124. res = self.client.embeddings.create(input=txt,
  125. model=self.model_name)
  126. arr.append(res.data[0].embedding)
  127. tks_num += res.usage.total_tokens
  128. return np.array(arr), tks_num
  129. def encode_queries(self, text):
  130. res = self.client.embeddings.create(input=text,
  131. model=self.model_name)
  132. return np.array(res.data[0].embedding), res.usage.total_tokens
  133. class OllamaEmbed(Base):
  134. def __init__(self, key, model_name, **kwargs):
  135. self.client = Client(host=kwargs["base_url"])
  136. self.model_name = model_name
  137. def encode(self, texts: list, batch_size=32):
  138. arr = []
  139. tks_num = 0
  140. for txt in texts:
  141. res = self.client.embeddings(prompt=txt,
  142. model=self.model_name)
  143. arr.append(res["embedding"])
  144. tks_num += 128
  145. return np.array(arr), tks_num
  146. def encode_queries(self, text):
  147. res = self.client.embeddings(prompt=text,
  148. model=self.model_name)
  149. return np.array(res["embedding"]), 128
  150. class FastEmbed(Base):
  151. def __init__(
  152. self,
  153. key: Optional[str] = None,
  154. model_name: str = "BAAI/bge-small-en-v1.5",
  155. cache_dir: Optional[str] = None,
  156. threads: Optional[int] = None,
  157. **kwargs,
  158. ):
  159. from fastembed import TextEmbedding
  160. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  161. def encode(self, texts: list, batch_size=32):
  162. # Using the internal tokenizer to encode the texts and get the total
  163. # number of tokens
  164. encodings = self._model.model.tokenizer.encode_batch(texts)
  165. total_tokens = sum(len(e) for e in encodings)
  166. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
  167. return np.array(embeddings), total_tokens
  168. def encode_queries(self, text: str):
  169. # Using the internal tokenizer to encode the texts and get the total
  170. # number of tokens
  171. encoding = self._model.model.tokenizer.encode(text)
  172. embedding = next(self._model.query_embed(text)).tolist()
  173. return np.array(embedding), len(encoding.ids)
  174. class XinferenceEmbed(Base):
  175. def __init__(self, key, model_name="", base_url=""):
  176. self.client = OpenAI(api_key="xxx", base_url=base_url)
  177. self.model_name = model_name
  178. def encode(self, texts: list, batch_size=32):
  179. res = self.client.embeddings.create(input=texts,
  180. model=self.model_name)
  181. return np.array([d.embedding for d in res.data]
  182. ), res.usage.total_tokens
  183. def encode_queries(self, text):
  184. res = self.client.embeddings.create(input=[text],
  185. model=self.model_name)
  186. return np.array(res.data[0].embedding), res.usage.total_tokens
  187. class YoudaoEmbed(Base):
  188. _client = None
  189. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  190. from BCEmbedding import EmbeddingModel as qanthing
  191. if not YoudaoEmbed._client:
  192. try:
  193. print("LOADING BCE...")
  194. YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
  195. get_home_cache_dir(),
  196. "bce-embedding-base_v1"))
  197. except Exception as e:
  198. YoudaoEmbed._client = qanthing(
  199. model_name_or_path=model_name.replace(
  200. "maidalun1020", "InfiniFlow"))
  201. def encode(self, texts: list, batch_size=10):
  202. res = []
  203. token_count = 0
  204. for t in texts:
  205. token_count += num_tokens_from_string(t)
  206. for i in range(0, len(texts), batch_size):
  207. embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
  208. res.extend(embds)
  209. return np.array(res), token_count
  210. def encode_queries(self, text):
  211. embds = YoudaoEmbed._client.encode([text])
  212. return np.array(embds[0]), num_tokens_from_string(text)