浏览代码

Add bce-embedding and fastembed (#383)

### What problem does this PR solve?


Issue link:#326

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.2.0
KevinHuSh 1年前
父节点
当前提交
890561703b
没有帐户链接到提交者的电子邮件

+ 2
- 0
README.md 查看文件

## 📌 Latest Features ## 📌 Latest Features
- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything).
- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding.
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment. - 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation. - 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment. - 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.

+ 2
- 0
README_ja.md 查看文件

## 📌 最新の機能 ## 📌 最新の機能
- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) から埋め込みモデル「bce-embedding-base_v1」を追加します。
- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) は、軽量かつ高速な埋め込み用に設計されています。
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。 - 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。 - 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。 - 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。

+ 2
- 0
README_zh.md 查看文件



## 📌 新增功能 ## 📌 新增功能


- 2024-04-16 添加嵌入模型 [QAnything的bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。
- 2024-04-16 添加 [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和高速嵌入而设计。
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。 - 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
- 2024-04-10 为‘Laws’版面分析增加了底层模型。 - 2024-04-10 为‘Laws’版面分析增加了底层模型。
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。 - 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。

+ 1
- 1
api/apps/chunk_app.py 查看文件

return get_data_error_result(retmsg="Knowledgebase not found!") return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value)
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids) vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]: for c in ranks["chunks"]:

+ 1
- 0
api/apps/document_app.py 查看文件

# #
import base64 import base64
import os
import pathlib import pathlib
import re import re

+ 2
- 2
api/apps/llm_app.py 查看文件

def factories(): def factories():
try: try:
fac = LLMFactoriesService.get_all() fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
llms = [m.to_dict() llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value] for m in llms if m.status == StatusEnum.VALID.value]
for m in llms: for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
llm_set = set([m["llm_name"] for m in llms]) llm_set = set([m["llm_name"] for m in llms])
for o in objs: for o in objs:

+ 19
- 11
api/db/init_data.py 查看文件

import uuid import uuid
from api.db import LLMType, UserTenantRole from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
from api.db.services import UserService from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
"logo": "", "logo": "",
"tags": "TEXT EMBEDDING", "tags": "TEXT EMBEDDING",
"status": "1", "status": "1",
},
{
}, {
"name": "Xinference", "name": "Xinference",
"logo": "", "logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1", "status": "1",
},{
"name": "QAnything",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
}, },
# { # {
# "name": "文心一言", # "name": "文心一言",
"tags": "LLM,CHAT,", "tags": "LLM,CHAT,",
"max_tokens": 7900, "max_tokens": 7900,
"model_type": LLMType.CHAT.value "model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-embedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
}, { }, {
"fid": factory_infos[4]["name"], "fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k", "llm_name": "moonshot-v1-32k",
"max_tokens": 2147483648, "max_tokens": 2147483648,
"model_type": LLMType.EMBEDDING.value "model_type": LLMType.EMBEDDING.value
}, },
# ------------------------ QAnything -----------------------
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-embedding-base_v1",
"tags": "TEXT EMBEDDING,",
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
] ]
for info in factory_infos: for info in factory_infos:
try: try:
except Exception as e: except Exception as e:
pass pass
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
LLMService.filter_delete([LLM.fid=="Local"])
LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
LLMService.filter_delete([LLM.fid == "Local"])
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
""" """
drop table llm; drop table llm;

+ 5
- 1
api/db/services/dialog_service.py 查看文件

raise LookupError("LLM(%s) not found" % dialog.llm_id) raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024 max_tokens = 1024
else: max_tokens = llm[0].max_tokens else: max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
questions = [m["content"] for m in messages if m["role"] == "user"] questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config

+ 8
- 3
api/db/services/llm_service.py 查看文件

raise LookupError("Tenant not found") raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value: elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: elif llm_type == LLMType.IMAGE2TEXT.value:
assert False, "LLM type error" assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm) model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config: if not model_config:
raise LookupError("Model({}) not authorized".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
llm = LLMService.query(llm_name=llm_name)
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
if llm_type == LLMType.EMBEDDING.value: if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: if model_config["llm_factory"] not in EmbeddingModel:
return return

+ 1
- 1
api/db/services/task_service.py 查看文件

Document.size, Document.size,
Knowledgebase.tenant_id, Knowledgebase.tenant_id,
Knowledgebase.language, Knowledgebase.language,
Tenant.embd_id,
Knowledgebase.embd_id,
Tenant.img2txt_id, Tenant.img2txt_id,
Tenant.asr_id, Tenant.asr_id,
cls.model.update_time] cls.model.update_time]

+ 2
- 2
rag/llm/__init__.py 查看文件

