### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.7.0
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | ||||
| ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, | |||||
| vector_similarity_weight, top, doc_ids) | |||||
| rerank_mdl = None | |||||
| if req.get("rerank_id"): | |||||
| rerank_mdl = TenantLLMService.model_instance( | |||||
| kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | |||||
| ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, | |||||
| similarity_threshold, vector_similarity_weight, top, | |||||
| doc_ids, rerank_mdl=rerank_mdl) | |||||
| for c in ranks["chunks"]: | for c in ranks["chunks"]: | ||||
| if "vector" in c: | if "vector" in c: | ||||
| del c["vector"] | del c["vector"] |
| name = req.get("name", "New Dialog") | name = req.get("name", "New Dialog") | ||||
| description = req.get("description", "A helpful Dialog") | description = req.get("description", "A helpful Dialog") | ||||
| top_n = req.get("top_n", 6) | top_n = req.get("top_n", 6) | ||||
| top_k = req.get("top_k", 1024) | |||||
| rerank_id = req.get("rerank_id", "") | |||||
| if not rerank_id: req["rerank_id"] = "" | |||||
| similarity_threshold = req.get("similarity_threshold", 0.1) | similarity_threshold = req.get("similarity_threshold", 0.1) | ||||
| vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | ||||
| llm_setting = req.get("llm_setting", {}) | llm_setting = req.get("llm_setting", {}) | ||||
| "llm_setting": llm_setting, | "llm_setting": llm_setting, | ||||
| "prompt_config": prompt_config, | "prompt_config": prompt_config, | ||||
| "top_n": top_n, | "top_n": top_n, | ||||
| "top_k": top_k, | |||||
| "rerank_id": rerank_id, | |||||
| "similarity_threshold": similarity_threshold, | "similarity_threshold": similarity_threshold, | ||||
| "vector_similarity_weight": vector_similarity_weight | "vector_similarity_weight": vector_similarity_weight | ||||
| } | } |
| from api.db import StatusEnum, LLMType | from api.db import StatusEnum, LLMType | ||||
| from api.db.db_models import TenantLLM | from api.db.db_models import TenantLLM | ||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| from rag.llm import EmbeddingModel, ChatModel | |||||
| from rag.llm import EmbeddingModel, ChatModel, RerankModel | |||||
| @manager.route('/factories', methods=['GET']) | @manager.route('/factories', methods=['GET']) | ||||
| def factories(): | def factories(): | ||||
| try: | try: | ||||
| fac = LLMFactoriesService.get_all() | fac = LLMFactoriesService.get_all() | ||||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) | |||||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| except Exception as e: | except Exception as e: | ||||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | ||||
| e) | e) | ||||
| elif llm.model_type == LLMType.RERANK: | |||||
| mdl = RerankModel[factory]( | |||||
| req["api_key"], llm.llm_name, base_url=req.get("base_url")) | |||||
| try: | |||||
| m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) | |||||
| if len(arr[0]) == 0 or tc == 0: | |||||
| raise Exception("Fail") | |||||
| except Exception as e: | |||||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||||
| e) | |||||
| if msg: | if msg: | ||||
| return get_data_error_result(retmsg=msg) | return get_data_error_result(retmsg=msg) | ||||
| llms = [m.to_dict() | llms = [m.to_dict() | ||||
| for m in llms if m.status == StatusEnum.VALID.value] | for m in llms if m.status == StatusEnum.VALID.value] | ||||
| for m in llms: | for m in llms: | ||||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] | |||||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"] | |||||
| llm_set = set([m["llm_name"] for m in llms]) | llm_set = set([m["llm_name"] for m in llms]) | ||||
| for o in objs: | for o in objs: |
| from api.utils.api_utils import server_error_response, validate_request | from api.utils.api_utils import server_error_response, validate_request | ||||
| from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | ||||
| from api.db import UserTenantRole, LLMType, FileType | from api.db import UserTenantRole, LLMType, FileType | ||||
| from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ | |||||
| LLM_FACTORY, LLM_BASE_URL | |||||
| from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ | |||||
| API_KEY, \ | |||||
| LLM_FACTORY, LLM_BASE_URL, RERANK_MDL | |||||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | from api.db.services.user_service import UserService, TenantService, UserTenantService | ||||
| from api.db.services.file_service import FileService | from api.db.services.file_service import FileService | ||||
| from api.settings import stat_logger | from api.settings import stat_logger | ||||
| "embd_id": EMBEDDING_MDL, | "embd_id": EMBEDDING_MDL, | ||||
| "asr_id": ASR_MDL, | "asr_id": ASR_MDL, | ||||
| "parser_ids": PARSERS, | "parser_ids": PARSERS, | ||||
| "img2txt_id": IMAGE2TEXT_MDL | |||||
| "img2txt_id": IMAGE2TEXT_MDL, | |||||
| "rerank_id": RERANK_MDL | |||||
| } | } | ||||
| usr_tenant = { | usr_tenant = { | ||||
| "tenant_id": user_id, | "tenant_id": user_id, |
| EMBEDDING = 'embedding' | EMBEDDING = 'embedding' | ||||
| SPEECH2TEXT = 'speech2text' | SPEECH2TEXT = 'speech2text' | ||||
| IMAGE2TEXT = 'image2text' | IMAGE2TEXT = 'image2text' | ||||
| RERANK = 'rerank' | |||||
| class ChatStyle(StrEnum): | class ChatStyle(StrEnum): |
| max_length=128, | max_length=128, | ||||
| null=False, | null=False, | ||||
| help_text="default image to text model ID") | help_text="default image to text model ID") | ||||
| rerank_id = CharField( | |||||
| max_length=128, | |||||
| null=False, | |||||
| help_text="default rerank model ID") | |||||
| parser_ids = CharField( | parser_ids = CharField( | ||||
| max_length=256, | max_length=256, | ||||
| null=False, | null=False, | ||||
| similarity_threshold = FloatField(default=0.2) | similarity_threshold = FloatField(default=0.2) | ||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| top_n = IntegerField(default=6) | top_n = IntegerField(default=6) | ||||
| top_k = IntegerField(default=1024) | |||||
| do_refer = CharField( | do_refer = CharField( | ||||
| max_length=1, | max_length=1, | ||||
| null=False, | null=False, | ||||
| help_text="it needs to insert reference index into answer or not", | help_text="it needs to insert reference index into answer or not", | ||||
| default="1") | default="1") | ||||
| rerank_id = CharField( | |||||
| max_length=128, | |||||
| null=False, | |||||
| help_text="default rerank model ID") | |||||
| kb_ids = JSONField(null=False, default=[]) | kb_ids = JSONField(null=False, default=[]) | ||||
| status = CharField( | status = CharField( | ||||
| def migrate_db(): | def migrate_db(): | ||||
| try: | |||||
| with DB.transaction(): | with DB.transaction(): | ||||
| migrator = MySQLMigrator(DB) | migrator = MySQLMigrator(DB) | ||||
| migrate( | |||||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) | |||||
| ) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| migrate( | |||||
| migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) | |||||
| ) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| migrate( | |||||
| migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")) | |||||
| ) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| migrate( | |||||
| migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID")) | |||||
| ) | |||||
| except Exception as e: | |||||
| pass | |||||
| try: | |||||
| migrate( | |||||
| migrator.add_column('dialog', 'top_k', IntegerField(default=1024)) | |||||
| ) | |||||
| except Exception as e: | |||||
| pass |
| "logo": "", | "logo": "", | ||||
| "tags": "LLM,TEXT EMBEDDING", | "tags": "LLM,TEXT EMBEDDING", | ||||
| "status": "1", | "status": "1", | ||||
| }, | |||||
| },{ | |||||
| "name": "Jina", | |||||
| "logo": "", | |||||
| "tags": "TEXT EMBEDDING, TEXT RE-RANK", | |||||
| "status": "1", | |||||
| },{ | |||||
| "name": "BAAI", | |||||
| "logo": "", | |||||
| "tags": "TEXT EMBEDDING, TEXT RE-RANK", | |||||
| "status": "1", | |||||
| } | |||||
| # { | # { | ||||
| # "name": "文心一言", | # "name": "文心一言", | ||||
| # "logo": "", | # "logo": "", | ||||
| "max_tokens": 512, | "max_tokens": 512, | ||||
| "model_type": LLMType.EMBEDDING.value | "model_type": LLMType.EMBEDDING.value | ||||
| }, | }, | ||||
| { | |||||
| "fid": factory_infos[7]["name"], | |||||
| "llm_name": "maidalun1020/bce-reranker-base_v1", | |||||
| "tags": "RE-RANK, 8K", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| # ------------------------ DeepSeek ----------------------- | # ------------------------ DeepSeek ----------------------- | ||||
| { | { | ||||
| "fid": factory_infos[8]["name"], | "fid": factory_infos[8]["name"], | ||||
| "max_tokens": 512, | "max_tokens": 512, | ||||
| "model_type": LLMType.EMBEDDING.value | "model_type": LLMType.EMBEDDING.value | ||||
| }, | }, | ||||
| # ------------------------ Jina ----------------------- | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-reranker-v1-base-en", | |||||
| "tags": "RE-RANK,8k", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-reranker-v1-turbo-en", | |||||
| "tags": "RE-RANK,8k", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-reranker-v1-tiny-en", | |||||
| "tags": "RE-RANK,8k", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-colbert-v1-en", | |||||
| "tags": "RE-RANK,8k", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-embeddings-v2-base-en", | |||||
| "tags": "TEXT EMBEDDING", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-embeddings-v2-base-de", | |||||
| "tags": "TEXT EMBEDDING", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-embeddings-v2-base-es", | |||||
| "tags": "TEXT EMBEDDING", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-embeddings-v2-base-code", | |||||
| "tags": "TEXT EMBEDDING", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[11]["name"], | |||||
| "llm_name": "jina-embeddings-v2-base-zh", | |||||
| "tags": "TEXT EMBEDDING", | |||||
| "max_tokens": 8196, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| # ------------------------ BAAI ----------------------- | |||||
| { | |||||
| "fid": factory_infos[12]["name"], | |||||
| "llm_name": "BAAI/bge-large-zh-v1.5", | |||||
| "tags": "TEXT EMBEDDING,", | |||||
| "max_tokens": 1024, | |||||
| "model_type": LLMType.EMBEDDING.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[12]["name"], | |||||
| "llm_name": "BAAI/bge-reranker-v2-m3", | |||||
| "tags": "LLM,CHAT,", | |||||
| "max_tokens": 16385, | |||||
| "model_type": LLMType.RERANK.value | |||||
| }, | |||||
| ] | ] | ||||
| for info in factory_infos: | for info in factory_infos: | ||||
| try: | try: |
| if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: | if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: | ||||
| kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} | kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} | ||||
| else: | else: | ||||
| rerank_mdl = None | |||||
| if dialog.rerank_id: | |||||
| rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) | |||||
| kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | ||||
| dialog.similarity_threshold, | dialog.similarity_threshold, | ||||
| dialog.vector_similarity_weight, | dialog.vector_similarity_weight, | ||||
| doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, | ||||
| top=1024, aggs=False) | |||||
| top=1024, aggs=False, rerank_mdl=rerank_mdl) | |||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | ||||
| chat_logger.info( | chat_logger.info( | ||||
| "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | ||||
| kwargs["knowledge"] = "\n".join(knowledges) | kwargs["knowledge"] = "\n".join(knowledges) | ||||
| gen_conf = dialog.llm_setting | gen_conf = dialog.llm_setting | ||||
| msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] | msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] | ||||
| msg.extend([{"role": m["role"], "content": m["content"]} | msg.extend([{"role": m["role"], "content": m["content"]} | ||||
| for m in messages if m["role"] != "system"]) | for m in messages if m["role"] != "system"]) |
| # | # | ||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.settings import database_logger | from api.settings import database_logger | ||||
| from rag.llm import EmbeddingModel, CvModel, ChatModel | |||||
| from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel | |||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.db_models import DB, UserTenant | from api.db.db_models import DB, UserTenant | ||||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | from api.db.db_models import LLMFactories, LLM, TenantLLM | ||||
| mdlnm = tenant.img2txt_id | mdlnm = tenant.img2txt_id | ||||
| elif llm_type == LLMType.CHAT.value: | elif llm_type == LLMType.CHAT.value: | ||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | mdlnm = tenant.llm_id if not llm_name else llm_name | ||||
| elif llm_type == LLMType.RERANK: | |||||
| mdlnm = tenant.rerank_id if not llm_name else llm_name | |||||
| else: | else: | ||||
| assert False, "LLM type error" | assert False, "LLM type error" | ||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | model_config = cls.get_api_key(tenant_id, mdlnm) | ||||
| if model_config: model_config = model_config.to_dict() | if model_config: model_config = model_config.to_dict() | ||||
| if not model_config: | if not model_config: | ||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: | |||||
| llm = LLMService.query(llm_name=llm_name) | llm = LLMService.query(llm_name=llm_name) | ||||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]: | |||||
| if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: | |||||
| model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | ||||
| if not model_config: | if not model_config: | ||||
| if llm_name == "flag-embedding": | if llm_name == "flag-embedding": | ||||
| model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", | model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", | ||||
| "llm_name": llm_name, "api_base": ""} | "llm_name": llm_name, "api_base": ""} | ||||
| else: | else: | ||||
| if not mdlnm: | |||||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | raise LookupError("Model({}) not authorized".format(mdlnm)) | ||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| return EmbeddingModel[model_config["llm_factory"]]( | return EmbeddingModel[model_config["llm_factory"]]( | ||||
| model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | ||||
| if llm_type == LLMType.RERANK: | |||||
| if model_config["llm_factory"] not in RerankModel: | |||||
| return | |||||
| return RerankModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT.value: | if llm_type == LLMType.IMAGE2TEXT.value: | ||||
| if model_config["llm_factory"] not in CvModel: | if model_config["llm_factory"] not in CvModel: | ||||
| return | return | ||||
| mdlnm = tenant.img2txt_id | mdlnm = tenant.img2txt_id | ||||
| elif llm_type == LLMType.CHAT.value: | elif llm_type == LLMType.CHAT.value: | ||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | mdlnm = tenant.llm_id if not llm_name else llm_name | ||||
| elif llm_type == LLMType.RERANK: | |||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||||
| else: | else: | ||||
| assert False, "LLM type error" | assert False, "LLM type error" | ||||
| num = 0 | num = 0 | ||||
| for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): | |||||
| num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ | |||||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||||
| .execute() | |||||
| try: | |||||
| for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): | |||||
| num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ | |||||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||||
| .execute() | |||||
| except Exception as e: | |||||
| print(e) | |||||
| pass | |||||
| return num | return num | ||||
| @classmethod | @classmethod | ||||
| "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | ||||
| return emd, used_tokens | return emd, used_tokens | ||||
| def similarity(self, query: str, texts: list): | |||||
| sim, used_tokens = self.mdl.similarity(query, texts) | |||||
| if not TenantLLMService.increase_usage( | |||||
| self.tenant_id, self.llm_type, used_tokens): | |||||
| database_logger.error( | |||||
| "Can't update token usage for {}/RERANK".format(self.tenant_id)) | |||||
| return sim, used_tokens | |||||
| def describe(self, image, max_tokens=300): | def describe(self, image, max_tokens=300): | ||||
| txt, used_tokens = self.mdl.describe(image, max_tokens) | txt, used_tokens = self.mdl.describe(image, max_tokens) | ||||
| if not TenantLLMService.increase_usage( | if not TenantLLMService.increase_usage( |
| cls.model.name, | cls.model.name, | ||||
| cls.model.llm_id, | cls.model.llm_id, | ||||
| cls.model.embd_id, | cls.model.embd_id, | ||||
| cls.model.rerank_id, | |||||
| cls.model.asr_id, | cls.model.asr_id, | ||||
| cls.model.img2txt_id, | cls.model.img2txt_id, | ||||
| cls.model.parser_ids, | cls.model.parser_ids, |
| }, | }, | ||||
| "DeepSeek": { | "DeepSeek": { | ||||
| "chat_model": "deepseek-chat", | "chat_model": "deepseek-chat", | ||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "VolcEngine": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "", | |||||
| "image2text_model": "", | |||||
| "asr_model": "", | |||||
| }, | |||||
| "BAAI": { | |||||
| "chat_model": "", | |||||
| "embedding_model": "BAAI/bge-large-zh-v1.5", | "embedding_model": "BAAI/bge-large-zh-v1.5", | ||||
| "image2text_model": "", | "image2text_model": "", | ||||
| "asr_model": "", | "asr_model": "", | ||||
| "rerank_model": "BAAI/bge-reranker-v2-m3", | |||||
| } | } | ||||
| } | } | ||||
| LLM = get_base_config("user_default_llm", {}) | LLM = get_base_config("user_default_llm", {}) | ||||
| f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") | f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") | ||||
| LLM_FACTORY = "Tongyi-Qianwen" | LLM_FACTORY = "Tongyi-Qianwen" | ||||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] | CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] | ||||
| EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] | |||||
| EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] | |||||
| RERANK_MDL = default_llm["BAAI"]["rerank_model"] | |||||
| ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] | ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] | ||||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | ||||
| from .embedding_model import * | from .embedding_model import * | ||||
| from .chat_model import * | from .chat_model import * | ||||
| from .cv_model import * | from .cv_model import * | ||||
| from .rerank_model import * | |||||
| EmbeddingModel = { | EmbeddingModel = { | ||||
| "Ollama": OllamaEmbed, | "Ollama": OllamaEmbed, | ||||
| "OpenAI": OpenAIEmbed, | "OpenAI": OpenAIEmbed, | ||||
| "Xinference": XinferenceEmbed, | "Xinference": XinferenceEmbed, | ||||
| "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed, | |||||
| "Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed, | |||||
| "ZHIPU-AI": ZhipuEmbed, | "ZHIPU-AI": ZhipuEmbed, | ||||
| "FastEmbed": FastEmbed, | "FastEmbed": FastEmbed, | ||||
| "Youdao": YoudaoEmbed, | "Youdao": YoudaoEmbed, | ||||
| "DeepSeek": DefaultEmbedding, | |||||
| "BaiChuan": BaiChuanEmbed | |||||
| "BaiChuan": BaiChuanEmbed, | |||||
| "BAAI": DefaultEmbedding | |||||
| } | } | ||||
| "BaiChuan": BaiChuanChat | "BaiChuan": BaiChuanChat | ||||
| } | } | ||||
| RerankModel = { | |||||
| "BAAI": DefaultRerank, | |||||
| "Jina": JinaRerank, | |||||
| "Youdao": YoudaoRerank, | |||||
| } |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import re | |||||
| from typing import Optional | from typing import Optional | ||||
| import requests | |||||
| from huggingface_hub import snapshot_download | from huggingface_hub import snapshot_download | ||||
| from zhipuai import ZhipuAI | from zhipuai import ZhipuAI | ||||
| import os | import os | ||||
| import torch | import torch | ||||
| import numpy as np | import numpy as np | ||||
| from api.utils.file_utils import get_project_base_directory, get_home_cache_dir | |||||
| from api.utils.file_utils import get_home_cache_dir | |||||
| from rag.utils import num_tokens_from_string, truncate | from rag.utils import num_tokens_from_string, truncate | ||||
| try: | |||||
| flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||||
| local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), | |||||
| local_dir_use_symlinks=False) | |||||
| flag_model = FlagModel(model_dir, | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| class Base(ABC): | class Base(ABC): | ||||
| def __init__(self, key, model_name): | def __init__(self, key, model_name): | ||||
| class DefaultEmbedding(Base): | class DefaultEmbedding(Base): | ||||
| def __init__(self, *args, **kwargs): | |||||
| _model = None | |||||
| def __init__(self, key, model_name, **kwargs): | |||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| ^_- | ^_- | ||||
| """ | """ | ||||
| self.model = flag_model | |||||
| if not DefaultEmbedding._model: | |||||
| try: | |||||
| self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |||||
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| self._model = FlagModel(model_dir, | |||||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| texts = [truncate(t, 2048) for t in texts] | texts = [truncate(t, 2048) for t in texts] | ||||
| token_count += num_tokens_from_string(t) | token_count += num_tokens_from_string(t) | ||||
| res = [] | res = [] | ||||
| for i in range(0, len(texts), batch_size): | for i in range(0, len(texts), batch_size): | ||||
| res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) | |||||
| res.extend(self._model.encode(texts[i:i + batch_size]).tolist()) | |||||
| return np.array(res), token_count | return np.array(res), token_count | ||||
| def encode_queries(self, text: str): | def encode_queries(self, text: str): | ||||
| token_count = num_tokens_from_string(text) | token_count = num_tokens_from_string(text) | ||||
| return self.model.encode_queries([text]).tolist()[0], token_count | |||||
| return self._model.encode_queries([text]).tolist()[0], token_count | |||||
| class OpenAIEmbed(Base): | class OpenAIEmbed(Base): | ||||
| class FastEmbed(Base): | class FastEmbed(Base): | ||||
| _model = None | |||||
| def __init__( | def __init__( | ||||
| self, | |||||
| key: Optional[str] = None, | |||||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||||
| cache_dir: Optional[str] = None, | |||||
| threads: Optional[int] = None, | |||||
| **kwargs, | |||||
| self, | |||||
| key: Optional[str] = None, | |||||
| model_name: str = "BAAI/bge-small-en-v1.5", | |||||
| cache_dir: Optional[str] = None, | |||||
| threads: Optional[int] = None, | |||||
| **kwargs, | |||||
| ): | ): | ||||
| from fastembed import TextEmbedding | from fastembed import TextEmbedding | ||||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||||
| if not FastEmbed._model: | |||||
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |||||
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| # Using the internal tokenizer to encode the texts and get the total | # Using the internal tokenizer to encode the texts and get the total | ||||
| def encode_queries(self, text): | def encode_queries(self, text): | ||||
| embds = YoudaoEmbed._client.encode([text]) | embds = YoudaoEmbed._client.encode([text]) | ||||
| return np.array(embds[0]), num_tokens_from_string(text) | return np.array(embds[0]), num_tokens_from_string(text) | ||||
| class JinaEmbed(Base): | |||||
| def __init__(self, key, model_name="jina-embeddings-v2-base-zh", | |||||
| base_url="https://api.jina.ai/v1/embeddings"): | |||||
| self.base_url = "https://api.jina.ai/v1/embeddings" | |||||
| self.headers = { | |||||
| "Content-Type": "application/json", | |||||
| "Authorization": f"Bearer {key}" | |||||
| } | |||||
| self.model_name = model_name | |||||
| def encode(self, texts: list, batch_size=None): | |||||
| texts = [truncate(t, 8196) for t in texts] | |||||
| data = { | |||||
| "model": self.model_name, | |||||
| "input": texts, | |||||
| 'encoding_type': 'float' | |||||
| } | |||||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||||
| return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] | |||||
| def encode_queries(self, text): | |||||
| embds, cnt = self.encode([text]) | |||||
| return np.array(embds[0]), cnt |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import re | |||||
| import requests | |||||
| import torch | |||||
| from FlagEmbedding import FlagReranker | |||||
| from huggingface_hub import snapshot_download | |||||
| import os | |||||
| from abc import ABC | |||||
| import numpy as np | |||||
| from api.utils.file_utils import get_home_cache_dir | |||||
| from rag.utils import num_tokens_from_string, truncate | |||||
| class Base(ABC): | |||||
| def __init__(self, key, model_name): | |||||
| pass | |||||
| def similarity(self, query: str, texts: list): | |||||
| raise NotImplementedError("Please implement encode method!") | |||||
| class DefaultRerank(Base): | |||||
| _model = None | |||||
| def __init__(self, key, model_name, **kwargs): | |||||
| """ | |||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||||
| For Linux: | |||||
| export HF_ENDPOINT=https://hf-mirror.com | |||||
| For Windows: | |||||
| Good luck | |||||
| ^_- | |||||
| """ | |||||
| if not DefaultRerank._model: | |||||
| try: | |||||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| except Exception as e: | |||||
| self._model = snapshot_download(repo_id=model_name, | |||||
| local_dir=os.path.join(get_home_cache_dir(), | |||||
| re.sub(r"^[a-zA-Z]+/", "", model_name)), | |||||
| local_dir_use_symlinks=False) | |||||
| self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), | |||||
| use_fp16=torch.cuda.is_available()) | |||||
| def similarity(self, query: str, texts: list): | |||||
| pairs = [(query,truncate(t, 2048)) for t in texts] | |||||
| token_count = 0 | |||||
| for _, t in pairs: | |||||
| token_count += num_tokens_from_string(t) | |||||
| batch_size = 32 | |||||
| res = [] | |||||
| for i in range(0, len(pairs), batch_size): | |||||
| scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) | |||||
| res.extend(scores) | |||||
| return np.array(res), token_count | |||||
| class JinaRerank(Base): | |||||
| def __init__(self, key, model_name="jina-reranker-v1-base-en", | |||||
| base_url="https://api.jina.ai/v1/rerank"): | |||||
| self.base_url = "https://api.jina.ai/v1/rerank" | |||||
| self.headers = { | |||||
| "Content-Type": "application/json", | |||||
| "Authorization": f"Bearer {key}" | |||||
| } | |||||
| self.model_name = model_name | |||||
| def similarity(self, query: str, texts: list): | |||||
| texts = [truncate(t, 8196) for t in texts] | |||||
| data = { | |||||
| "model": self.model_name, | |||||
| "query": query, | |||||
| "documents": texts, | |||||
| "top_n": len(texts) | |||||
| } | |||||
| res = requests.post(self.base_url, headers=self.headers, json=data) | |||||
| return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] | |||||
| class YoudaoRerank(DefaultRerank): | |||||
| _model = None | |||||
| def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): | |||||
| from BCEmbedding import RerankerModel | |||||
| if not YoudaoRerank._model: | |||||
| try: | |||||
| print("LOADING BCE...") | |||||
| YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( | |||||
| get_home_cache_dir(), | |||||
| re.sub(r"^[a-zA-Z]+/", "", model_name))) | |||||
| except Exception as e: | |||||
| YoudaoRerank._model = RerankerModel( | |||||
| model_name_or_path=model_name.replace( | |||||
| "maidalun1020", "InfiniFlow")) | |||||
| if not self.isChinese(txt): | if not self.isChinese(txt): | ||||
| tks = rag_tokenizer.tokenize(txt).split(" ") | tks = rag_tokenizer.tokenize(txt).split(" ") | ||||
| tks_w = self.tw.weights(tks) | tks_w = self.tw.weights(tks) | ||||
| q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w] | |||||
| tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w] | |||||
| q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] | |||||
| for i in range(1, len(tks_w)): | for i in range(1, len(tks_w)): | ||||
| q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) | q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) | ||||
| if not q: | if not q: | ||||
| from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity | from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity | ||||
| import numpy as np | import numpy as np | ||||
| sims = CosineSimilarity([avec], bvecs) | sims = CosineSimilarity([avec], bvecs) | ||||
| tksim = self.token_similarity(atks, btkss) | |||||
| return np.array(sims[0]) * vtweight + \ | |||||
| np.array(tksim) * tkweight, tksim, sims[0] | |||||
| def token_similarity(self, atks, btkss): | |||||
| def toDict(tks): | def toDict(tks): | ||||
| d = {} | d = {} | ||||
| if isinstance(tks, str): | if isinstance(tks, str): | ||||
| atks = toDict(atks) | atks = toDict(atks) | ||||
| btkss = [toDict(tks) for tks in btkss] | btkss = [toDict(tks) for tks in btkss] | ||||
| tksim = [self.similarity(atks, btks) for btks in btkss] | |||||
| return np.array(sims[0]) * vtweight + \ | |||||
| np.array(tksim) * tkweight, tksim, sims[0] | |||||
| return [self.similarity(atks, btks) for btks in btkss] | |||||
| def similarity(self, qtwt, dtwt): | def similarity(self, qtwt, dtwt): | ||||
| if isinstance(dtwt, type("")): | if isinstance(dtwt, type("")): |
| return self.score_(res[::-1]) | return self.score_(res[::-1]) | ||||
| def english_normalize_(self, tks): | |||||
| return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] | |||||
| def tokenize(self, line): | def tokenize(self, line): | ||||
| line = self._strQ2B(line).lower() | line = self._strQ2B(line).lower() | ||||
| line = self._tradi2simp(line) | line = self._tradi2simp(line) | ||||
| zh_num = len([1 for c in line if is_chinese(c)]) | zh_num = len([1 for c in line if is_chinese(c)]) | ||||
| if zh_num < len(line) * 0.2: | |||||
| if zh_num == 0: | |||||
| return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)]) | return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)]) | ||||
| arr = re.split(self.SPLIT_CHAR, line) | arr = re.split(self.SPLIT_CHAR, line) | ||||
| i = e + 1 | i = e + 1 | ||||
| res = " ".join(res) | |||||
| res = " ".join(self.english_normalize_(res)) | |||||
| if self.DEBUG: | if self.DEBUG: | ||||
| print("[TKS]", self.merge_(res)) | print("[TKS]", self.merge_(res)) | ||||
| return self.merge_(res) | return self.merge_(res) | ||||
| res.append(stk) | res.append(stk) | ||||
| return " ".join(res) | |||||
| return " ".join(self.english_normalize_(res)) | |||||
| def is_chinese(s): | def is_chinese(s): |
| s = Search() | s = Search() | ||||
| pg = int(req.get("page", 1)) - 1 | pg = int(req.get("page", 1)) - 1 | ||||
| ps = int(req.get("size", 1000)) | |||||
| topk = int(req.get("topk", 1024)) | topk = int(req.get("topk", 1024)) | ||||
| ps = int(req.get("size", topk)) | |||||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | ||||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", | "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", | ||||
| "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | ||||
| ins_tw, tkweight, vtweight) | ins_tw, tkweight, vtweight) | ||||
| return sim, tksim, vtsim | return sim, tksim, vtsim | ||||
| def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, | |||||
| vtweight=0.7, cfield="content_ltks"): | |||||
| _, keywords = self.qryr.question(query) | |||||
| for i in sres.ids: | |||||
| if isinstance(sres.field[i].get("important_kwd", []), str): | |||||
| sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] | |||||
| ins_tw = [] | |||||
| for i in sres.ids: | |||||
| content_ltks = sres.field[i][cfield].split(" ") | |||||
| title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t] | |||||
| important_kwd = sres.field[i].get("important_kwd", []) | |||||
| tks = content_ltks + title_tks + important_kwd | |||||
| ins_tw.append(tks) | |||||
| tksim = self.qryr.token_similarity(keywords, ins_tw) | |||||
| vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw]) | |||||
| return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim | |||||
| def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): | def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): | ||||
| return self.qryr.hybrid_similarity(ans_embd, | return self.qryr.hybrid_similarity(ans_embd, | ||||
| ins_embd, | ins_embd, | ||||
| rag_tokenizer.tokenize(inst).split(" ")) | rag_tokenizer.tokenize(inst).split(" ")) | ||||
| def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | ||||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | |||||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None): | |||||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | ||||
| if not question: | if not question: | ||||
| return ranks | return ranks | ||||
| req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, | req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, | ||||
| "question": question, "vector": True, "topk": top, | "question": question, "vector": True, "topk": top, | ||||
| "similarity": similarity_threshold} | |||||
| "similarity": similarity_threshold, | |||||
| "available_int": 1} | |||||
| sres = self.search(req, index_name(tenant_id), embd_mdl) | sres = self.search(req, index_name(tenant_id), embd_mdl) | ||||
| sim, tsim, vsim = self.rerank( | |||||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||||
| if rerank_mdl: | |||||
| sim, tsim, vsim = self.rerank_by_model(rerank_mdl, | |||||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||||
| else: | |||||
| sim, tsim, vsim = self.rerank( | |||||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | |||||
| idx = np.argsort(sim * -1) | idx = np.argsort(sim * -1) | ||||
| dim = len(sres.query_vector) | dim = len(sres.query_vector) |