Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

embedding_model.py 9.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from typing import Optional
  17. from zhipuai import ZhipuAI
  18. import os
  19. from abc import ABC
  20. from ollama import Client
  21. import dashscope
  22. from openai import OpenAI
  23. from FlagEmbedding import FlagModel
  24. import torch
  25. import numpy as np
  26. from api.utils.file_utils import get_project_base_directory
  27. from rag.utils import num_tokens_from_string
  28. try:
  29. flag_model = FlagModel(os.path.join(
  30. get_project_base_directory(),
  31. "rag/res/bge-large-zh-v1.5"),
  32. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  33. use_fp16=torch.cuda.is_available())
  34. except Exception as e:
  35. flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
  36. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  37. use_fp16=torch.cuda.is_available())
  38. class Base(ABC):
  39. def __init__(self, key, model_name):
  40. pass
  41. def encode(self, texts: list, batch_size=32):
  42. raise NotImplementedError("Please implement encode method!")
  43. def encode_queries(self, text: str):
  44. raise NotImplementedError("Please implement encode method!")
  45. class HuEmbedding(Base):
  46. def __init__(self, *args, **kwargs):
  47. """
  48. If you have trouble downloading HuggingFace models, -_^ this might help!!
  49. For Linux:
  50. export HF_ENDPOINT=https://hf-mirror.com
  51. For Windows:
  52. Good luck
  53. ^_-
  54. """
  55. self.model = flag_model
  56. def encode(self, texts: list, batch_size=32):
  57. texts = [t[:2000] for t in texts]
  58. token_count = 0
  59. for t in texts:
  60. token_count += num_tokens_from_string(t)
  61. res = []
  62. for i in range(0, len(texts), batch_size):
  63. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  64. return np.array(res), token_count
  65. def encode_queries(self, text: str):
  66. token_count = num_tokens_from_string(text)
  67. return self.model.encode_queries([text]).tolist()[0], token_count
  68. class OpenAIEmbed(Base):
  69. def __init__(self, key, model_name="text-embedding-ada-002",
  70. base_url="https://api.openai.com/v1"):
  71. if not base_url:
  72. base_url = "https://api.openai.com/v1"
  73. self.client = OpenAI(api_key=key, base_url=base_url)
  74. self.model_name = model_name
  75. def encode(self, texts: list, batch_size=32):
  76. res = self.client.embeddings.create(input=texts,
  77. model=self.model_name)
  78. return np.array([d.embedding for d in res.data]
  79. ), res.usage.total_tokens
  80. def encode_queries(self, text):
  81. res = self.client.embeddings.create(input=[text],
  82. model=self.model_name)
  83. return np.array(res.data[0].embedding), res.usage.total_tokens
  84. class QWenEmbed(Base):
  85. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  86. dashscope.api_key = key
  87. self.model_name = model_name
  88. def encode(self, texts: list, batch_size=10):
  89. import dashscope
  90. res = []
  91. token_count = 0
  92. texts = [txt[:2048] for txt in texts]
  93. for i in range(0, len(texts), batch_size):
  94. resp = dashscope.TextEmbedding.call(
  95. model=self.model_name,
  96. input=texts[i:i + batch_size],
  97. text_type="document"
  98. )
  99. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  100. for e in resp["output"]["embeddings"]:
  101. embds[e["text_index"]] = e["embedding"]
  102. res.extend(embds)
  103. token_count += resp["usage"]["total_tokens"]
  104. return np.array(res), token_count
  105. def encode_queries(self, text):
  106. resp = dashscope.TextEmbedding.call(
  107. model=self.model_name,
  108. input=text[:2048],
  109. text_type="query"
  110. )
  111. return np.array(resp["output"]["embeddings"][0]
  112. ["embedding"]), resp["usage"]["total_tokens"]
  113. class ZhipuEmbed(Base):
  114. def __init__(self, key, model_name="embedding-2", **kwargs):
  115. self.client = ZhipuAI(api_key=key)
  116. self.model_name = model_name
  117. def encode(self, texts: list, batch_size=32):
  118. arr = []
  119. tks_num = 0
  120. for txt in texts:
  121. res = self.client.embeddings.create(input=txt,
  122. model=self.model_name)
  123. arr.append(res.data[0].embedding)
  124. tks_num += res.usage.total_tokens
  125. return np.array(arr), tks_num
  126. def encode_queries(self, text):
  127. res = self.client.embeddings.create(input=text,
  128. model=self.model_name)
  129. return np.array(res.data[0].embedding), res.usage.total_tokens
  130. class OllamaEmbed(Base):
  131. def __init__(self, key, model_name, **kwargs):
  132. self.client = Client(host=kwargs["base_url"])
  133. self.model_name = model_name
  134. def encode(self, texts: list, batch_size=32):
  135. arr = []
  136. tks_num = 0
  137. for txt in texts:
  138. res = self.client.embeddings(prompt=txt,
  139. model=self.model_name)
  140. arr.append(res["embedding"])
  141. tks_num += 128
  142. return np.array(arr), tks_num
  143. def encode_queries(self, text):
  144. res = self.client.embeddings(prompt=text,
  145. model=self.model_name)
  146. return np.array(res["embedding"]), 128
  147. class FastEmbed(Base):
  148. def __init__(
  149. self,
  150. key: Optional[str] = None,
  151. model_name: str = "BAAI/bge-small-en-v1.5",
  152. cache_dir: Optional[str] = None,
  153. threads: Optional[int] = None,
  154. **kwargs,
  155. ):
  156. from fastembed import TextEmbedding
  157. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  158. def encode(self, texts: list, batch_size=32):
  159. # Using the internal tokenizer to encode the texts and get the total
  160. # number of tokens
  161. encodings = self._model.model.tokenizer.encode_batch(texts)
  162. total_tokens = sum(len(e) for e in encodings)
  163. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
  164. return np.array(embeddings), total_tokens
  165. def encode_queries(self, text: str):
  166. # Using the internal tokenizer to encode the texts and get the total
  167. # number of tokens
  168. encoding = self._model.model.tokenizer.encode(text)
  169. embedding = next(self._model.query_embed(text)).tolist()
  170. return np.array(embedding), len(encoding.ids)
  171. class XinferenceEmbed(Base):
  172. def __init__(self, key, model_name="", base_url=""):
  173. self.client = OpenAI(api_key="xxx", base_url=base_url)
  174. self.model_name = model_name
  175. def encode(self, texts: list, batch_size=32):
  176. res = self.client.embeddings.create(input=texts,
  177. model=self.model_name)
  178. return np.array([d.embedding for d in res.data]
  179. ), res.usage.total_tokens
  180. def encode_queries(self, text):
  181. res = self.client.embeddings.create(input=[text],
  182. model=self.model_name)
  183. return np.array(res.data[0].embedding), res.usage.total_tokens
  184. class QAnythingEmbed(Base):
  185. _client = None
  186. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  187. from BCEmbedding import EmbeddingModel as qanthing
  188. if not QAnythingEmbed._client:
  189. try:
  190. print("LOADING BCE...")
  191. QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
  192. get_project_base_directory(),
  193. "rag/res/bce-embedding-base_v1"))
  194. except Exception as e:
  195. QAnythingEmbed._client = qanthing(
  196. model_name_or_path=model_name.replace(
  197. "maidalun1020", "InfiniFlow"))
  198. def encode(self, texts: list, batch_size=10):
  199. res = []
  200. token_count = 0
  201. for t in texts:
  202. token_count += num_tokens_from_string(t)
  203. for i in range(0, len(texts), batch_size):
  204. embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
  205. res.extend(embds)
  206. return np.array(res), token_count
  207. def encode_queries(self, text):
  208. embds = QAnythingEmbed._client.encode([text])
  209. return np.array(embds[0]), num_tokens_from_string(text)