Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import threading
  19. import requests
  20. from huggingface_hub import snapshot_download
  21. from zhipuai import ZhipuAI
  22. import os
  23. from abc import ABC
  24. from ollama import Client
  25. import dashscope
  26. from openai import OpenAI
  27. import numpy as np
  28. import asyncio
  29. from api import settings
  30. from api.utils.file_utils import get_home_cache_dir
  31. from rag.utils import num_tokens_from_string, truncate
  32. import google.generativeai as genai
  33. import json
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def encode(self, texts: list):
  38. raise NotImplementedError("Please implement encode method!")
  39. def encode_queries(self, text: str):
  40. raise NotImplementedError("Please implement encode method!")
  41. class DefaultEmbedding(Base):
  42. _model = None
  43. _model_lock = threading.Lock()
  44. def __init__(self, key, model_name, **kwargs):
  45. """
  46. If you have trouble downloading HuggingFace models, -_^ this might help!!
  47. For Linux:
  48. export HF_ENDPOINT=https://hf-mirror.com
  49. For Windows:
  50. Good luck
  51. ^_-
  52. """
  53. if not settings.LIGHTEN and not DefaultEmbedding._model:
  54. with DefaultEmbedding._model_lock:
  55. from FlagEmbedding import FlagModel
  56. import torch
  57. if not DefaultEmbedding._model:
  58. try:
  59. DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  60. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  61. use_fp16=torch.cuda.is_available())
  62. except Exception:
  63. model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
  64. local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
  65. local_dir_use_symlinks=False)
  66. DefaultEmbedding._model = FlagModel(model_dir,
  67. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  68. use_fp16=torch.cuda.is_available())
  69. self._model = DefaultEmbedding._model
  70. def encode(self, texts: list):
  71. batch_size = 16
  72. texts = [truncate(t, 2048) for t in texts]
  73. token_count = 0
  74. for t in texts:
  75. token_count += num_tokens_from_string(t)
  76. ress = []
  77. for i in range(0, len(texts), batch_size):
  78. ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
  79. return np.array(ress), token_count
  80. def encode_queries(self, text: str):
  81. token_count = num_tokens_from_string(text)
  82. return self._model.encode_queries([text]).tolist()[0], token_count
  83. class OpenAIEmbed(Base):
  84. def __init__(self, key, model_name="text-embedding-ada-002",
  85. base_url="https://api.openai.com/v1"):
  86. if not base_url:
  87. base_url = "https://api.openai.com/v1"
  88. self.client = OpenAI(api_key=key, base_url=base_url)
  89. self.model_name = model_name
  90. def encode(self, texts: list):
  91. # OpenAI requires batch size <=16
  92. batch_size = 16
  93. texts = [truncate(t, 8191) for t in texts]
  94. ress = []
  95. total_tokens = 0
  96. for i in range(0, len(texts), batch_size):
  97. res = self.client.embeddings.create(input=texts[i:i + batch_size],
  98. model=self.model_name)
  99. ress.extend([d.embedding for d in res.data])
  100. total_tokens += res.usage.total_tokens
  101. return np.array(ress), total_tokens
  102. def encode_queries(self, text):
  103. res = self.client.embeddings.create(input=[truncate(text, 8191)],
  104. model=self.model_name)
  105. return np.array(res.data[0].embedding), res.usage.total_tokens
  106. class LocalAIEmbed(Base):
  107. def __init__(self, key, model_name, base_url):
  108. if not base_url:
  109. raise ValueError("Local embedding model url cannot be None")
  110. if base_url.split("/")[-1] != "v1":
  111. base_url = os.path.join(base_url, "v1")
  112. self.client = OpenAI(api_key="empty", base_url=base_url)
  113. self.model_name = model_name.split("___")[0]
  114. def encode(self, texts: list):
  115. batch_size = 16
  116. ress = []
  117. for i in range(0, len(texts), batch_size):
  118. res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
  119. ress.extend([d.embedding for d in res.data])
  120. # local embedding for LmStudio donot count tokens
  121. return np.array(ress), 1024
  122. def encode_queries(self, text):
  123. embds, cnt = self.encode([text])
  124. return np.array(embds[0]), cnt
  125. class AzureEmbed(OpenAIEmbed):
  126. def __init__(self, key, model_name, **kwargs):
  127. from openai.lib.azure import AzureOpenAI
  128. api_key = json.loads(key).get('api_key', '')
  129. api_version = json.loads(key).get('api_version', '2024-02-01')
  130. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  131. self.model_name = model_name
  132. class BaiChuanEmbed(OpenAIEmbed):
  133. def __init__(self, key,
  134. model_name='Baichuan-Text-Embedding',
  135. base_url='https://api.baichuan-ai.com/v1'):
  136. if not base_url:
  137. base_url = "https://api.baichuan-ai.com/v1"
  138. super().__init__(key, model_name, base_url)
  139. class QWenEmbed(Base):
  140. def __init__(self, key, model_name="text_embedding_v2", **kwargs):
  141. self.key = key
  142. self.model_name = model_name
  143. def encode(self, texts: list):
  144. import dashscope
  145. batch_size = 4
  146. try:
  147. res = []
  148. token_count = 0
  149. texts = [truncate(t, 2048) for t in texts]
  150. for i in range(0, len(texts), batch_size):
  151. resp = dashscope.TextEmbedding.call(
  152. model=self.model_name,
  153. input=texts[i:i + batch_size],
  154. api_key=self.key,
  155. text_type="document"
  156. )
  157. embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
  158. for e in resp["output"]["embeddings"]:
  159. embds[e["text_index"]] = e["embedding"]
  160. res.extend(embds)
  161. token_count += resp["usage"]["total_tokens"]
  162. return np.array(res), token_count
  163. except Exception as e:
  164. raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
  165. return np.array([]), 0
  166. def encode_queries(self, text):
  167. try:
  168. resp = dashscope.TextEmbedding.call(
  169. model=self.model_name,
  170. input=text[:2048],
  171. api_key=self.key,
  172. text_type="query"
  173. )
  174. return np.array(resp["output"]["embeddings"][0]
  175. ["embedding"]), resp["usage"]["total_tokens"]
  176. except Exception:
  177. raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
  178. return np.array([]), 0
  179. class ZhipuEmbed(Base):
  180. def __init__(self, key, model_name="embedding-2", **kwargs):
  181. self.client = ZhipuAI(api_key=key)
  182. self.model_name = model_name
  183. def encode(self, texts: list):
  184. arr = []
  185. tks_num = 0
  186. for txt in texts:
  187. res = self.client.embeddings.create(input=txt,
  188. model=self.model_name)
  189. arr.append(res.data[0].embedding)
  190. tks_num += res.usage.total_tokens
  191. return np.array(arr), tks_num
  192. def encode_queries(self, text):
  193. res = self.client.embeddings.create(input=text,
  194. model=self.model_name)
  195. return np.array(res.data[0].embedding), res.usage.total_tokens
  196. class OllamaEmbed(Base):
  197. def __init__(self, key, model_name, **kwargs):
  198. self.client = Client(host=kwargs["base_url"])
  199. self.model_name = model_name
  200. def encode(self, texts: list):
  201. arr = []
  202. tks_num = 0
  203. for txt in texts:
  204. res = self.client.embeddings(prompt=txt,
  205. model=self.model_name)
  206. arr.append(res["embedding"])
  207. tks_num += 128
  208. return np.array(arr), tks_num
  209. def encode_queries(self, text):
  210. res = self.client.embeddings(prompt=text,
  211. model=self.model_name)
  212. return np.array(res["embedding"]), 128
  213. class FastEmbed(Base):
  214. _model = None
  215. def __init__(
  216. self,
  217. key: str | None = None,
  218. model_name: str = "BAAI/bge-small-en-v1.5",
  219. cache_dir: str | None = None,
  220. threads: int | None = None,
  221. **kwargs,
  222. ):
  223. if not settings.LIGHTEN and not FastEmbed._model:
  224. from fastembed import TextEmbedding
  225. self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
  226. def encode(self, texts: list):
  227. # Using the internal tokenizer to encode the texts and get the total
  228. # number of tokens
  229. encodings = self._model.model.tokenizer.encode_batch(texts)
  230. total_tokens = sum(len(e) for e in encodings)
  231. embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
  232. return np.array(embeddings), total_tokens
  233. def encode_queries(self, text: str):
  234. # Using the internal tokenizer to encode the texts and get the total
  235. # number of tokens
  236. encoding = self._model.model.tokenizer.encode(text)
  237. embedding = next(self._model.query_embed(text)).tolist()
  238. return np.array(embedding), len(encoding.ids)
  239. class XinferenceEmbed(Base):
  240. def __init__(self, key, model_name="", base_url=""):
  241. if base_url.split("/")[-1] != "v1":
  242. base_url = os.path.join(base_url, "v1")
  243. self.client = OpenAI(api_key=key, base_url=base_url)
  244. self.model_name = model_name
  245. def encode(self, texts: list):
  246. batch_size = 16
  247. ress = []
  248. total_tokens = 0
  249. for i in range(0, len(texts), batch_size):
  250. res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
  251. ress.extend([d.embedding for d in res.data])
  252. total_tokens += res.usage.total_tokens
  253. return np.array(ress), total_tokens
  254. def encode_queries(self, text):
  255. res = self.client.embeddings.create(input=[text],
  256. model=self.model_name)
  257. return np.array(res.data[0].embedding), res.usage.total_tokens
  258. class YoudaoEmbed(Base):
  259. _client = None
  260. def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
  261. if not settings.LIGHTEN and not YoudaoEmbed._client:
  262. from BCEmbedding import EmbeddingModel as qanthing
  263. try:
  264. logging.info("LOADING BCE...")
  265. YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
  266. get_home_cache_dir(),
  267. "bce-embedding-base_v1"))
  268. except Exception:
  269. YoudaoEmbed._client = qanthing(
  270. model_name_or_path=model_name.replace(
  271. "maidalun1020", "InfiniFlow"))
  272. def encode(self, texts: list):
  273. batch_size = 10
  274. res = []
  275. token_count = 0
  276. for t in texts:
  277. token_count += num_tokens_from_string(t)
  278. for i in range(0, len(texts), batch_size):
  279. embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
  280. res.extend(embds)
  281. return np.array(res), token_count
  282. def encode_queries(self, text):
  283. embds = YoudaoEmbed._client.encode([text])
  284. return np.array(embds[0]), num_tokens_from_string(text)
  285. class JinaEmbed(Base):
  286. def __init__(self, key, model_name="jina-embeddings-v3",
  287. base_url="https://api.jina.ai/v1/embeddings"):
  288. self.base_url = "https://api.jina.ai/v1/embeddings"
  289. self.headers = {
  290. "Content-Type": "application/json",
  291. "Authorization": f"Bearer {key}"
  292. }
  293. self.model_name = model_name
  294. def encode(self, texts: list):
  295. texts = [truncate(t, 8196) for t in texts]
  296. batch_size = 16
  297. ress = []
  298. token_count = 0
  299. for i in range(0, len(texts), batch_size):
  300. data = {
  301. "model": self.model_name,
  302. "input": texts[i:i + batch_size],
  303. 'encoding_type': 'float'
  304. }
  305. res = requests.post(self.base_url, headers=self.headers, json=data).json()
  306. ress.extend([d["embedding"] for d in res["data"]])
  307. token_count += res["usage"]["total_tokens"]
  308. return np.array(ress), token_count
  309. def encode_queries(self, text):
  310. embds, cnt = self.encode([text])
  311. return np.array(embds[0]), cnt
  312. class InfinityEmbed(Base):
  313. _model = None
  314. def __init__(
  315. self,
  316. model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
  317. engine_kwargs: dict = {},
  318. key = None,
  319. ):
  320. from infinity_emb import EngineArgs
  321. from infinity_emb.engine import AsyncEngineArray
  322. self._default_model = model_names[0]
  323. self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
  324. async def _embed(self, sentences: list[str], model_name: str = ""):
  325. if not model_name:
  326. model_name = self._default_model
  327. engine = self.engine_array[model_name]
  328. was_already_running = engine.is_running
  329. if not was_already_running:
  330. await engine.astart()
  331. embeddings, usage = await engine.embed(sentences=sentences)
  332. if not was_already_running:
  333. await engine.astop()
  334. return embeddings, usage
  335. def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
  336. # Using the internal tokenizer to encode the texts and get the total
  337. # number of tokens
  338. embeddings, usage = asyncio.run(self._embed(texts, model_name))
  339. return np.array(embeddings), usage
  340. def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
  341. # Using the internal tokenizer to encode the texts and get the total
  342. # number of tokens
  343. return self.encode([text])
  344. class MistralEmbed(Base):
  345. def __init__(self, key, model_name="mistral-embed",
  346. base_url=None):
  347. from mistralai.client import MistralClient
  348. self.client = MistralClient(api_key=key)
  349. self.model_name = model_name
  350. def encode(self, texts: list):
  351. texts = [truncate(t, 8196) for t in texts]
  352. batch_size = 16
  353. ress = []
  354. token_count = 0
  355. for i in range(0, len(texts), batch_size):
  356. res = self.client.embeddings(input=texts[i:i + batch_size],
  357. model=self.model_name)
  358. ress.extend([d.embedding for d in res.data])
  359. token_count += res.usage.total_tokens
  360. return np.array(ress), token_count
  361. def encode_queries(self, text):
  362. res = self.client.embeddings(input=[truncate(text, 8196)],
  363. model=self.model_name)
  364. return np.array(res.data[0].embedding), res.usage.total_tokens
  365. class BedrockEmbed(Base):
  366. def __init__(self, key, model_name,
  367. **kwargs):
  368. import boto3
  369. self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
  370. self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
  371. self.bedrock_region = json.loads(key).get('bedrock_region', '')
  372. self.model_name = model_name
  373. self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
  374. aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
  375. def encode(self, texts: list):
  376. texts = [truncate(t, 8196) for t in texts]
  377. embeddings = []
  378. token_count = 0
  379. for text in texts:
  380. if self.model_name.split('.')[0] == 'amazon':
  381. body = {"inputText": text}
  382. elif self.model_name.split('.')[0] == 'cohere':
  383. body = {"texts": [text], "input_type": 'search_document'}
  384. response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
  385. model_response = json.loads(response["body"].read())
  386. embeddings.extend([model_response["embedding"]])
  387. token_count += num_tokens_from_string(text)
  388. return np.array(embeddings), token_count
  389. def encode_queries(self, text):
  390. embeddings = []
  391. token_count = num_tokens_from_string(text)
  392. if self.model_name.split('.')[0] == 'amazon':
  393. body = {"inputText": truncate(text, 8196)}
  394. elif self.model_name.split('.')[0] == 'cohere':
  395. body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
  396. response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
  397. model_response = json.loads(response["body"].read())
  398. embeddings.extend(model_response["embedding"])
  399. return np.array(embeddings), token_count
  400. class GeminiEmbed(Base):
  401. def __init__(self, key, model_name='models/text-embedding-004',
  402. **kwargs):
  403. self.key = key
  404. self.model_name = 'models/' + model_name
  405. def encode(self, texts: list):
  406. texts = [truncate(t, 2048) for t in texts]
  407. token_count = sum(num_tokens_from_string(text) for text in texts)
  408. genai.configure(api_key=self.key)
  409. batch_size = 16
  410. ress = []
  411. for i in range(0, len(texts), batch_size):
  412. result = genai.embed_content(
  413. model=self.model_name,
  414. content=texts[i, i + batch_size],
  415. task_type="retrieval_document",
  416. title="Embedding of single string")
  417. ress.extend(result['embedding'])
  418. return np.array(ress),token_count
  419. def encode_queries(self, text):
  420. genai.configure(api_key=self.key)
  421. result = genai.embed_content(
  422. model=self.model_name,
  423. content=truncate(text,2048),
  424. task_type="retrieval_document",
  425. title="Embedding of single string")
  426. token_count = num_tokens_from_string(text)
  427. return np.array(result['embedding']),token_count
  428. class NvidiaEmbed(Base):
  429. def __init__(
  430. self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
  431. ):
  432. if not base_url:
  433. base_url = "https://integrate.api.nvidia.com/v1/embeddings"
  434. self.api_key = key
  435. self.base_url = base_url
  436. self.headers = {
  437. "accept": "application/json",
  438. "Content-Type": "application/json",
  439. "authorization": f"Bearer {self.api_key}",
  440. }
  441. self.model_name = model_name
  442. if model_name == "nvidia/embed-qa-4":
  443. self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
  444. self.model_name = "NV-Embed-QA"
  445. if model_name == "snowflake/arctic-embed-l":
  446. self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
  447. def encode(self, texts: list):
  448. batch_size = 16
  449. ress = []
  450. token_count = 0
  451. for i in range(0, len(texts), batch_size):
  452. payload = {
  453. "input": texts[i : i + batch_size],
  454. "input_type": "query",
  455. "model": self.model_name,
  456. "encoding_format": "float",
  457. "truncate": "END",
  458. }
  459. res = requests.post(self.base_url, headers=self.headers, json=payload).json()
  460. ress.extend([d["embedding"] for d in res["data"]])
  461. token_count += res["usage"]["total_tokens"]
  462. return np.array(ress), token_count
  463. def encode_queries(self, text):
  464. embds, cnt = self.encode([text])
  465. return np.array(embds[0]), cnt
  466. class LmStudioEmbed(LocalAIEmbed):
  467. def __init__(self, key, model_name, base_url):
  468. if not base_url:
  469. raise ValueError("Local llm url cannot be None")
  470. if base_url.split("/")[-1] != "v1":
  471. base_url = os.path.join(base_url, "v1")
  472. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  473. self.model_name = model_name
  474. class OpenAI_APIEmbed(OpenAIEmbed):
  475. def __init__(self, key, model_name, base_url):
  476. if not base_url:
  477. raise ValueError("url cannot be None")
  478. if base_url.split("/")[-1] != "v1":
  479. base_url = os.path.join(base_url, "v1")
  480. self.client = OpenAI(api_key=key, base_url=base_url)
  481. self.model_name = model_name.split("___")[0]
  482. class CoHereEmbed(Base):
  483. def __init__(self, key, model_name, base_url=None):
  484. from cohere import Client
  485. self.client = Client(api_key=key)
  486. self.model_name = model_name
  487. def encode(self, texts: list):
  488. batch_size = 16
  489. ress = []
  490. token_count = 0
  491. for i in range(0, len(texts), batch_size):
  492. res = self.client.embed(
  493. texts=texts[i : i + batch_size],
  494. model=self.model_name,
  495. input_type="search_document",
  496. embedding_types=["float"],
  497. )
  498. ress.extend([d for d in res.embeddings.float])
  499. token_count += res.meta.billed_units.input_tokens
  500. return np.array(ress), token_count
  501. def encode_queries(self, text):
  502. res = self.client.embed(
  503. texts=[text],
  504. model=self.model_name,
  505. input_type="search_query",
  506. embedding_types=["float"],
  507. )
  508. return np.array(res.embeddings.float[0]), int(
  509. res.meta.billed_units.input_tokens
  510. )
  511. class TogetherAIEmbed(OllamaEmbed):
  512. def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
  513. if not base_url:
  514. base_url = "https://api.together.xyz/v1"
  515. super().__init__(key, model_name, base_url=base_url)
  516. class PerfXCloudEmbed(OpenAIEmbed):
  517. def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
  518. if not base_url:
  519. base_url = "https://cloud.perfxlab.cn/v1"
  520. super().__init__(key, model_name, base_url)
  521. class UpstageEmbed(OpenAIEmbed):
  522. def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
  523. if not base_url:
  524. base_url = "https://api.upstage.ai/v1/solar"
  525. super().__init__(key, model_name, base_url)
  526. class SILICONFLOWEmbed(Base):
  527. def __init__(
  528. self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"
  529. ):
  530. if not base_url:
  531. base_url = "https://api.siliconflow.cn/v1/embeddings"
  532. self.headers = {
  533. "accept": "application/json",
  534. "content-type": "application/json",
  535. "authorization": f"Bearer {key}",
  536. }
  537. self.base_url = base_url
  538. self.model_name = model_name
  539. def encode(self, texts: list):
  540. batch_size = 16
  541. ress = []
  542. token_count = 0
  543. for i in range(0, len(texts), batch_size):
  544. texts_batch = texts[i : i + batch_size]
  545. payload = {
  546. "model": self.model_name,
  547. "input": texts_batch,
  548. "encoding_format": "float",
  549. }
  550. res = requests.post(self.base_url, json=payload, headers=self.headers).json()
  551. if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
  552. raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
  553. ress.extend([d["embedding"] for d in res["data"]])
  554. token_count += res["usage"]["total_tokens"]
  555. return np.array(ress), token_count
  556. def encode_queries(self, text):
  557. payload = {
  558. "model": self.model_name,
  559. "input": text,
  560. "encoding_format": "float",
  561. }
  562. res = requests.post(self.base_url, json=payload, headers=self.headers).json()
  563. if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1:
  564. raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}")
  565. return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
  566. class ReplicateEmbed(Base):
  567. def __init__(self, key, model_name, base_url=None):
  568. from replicate.client import Client
  569. self.model_name = model_name
  570. self.client = Client(api_token=key)
  571. def encode(self, texts: list):
  572. batch_size = 16
  573. token_count = sum([num_tokens_from_string(text) for text in texts])
  574. ress = []
  575. for i in range(0, len(texts), batch_size):
  576. res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
  577. ress.extend(res)
  578. return np.array(ress), token_count
  579. def encode_queries(self, text):
  580. res = self.client.embed(self.model_name, input={"texts": [text]})
  581. return np.array(res), num_tokens_from_string(text)
  582. class BaiduYiyanEmbed(Base):
  583. def __init__(self, key, model_name, base_url=None):
  584. import qianfan
  585. key = json.loads(key)
  586. ak = key.get("yiyan_ak", "")
  587. sk = key.get("yiyan_sk", "")
  588. self.client = qianfan.Embedding(ak=ak, sk=sk)
  589. self.model_name = model_name
  590. def encode(self, texts: list, batch_size=16):
  591. res = self.client.do(model=self.model_name, texts=texts).body
  592. return (
  593. np.array([r["embedding"] for r in res["data"]]),
  594. res["usage"]["total_tokens"],
  595. )
  596. def encode_queries(self, text):
  597. res = self.client.do(model=self.model_name, texts=[text]).body
  598. return (
  599. np.array([r["embedding"] for r in res["data"]]),
  600. res["usage"]["total_tokens"],
  601. )
  602. class VoyageEmbed(Base):
  603. def __init__(self, key, model_name, base_url=None):
  604. import voyageai
  605. self.client = voyageai.Client(api_key=key)
  606. self.model_name = model_name
  607. def encode(self, texts: list):
  608. batch_size = 16
  609. ress = []
  610. token_count = 0
  611. for i in range(0, len(texts), batch_size):
  612. res = self.client.embed(
  613. texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
  614. )
  615. ress.extend(res.embeddings)
  616. token_count += res.total_tokens
  617. return np.array(ress), token_count
  618. def encode_queries(self, text):
  619. res = self.client.embed(
  620. texts=text, model=self.model_name, input_type="query"
  621. )
  622. return np.array(res.embeddings)[0], res.total_tokens
  623. class HuggingFaceEmbed(Base):
  624. def __init__(self, key, model_name, base_url=None):
  625. if not model_name:
  626. raise ValueError("Model name cannot be None")
  627. self.key = key
  628. self.model_name = model_name.split("___")[0]
  629. self.base_url = base_url or "http://127.0.0.1:8080"
  630. def encode(self, texts: list):
  631. embeddings = []
  632. for text in texts:
  633. response = requests.post(
  634. f"{self.base_url}/embed",
  635. json={"inputs": text},
  636. headers={'Content-Type': 'application/json'}
  637. )
  638. if response.status_code == 200:
  639. embedding = response.json()
  640. embeddings.append(embedding[0])
  641. else:
  642. raise Exception(f"Error: {response.status_code} - {response.text}")
  643. return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
  644. def encode_queries(self, text):
  645. response = requests.post(
  646. f"{self.base_url}/embed",
  647. json={"inputs": text},
  648. headers={'Content-Type': 'application/json'}
  649. )
  650. if response.status_code == 200:
  651. embedding = response.json()
  652. return np.array(embedding[0]), num_tokens_from_string(text)
  653. else:
  654. raise Exception(f"Error: {response.status_code} - {response.text}")
  655. class VolcEngineEmbed(OpenAIEmbed):
  656. def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
  657. if not base_url:
  658. base_url = "https://ark.cn-beijing.volces.com/api/v3"
  659. ark_api_key = json.loads(key).get('ark_api_key', '')
  660. model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
  661. super().__init__(ark_api_key,model_name,base_url)