| @@ -17,7 +17,7 @@ from flask import request | |||
| from flask_login import login_required | |||
| from api.db.services.dialog_service import DialogService, ConversationService | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_json_result | |||
| @@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs): | |||
| if p["key"] not in kwargs: | |||
| prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") | |||
| model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id) | |||
| if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) | |||
| question = messages[-1]["content"] | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| dialog.tenant_id, LLMType.EMBEDDING.value) | |||
| embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | |||
| dialog.vector_similarity_weight, top=1024, aggs=False) | |||
| knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] | |||
| @@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs): | |||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | |||
| if "max_tokens" in gen_conf: | |||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | |||
| mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id) | |||
| answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||
| answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||
| answer = retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] for ck in kbinfos["chunks"]], | |||
| @@ -524,6 +524,7 @@ class Dialog(DataBaseModel): | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| top_n = IntegerField(default=6) | |||
| do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") | |||
| kb_ids = JSONField(null=False, default=[]) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| @@ -14,12 +14,12 @@ | |||
| # limitations under the License. | |||
| # | |||
| from api.db.services.user_service import TenantService | |||
| from rag.llm import EmbeddingModel, CvModel | |||
| from api.settings import database_logger | |||
| from rag.llm import EmbeddingModel, CvModel, ChatModel | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | |||
| from api.db.services.common_service import CommonService | |||
| from api.db import StatusEnum | |||
| class LLMFactoriesService(CommonService): | |||
| @@ -37,13 +37,19 @@ class TenantLLMService(CommonService): | |||
| @DB.connection_context() | |||
| def get_api_key(cls, tenant_id, model_name): | |||
| objs = cls.query(tenant_id=tenant_id, llm_name=model_name) | |||
| if not objs: return | |||
| if not objs: | |||
| return | |||
| return objs[0] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_my_llms(cls, tenant_id): | |||
| fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name] | |||
| fields = [ | |||
| cls.model.llm_factory, | |||
| LLMFactories.logo, | |||
| LLMFactories.tags, | |||
| cls.model.model_type, | |||
| cls.model.llm_name] | |||
| objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( | |||
| cls.model.tenant_id == tenant_id).dicts() | |||
| @@ -51,23 +57,96 @@ class TenantLLMService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type): | |||
| e,tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: raise LookupError("Tenant not found") | |||
| def model_instance(cls, tenant_id, llm_type, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id | |||
| elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id | |||
| elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id | |||
| elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id | |||
| else: assert False, "LLM type error" | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| mdlnm = tenant.embd_id | |||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||
| mdlnm = tenant.asr_id | |||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||
| mdlnm = tenant.img2txt_id | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| model_config = cls.get_api_key(tenant_id, mdlnm) | |||
| if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) | |||
| if not model_config: | |||
| raise LookupError("Model({}) not found".format(mdlnm)) | |||
| model_config = model_config.to_dict() | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if model_config["llm_factory"] not in EmbeddingModel: return | |||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||
| if model_config["llm_factory"] not in EmbeddingModel: | |||
| return | |||
| return EmbeddingModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], model_config["llm_name"]) | |||
| if llm_type == LLMType.IMAGE2TEXT.value: | |||
| if model_config["llm_factory"] not in CvModel: return | |||
| return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | |||
| if model_config["llm_factory"] not in CvModel: | |||
| return | |||
| return CvModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], model_config["llm_name"]) | |||
| if llm_type == LLMType.CHAT.value: | |||
| if model_config["llm_factory"] not in ChatModel: | |||
| return | |||
| return ChatModel[model_config["llm_factory"]]( | |||
| model_config["api_key"], model_config["llm_name"]) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| mdlnm = tenant.embd_id | |||
| elif llm_type == LLMType.SPEECH2TEXT.value: | |||
| mdlnm = tenant.asr_id | |||
| elif llm_type == LLMType.IMAGE2TEXT.value: | |||
| mdlnm = tenant.img2txt_id | |||
| elif llm_type == LLMType.CHAT.value: | |||
| mdlnm = tenant.llm_id if not llm_name else llm_name | |||
| else: | |||
| assert False, "LLM type error" | |||
| num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\ | |||
| .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ | |||
| .execute() | |||
| return num | |||
| class LLMBundle(object): | |||
| def __init__(self, tenant_id, llm_type, llm_name=None): | |||
| self.tenant_id = tenant_id | |||
| self.llm_type = llm_type | |||
| self.llm_name = llm_name | |||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name) | |||
| assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||
| def encode(self, texts: list, batch_size=32): | |||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| return emd, used_tokens | |||
| def encode_queries(self, query: str): | |||
| emd, used_tokens = self.mdl.encode_queries(query) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| return emd, used_tokens | |||
| def describe(self, image, max_tokens=300): | |||
| txt, used_tokens = self.mdl.describe(image, max_tokens) | |||
| if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) | |||
| return txt | |||
| def chat(self, system, history, gen_conf): | |||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): | |||
| database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||
| return txt | |||
| @@ -143,11 +143,11 @@ def filename_type(filename): | |||
| if re.match(r".*\.pdf$", filename): | |||
| return FileType.PDF.value | |||
| if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| return FileType.DOC.value | |||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | |||
| return FileType.AURAL.value | |||
| if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): | |||
| return FileType.VISUAL | |||
| return FileType.VISUAL | |||
| @@ -37,7 +37,7 @@ class GptTurbo(Base): | |||
| model=self.model_name, | |||
| messages=history, | |||
| **gen_conf) | |||
| return res.choices[0].message.content.strip() | |||
| return res.choices[0].message.content.strip(), res.usage.completion_tokens | |||
| from dashscope import Generation | |||
| @@ -56,5 +56,5 @@ class QWenChat(Base): | |||
| result_format='message' | |||
| ) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'] | |||
| return response.message | |||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||
| return response.message, 0 | |||
| @@ -72,7 +72,7 @@ class GptV4(Base): | |||
| messages=self.prompt(b64), | |||
| max_tokens=max_tokens, | |||
| ) | |||
| return res.choices[0].message.content.strip() | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| class QWenCV(Base): | |||
| @@ -87,5 +87,5 @@ class QWenCV(Base): | |||
| response = MultiModalConversation.call(model=self.model_name, | |||
| messages=self.prompt(self.image2base64(image))) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'] | |||
| return response.message | |||
| return response.output.choices[0]['message']['content'], response.usage.output_tokens | |||
| return response.message, 0 | |||
| @@ -36,6 +36,9 @@ class Base(ABC): | |||
| def encode(self, texts: list, batch_size=32): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| def encode_queries(self, text: str): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| class HuEmbedding(Base): | |||
| def __init__(self, key="", model_name=""): | |||
| @@ -68,15 +71,18 @@ class HuEmbedding(Base): | |||
| class OpenAIEmbed(Base): | |||
| def __init__(self, key, model_name="text-embedding-ada-002"): | |||
| self.client = OpenAI(key) | |||
| self.client = OpenAI(api_key=key) | |||
| self.model_name = model_name | |||
| def encode(self, texts: list, batch_size=32): | |||
| token_count = 0 | |||
| for t in texts: token_count += num_tokens_from_string(t) | |||
| res = self.client.embeddings.create(input=texts, | |||
| model=self.model_name) | |||
| return [d["embedding"] for d in res["data"]], token_count | |||
| return np.array([d.embedding for d in res.data]), res.usage.total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[text], | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| class QWenEmbed(Base): | |||
| @@ -84,16 +90,28 @@ class QWenEmbed(Base): | |||
| dashscope.api_key = key | |||
| self.model_name = model_name | |||
| def encode(self, texts: list, batch_size=32, text_type="document"): | |||
| def encode(self, texts: list, batch_size=10): | |||
| import dashscope | |||
| res = [] | |||
| token_count = 0 | |||
| for txt in texts: | |||
| texts = [txt[:2048] for txt in texts] | |||
| for i in range(0, len(texts), batch_size): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=txt[:2048], | |||
| text_type=text_type | |||
| input=texts[i:i+batch_size], | |||
| text_type="document" | |||
| ) | |||
| embds = [[]] * len(resp["output"]["embeddings"]) | |||
| for e in resp["output"]["embeddings"]: | |||
| embds[e["text_index"]] = e["embedding"] | |||
| res.extend(embds) | |||
| token_count += resp["usage"]["input_tokens"] | |||
| return np.array(res), token_count | |||
| def encode_queries(self, text): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=text[:2048], | |||
| text_type="query" | |||
| ) | |||
| res.append(resp["output"]["embeddings"][0]["embedding"]) | |||
| token_count += resp["usage"]["total_tokens"] | |||
| return res, token_count | |||
| return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"] | |||
| @@ -11,6 +11,11 @@ from io import BytesIO | |||
| class HuChunker: | |||
| @dataclass | |||
| class Fields: | |||
| text_chunks: List = None | |||
| table_chunks: List = None | |||
| def __init__(self): | |||
| self.MAX_LVL = 12 | |||
| self.proj_patt = [ | |||
| @@ -228,11 +233,6 @@ class HuChunker: | |||
| class PdfChunker(HuChunker): | |||
| @dataclass | |||
| class Fields: | |||
| text_chunks: List = None | |||
| table_chunks: List = None | |||
| def __init__(self, pdf_parser): | |||
| self.pdf = pdf_parser | |||
| super().__init__() | |||
| @@ -293,11 +293,6 @@ class PdfChunker(HuChunker): | |||
| class DocxChunker(HuChunker): | |||
| @dataclass | |||
| class Fields: | |||
| text_chunks: List = None | |||
| table_chunks: List = None | |||
| def __init__(self, doc_parser): | |||
| self.doc = doc_parser | |||
| super().__init__() | |||
| @@ -344,11 +339,6 @@ class DocxChunker(HuChunker): | |||
| class ExcelChunker(HuChunker): | |||
| @dataclass | |||
| class Fields: | |||
| text_chunks: List = None | |||
| table_chunks: List = None | |||
| def __init__(self, excel_parser): | |||
| self.excel = excel_parser | |||
| super().__init__() | |||
| @@ -370,18 +360,51 @@ class PptChunker(HuChunker): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def __extract(self, shape): | |||
| if shape.shape_type == 19: | |||
| tb = shape.table | |||
| rows = [] | |||
| for i in range(1, len(tb.rows)): | |||
| rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||
| return "\n".join(rows) | |||
| if shape.has_text_frame: | |||
| return shape.text_frame.text | |||
| if shape.shape_type == 6: | |||
| texts = [] | |||
| for p in shape.shapes: | |||
| t = self.__extract(p) | |||
| if t: texts.append(t) | |||
| return "\n".join(texts) | |||
| def __call__(self, fnm): | |||
| from pptx import Presentation | |||
| ppt = Presentation(fnm) if isinstance( | |||
| fnm, str) else Presentation( | |||
| BytesIO(fnm)) | |||
| flds = self.Fields() | |||
| flds.text_chunks = [] | |||
| txts = [] | |||
| for slide in ppt.slides: | |||
| texts = [] | |||
| for shape in slide.shapes: | |||
| if hasattr(shape, "text"): | |||
| flds.text_chunks.append((shape.text, None)) | |||
| txt = self.__extract(shape) | |||
| if txt: texts.append(txt) | |||
| txts.append("\n".join(texts)) | |||
| import aspose.slides as slides | |||
| import aspose.pydrawing as drawing | |||
| imgs = [] | |||
| with slides.Presentation(BytesIO(fnm)) as presentation: | |||
| for slide in presentation.slides: | |||
| buffered = BytesIO() | |||
| slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) | |||
| imgs.append(buffered.getvalue()) | |||
| assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| flds = self.Fields() | |||
| flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] | |||
| flds.table_chunks = [] | |||
| return flds | |||
| @@ -58,7 +58,8 @@ class Dealer: | |||
| if req["available_int"] == 0: | |||
| bqry.filter.append(Q("range", available_int={"lt": 1})) | |||
| else: | |||
| bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||
| bqry.filter.append( | |||
| Q("bool", must_not=Q("range", available_int={"lt": 1}))) | |||
| bqry.boost = 0.05 | |||
| s = Search() | |||
| @@ -87,9 +88,12 @@ class Dealer: | |||
| q_vec = [] | |||
| if req.get("vector"): | |||
| assert emb_mdl, "No embedding model selected" | |||
| s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps) | |||
| s["knn"] = self._vector( | |||
| qst, emb_mdl, req.get( | |||
| "similarity", 0.4), ps) | |||
| s["knn"]["filter"] = bqry.to_dict() | |||
| if "highlight" in s: del s["highlight"] | |||
| if "highlight" in s: | |||
| del s["highlight"] | |||
| q_vec = s["knn"]["query_vector"] | |||
| es_logger.info("【Q】: {}".format(json.dumps(s))) | |||
| res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) | |||
| @@ -175,7 +179,8 @@ class Dealer: | |||
| def trans2floats(txt): | |||
| return [float(t) for t in txt.split("\t")] | |||
| def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7): | |||
| def insert_citations(self, answer, chunks, chunk_v, | |||
| embd_mdl, tkweight=0.3, vtweight=0.7): | |||
| pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer) | |||
| for i in range(1, len(pieces)): | |||
| if re.match(r"[a-z][.?;!][ \n]", pieces[i]): | |||
| @@ -184,47 +189,57 @@ class Dealer: | |||
| idx = [] | |||
| pieces_ = [] | |||
| for i, t in enumerate(pieces): | |||
| if len(t) < 5: continue | |||
| if len(t) < 5: | |||
| continue | |||
| idx.append(i) | |||
| pieces_.append(t) | |||
| es_logger.info("{} => {}".format(answer, pieces_)) | |||
| if not pieces_: return answer | |||
| if not pieces_: | |||
| return answer | |||
| ans_v, c = embd_mdl.encode(pieces_) | |||
| ans_v, _ = embd_mdl.encode(pieces_) | |||
| assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | |||
| len(ans_v[0]), len(chunk_v[0])) | |||
| chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks] | |||
| cites = {} | |||
| for i,a in enumerate(pieces_): | |||
| for i, a in enumerate(pieces_): | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], | |||
| chunk_v, | |||
| huqie.qie(pieces_[i]).split(" "), | |||
| huqie.qie( | |||
| pieces_[i]).split(" "), | |||
| chunks_tks, | |||
| tkweight, vtweight) | |||
| mx = np.max(sim) * 0.99 | |||
| if mx < 0.55: continue | |||
| cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||
| if mx < 0.55: | |||
| continue | |||
| cites[idx[i]] = list( | |||
| set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4] | |||
| res = "" | |||
| for i,p in enumerate(pieces): | |||
| for i, p in enumerate(pieces): | |||
| res += p | |||
| if i not in idx:continue | |||
| if i not in cites:continue | |||
| res += "##%s$$"%"$".join(cites[i]) | |||
| if i not in idx: | |||
| continue | |||
| if i not in cites: | |||
| continue | |||
| res += "##%s$$" % "$".join(cites[i]) | |||
| return res | |||
| def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"): | |||
| def rerank(self, sres, query, tkweight=0.3, | |||
| vtweight=0.7, cfield="content_ltks"): | |||
| ins_embd = [ | |||
| Dealer.trans2floats( | |||
| sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] | |||
| sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | |||
| if not ins_embd: | |||
| return [], [], [] | |||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] | |||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") | |||
| for i in sres.ids] | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | |||
| ins_embd, | |||
| huqie.qie(query).split(" "), | |||
| huqie.qie( | |||
| query).split(" "), | |||
| ins_tw, tkweight, vtweight) | |||
| return sim, tksim, vtsim | |||
| @@ -237,7 +252,8 @@ class Dealer: | |||
| def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | |||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | |||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | |||
| if not question: return ranks | |||
| if not question: | |||
| return ranks | |||
| req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, | |||
| "question": question, "vector": True, | |||
| "similarity": similarity_threshold} | |||
| @@ -49,7 +49,7 @@ from rag.nlp.huchunk import ( | |||
| ) | |||
| from api.db import LLMType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.llm_service import TenantLLMService, LLMBundle | |||
| from api.settings import database_logger | |||
| from api.utils import get_format_time | |||
| from api.utils.file_utils import get_project_base_directory | |||
| @@ -62,7 +62,7 @@ EXC = ExcelChunker(ExcelParser()) | |||
| PPT = PptChunker() | |||
| def chuck_doc(name, binary, cvmdl=None): | |||
| def chuck_doc(name, binary, tenant_id, cvmdl=None): | |||
| suff = os.path.split(name)[-1].lower().split(".")[-1] | |||
| if suff.find("pdf") >= 0: | |||
| return PDF(binary) | |||
| @@ -127,7 +127,7 @@ def build(row, cvmdl): | |||
| 100., "Finished preparing! Start to slice file!", True) | |||
| try: | |||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | |||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) | |||
| obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl) | |||
| except Exception as e: | |||
| if re.search("(No such file|not found)", str(e)): | |||
| set_progress( | |||
| @@ -236,12 +236,14 @@ def main(comm, mod): | |||
| tmf = open(tm_fnm, "a+") | |||
| for _, r in rows.iterrows(): | |||
| embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING) | |||
| if not embd_mdl: | |||
| set_progress(r["id"], -1, "Can't find embedding model!") | |||
| cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"])) | |||
| try: | |||
| embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING) | |||
| cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT) | |||
| #TODO: sequence2text model | |||
| except Exception as e: | |||
| set_progress(r["id"], -1, str(e)) | |||
| continue | |||
| cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT) | |||
| st_tm = timer() | |||
| cks = build(r, cv_mdl) | |||
| if not cks: | |||