| from flask_login import login_required | from flask_login import login_required | ||||
| from api.db.services.dialog_service import DialogService, ConversationService | from api.db.services.dialog_service import DialogService, ConversationService | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | |||||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_json_result | from api.utils.api_utils import get_json_result | ||||
| if p["key"] not in kwargs: | if p["key"] not in kwargs: | ||||
| prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") | prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") | ||||
| model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id) | |||||
| if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) | |||||
| question = messages[-1]["content"] | question = messages[-1]["content"] | ||||
| embd_mdl = TenantLLMService.model_instance( | |||||
| dialog.tenant_id, LLMType.EMBEDDING.value) | |||||
| embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) | |||||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||||
| kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | ||||
| dialog.vector_similarity_weight, top=1024, aggs=False) | dialog.vector_similarity_weight, top=1024, aggs=False) | ||||
| knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] | knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] | ||||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | ||||
| if "max_tokens" in gen_conf: | if "max_tokens" in gen_conf: | ||||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | ||||
| mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id) | |||||
| answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||||
| answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||||
| answer = retrievaler.insert_citations(answer, | answer = retrievaler.insert_citations(answer, | ||||
| [ck["content_ltks"] for ck in kbinfos["chunks"]], | [ck["content_ltks"] for ck in kbinfos["chunks"]], |
| similarity_threshold = FloatField(default=0.2) | similarity_threshold = FloatField(default=0.2) | ||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| top_n = IntegerField(default=6) | top_n = IntegerField(default=6) | ||||
| do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") | |||||
| kb_ids = JSONField(null=False, default=[]) | kb_ids = JSONField(null=False, default=[]) | ||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from rag.llm import EmbeddingModel, CvModel | |||||
| from api.settings import database_logger | |||||
| from rag.llm import EmbeddingModel, CvModel, ChatModel | |||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.db_models import DB, UserTenant | from api.db.db_models import DB, UserTenant | ||||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | from api.db.db_models import LLMFactories, LLM, TenantLLM | ||||
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.db import StatusEnum | |||||
| class LLMFactoriesService(CommonService): | class LLMFactoriesService(CommonService): | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_api_key(cls, tenant_id, model_name): | def get_api_key(cls, tenant_id, model_name): | ||||
| objs = cls.query(tenant_id=tenant_id, llm_name=model_name) | objs = cls.query(tenant_id=tenant_id, llm_name=model_name) | ||||
| if not objs: return | |||||
| if not objs: | |||||
| return | |||||
| return objs[0] | return objs[0] | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_my_llms(cls, tenant_id): | def get_my_llms(cls, tenant_id): | ||||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name] | |||||
| fields = [ | |||||
| cls.model.llm_factory, | |||||
| LLMFactories.logo, | |||||
| LLMFactories.tags, | |||||
| cls.model.model_type, | |||||
| cls.model.llm_name] | |||||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( | objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( | ||||
| cls.model.tenant_id == tenant_id).dicts() | cls.model.tenant_id == tenant_id).dicts() | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def model_instance(cls, tenant_id, llm_type): | |||||
| e,tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: raise LookupError("Tenant not found") | |||||
| def model_instance(cls, tenant_id, llm_type, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| raise LookupError("Tenant not found") | |||||
| if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id | |||||
| elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id | |||||
| elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id | |||||
| elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id | |||||
| else: assert False, "LLM type error" | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| mdlnm = tenant.embd_id | |||||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||||
| mdlnm = tenant.asr_id | |||||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||||
| mdlnm = tenant.img2txt_id | |||||
| elif llm_type == LLMType.CHAT.value: | |||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||||
| else: | |||||
| assert False, "LLM type error" | |||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | model_config = cls.get_api_key(tenant_id, mdlnm) | ||||
| if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) | |||||
| if not model_config: | |||||
| raise LookupError("Model({}) not found".format(mdlnm)) | |||||
| model_config = model_config.to_dict() | model_config = model_config.to_dict() | ||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||||
| if model_config["llm_factory"] not in EmbeddingModel: | |||||
| return | |||||
| return EmbeddingModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], model_config["llm_name"]) | |||||
| if llm_type == LLMType.IMAGE2TEXT.value: | if llm_type == LLMType.IMAGE2TEXT.value: | ||||
| if model_config["llm_factory"] not in CvModel: return | |||||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||||
| if model_config["llm_factory"] not in CvModel: | |||||
| return | |||||
| return CvModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], model_config["llm_name"]) | |||||
| if llm_type == LLMType.CHAT.value: | |||||
| if model_config["llm_factory"] not in ChatModel: | |||||
| return | |||||
| return ChatModel[model_config["llm_factory"]]( | |||||
| model_config["api_key"], model_config["llm_name"]) | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| raise LookupError("Tenant not found") | |||||
| if llm_type == LLMType.EMBEDDING.value: | |||||
| mdlnm = tenant.embd_id | |||||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||||
| mdlnm = tenant.asr_id | |||||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||||
| mdlnm = tenant.img2txt_id | |||||
| elif llm_type == LLMType.CHAT.value: | |||||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||||
| else: | |||||
| assert False, "LLM type error" | |||||
| num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\ | |||||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||||
| .execute() | |||||
| return num | |||||
| class LLMBundle(object): | |||||
| def __init__(self, tenant_id, llm_type, llm_name=None): | |||||
| self.tenant_id = tenant_id | |||||
| self.llm_type = llm_type | |||||
| self.llm_name = llm_name | |||||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name) | |||||
| assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||||
| def encode(self, texts: list, batch_size=32): | |||||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | |||||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||||
| return emd, used_tokens | |||||
| def encode_queries(self, query: str): | |||||
| emd, used_tokens = self.mdl.encode_queries(query) | |||||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||||
| return emd, used_tokens | |||||
| def describe(self, image, max_tokens=300): | |||||
| txt, used_tokens = self.mdl.describe(image, max_tokens) | |||||
| if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||||
| database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) | |||||
| return txt | |||||
| def chat(self, system, history, gen_conf): | |||||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | |||||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): | |||||
| database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||||
| return txt |
| if re.match(r".*\.pdf$", filename): | if re.match(r".*\.pdf$", filename): | ||||
| return FileType.PDF.value | return FileType.PDF.value | ||||
| if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||||
| if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||||
| return FileType.DOC.value | return FileType.DOC.value | ||||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | ||||
| return FileType.AURAL.value | return FileType.AURAL.value | ||||
| if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): | if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): | ||||
| return FileType.VISUAL | |||||
| return FileType.VISUAL |
| model=self.model_name, | model=self.model_name, | ||||
| messages=history, | messages=history, | ||||
| **gen_conf) | **gen_conf) | ||||
| return res.choices[0].message.content.strip() | |||||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||||
| from dashscope import Generation | from dashscope import Generation | ||||
| result_format='message' | result_format='message' | ||||
| ) | ) | ||||
| if response.status_code == HTTPStatus.OK: | if response.status_code == HTTPStatus.OK: | ||||
| return response.output.choices[0]['message']['content'] | |||||
| return response.message | |||||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||||
| return response.message, 0 |
| messages=self.prompt(b64), | messages=self.prompt(b64), | ||||
| max_tokens=max_tokens, | max_tokens=max_tokens, | ||||
| ) | ) | ||||
| return res.choices[0].message.content.strip() | |||||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||||
| class QWenCV(Base): | class QWenCV(Base): | ||||
| response = MultiModalConversation.call(model=self.model_name, | response = MultiModalConversation.call(model=self.model_name, | ||||
| messages=self.prompt(self.image2base64(image))) | messages=self.prompt(self.image2base64(image))) | ||||
| if response.status_code == HTTPStatus.OK: | if response.status_code == HTTPStatus.OK: | ||||
| return response.output.choices[0]['message']['content'] | |||||
| return response.message | |||||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||||
| return response.message, 0 |
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| raise NotImplementedError("Please implement encode method!") | raise NotImplementedError("Please implement encode method!") | ||||
| def encode_queries(self, text: str): | |||||
| raise NotImplementedError("Please implement encode method!") | |||||
| class HuEmbedding(Base): | class HuEmbedding(Base): | ||||
| def __init__(self, key="", model_name=""): | def __init__(self, key="", model_name=""): | ||||
| class OpenAIEmbed(Base): | class OpenAIEmbed(Base): | ||||
| def __init__(self, key, model_name="text-embedding-ada-002"): | def __init__(self, key, model_name="text-embedding-ada-002"): | ||||
| self.client = OpenAI(key) | |||||
| self.client = OpenAI(api_key=key) | |||||
| self.model_name = model_name | self.model_name = model_name | ||||
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| token_count = 0 | |||||
| for t in texts: token_count += num_tokens_from_string(t) | |||||
| res = self.client.embeddings.create(input=texts, | res = self.client.embeddings.create(input=texts, | ||||
| model=self.model_name) | model=self.model_name) | ||||
| return [d["embedding"] for d in res["data"]], token_count | |||||
| return np.array([d.embedding for d in res.data]), res.usage.total_tokens | |||||
| def encode_queries(self, text): | |||||
| res = self.client.embeddings.create(input=[text], | |||||
| model=self.model_name) | |||||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||||
| class QWenEmbed(Base): | class QWenEmbed(Base): | ||||
| dashscope.api_key = key | dashscope.api_key = key | ||||
| self.model_name = model_name | self.model_name = model_name | ||||
| def encode(self, texts: list, batch_size=32, text_type="document"): | |||||
| def encode(self, texts: list, batch_size=10): | |||||
| import dashscope | import dashscope | ||||
| res = [] | res = [] | ||||
| token_count = 0 | token_count = 0 | ||||
| for txt in texts: | |||||
| texts = [txt[:2048] for txt in texts] | |||||
| for i in range(0, len(texts), batch_size): | |||||
| resp = dashscope.TextEmbedding.call( | resp = dashscope.TextEmbedding.call( | ||||
| model=self.model_name, | model=self.model_name, | ||||
| input=txt[:2048], | |||||
| text_type=text_type | |||||
| input=texts[i:i+batch_size], | |||||
| text_type="document" | |||||
| ) | |||||
| embds = [[]] * len(resp["output"]["embeddings"]) | |||||
| for e in resp["output"]["embeddings"]: | |||||
| embds[e["text_index"]] = e["embedding"] | |||||
| res.extend(embds) | |||||
| token_count += resp["usage"]["input_tokens"] | |||||
| return np.array(res), token_count | |||||
| def encode_queries(self, text): | |||||
| resp = dashscope.TextEmbedding.call( | |||||
| model=self.model_name, | |||||
| input=text[:2048], | |||||
| text_type="query" | |||||
| ) | ) | ||||
| res.append(resp["output"]["embeddings"][0]["embedding"]) | |||||
| token_count += resp["usage"]["total_tokens"] | |||||
| return res, token_count | |||||
| return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"] |
| class HuChunker: | class HuChunker: | ||||
| @dataclass | |||||
| class Fields: | |||||
| text_chunks: List = None | |||||
| table_chunks: List = None | |||||
| def __init__(self): | def __init__(self): | ||||
| self.MAX_LVL = 12 | self.MAX_LVL = 12 | ||||
| self.proj_patt = [ | self.proj_patt = [ | ||||
| class PdfChunker(HuChunker): | class PdfChunker(HuChunker): | ||||
| @dataclass | |||||
| class Fields: | |||||
| text_chunks: List = None | |||||
| table_chunks: List = None | |||||
| def __init__(self, pdf_parser): | def __init__(self, pdf_parser): | ||||
| self.pdf = pdf_parser | self.pdf = pdf_parser | ||||
| super().__init__() | super().__init__() | ||||
| class DocxChunker(HuChunker): | class DocxChunker(HuChunker): | ||||
| @dataclass | |||||
| class Fields: | |||||
| text_chunks: List = None | |||||
| table_chunks: List = None | |||||
| def __init__(self, doc_parser): | def __init__(self, doc_parser): | ||||
| self.doc = doc_parser | self.doc = doc_parser | ||||
| super().__init__() | super().__init__() | ||||
| class ExcelChunker(HuChunker): | class ExcelChunker(HuChunker): | ||||
| @dataclass | |||||
| class Fields: | |||||
| text_chunks: List = None | |||||
| table_chunks: List = None | |||||
| def __init__(self, excel_parser): | def __init__(self, excel_parser): | ||||
| self.excel = excel_parser | self.excel = excel_parser | ||||
| super().__init__() | super().__init__() | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| def __extract(self, shape): | |||||
| if shape.shape_type == 19: | |||||
| tb = shape.table | |||||
| rows = [] | |||||
| for i in range(1, len(tb.rows)): | |||||
| rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||||
| return "\n".join(rows) | |||||
| if shape.has_text_frame: | |||||
| return shape.text_frame.text | |||||
| if shape.shape_type == 6: | |||||
| texts = [] | |||||
| for p in shape.shapes: | |||||
| t = self.__extract(p) | |||||
| if t: texts.append(t) | |||||
| return "\n".join(texts) | |||||
| def __call__(self, fnm): | def __call__(self, fnm): | ||||
| from pptx import Presentation | from pptx import Presentation | ||||
| ppt = Presentation(fnm) if isinstance( | ppt = Presentation(fnm) if isinstance( | ||||
| fnm, str) else Presentation( | fnm, str) else Presentation( | ||||
| BytesIO(fnm)) | BytesIO(fnm)) | ||||
| flds = self.Fields() | |||||
| flds.text_chunks = [] | |||||
| txts = [] | |||||
| for slide in ppt.slides: | for slide in ppt.slides: | ||||
| texts = [] | |||||
| for shape in slide.shapes: | for shape in slide.shapes: | ||||
| if hasattr(shape, "text"): | |||||
| flds.text_chunks.append((shape.text, None)) | |||||
| txt = self.__extract(shape) | |||||
| if txt: texts.append(txt) | |||||
| txts.append("\n".join(texts)) | |||||
| import aspose.slides as slides | |||||
| import aspose.pydrawing as drawing | |||||
| imgs = [] | |||||
| with slides.Presentation(BytesIO(fnm)) as presentation: | |||||
| for slide in presentation.slides: | |||||
| buffered = BytesIO() | |||||
| slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) | |||||
| imgs.append(buffered.getvalue()) | |||||
| assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||||
| flds = self.Fields() | |||||
| flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] | |||||
| flds.table_chunks = [] | flds.table_chunks = [] | ||||
| return flds | return flds | ||||
| if req["available_int"] == 0: | if req["available_int"] == 0: | ||||
| bqry.filter.append(Q("range", available_int={"lt": 1})) | bqry.filter.append(Q("range", available_int={"lt": 1})) | ||||
| else: | else: | ||||
| bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||||
| bqry.filter.append( | |||||
| Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||||
| bqry.boost = 0.05 | bqry.boost = 0.05 | ||||
| s = Search() | s = Search() | ||||
| q_vec = [] | q_vec = [] | ||||
| if req.get("vector"): | if req.get("vector"): | ||||
| assert emb_mdl, "No embedding model selected" | assert emb_mdl, "No embedding model selected" | ||||
| s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps) | |||||
| s["knn"] = self._vector( | |||||
| qst, emb_mdl, req.get( | |||||
| "similarity", 0.4), ps) | |||||
| s["knn"]["filter"] = bqry.to_dict() | s["knn"]["filter"] = bqry.to_dict() | ||||
| if "highlight" in s: del s["highlight"] | |||||
| if "highlight" in s: | |||||
| del s["highlight"] | |||||
| q_vec = s["knn"]["query_vector"] | q_vec = s["knn"]["query_vector"] | ||||
| es_logger.info("【Q】: {}".format(json.dumps(s))) | es_logger.info("【Q】: {}".format(json.dumps(s))) | ||||
| res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) | res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) | ||||
| def trans2floats(txt): | def trans2floats(txt): | ||||
| return [float(t) for t in txt.split("\t")] | return [float(t) for t in txt.split("\t")] | ||||
| def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7): | |||||
| def insert_citations(self, answer, chunks, chunk_v, | |||||
| embd_mdl, tkweight=0.3, vtweight=0.7): | |||||
| pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | ||||
| for i in range(1, len(pieces)): | for i in range(1, len(pieces)): | ||||
| if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | ||||
| idx = [] | idx = [] | ||||
| pieces_ = [] | pieces_ = [] | ||||
| for i, t in enumerate(pieces): | for i, t in enumerate(pieces): | ||||
| if len(t) < 5: continue | |||||
| if len(t) < 5: | |||||
| continue | |||||
| idx.append(i) | idx.append(i) | ||||
| pieces_.append(t) | pieces_.append(t) | ||||
| es_logger.info("{} => {}".format(answer, pieces_)) | es_logger.info("{} => {}".format(answer, pieces_)) | ||||
| if not pieces_: return answer | |||||
| if not pieces_: | |||||
| return answer | |||||
| ans_v, c = embd_mdl.encode(pieces_) | |||||
| ans_v, _ = embd_mdl.encode(pieces_) | |||||
| assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | ||||
| len(ans_v[0]), len(chunk_v[0])) | len(ans_v[0]), len(chunk_v[0])) | ||||
| chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks] | chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks] | ||||
| cites = {} | cites = {} | ||||
| for i,a in enumerate(pieces_): | |||||
| for i, a in enumerate(pieces_): | |||||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], | sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], | ||||
| chunk_v, | chunk_v, | ||||
| huqie.qie(pieces_[i]).split(" "), | |||||
| huqie.qie( | |||||
| pieces_[i]).split(" "), | |||||
| chunks_tks, | chunks_tks, | ||||
| tkweight, vtweight) | tkweight, vtweight) | ||||
| mx = np.max(sim) * 0.99 | mx = np.max(sim) * 0.99 | ||||
| if mx < 0.55: continue | |||||
| cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||||
| if mx < 0.55: | |||||
| continue | |||||
| cites[idx[i]] = list( | |||||
| set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||||
| res = "" | res = "" | ||||
| for i,p in enumerate(pieces): | |||||
| for i, p in enumerate(pieces): | |||||
| res += p | res += p | ||||
| if i not in idx:continue | |||||
| if i not in cites:continue | |||||
| res += "##%s$$"%"$".join(cites[i]) | |||||
| if i not in idx: | |||||
| continue | |||||
| if i not in cites: | |||||
| continue | |||||
| res += "##%s$$" % "$".join(cites[i]) | |||||
| return res | return res | ||||
| def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): | |||||
| def rerank(self, sres, query, tkweight=0.3, | |||||
| vtweight=0.7, cfield="content_ltks"): | |||||
| ins_embd = [ | ins_embd = [ | ||||
| Dealer.trans2floats( | Dealer.trans2floats( | ||||
| sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] | |||||
| sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | |||||
| if not ins_embd: | if not ins_embd: | ||||
| return [], [], [] | return [], [], [] | ||||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] | |||||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") | |||||
| for i in sres.ids] | |||||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | ||||
| ins_embd, | ins_embd, | ||||
| huqie.qie(query).split(" "), | |||||
| huqie.qie( | |||||
| query).split(" "), | |||||
| ins_tw, tkweight, vtweight) | ins_tw, tkweight, vtweight) | ||||
| return sim, tksim, vtsim | return sim, tksim, vtsim | ||||
| def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | ||||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | ||||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | ||||
| if not question: return ranks | |||||
| if not question: | |||||
| return ranks | |||||
| req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, | req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, | ||||
| "question": question, "vector": True, | "question": question, "vector": True, | ||||
| "similarity": similarity_threshold} | "similarity": similarity_threshold} |
| ) | ) | ||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.llm_service import TenantLLMService | |||||
| from api.db.services.llm_service import TenantLLMService, LLMBundle | |||||
| from api.settings import database_logger | from api.settings import database_logger | ||||
| from api.utils import get_format_time | from api.utils import get_format_time | ||||
| from api.utils.file_utils import get_project_base_directory | from api.utils.file_utils import get_project_base_directory | ||||
| PPT = PptChunker() | PPT = PptChunker() | ||||
| def chuck_doc(name, binary, cvmdl=None): | |||||
| def chuck_doc(name, binary, tenant_id, cvmdl=None): | |||||
| suff = os.path.split(name)[-1].lower().split(".")[-1] | suff = os.path.split(name)[-1].lower().split(".")[-1] | ||||
| if suff.find("pdf") >= 0: | if suff.find("pdf") >= 0: | ||||
| return PDF(binary) | return PDF(binary) | ||||
| 100., "Finished preparing! Start to slice file!", True) | 100., "Finished preparing! Start to slice file!", True) | ||||
| try: | try: | ||||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | ||||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) | |||||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl) | |||||
| except Exception as e: | except Exception as e: | ||||
| if re.search("(No such file|not found)", str(e)): | if re.search("(No such file|not found)", str(e)): | ||||
| set_progress( | set_progress( | ||||
| tmf = open(tm_fnm, "a+") | tmf = open(tm_fnm, "a+") | ||||
| for _, r in rows.iterrows(): | for _, r in rows.iterrows(): | ||||
| embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING) | |||||
| if not embd_mdl: | |||||
| set_progress(r["id"], -1, "Can't find embedding model!") | |||||
| cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"])) | |||||
| try: | |||||
| embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) | |||||
| cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT) | |||||
| #TODO: sequence2text model | |||||
| except Exception as e: | |||||
| set_progress(r["id"], -1, str(e)) | |||||
| continue | continue | ||||
| cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT) | |||||
| st_tm = timer() | st_tm = timer() | ||||
| cks = build(r, cv_mdl) | cks = build(r, cv_mdl) | ||||
| if not cks: | if not cks: |