您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

embedding_model.py 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from typing import Optional
  17. from huggingface_hub import snapshot_download
  18. from zhipuai import ZhipuAI
  19. import os
  20. from abc import ABC
  21. from ollama import Client
  22. import dashscope
  23. from openai import OpenAI
  24. from FlagEmbedding import FlagModel
  25. import torch
  26. import numpy as np
  27. from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
  28. from rag.utils import num_tokens_from_string
  29. try:
  30. flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  31. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  32. use_fp16=torch.cuda.is_available())
  33. except Exception as e:
  34. model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
  35. local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  36. local_dir_use_symlinks=False)
  37. flag_model = FlagModel(model_dir,
  38. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  39. use_fp16=torch.cuda.is_available())
  40. class Base(ABC):
  41. def __init__(self, key, model_name):
  42. pass
  43. def encode(self, texts: list, batch_size=32):
  44. raise NotImplementedError("Please implement encode method!")
  45. def encode_queries(self, text: str):
  46. raise NotImplementedError("Please implement encode method!")
  47. class DefaultEmbedding(Base):
  48. def __init__(self, *args, **kwargs):
  49. """
  50. If you have trouble downloading HuggingFace models, -_^ this might help!!
  51. For Linux:
  52. export HF_ENDPOINT=https://hf-mirror.com
  53. For Windows:
  54. Good luck
  55. ^_-
  56. """
  57. self.model = flag_model
  58. def encode(self, texts: list, batch_size=32):
  59. texts = [t[:2000] for t in texts]
  60. token_count = 0
  61. for t in texts:
  62. token_count += num_tokens_from_string(t)
  63. res = []
  64. for i in range(0, len(texts), batch_size):
  65. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  66. return np.array(res), token_count
  67. def encode_queries(self, text: str):
  68. token_count = num_tokens_from_string(text)
  69. return self.model.encode_queries([text]).tolist()[0], token_count
  70. class OpenAIEmbed(Base):
  71. def __init__(self, key, model_name="text-embedding-ada-002",
  72. base_url="https://api.openai.com/v1"):
  73. if not base_url:
  74. base_url = "https://api.openai.com/v1"
  75. self.client = OpenAI(api_key=key, base_url=base_url)
  76. self.model_name = model_name
  77. def encode(self, texts: list, batch_size=32):
  78. res = self.client.embeddings.create(input=texts,
  79. model=self.model_name)
  80. return np.array([d.embedding for d in res.data]), res.usage.total_tokens
  81. def encode_queries(self, text):
  82. res = self.client.embeddings.create(input=[text],
  83. model=self.model_name)
  84. return np.array(res.data[0].embedding), res.usage.total_tokens
  85. class QWenEmbed(Base):
  86. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  87. dashscope.api_key = key
  88. self.model_name = model_name
  89. def encode(self, texts: list, batch_size=10):
  90. import dashscope
  91. res = []
  92. token_count = 0
  93. texts = [txt[:2048] for txt in texts]
  94. for i in range(0, len(texts), batch_size):
  95. resp = dashscope.TextEmbedding.call(
  96. model=self.model_name,
  97. input=texts[i:i + batch_size],
  98. text_type="document"
  99. )
  100. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  101. for e in resp["output"]["embeddings"]:
  102. embds[e["text_index"]] = e["embedding"]
  103. res.extend(embds)
  104. token_count += resp["usage"]["total_tokens"]
  105. return np.array(res), token_count
  106. def encode_queries(self, text):
  107. resp = dashscope.TextEmbedding.call(
  108. model=self.model_name,
  109. input=text[:2048],
  110. text_type="query"
  111. )
  112. return np.array(resp["output"]["embeddings"][0]
  113. ["embedding"]), resp["usage"]["total_tokens"]
  114. class ZhipuEmbed(Base):
  115. def __init__(self, key, model_name="embedding-2", **kwargs):
  116. self.client = ZhipuAI(api_key=key)
  117. self.model_name = model_name
  118. def encode(self, texts: list, batch_size=32):
  119. arr = []
  120. tks_num = 0
  121. for txt in texts:
  122. res = self.client.embeddings.create(input=txt,
  123. model=self.model_name)
  124. arr.append(res.data[0].embedding)
  125. tks_num += res.usage.total_tokens
  126. return np.array(arr), tks_num
  127. def encode_queries(self, text):
  128. res = self.client.embeddings.create(input=text,
  129. model=self.model_name)
  130. return np.array(res.data[0].embedding), res.usage.total_tokens
  131. class OllamaEmbed(Base):
  132. def __init__(self, key, model_name, **kwargs):
  133. self.client = Client(host=kwargs["base_url"])
  134. self.model_name = model_name
  135. def encode(self, texts: list, batch_size=32):
  136. arr = []
  137. tks_num = 0
  138. for txt in texts:
  139. res = self.client.embeddings(prompt=txt,
  140. model=self.model_name)
  141. arr.append(res["embedding"])
  142. tks_num += 128
  143. return np.array(arr), tks_num
  144. def encode_queries(self, text):
  145. res = self.client.embeddings(prompt=text,
  146. model=self.model_name)
  147. return np.array(res["embedding"]), 128
  148. class FastEmbed(Base):
  149. def __init__(
  150. self,
  151. key: Optional[str] = None,
  152. model_name: str = "BAAI/bge-small-en-v1.5",
  153. cache_dir: Optional[str] = None,
  154. threads: Optional[int] = None,
  155. **kwargs,
  156. ):
  157. from fastembed import TextEmbedding
  158. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  159. def encode(self, texts: list, batch_size=32):
  160. # Using the internal tokenizer to encode the texts and get the total
  161. # number of tokens
  162. encodings = self._model.model.tokenizer.encode_batch(texts)
  163. total_tokens = sum(len(e) for e in encodings)
  164. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
  165. return np.array(embeddings), total_tokens
  166. def encode_queries(self, text: str):
  167. # Using the internal tokenizer to encode the texts and get the total
  168. # number of tokens
  169. encoding = self._model.model.tokenizer.encode(text)
  170. embedding = next(self._model.query_embed(text)).tolist()
  171. return np.array(embedding), len(encoding.ids)
  172. class XinferenceEmbed(Base):
  173. def __init__(self, key, model_name="", base_url=""):
  174. self.client = OpenAI(api_key="xxx", base_url=base_url)
  175. self.model_name = model_name
  176. def encode(self, texts: list, batch_size=32):
  177. res = self.client.embeddings.create(input=texts,
  178. model=self.model_name)
  179. return np.array([d.embedding for d in res.data]
  180. ), res.usage.total_tokens
  181. def encode_queries(self, text):
  182. res = self.client.embeddings.create(input=[text],
  183. model=self.model_name)
  184. return np.array(res.data[0].embedding), res.usage.total_tokens
  185. class YoudaoEmbed(Base):
  186. _client = None
  187. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  188. from BCEmbedding import EmbeddingModel as qanthing
  189. if not YoudaoEmbed._client:
  190. try:
  191. print("LOADING BCE...")
  192. YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
  193. get_home_cache_dir(),
  194. "bce-embedding-base_v1"))
  195. except Exception as e:
  196. YoudaoEmbed._client = qanthing(
  197. model_name_or_path=model_name.replace(
  198. "maidalun1020", "InfiniFlow"))
  199. def encode(self, texts: list, batch_size=10):
  200. res = []
  201. token_count = 0
  202. for t in texts:
  203. token_count += num_tokens_from_string(t)
  204. for i in range(0, len(texts), batch_size):
  205. embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
  206. res.extend(embds)
  207. return np.array(res), token_count
  208. def encode_queries(self, text):
  209. embds = YoudaoEmbed._client.encode([text])
  210. return np.array(embds[0]), num_tokens_from_string(text)