You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding_model.py 2.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. import dashscope
  18. from openai import OpenAI
  19. from FlagEmbedding import FlagModel
  20. import torch
  21. import os
  22. import numpy as np
  23. from rag.utils import num_tokens_from_string
  24. class Base(ABC):
  25. def __init__(self, key, model_name):
  26. pass
  27. def encode(self, texts: list, batch_size=32):
  28. raise NotImplementedError("Please implement encode method!")
  29. class HuEmbedding(Base):
  30. def __init__(self, key="", model_name=""):
  31. """
  32. If you have trouble downloading HuggingFace models, -_^ this might help!!
  33. For Linux:
  34. export HF_ENDPOINT=https://hf-mirror.com
  35. For Windows:
  36. Good luck
  37. ^_-
  38. """
  39. self.model = FlagModel("BAAI/bge-large-zh-v1.5",
  40. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  41. use_fp16=torch.cuda.is_available())
  42. def encode(self, texts: list, batch_size=32):
  43. token_count = 0
  44. for t in texts: token_count += num_tokens_from_string(t)
  45. res = []
  46. for i in range(0, len(texts), batch_size):
  47. res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
  48. return np.array(res), token_count
  49. class OpenAIEmbed(Base):
  50. def __init__(self, key, model_name="text-embedding-ada-002"):
  51. self.client = OpenAI(key)
  52. self.model_name = model_name
  53. def encode(self, texts: list, batch_size=32):
  54. token_count = 0
  55. for t in texts: token_count += num_tokens_from_string(t)
  56. res = self.client.embeddings.create(input=texts,
  57. model=self.model_name)
  58. return [d["embedding"] for d in res["data"]], token_count
  59. class QWenEmbed(Base):
  60. def __init__(self, key, model_name="text_embedding_v2"):
  61. dashscope.api_key = key
  62. self.model_name = model_name
  63. def encode(self, texts: list, batch_size=32, text_type="document"):
  64. import dashscope
  65. res = []
  66. token_count = 0
  67. for txt in texts:
  68. resp = dashscope.TextEmbedding.call(
  69. model=self.model_name,
  70. input=txt[:2048],
  71. text_type=text_type
  72. )
  73. res.append(resp["output"]["embeddings"][0]["embedding"])
  74. token_count += resp["usage"]["total_tokens"]
  75. return res, token_count