|  |  | @@ -23,7 +23,7 @@ from copy import deepcopy | 
		
	
		
			
			|  |  |  | from timeit import default_timer as timer | 
		
	
		
			
			|  |  |  | import datetime | 
		
	
		
			
			|  |  |  | from datetime import timedelta | 
		
	
		
			
			|  |  |  | from api.db import LLMType, ParserType,StatusEnum | 
		
	
		
			
			|  |  |  | from api.db import LLMType, ParserType, StatusEnum | 
		
	
		
			
			|  |  |  | from api.db.db_models import Dialog, DB | 
		
	
		
			
			|  |  |  | from api.db.services.common_service import CommonService | 
		
	
		
			
			|  |  |  | from api.db.services.knowledgebase_service import KnowledgebaseService | 
		
	
	
		
			
			|  |  | @@ -41,14 +41,14 @@ class DialogService(CommonService): | 
		
	
		
			
			|  |  |  | @classmethod | 
		
	
		
			
			|  |  |  | @DB.connection_context() | 
		
	
		
			
			|  |  |  | def get_list(cls, tenant_id, | 
		
	
		
			
			|  |  |  | page_number, items_per_page, orderby, desc, id , name): | 
		
	
		
			
			|  |  |  | page_number, items_per_page, orderby, desc, id, name): | 
		
	
		
			
			|  |  |  | chats = cls.model.select() | 
		
	
		
			
			|  |  |  | if id: | 
		
	
		
			
			|  |  |  | chats = chats.where(cls.model.id == id) | 
		
	
		
			
			|  |  |  | if name: | 
		
	
		
			
			|  |  |  | chats = chats.where(cls.model.name == name) | 
		
	
		
			
			|  |  |  | chats = chats.where( | 
		
	
		
			
			|  |  |  | (cls.model.tenant_id == tenant_id) | 
		
	
		
			
			|  |  |  | (cls.model.tenant_id == tenant_id) | 
		
	
		
			
			|  |  |  | & (cls.model.status == StatusEnum.VALID.value) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if desc: | 
		
	
	
		
			
			|  |  | @@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." | 
		
	
		
			
			|  |  |  | st = timer() | 
		
	
		
			
			|  |  |  | llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id) | 
		
	
		
			
			|  |  |  | llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | chat_start_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Get llm model name and model provider name | 
		
	
		
			
			|  |  |  | llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Get llm model instance by model and provide name | 
		
	
		
			
			|  |  |  | llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not llm: | 
		
	
		
			
			|  |  |  | llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \ | 
		
	
		
			
			|  |  |  | TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid) | 
		
	
		
			
			|  |  |  | # Model name is provided by tenant, but not system built-in | 
		
	
		
			
			|  |  |  | llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ | 
		
	
		
			
			|  |  |  | TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) | 
		
	
		
			
			|  |  |  | if not llm: | 
		
	
		
			
			|  |  |  | raise LookupError("LLM(%s) not found" % dialog.llm_id) | 
		
	
		
			
			|  |  |  | max_tokens = 8192 | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | max_tokens = llm[0].max_tokens | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | check_llm_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) | 
		
	
		
			
			|  |  |  | embd_nms = list(set([kb.embd_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | if len(embd_nms) != 1: | 
		
	
		
			
			|  |  |  | embedding_list = list(set([kb.embd_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | if len(embedding_list) != 1: | 
		
	
		
			
			|  |  |  | yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | 
		
	
		
			
			|  |  |  | return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | 
		
	
		
			
			|  |  |  | retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | 
		
	
		
			
			|  |  |  | embedding_model_name = embedding_list[0] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) | 
		
	
		
			
			|  |  |  | retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | questions = [m["content"] for m in messages if m["role"] == "user"][-3:] | 
		
	
		
			
			|  |  |  | attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None | 
		
	
	
		
			
			|  |  | @@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | if "doc_ids" in m: | 
		
	
		
			
			|  |  |  | attachments.extend(m["doc_ids"]) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) | 
		
	
		
			
			|  |  |  | create_retriever_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) | 
		
	
		
			
			|  |  |  | if not embd_mdl: | 
		
	
		
			
			|  |  |  | raise LookupError("Embedding model(%s) not found" % embd_nms[0]) | 
		
	
		
			
			|  |  |  | raise LookupError("Embedding model(%s) not found" % embedding_model_name) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | bind_embedding_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if llm_id2llm_type(dialog.llm_id) == "image2text": | 
		
	
		
			
			|  |  |  | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | bind_llm_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | prompt_config = dialog.prompt_config | 
		
	
		
			
			|  |  |  | field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) | 
		
	
		
			
			|  |  |  | tts_mdl = None | 
		
	
	
		
			
			|  |  | @@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | questions = questions[-1:] | 
		
	
		
			
			|  |  |  | refineQ_tm = timer() | 
		
	
		
			
			|  |  |  | keyword_tm = timer() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | refine_question_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | rerank_mdl = None | 
		
	
		
			
			|  |  |  | if dialog.rerank_id: | 
		
	
		
			
			|  |  |  | rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for _ in range(len(questions) // 2): | 
		
	
		
			
			|  |  |  | questions.append(questions[-1]) | 
		
	
		
			
			|  |  |  | bind_reranker_ts = timer() | 
		
	
		
			
			|  |  |  | generate_keyword_ts = bind_reranker_ts | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: | 
		
	
		
			
			|  |  |  | kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | if prompt_config.get("keyword", False): | 
		
	
		
			
			|  |  |  | questions[-1] += keyword_extraction(chat_mdl, questions[-1]) | 
		
	
		
			
			|  |  |  | keyword_tm = timer() | 
		
	
		
			
			|  |  |  | generate_keyword_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | tenant_ids = list(set([kb.tenant_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, | 
		
	
		
			
			|  |  |  | dialog.similarity_threshold, | 
		
	
		
			
			|  |  |  | dialog.vector_similarity_weight, | 
		
	
		
			
			|  |  |  | doc_ids=attachments, | 
		
	
		
			
			|  |  |  | top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) | 
		
	
		
			
			|  |  |  | kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, | 
		
	
		
			
			|  |  |  | dialog.similarity_threshold, | 
		
	
		
			
			|  |  |  | dialog.vector_similarity_weight, | 
		
	
		
			
			|  |  |  | doc_ids=attachments, | 
		
	
		
			
			|  |  |  | top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | retrieval_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | knowledges = kb_prompt(kbinfos, max_tokens) | 
		
	
		
			
			|  |  |  | logging.debug( | 
		
	
		
			
			|  |  |  | "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | 
		
	
		
			
			|  |  |  | retrieval_tm = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not knowledges and prompt_config.get("empty_response"): | 
		
	
		
			
			|  |  |  | empty_res = prompt_config["empty_response"] | 
		
	
	
		
			
			|  |  | @@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | max_tokens - used_token_count) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def decorate_answer(answer): | 
		
	
		
			
			|  |  |  | nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm | 
		
	
		
			
			|  |  |  | nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | finish_chat_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | refs = [] | 
		
	
		
			
			|  |  |  | if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): | 
		
	
		
			
			|  |  |  | answer, idx = retr.insert_citations(answer, | 
		
	
		
			
			|  |  |  | [ck["content_ltks"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | [ck["vector"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | embd_mdl, | 
		
	
		
			
			|  |  |  | tkweight=1 - dialog.vector_similarity_weight, | 
		
	
		
			
			|  |  |  | vtweight=dialog.vector_similarity_weight) | 
		
	
		
			
			|  |  |  | answer, idx = retriever.insert_citations(answer, | 
		
	
		
			
			|  |  |  | [ck["content_ltks"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | [ck["vector"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | embd_mdl, | 
		
	
		
			
			|  |  |  | tkweight=1 - dialog.vector_similarity_weight, | 
		
	
		
			
			|  |  |  | vtweight=dialog.vector_similarity_weight) | 
		
	
		
			
			|  |  |  | idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | 
		
	
		
			
			|  |  |  | recall_docs = [ | 
		
	
		
			
			|  |  |  | d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | 
		
	
	
		
			
			|  |  | @@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: | 
		
	
		
			
			|  |  |  | answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" | 
		
	
		
			
			|  |  |  | done_tm = timer() | 
		
	
		
			
			|  |  |  | prompt += "\n\n### Elapsed\n  - Refine Question: %.1f ms\n  - Keywords: %.1f ms\n  - Retrieval: %.1f ms\n  - LLM: %.1f ms" % ( | 
		
	
		
			
			|  |  |  | (refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000, | 
		
	
		
			
			|  |  |  | (done_tm - retrieval_tm) * 1000) | 
		
	
		
			
			|  |  |  | finish_chat_ts = timer() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 | 
		
	
		
			
			|  |  |  | check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 | 
		
	
		
			
			|  |  |  | create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000 | 
		
	
		
			
			|  |  |  | bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 | 
		
	
		
			
			|  |  |  | bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 | 
		
	
		
			
			|  |  |  | refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 | 
		
	
		
			
			|  |  |  | bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 | 
		
	
		
			
			|  |  |  | generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 | 
		
	
		
			
			|  |  |  | retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 | 
		
	
		
			
			|  |  |  | generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | prompt = f"{prompt} ### Elapsed\n  - Total: {total_time_cost:.1f}ms\n  - Check LLM: {check_llm_time_cost:.1f}ms\n  - Create retriever: {create_retriever_time_cost:.1f}ms\n  - Bind embedding: {bind_embedding_time_cost:.1f}ms\n  - Bind LLM: {bind_llm_time_cost:.1f}ms\n  - Tune question: {refine_question_time_cost:.1f}ms\n  - Bind reranker: {bind_reranker_time_cost:.1f}ms\n  - Generate keyword: {generate_keyword_time_cost:.1f}ms\n  - Retrieval: {retrieval_time_cost:.1f}ms\n  - Generate answer: {generate_result_time_cost:.1f}ms" | 
		
	
		
			
			|  |  |  | return {"answer": answer, "reference": refs, "prompt": prompt} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if stream: | 
		
	
	
		
			
			|  |  | @@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | 
		
	
		
			
			|  |  |  | sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" | 
		
	
		
			
			|  |  |  | user_promt = """ | 
		
	
		
			
			|  |  |  | 表名:{}; | 
		
	
		
			
			|  |  |  | 数据库表字段说明如下: | 
		
	
		
			
			|  |  |  | sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." | 
		
	
		
			
			|  |  |  | user_prompt = """ | 
		
	
		
			
			|  |  |  | Table name: {}; | 
		
	
		
			
			|  |  |  | Table of database fields are as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 问题如下: | 
		
	
		
			
			|  |  |  | Question are as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  | 请写出SQL, 且只要SQL,不要有其他说明及文字。 | 
		
	
		
			
			|  |  |  | Please write the SQL, only SQL, without any other explanations or text. | 
		
	
		
			
			|  |  |  | """.format( | 
		
	
		
			
			|  |  |  | index_name(tenant_id), | 
		
	
		
			
			|  |  |  | "\n".join([f"{k}: {v}" for k, v in field_map.items()]), | 
		
	
	
		
			
			|  |  | @@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | 
		
	
		
			
			|  |  |  | tried_times = 0 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def get_table(): | 
		
	
		
			
			|  |  |  | nonlocal sys_prompt, user_promt, question, tried_times | 
		
	
		
			
			|  |  |  | sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { | 
		
	
		
			
			|  |  |  | nonlocal sys_prompt, user_prompt, question, tried_times | 
		
	
		
			
			|  |  |  | sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { | 
		
	
		
			
			|  |  |  | "temperature": 0.06}) | 
		
	
		
			
			|  |  |  | logging.debug(f"{question} ==> {user_promt} get SQL: {sql}") | 
		
	
		
			
			|  |  |  | logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") | 
		
	
		
			
			|  |  |  | sql = re.sub(r"[\r\n]+", " ", sql.lower()) | 
		
	
		
			
			|  |  |  | sql = re.sub(r".*select ", "select ", sql.lower()) | 
		
	
		
			
			|  |  |  | sql = re.sub(r" +", " ", sql) | 
		
	
	
		
			
			|  |  | @@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | 
		
	
		
			
			|  |  |  | if tbl is None: | 
		
	
		
			
			|  |  |  | return None | 
		
	
		
			
			|  |  |  | if tbl.get("error") and tried_times <= 2: | 
		
	
		
			
			|  |  |  | user_promt = """ | 
		
	
		
			
			|  |  |  | 表名:{}; | 
		
	
		
			
			|  |  |  | 数据库表字段说明如下: | 
		
	
		
			
			|  |  |  | user_prompt = """ | 
		
	
		
			
			|  |  |  | Table name: {}; | 
		
	
		
			
			|  |  |  | Table of database fields are as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | 问题如下: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | Question are as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  | Please write the SQL, only SQL, without any other explanations or text. | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 你上一次给出的错误SQL如下: | 
		
	
		
			
			|  |  |  | The SQL error you provided last time is as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 后台报错如下: | 
		
	
		
			
			|  |  |  | Error issued by database as follows: | 
		
	
		
			
			|  |  |  | {} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 | 
		
	
		
			
			|  |  |  | Please correct the error and write SQL again, only SQL, without any other explanations or text. | 
		
	
		
			
			|  |  |  | """.format( | 
		
	
		
			
			|  |  |  | index_name(tenant_id), | 
		
	
		
			
			|  |  |  | "\n".join([f"{k}: {v}" for k, v in field_map.items()]), | 
		
	
	
		
			
			|  |  | @@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | docid_idx = set([ii for ii, c in enumerate( | 
		
	
		
			
			|  |  |  | tbl["columns"]) if c["name"] == "doc_id"]) | 
		
	
		
			
			|  |  |  | docnm_idx = set([ii for ii, c in enumerate( | 
		
	
		
			
			|  |  |  | doc_name_idx = set([ii for ii, c in enumerate( | 
		
	
		
			
			|  |  |  | tbl["columns"]) if c["name"] == "docnm_kwd"]) | 
		
	
		
			
			|  |  |  | clmn_idx = [ii for ii in range( | 
		
	
		
			
			|  |  |  | len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] | 
		
	
		
			
			|  |  |  | column_idx = [ii for ii in range( | 
		
	
		
			
			|  |  |  | len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # compose markdown table | 
		
	
		
			
			|  |  |  | clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], | 
		
	
		
			
			|  |  |  | tbl["columns"][i]["name"])) for i in | 
		
	
		
			
			|  |  |  | clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | 
		
	
		
			
			|  |  |  | # compose Markdown table | 
		
	
		
			
			|  |  |  | columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], | 
		
	
		
			
			|  |  |  | tbl["columns"][i]["name"])) for i in | 
		
	
		
			
			|  |  |  | column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ | 
		
	
		
			
			|  |  |  | line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ | 
		
	
		
			
			|  |  |  | ("|------|" if docid_idx and docid_idx else "") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | rows = ["|" + | 
		
	
		
			
			|  |  |  | "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + | 
		
	
		
			
			|  |  |  | "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + | 
		
	
		
			
			|  |  |  | "|" for r in tbl["rows"]] | 
		
	
		
			
			|  |  |  | rows = [r for r in rows if re.sub(r"[ |]+", "", r)] | 
		
	
		
			
			|  |  |  | if quota: | 
		
	
	
		
			
			|  |  | @@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | 
		
	
		
			
			|  |  |  | rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) | 
		
	
		
			
			|  |  |  | rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not docid_idx or not docnm_idx: | 
		
	
		
			
			|  |  |  | if not docid_idx or not doc_name_idx: | 
		
	
		
			
			|  |  |  | logging.warning("SQL missing field: " + sql) | 
		
	
		
			
			|  |  |  | return { | 
		
	
		
			
			|  |  |  | "answer": "\n".join([clmns, line, rows]), | 
		
	
		
			
			|  |  |  | "answer": "\n".join([columns, line, rows]), | 
		
	
		
			
			|  |  |  | "reference": {"chunks": [], "doc_aggs": []}, | 
		
	
		
			
			|  |  |  | "prompt": sys_prompt | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | docid_idx = list(docid_idx)[0] | 
		
	
		
			
			|  |  |  | docnm_idx = list(docnm_idx)[0] | 
		
	
		
			
			|  |  |  | doc_name_idx = list(doc_name_idx)[0] | 
		
	
		
			
			|  |  |  | doc_aggs = {} | 
		
	
		
			
			|  |  |  | for r in tbl["rows"]: | 
		
	
		
			
			|  |  |  | if r[docid_idx] not in doc_aggs: | 
		
	
		
			
			|  |  |  | doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0} | 
		
	
		
			
			|  |  |  | doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} | 
		
	
		
			
			|  |  |  | doc_aggs[r[docid_idx]]["count"] += 1 | 
		
	
		
			
			|  |  |  | return { | 
		
	
		
			
			|  |  |  | "answer": "\n".join([clmns, line, rows]), | 
		
	
		
			
			|  |  |  | "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], | 
		
	
		
			
			|  |  |  | "answer": "\n".join([columns, line, rows]), | 
		
	
		
			
			|  |  |  | "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], | 
		
	
		
			
			|  |  |  | "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in | 
		
	
		
			
			|  |  |  | doc_aggs.items()]}, | 
		
	
		
			
			|  |  |  | "prompt": sys_prompt | 
		
	
	
		
			
			|  |  | @@ -492,7 +528,7 @@ Requirements: | 
		
	
		
			
			|  |  |  | kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) | 
		
	
		
			
			|  |  |  | if isinstance(kwd, tuple): | 
		
	
		
			
			|  |  |  | kwd = kwd[0] | 
		
	
		
			
			|  |  |  | if kwd.find("**ERROR**") >=0: | 
		
	
		
			
			|  |  |  | if kwd.find("**ERROR**") >= 0: | 
		
	
		
			
			|  |  |  | return "" | 
		
	
		
			
			|  |  |  | return kwd | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -605,16 +641,16 @@ def tts(tts_mdl, text): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def ask(question, kb_ids, tenant_id): | 
		
	
		
			
			|  |  |  | kbs = KnowledgebaseService.get_by_ids(kb_ids) | 
		
	
		
			
			|  |  |  | embd_nms = list(set([kb.embd_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | embedding_list = list(set([kb.embd_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) | 
		
	
		
			
			|  |  |  | retr = settings.retrievaler if not is_kg else settings.kg_retrievaler | 
		
	
		
			
			|  |  |  | is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) | 
		
	
		
			
			|  |  |  | retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) | 
		
	
		
			
			|  |  |  | embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) | 
		
	
		
			
			|  |  |  | chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | 
		
	
		
			
			|  |  |  | max_tokens = chat_mdl.max_length | 
		
	
		
			
			|  |  |  | tenant_ids = list(set([kb.tenant_id for kb in kbs])) | 
		
	
		
			
			|  |  |  | kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) | 
		
	
		
			
			|  |  |  | kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) | 
		
	
		
			
			|  |  |  | knowledges = kb_prompt(kbinfos, max_tokens) | 
		
	
		
			
			|  |  |  | prompt = """ | 
		
	
		
			
			|  |  |  | Role: You're a smart assistant. Your name is Miss R. | 
		
	
	
		
			
			|  |  | @@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def decorate_answer(answer): | 
		
	
		
			
			|  |  |  | nonlocal knowledges, kbinfos, prompt | 
		
	
		
			
			|  |  |  | answer, idx = retr.insert_citations(answer, | 
		
	
		
			
			|  |  |  | [ck["content_ltks"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | [ck["vector"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | embd_mdl, | 
		
	
		
			
			|  |  |  | tkweight=0.7, | 
		
	
		
			
			|  |  |  | vtweight=0.3) | 
		
	
		
			
			|  |  |  | answer, idx = retriever.insert_citations(answer, | 
		
	
		
			
			|  |  |  | [ck["content_ltks"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | [ck["vector"] | 
		
	
		
			
			|  |  |  | for ck in kbinfos["chunks"]], | 
		
	
		
			
			|  |  |  | embd_mdl, | 
		
	
		
			
			|  |  |  | tkweight=0.7, | 
		
	
		
			
			|  |  |  | vtweight=0.3) | 
		
	
		
			
			|  |  |  | idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | 
		
	
		
			
			|  |  |  | recall_docs = [ | 
		
	
		
			
			|  |  |  | d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | 
		
	
	
		
			
			|  |  | @@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id): | 
		
	
		
			
			|  |  |  | answer = ans | 
		
	
		
			
			|  |  |  | yield {"answer": answer, "reference": {}} | 
		
	
		
			
			|  |  |  | yield decorate_answer(answer) | 
		
	
		
			
			|  |  |  | 
 |