"Xinference": XinferenceEmbed, "Xinference": XinferenceEmbed,
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed, "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
"ZHIPU-AI": ZhipuEmbed, "ZHIPU-AI": ZhipuEmbed,
"Moonshot": HuEmbedding,
"FastEmbed": FastEmbed
"FastEmbed": FastEmbed,
"QAnything": QAnythingEmbed
} }





+ 49
- 14
rag/llm/embedding_model.py 查看文件

from ollama import Client from ollama import Client
import dashscope import dashscope
from openai import OpenAI from openai import OpenAI
from fastembed import TextEmbedding
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import numpy as np import numpy as np
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string



try: try:
flag_model = FlagModel(os.path.join( flag_model = FlagModel(os.path.join(
get_project_base_directory(),
"rag/res/bge-large-zh-v1.5"),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
get_project_base_directory(),
"rag/res/bge-large-zh-v1.5"),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
except Exception as e: except Exception as e:
flag_model = FlagModel("BAAI/bge-large-zh-v1.5", flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())




class Base(ABC): class Base(ABC):




class OpenAIEmbed(Base): class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
def __init__(self, key, model_name="text-embedding-ada-002",
base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name


tks_num = 0 tks_num = 0
for txt in texts: for txt in texts:
res = self.client.embeddings.create(input=txt, res = self.client.embeddings.create(input=txt,
model=self.model_name)
model=self.model_name)
arr.append(res.data[0].embedding) arr.append(res.data[0].embedding)
tks_num += res.usage.total_tokens tks_num += res.usage.total_tokens
return np.array(arr), tks_num return np.array(arr), tks_num
tks_num = 0 tks_num = 0
for txt in texts: for txt in texts:
res = self.client.embeddings(prompt=txt, res = self.client.embeddings(prompt=txt,
model=self.model_name)
model=self.model_name)
arr.append(res["embedding"]) arr.append(res["embedding"])
tks_num += 128 tks_num += 128
return np.array(arr), tks_num return np.array(arr), tks_num


def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings(prompt=text, res = self.client.embeddings(prompt=text,
model=self.model_name)
model=self.model_name)
return np.array(res["embedding"]), 128 return np.array(res["embedding"]), 128




threads: Optional[int] = None, threads: Optional[int] = None,
**kwargs, **kwargs,
): ):
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)


def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
# Using the internal tokenizer to encode the texts and get the total number of tokens
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encodings = self._model.model.tokenizer.encode_batch(texts) encodings = self._model.model.tokenizer.encode_batch(texts)
total_tokens = sum(len(e) for e in encodings) total_tokens = sum(len(e) for e in encodings)


return np.array(embeddings), total_tokens return np.array(embeddings), total_tokens


def encode_queries(self, text: str): def encode_queries(self, text: str):
# Using the internal tokenizer to encode the texts and get the total number of tokens
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
encoding = self._model.model.tokenizer.encode(text) encoding = self._model.model.tokenizer.encode(text)
embedding = next(self._model.query_embed(text)).tolist() embedding = next(self._model.query_embed(text)).tolist()


model=self.model_name) model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens return np.array(res.data[0].embedding), res.usage.total_tokens



class QAnythingEmbed(Base):
_client = None

def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
from BCEmbedding import EmbeddingModel as qanthing
if not QAnythingEmbed._client:
try:
print("LOADING BCE...")
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
get_project_base_directory(),
"rag/res/bce-embedding-base_v1"))
except Exception as e:
QAnythingEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))

def encode(self, texts: list, batch_size=10):
res = []
token_count = 0
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
res.extend(embds)
return np.array(res), token_count

def encode_queries(self, text):
embds = QAnythingEmbed._client.encode([text])
return np.array(embds[0]), num_tokens_from_string(text)

+ 1
- 1
rag/nlp/search.py 查看文件

"k": topk, "k": topk,
"similarity": sim, "similarity": sim,
"num_candidates": topk * 2, "num_candidates": topk * 2,
"query_vector": list(qv)
"query_vector": [float(v) for v in qv]
} }


def search(self, req, idxnm, emb_mdl=None): def search(self, req, idxnm, emb_mdl=None):

+ 2
- 1
rag/svr/task_executor.py 查看文件

for _, r in rows.iterrows(): for _, r in rows.iterrows():
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"]) callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
try: try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
except Exception as e: except Exception as e:
traceback.print_stack(e)
callback(prog=-1, msg=str(e)) callback(prog=-1, msg=str(e))
continue continue



+ 2
- 0
requirements.txt 查看文件

xxhash==3.4.1 xxhash==3.4.1
yarl==1.9.4 yarl==1.9.4
zhipuai==2.0.1 zhipuai==2.0.1
BCEmbedding
loguru==0.7.2

正在加载...
取消
保存