Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

embedding_model.py 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from typing import Optional
  17. from huggingface_hub import snapshot_download
  18. from zhipuai import ZhipuAI
  19. import os
  20. from abc import ABC
  21. from ollama import Client
  22. import dashscope
  23. from openai import OpenAI
  24. from FlagEmbedding import FlagModel
  25. import torch
  26. import numpy as np
  27. from api.utils.file_utils import get_project_base_directory, get_home_cache_dir
  28. from rag.utils import num_tokens_from_string
  29. try:
  30. flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  31. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  32. use_fp16=torch.cuda.is_available())
  33. except Exception as e:
  34. model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
  35. local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"),
  36. local_dir_use_symlinks=False)
  37. flag_model = FlagModel(model_dir,
  38. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  39. use_fp16=torch.cuda.is_available())
  40. class Base(ABC):
  41. def __init__(self, key, model_name):
  42. pass
  43. def encode(self, texts: list, batch_size=32):
  44. raise NotImplementedError("Please implement encode method!")
  45. def encode_queries(self, text: str):
  46. raise NotImplementedError("Please implement encode method!")
  47. class DefaultEmbedding(Base):
  48. def __init__(self, *args, **kwargs):
  49. """
  50. If you have trouble downloading HuggingFace models, -_^ this might help!!
  51. For Linux:
  52. export HF_ENDPOINT=https://hf-mirror.com
  53. For Windows:
  54. Good luck
  55. ^_-
  56. """
  57. self.model = flag_model
  58. def encode(self, texts: list, batch_size=32):
  59. texts = [t[:2000] for t in texts]
  60. token_count = 0
  61. for t in texts:
  62. token_count += num_tokens_from_string(t)
  63. res = []
  64. for i in range(0, len(texts), batch_size):
  65. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  66. return np.array(res), token_count
  67. def encode_queries(self, text: str):
  68. token_count = num_tokens_from_string(text)
  69. return self.model.encode_queries([text]).tolist()[0], token_count
  70. class OpenAIEmbed(Base):
  71. def __init__(self, key, model_name="text-embedding-ada-002",
  72. base_url="https://api.openai.com/v1"):
  73. if not base_url:
  74. base_url = "https://api.openai.com/v1"
  75. self.client = OpenAI(api_key=key, base_url=base_url)
  76. self.model_name = model_name
  77. def encode(self, texts: list, batch_size=32):
  78. res = self.client.embeddings.create(input=texts,
  79. model=self.model_name)
  80. return np.array([d.embedding for d in res.data]
  81. ), res.usage.total_tokens
  82. def encode_queries(self, text):
  83. res = self.client.embeddings.create(input=[text],
  84. model=self.model_name)
  85. return np.array(res.data[0].embedding), res.usage.total_tokens
  86. class QWenEmbed(Base):
  87. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  88. dashscope.api_key = key
  89. self.model_name = model_name
  90. def encode(self, texts: list, batch_size=10):
  91. import dashscope
  92. res = []
  93. token_count = 0
  94. texts = [txt[:2048] for txt in texts]
  95. for i in range(0, len(texts), batch_size):
  96. resp = dashscope.TextEmbedding.call(
  97. model=self.model_name,
  98. input=texts[i:i + batch_size],
  99. text_type="document"
  100. )
  101. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  102. for e in resp["output"]["embeddings"]:
  103. embds[e["text_index"]] = e["embedding"]
  104. res.extend(embds)
  105. token_count += resp["usage"]["total_tokens"]
  106. return np.array(res), token_count
  107. def encode_queries(self, text):
  108. resp = dashscope.TextEmbedding.call(
  109. model=self.model_name,
  110. input=text[:2048],
  111. text_type="query"
  112. )
  113. return np.array(resp["output"]["embeddings"][0]
  114. ["embedding"]), resp["usage"]["total_tokens"]
  115. class ZhipuEmbed(Base):
  116. def __init__(self, key, model_name="embedding-2", **kwargs):
  117. self.client = ZhipuAI(api_key=key)
  118. self.model_name = model_name
  119. def encode(self, texts: list, batch_size=32):
  120. arr = []
  121. tks_num = 0
  122. for txt in texts:
  123. res = self.client.embeddings.create(input=txt,
  124. model=self.model_name)
  125. arr.append(res.data[0].embedding)
  126. tks_num += res.usage.total_tokens
  127. return np.array(arr), tks_num
  128. def encode_queries(self, text):
  129. res = self.client.embeddings.create(input=text,
  130. model=self.model_name)
  131. return np.array(res.data[0].embedding), res.usage.total_tokens
  132. class OllamaEmbed(Base):
  133. def __init__(self, key, model_name, **kwargs):
  134. self.client = Client(host=kwargs["base_url"])
  135. self.model_name = model_name
  136. def encode(self, texts: list, batch_size=32):
  137. arr = []
  138. tks_num = 0
  139. for txt in texts:
  140. res = self.client.embeddings(prompt=txt,
  141. model=self.model_name)
  142. arr.append(res["embedding"])
  143. tks_num += 128
  144. return np.array(arr), tks_num
  145. def encode_queries(self, text):
  146. res = self.client.embeddings(prompt=text,
  147. model=self.model_name)
  148. return np.array(res["embedding"]), 128
  149. class FastEmbed(Base):
  150. def __init__(
  151. self,
  152. key: Optional[str] = None,
  153. model_name: str = "BAAI/bge-small-en-v1.5",
  154. cache_dir: Optional[str] = None,
  155. threads: Optional[int] = None,
  156. **kwargs,
  157. ):
  158. from fastembed import TextEmbedding
  159. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  160. def encode(self, texts: list, batch_size=32):
  161. # Using the internal tokenizer to encode the texts and get the total
  162. # number of tokens
  163. encodings = self._model.model.tokenizer.encode_batch(texts)
  164. total_tokens = sum(len(e) for e in encodings)
  165. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
  166. return np.array(embeddings), total_tokens
  167. def encode_queries(self, text: str):
  168. # Using the internal tokenizer to encode the texts and get the total
  169. # number of tokens
  170. encoding = self._model.model.tokenizer.encode(text)
  171. embedding = next(self._model.query_embed(text)).tolist()
  172. return np.array(embedding), len(encoding.ids)
  173. class XinferenceEmbed(Base):
  174. def __init__(self, key, model_name="", base_url=""):
  175. self.client = OpenAI(api_key="xxx", base_url=base_url)
  176. self.model_name = model_name
  177. def encode(self, texts: list, batch_size=32):
  178. res = self.client.embeddings.create(input=texts,
  179. model=self.model_name)
  180. return np.array([d.embedding for d in res.data]
  181. ), res.usage.total_tokens
  182. def encode_queries(self, text):
  183. res = self.client.embeddings.create(input=[text],
  184. model=self.model_name)
  185. return np.array(res.data[0].embedding), res.usage.total_tokens
  186. class YoudaoEmbed(Base):
  187. _client = None
  188. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  189. from BCEmbedding import EmbeddingModel as qanthing
  190. if not YoudaoEmbed._client:
  191. try:
  192. print("LOADING BCE...")
  193. YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
  194. get_home_cache_dir(),
  195. "bce-embedding-base_v1"))
  196. except Exception as e:
  197. YoudaoEmbed._client = qanthing(
  198. model_name_or_path=model_name.replace(
  199. "maidalun1020", "InfiniFlow"))
  200. def encode(self, texts: list, batch_size=10):
  201. res = []
  202. token_count = 0
  203. for t in texts:
  204. token_count += num_tokens_from_string(t)
  205. for i in range(0, len(texts), batch_size):
  206. embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
  207. res.extend(embds)
  208. return np.array(res), token_count
  209. def encode_queries(self, text):
  210. embds = YoudaoEmbed._client.encode([text])
  211. return np.array(embds[0]), num_tokens_from_string(text)