### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.7.0
| @@ -257,8 +257,15 @@ def retrieval_test(): | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, | |||
| vector_similarity_weight, top, doc_ids) | |||
| rerank_mdl = None | |||
| if req.get("rerank_id"): | |||
| rerank_mdl = TenantLLMService.model_instance( | |||
| kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||
| ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, | |||
| similarity_threshold, vector_similarity_weight, top, | |||
| doc_ids, rerank_mdl=rerank_mdl) | |||
| for c in ranks["chunks"]: | |||
| if "vector" in c: | |||
| del c["vector"] | |||
| @@ -33,6 +33,9 @@ def set_dialog(): | |||
| name = req.get("name", "New Dialog") | |||
| description = req.get("description", "A helpful Dialog") | |||
| top_n = req.get("top_n", 6) | |||
| top_k = req.get("top_k", 1024) | |||
| rerank_id = req.get("rerank_id", "") | |||
| if not rerank_id: req["rerank_id"] = "" | |||
| similarity_threshold = req.get("similarity_threshold", 0.1) | |||
| vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | |||
| llm_setting = req.get("llm_setting", {}) | |||
| @@ -83,6 +86,8 @@ def set_dialog(): | |||
| "llm_setting": llm_setting, | |||
| "prompt_config": prompt_config, | |||
| "top_n": top_n, | |||
| "top_k": top_k, | |||
| "rerank_id": rerank_id, | |||
| "similarity_threshold": similarity_threshold, | |||
| "vector_similarity_weight": vector_similarity_weight | |||
| } | |||
| @@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va | |||
| from api.db import StatusEnum, LLMType | |||
| from api.db.db_models import TenantLLM | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.llm import EmbeddingModel, ChatModel | |||
| from rag.llm import EmbeddingModel, ChatModel, RerankModel | |||
| @manager.route('/factories', methods=['GET']) | |||
| @@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel | |||
| def factories(): | |||
| try: | |||
| fac = LLMFactoriesService.get_all() | |||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) | |||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -64,6 +64,16 @@ def set_api_key(): | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||
| e) | |||
| elif llm.model_type == LLMType.RERANK: | |||
| mdl = RerankModel[factory]( | |||
| req["api_key"], llm.llm_name, base_url=req.get("base_url")) | |||
| try: | |||
| m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) | |||
| if len(arr[0]) == 0 or tc == 0: | |||
| raise Exception("Fail") | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||
| e) | |||
| if msg: | |||
| return get_data_error_result(retmsg=msg) | |||
| @@ -199,7 +209,7 @@ def list_app(): | |||
| llms = [m.to_dict() | |||
| for m in llms if m.status == StatusEnum.VALID.value] | |||
| for m in llms: | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"] | |||
| llm_set = set([m["llm_name"] for m in llms]) | |||
| for o in objs: | |||
| @@ -26,8 +26,9 @@ from api.db.services.llm_service import TenantLLMService, LLMService | |||
| from api.utils.api_utils import server_error_response, validate_request | |||
| from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | |||
| from api.db import UserTenantRole, LLMType, FileType | |||
| from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ | |||
| LLM_FACTORY, LLM_BASE_URL | |||
| from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ | |||
| API_KEY, \ | |||
| LLM_FACTORY, LLM_BASE_URL, RERANK_MDL | |||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from api.db.services.file_service import FileService | |||
| from api.settings import stat_logger | |||
| @@ -288,7 +289,8 @@ def user_register(user_id, user): | |||
| "embd_id": EMBEDDING_MDL, | |||
| "asr_id": ASR_MDL, | |||
| "parser_ids": PARSERS, | |||
| "img2txt_id": IMAGE2TEXT_MDL | |||
| "img2txt_id": IMAGE2TEXT_MDL, | |||
| "rerank_id": RERANK_MDL | |||
| } | |||
| usr_tenant = { | |||
| "tenant_id": user_id, | |||
| @@ -54,6 +54,7 @@ class LLMType(StrEnum): | |||
| EMBEDDING = 'embedding' | |||
| SPEECH2TEXT = 'speech2text' | |||
| IMAGE2TEXT = 'image2text' | |||
| RERANK = 'rerank' | |||
| class ChatStyle(StrEnum): | |||
| @@ -437,6 +437,10 @@ class Tenant(DataBaseModel): | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default image to text model ID") | |||
| rerank_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default rerank model ID") | |||
| parser_ids = CharField( | |||
| max_length=256, | |||
| null=False, | |||
| @@ -771,11 +775,16 @@ class Dialog(DataBaseModel): | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| top_n = IntegerField(default=6) | |||
| top_k = IntegerField(default=1024) | |||
| do_refer = CharField( | |||
| max_length=1, | |||
| null=False, | |||
| help_text="it needs to insert reference index into answer or not", | |||
| default="1") | |||
| rerank_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default rerank model ID") | |||
| kb_ids = JSONField(null=False, default=[]) | |||
| status = CharField( | |||
| @@ -825,11 +834,29 @@ class API4Conversation(DataBaseModel): | |||
| def migrate_db(): | |||
| try: | |||
| with DB.transaction(): | |||
| migrator = MySQLMigrator(DB) | |||
| migrate( | |||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) | |||
| ) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| migrate( | |||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) | |||
| ) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| migrate( | |||
| migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")) | |||
| ) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| migrate( | |||
| migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID")) | |||
| ) | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| migrate( | |||
| migrator.add_column('dialog', 'top_k', IntegerField(default=1024)) | |||
| ) | |||
| except Exception as e: | |||
| pass | |||
| @@ -142,7 +142,17 @@ factory_infos = [{ | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING", | |||
| "status": "1", | |||
| }, | |||
| },{ | |||
| "name": "Jina", | |||
| "logo": "", | |||
| "tags": "TEXT EMBEDDING, TEXT RE-RANK", | |||
| "status": "1", | |||
| },{ | |||
| "name": "BAAI", | |||
| "logo": "", | |||
| "tags": "TEXT EMBEDDING, TEXT RE-RANK", | |||
| "status": "1", | |||
| } | |||
| # { | |||
| # "name": "文心一言", | |||
| # "logo": "", | |||
| @@ -367,6 +377,13 @@ def init_llm_factory(): | |||
| "max_tokens": 512, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[7]["name"], | |||
| "llm_name": "maidalun1020/bce-reranker-base_v1", | |||
| "tags": "RE-RANK, 8K", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| # ------------------------ DeepSeek ----------------------- | |||
| { | |||
| "fid": factory_infos[8]["name"], | |||
| @@ -440,6 +457,85 @@ def init_llm_factory(): | |||
| "max_tokens": 512, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| # ------------------------ Jina ----------------------- | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-reranker-v1-base-en", | |||
| "tags": "RE-RANK,8k", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-reranker-v1-turbo-en", | |||
| "tags": "RE-RANK,8k", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-reranker-v1-tiny-en", | |||
| "tags": "RE-RANK,8k", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-colbert-v1-en", | |||
| "tags": "RE-RANK,8k", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-embeddings-v2-base-en", | |||
| "tags": "TEXT EMBEDDING", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-embeddings-v2-base-de", | |||
| "tags": "TEXT EMBEDDING", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-embeddings-v2-base-es", | |||
| "tags": "TEXT EMBEDDING", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-embeddings-v2-base-code", | |||
| "tags": "TEXT EMBEDDING", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[11]["name"], | |||
| "llm_name": "jina-embeddings-v2-base-zh", | |||
| "tags": "TEXT EMBEDDING", | |||
| "max_tokens": 8196, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| # ------------------------ BAAI ----------------------- | |||
| { | |||
| "fid": factory_infos[12]["name"], | |||
| "llm_name": "BAAI/bge-large-zh-v1.5", | |||
| "tags": "TEXT EMBEDDING,", | |||
| "max_tokens": 1024, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[12]["name"], | |||
| "llm_name": "BAAI/bge-reranker-v2-m3", | |||
| "tags": "LLM,CHAT,", | |||
| "max_tokens": 16385, | |||
| "model_type": LLMType.RERANK.value | |||
| }, | |||
| ] | |||
| for info in factory_infos: | |||
| try: | |||
| @@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: | |||
| kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} | |||
| else: | |||
| rerank_mdl = None | |||
| if dialog.rerank_id: | |||
| rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) | |||
| kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | |||
| dialog.similarity_threshold, | |||
| dialog.vector_similarity_weight, | |||
| doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | |||
| top=1024, aggs=False) | |||
| top=1024, aggs=False, rerank_mdl=rerank_mdl) | |||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||
| chat_logger.info( | |||
| "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | |||
| @@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| kwargs["knowledge"] = "\n".join(knowledges) | |||
| gen_conf = dialog.llm_setting | |||
| msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] | |||
| msg.extend([{"role": m["role"], "content": m["content"]} | |||
| for m in messages if m["role"] != "system"]) | |||
| @@ -15,7 +15,7 @@ | |||
| # | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import database_logger | |||
| from rag.llm import EmbeddingModel, CvModel, ChatModel | |||
| from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | |||
| @@ -73,21 +73,25 @@ class TenantLLMService(CommonService): | |||
| mdlnm = tenant.img2txt_id | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.RERANK: | |||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| if model_config: model_config = model_config.to_dict() | |||
| if not model_config: | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||
| llm = LLMService.query(llm_name=llm_name) | |||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]: | |||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||
| model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | |||
| if not model_config: | |||
| if llm_name == "flag-embedding": | |||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", | |||
| "llm_name": llm_name, "api_base": ""} | |||
| else: | |||
| if not mdlnm: | |||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| @@ -96,6 +100,12 @@ class TenantLLMService(CommonService): | |||
| return EmbeddingModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.RERANK: | |||
| if model_config["llm_factory"] not in RerankModel: | |||
| return | |||
| return RerankModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||
| if model_config["llm_factory"] not in CvModel: | |||
| return | |||
| @@ -125,14 +135,20 @@ class TenantLLMService(CommonService): | |||
| mdlnm = tenant.img2txt_id | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| elif llm_type == LLMType.RERANK: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| num = 0 | |||
| for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): | |||
| num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||
| .execute() | |||
| try: | |||
| for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): | |||
| num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||
| .execute() | |||
| except Exception as e: | |||
| print(e) | |||
| pass | |||
| return num | |||
| @classmethod | |||
| @@ -176,6 +192,14 @@ class LLMBundle(object): | |||
| "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| return emd, used_tokens | |||
| def similarity(self, query: str, texts: list): | |||
| sim, used_tokens = self.mdl.similarity(query, texts) | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/RERANK".format(self.tenant_id)) | |||
| return sim, used_tokens | |||
| def describe(self, image, max_tokens=300): | |||
| txt, used_tokens = self.mdl.describe(image, max_tokens) | |||
| if not TenantLLMService.increase_usage( | |||
| @@ -93,6 +93,7 @@ class TenantService(CommonService): | |||
| cls.model.name, | |||
| cls.model.llm_id, | |||
| cls.model.embd_id, | |||
| cls.model.rerank_id, | |||
| cls.model.asr_id, | |||
| cls.model.img2txt_id, | |||
| cls.model.parser_ids, | |||
| @@ -89,9 +89,22 @@ default_llm = { | |||
| }, | |||
| "DeepSeek": { | |||
| "chat_model": "deepseek-chat", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "VolcEngine": { | |||
| "chat_model": "", | |||
| "embedding_model": "", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| }, | |||
| "BAAI": { | |||
| "chat_model": "", | |||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | |||
| "image2text_model": "", | |||
| "asr_model": "", | |||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||
| } | |||
| } | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| @@ -104,7 +117,8 @@ if LLM_FACTORY not in default_llm: | |||
| f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") | |||
| LLM_FACTORY = "Tongyi-Qianwen" | |||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] | |||
| EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] | |||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] | |||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] | |||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] | |||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | |||
| @@ -16,18 +16,19 @@ | |||
| from .embedding_model import * | |||
| from .chat_model import * | |||
| from .cv_model import * | |||
| from .rerank_model import * | |||
| EmbeddingModel = { | |||
| "Ollama": OllamaEmbed, | |||
| "OpenAI": OpenAIEmbed, | |||
| "Xinference": XinferenceEmbed, | |||
| "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed, | |||
| "Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed, | |||
| "ZHIPU-AI": ZhipuEmbed, | |||
| "FastEmbed": FastEmbed, | |||
| "Youdao": YoudaoEmbed, | |||
| "DeepSeek": DefaultEmbedding, | |||
| "BaiChuan": BaiChuanEmbed | |||
| "BaiChuan": BaiChuanEmbed, | |||
| "BAAI": DefaultEmbedding | |||
| } | |||
| @@ -52,3 +53,9 @@ ChatModel = { | |||
| "BaiChuan": BaiChuanChat | |||
| } | |||
| RerankModel = { | |||
| "BAAI": DefaultRerank, | |||
| "Jina": JinaRerank, | |||
| "Youdao": YoudaoRerank, | |||
| } | |||
| @@ -13,8 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| from typing import Optional | |||
| import requests | |||
| from huggingface_hub import snapshot_download | |||
| from zhipuai import ZhipuAI | |||
| import os | |||
| @@ -26,21 +28,9 @@ from FlagEmbedding import FlagModel | |||
| import torch | |||
| import numpy as np | |||
| from api.utils.file_utils import get_project_base_directory, get_home_cache_dir | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| try: | |||
| flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), | |||
| local_dir_use_symlinks=False) | |||
| flag_model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name): | |||
| @@ -54,7 +44,9 @@ class Base(ABC): | |||
| class DefaultEmbedding(Base): | |||
| def __init__(self, *args, **kwargs): | |||
| _model = None | |||
| def __init__(self, key, model_name, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -66,7 +58,18 @@ class DefaultEmbedding(Base): | |||
| ^_- | |||
| """ | |||
| self.model = flag_model | |||
| if not DefaultEmbedding._model: | |||
| try: | |||
| self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| self._model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| def encode(self, texts: list, batch_size=32): | |||
| texts = [truncate(t, 2048) for t in texts] | |||
| @@ -75,12 +78,12 @@ class DefaultEmbedding(Base): | |||
| token_count += num_tokens_from_string(t) | |||
| res = [] | |||
| for i in range(0, len(texts), batch_size): | |||
| res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) | |||
| res.extend(self._model.encode(texts[i:i + batch_size]).tolist()) | |||
| return np.array(res), token_count | |||
| def encode_queries(self, text: str): | |||
| token_count = num_tokens_from_string(text) | |||
| return self.model.encode_queries([text]).tolist()[0], token_count | |||
| return self._model.encode_queries([text]).tolist()[0], token_count | |||
| class OpenAIEmbed(Base): | |||
| @@ -189,16 +192,19 @@ class OllamaEmbed(Base): | |||
| class FastEmbed(Base): | |||
| _model = None | |||
| def __init__( | |||
| self, | |||
| key: Optional[str] = None, | |||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||
| cache_dir: Optional[str] = None, | |||
| threads: Optional[int] = None, | |||
| **kwargs, | |||
| self, | |||
| key: Optional[str] = None, | |||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||
| cache_dir: Optional[str] = None, | |||
| threads: Optional[int] = None, | |||
| **kwargs, | |||
| ): | |||
| from fastembed import TextEmbedding | |||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||
| if not FastEmbed._model: | |||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||
| def encode(self, texts: list, batch_size=32): | |||
| # Using the internal tokenizer to encode the texts and get the total | |||
| @@ -265,3 +271,29 @@ class YoudaoEmbed(Base): | |||
| def encode_queries(self, text): | |||
| embds = YoudaoEmbed._client.encode([text]) | |||
| return np.array(embds[0]), num_tokens_from_string(text) | |||
| class JinaEmbed(Base): | |||
| def __init__(self, key, model_name="jina-embeddings-v2-base-zh", | |||
| base_url="https://api.jina.ai/v1/embeddings"): | |||
| self.base_url = "https://api.jina.ai/v1/embeddings" | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.model_name = model_name | |||
| def encode(self, texts: list, batch_size=None): | |||
| texts = [truncate(t, 8196) for t in texts] | |||
| data = { | |||
| "model": self.model_name, | |||
| "input": texts, | |||
| 'encoding_type': 'float' | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||
| return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] | |||
| def encode_queries(self, text): | |||
| embds, cnt = self.encode([text]) | |||
| return np.array(embds[0]), cnt | |||
| @@ -0,0 +1,113 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| import requests | |||
| import torch | |||
| from FlagEmbedding import FlagReranker | |||
| from huggingface_hub import snapshot_download | |||
| import os | |||
| from abc import ABC | |||
| import numpy as np | |||
| from api.utils.file_utils import get_home_cache_dir | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name): | |||
| pass | |||
| def similarity(self, query: str, texts: list): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| class DefaultRerank(Base): | |||
| _model = None | |||
| def __init__(self, key, model_name, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| For Linux: | |||
| export HF_ENDPOINT=https://hf-mirror.com | |||
| For Windows: | |||
| Good luck | |||
| ^_- | |||
| """ | |||
| if not DefaultRerank._model: | |||
| try: | |||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| use_fp16=torch.cuda.is_available()) | |||
| except Exception as e: | |||
| self._model = snapshot_download(repo_id=model_name, | |||
| local_dir=os.path.join(get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||
| local_dir_use_symlinks=False) | |||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), | |||
| use_fp16=torch.cuda.is_available()) | |||
| def similarity(self, query: str, texts: list): | |||
| pairs = [(query,truncate(t, 2048)) for t in texts] | |||
| token_count = 0 | |||
| for _, t in pairs: | |||
| token_count += num_tokens_from_string(t) | |||
| batch_size = 32 | |||
| res = [] | |||
| for i in range(0, len(pairs), batch_size): | |||
| scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) | |||
| res.extend(scores) | |||
| return np.array(res), token_count | |||
| class JinaRerank(Base): | |||
| def __init__(self, key, model_name="jina-reranker-v1-base-en", | |||
| base_url="https://api.jina.ai/v1/rerank"): | |||
| self.base_url = "https://api.jina.ai/v1/rerank" | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {key}" | |||
| } | |||
| self.model_name = model_name | |||
| def similarity(self, query: str, texts: list): | |||
| texts = [truncate(t, 8196) for t in texts] | |||
| data = { | |||
| "model": self.model_name, | |||
| "query": query, | |||
| "documents": texts, | |||
| "top_n": len(texts) | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||
| return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] | |||
| class YoudaoRerank(DefaultRerank): | |||
| _model = None | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | |||
| from BCEmbedding import RerankerModel | |||
| if not YoudaoRerank._model: | |||
| try: | |||
| print("LOADING BCE...") | |||
| YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( | |||
| get_home_cache_dir(), | |||
| re.sub(r"^[a-zA-Z]+/", "", model_name))) | |||
| except Exception as e: | |||
| YoudaoRerank._model = RerankerModel( | |||
| model_name_or_path=model_name.replace( | |||
| "maidalun1020", "InfiniFlow")) | |||
| @@ -54,7 +54,8 @@ class EsQueryer: | |||
| if not self.isChinese(txt): | |||
| tks = rag_tokenizer.tokenize(txt).split(" ") | |||
| tks_w = self.tw.weights(tks) | |||
| q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w] | |||
| tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w] | |||
| q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] | |||
| for i in range(1, len(tks_w)): | |||
| q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) | |||
| if not q: | |||
| @@ -136,7 +137,11 @@ class EsQueryer: | |||
| from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity | |||
| import numpy as np | |||
| sims = CosineSimilarity([avec], bvecs) | |||
| tksim = self.token_similarity(atks, btkss) | |||
| return np.array(sims[0]) * vtweight + \ | |||
| np.array(tksim) * tkweight, tksim, sims[0] | |||
| def token_similarity(self, atks, btkss): | |||
| def toDict(tks): | |||
| d = {} | |||
| if isinstance(tks, str): | |||
| @@ -149,9 +154,7 @@ class EsQueryer: | |||
| atks = toDict(atks) | |||
| btkss = [toDict(tks) for tks in btkss] | |||
| tksim = [self.similarity(atks, btks) for btks in btkss] | |||
| return np.array(sims[0]) * vtweight + \ | |||
| np.array(tksim) * tkweight, tksim, sims[0] | |||
| return [self.similarity(atks, btks) for btks in btkss] | |||
| def similarity(self, qtwt, dtwt): | |||
| if isinstance(dtwt, type("")): | |||
| @@ -241,11 +241,14 @@ class RagTokenizer: | |||
| return self.score_(res[::-1]) | |||
| def english_normalize_(self, tks): | |||
| return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] | |||
| def tokenize(self, line): | |||
| line = self._strQ2B(line).lower() | |||
| line = self._tradi2simp(line) | |||
| zh_num = len([1 for c in line if is_chinese(c)]) | |||
| if zh_num < len(line) * 0.2: | |||
| if zh_num == 0: | |||
| return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)]) | |||
| arr = re.split(self.SPLIT_CHAR, line) | |||
| @@ -293,7 +296,7 @@ class RagTokenizer: | |||
| i = e + 1 | |||
| res = " ".join(res) | |||
| res = " ".join(self.english_normalize_(res)) | |||
| if self.DEBUG: | |||
| print("[TKS]", self.merge_(res)) | |||
| return self.merge_(res) | |||
| @@ -336,7 +339,7 @@ class RagTokenizer: | |||
| res.append(stk) | |||
| return " ".join(res) | |||
| return " ".join(self.english_normalize_(res)) | |||
| def is_chinese(s): | |||
| @@ -71,8 +71,8 @@ class Dealer: | |||
| s = Search() | |||
| pg = int(req.get("page", 1)) - 1 | |||
| ps = int(req.get("size", 1000)) | |||
| topk = int(req.get("topk", 1024)) | |||
| ps = int(req.get("size", topk)) | |||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | |||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", | |||
| "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | |||
| @@ -311,6 +311,26 @@ class Dealer: | |||
| ins_tw, tkweight, vtweight) | |||
| return sim, tksim, vtsim | |||
| def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, | |||
| vtweight=0.7, cfield="content_ltks"): | |||
| _, keywords = self.qryr.question(query) | |||
| for i in sres.ids: | |||
| if isinstance(sres.field[i].get("important_kwd", []), str): | |||
| sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] | |||
| ins_tw = [] | |||
| for i in sres.ids: | |||
| content_ltks = sres.field[i][cfield].split(" ") | |||
| title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t] | |||
| important_kwd = sres.field[i].get("important_kwd", []) | |||
| tks = content_ltks + title_tks + important_kwd | |||
| ins_tw.append(tks) | |||
| tksim = self.qryr.token_similarity(keywords, ins_tw) | |||
| vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw]) | |||
| return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim | |||
| def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): | |||
| return self.qryr.hybrid_similarity(ans_embd, | |||
| ins_embd, | |||
| @@ -318,17 +338,22 @@ class Dealer: | |||
| rag_tokenizer.tokenize(inst).split(" ")) | |||
| def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | |||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | |||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None): | |||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | |||
| if not question: | |||
| return ranks | |||
| req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, | |||
| "question": question, "vector": True, "topk": top, | |||
| "similarity": similarity_threshold} | |||
| "similarity": similarity_threshold, | |||
| "available_int": 1} | |||
| sres = self.search(req, index_name(tenant_id), embd_mdl) | |||
| sim, tsim, vsim = self.rerank( | |||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||
| if rerank_mdl: | |||
| sim, tsim, vsim = self.rerank_by_model(rerank_mdl, | |||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||
| else: | |||
| sim, tsim, vsim = self.rerank( | |||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||
| idx = np.argsort(sim * -1) | |||
| dim = len(sres.query_vector) | |||