|
|
|
|
|
|
|
|
from timeit import default_timer as timer |
|
|
from timeit import default_timer as timer |
|
|
import datetime |
|
|
import datetime |
|
|
from datetime import timedelta |
|
|
from datetime import timedelta |
|
|
from api.db import LLMType, ParserType,StatusEnum |
|
|
|
|
|
|
|
|
from api.db import LLMType, ParserType, StatusEnum |
|
|
from api.db.db_models import Dialog, DB |
|
|
from api.db.db_models import Dialog, DB |
|
|
from api.db.services.common_service import CommonService |
|
|
from api.db.services.common_service import CommonService |
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService |
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
@DB.connection_context() |
|
|
@DB.connection_context() |
|
|
def get_list(cls, tenant_id, |
|
|
def get_list(cls, tenant_id, |
|
|
page_number, items_per_page, orderby, desc, id , name): |
|
|
|
|
|
|
|
|
page_number, items_per_page, orderby, desc, id, name): |
|
|
chats = cls.model.select() |
|
|
chats = cls.model.select() |
|
|
if id: |
|
|
if id: |
|
|
chats = chats.where(cls.model.id == id) |
|
|
chats = chats.where(cls.model.id == id) |
|
|
if name: |
|
|
if name: |
|
|
chats = chats.where(cls.model.name == name) |
|
|
chats = chats.where(cls.model.name == name) |
|
|
chats = chats.where( |
|
|
chats = chats.where( |
|
|
(cls.model.tenant_id == tenant_id) |
|
|
|
|
|
|
|
|
(cls.model.tenant_id == tenant_id) |
|
|
& (cls.model.status == StatusEnum.VALID.value) |
|
|
& (cls.model.status == StatusEnum.VALID.value) |
|
|
) |
|
|
) |
|
|
if desc: |
|
|
if desc: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat(dialog, messages, stream=True, **kwargs): |
|
|
def chat(dialog, messages, stream=True, **kwargs): |
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." |
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." |
|
|
st = timer() |
|
|
|
|
|
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id) |
|
|
|
|
|
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_start_ts = timer() |
|
|
|
|
|
|
|
|
|
|
|
# Get llm model name and model provider name |
|
|
|
|
|
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) |
|
|
|
|
|
|
|
|
|
|
|
# Get llm model instance by model and provide name |
|
|
|
|
|
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) |
|
|
|
|
|
|
|
|
if not llm: |
|
|
if not llm: |
|
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \ |
|
|
|
|
|
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid) |
|
|
|
|
|
|
|
|
# Model name is provided by tenant, but not system built-in |
|
|
|
|
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ |
|
|
|
|
|
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) |
|
|
if not llm: |
|
|
if not llm: |
|
|
raise LookupError("LLM(%s) not found" % dialog.llm_id) |
|
|
raise LookupError("LLM(%s) not found" % dialog.llm_id) |
|
|
max_tokens = 8192 |
|
|
max_tokens = 8192 |
|
|
else: |
|
|
else: |
|
|
max_tokens = llm[0].max_tokens |
|
|
max_tokens = llm[0].max_tokens |
|
|
|
|
|
|
|
|
|
|
|
check_llm_ts = timer() |
|
|
|
|
|
|
|
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) |
|
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) |
|
|
embd_nms = list(set([kb.embd_id for kb in kbs])) |
|
|
|
|
|
if len(embd_nms) != 1: |
|
|
|
|
|
|
|
|
embedding_list = list(set([kb.embd_id for kb in kbs])) |
|
|
|
|
|
if len(embedding_list) != 1: |
|
|
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
|
|
|
|
|
|
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) |
|
|
|
|
|
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler |
|
|
|
|
|
|
|
|
embedding_model_name = embedding_list[0] |
|
|
|
|
|
|
|
|
|
|
|
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) |
|
|
|
|
|
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler |
|
|
|
|
|
|
|
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] |
|
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] |
|
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None |
|
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None |
|
|
|
|
|
|
|
|
if "doc_ids" in m: |
|
|
if "doc_ids" in m: |
|
|
attachments.extend(m["doc_ids"]) |
|
|
attachments.extend(m["doc_ids"]) |
|
|
|
|
|
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) |
|
|
|
|
|
|
|
|
create_retriever_ts = timer() |
|
|
|
|
|
|
|
|
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) |
|
|
if not embd_mdl: |
|
|
if not embd_mdl: |
|
|
raise LookupError("Embedding model(%s) not found" % embd_nms[0]) |
|
|
|
|
|
|
|
|
raise LookupError("Embedding model(%s) not found" % embedding_model_name) |
|
|
|
|
|
|
|
|
|
|
|
bind_embedding_ts = timer() |
|
|
|
|
|
|
|
|
if llm_id2llm_type(dialog.llm_id) == "image2text": |
|
|
if llm_id2llm_type(dialog.llm_id) == "image2text": |
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) |
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) |
|
|
else: |
|
|
else: |
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
|
|
|
|
|
|
|
|
|
|
|
bind_llm_ts = timer() |
|
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config |
|
|
prompt_config = dialog.prompt_config |
|
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) |
|
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) |
|
|
tts_mdl = None |
|
|
tts_mdl = None |
|
|
|
|
|
|
|
|
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] |
|
|
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] |
|
|
else: |
|
|
else: |
|
|
questions = questions[-1:] |
|
|
questions = questions[-1:] |
|
|
refineQ_tm = timer() |
|
|
|
|
|
keyword_tm = timer() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
refine_question_ts = timer() |
|
|
|
|
|
|
|
|
rerank_mdl = None |
|
|
rerank_mdl = None |
|
|
if dialog.rerank_id: |
|
|
if dialog.rerank_id: |
|
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) |
|
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) |
|
|
|
|
|
|
|
|
for _ in range(len(questions) // 2): |
|
|
|
|
|
questions.append(questions[-1]) |
|
|
|
|
|
|
|
|
bind_reranker_ts = timer() |
|
|
|
|
|
generate_keyword_ts = bind_reranker_ts |
|
|
|
|
|
|
|
|
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: |
|
|
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: |
|
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} |
|
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} |
|
|
else: |
|
|
else: |
|
|
if prompt_config.get("keyword", False): |
|
|
if prompt_config.get("keyword", False): |
|
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) |
|
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) |
|
|
keyword_tm = timer() |
|
|
|
|
|
|
|
|
generate_keyword_ts = timer() |
|
|
|
|
|
|
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, |
|
|
|
|
|
dialog.similarity_threshold, |
|
|
|
|
|
dialog.vector_similarity_weight, |
|
|
|
|
|
doc_ids=attachments, |
|
|
|
|
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) |
|
|
|
|
|
|
|
|
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, |
|
|
|
|
|
dialog.similarity_threshold, |
|
|
|
|
|
dialog.vector_similarity_weight, |
|
|
|
|
|
doc_ids=attachments, |
|
|
|
|
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) |
|
|
|
|
|
|
|
|
|
|
|
retrieval_ts = timer() |
|
|
|
|
|
|
|
|
knowledges = kb_prompt(kbinfos, max_tokens) |
|
|
knowledges = kb_prompt(kbinfos, max_tokens) |
|
|
logging.debug( |
|
|
logging.debug( |
|
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) |
|
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) |
|
|
retrieval_tm = timer() |
|
|
|
|
|
|
|
|
|
|
|
if not knowledges and prompt_config.get("empty_response"): |
|
|
if not knowledges and prompt_config.get("empty_response"): |
|
|
empty_res = prompt_config["empty_response"] |
|
|
empty_res = prompt_config["empty_response"] |
|
|
|
|
|
|
|
|
max_tokens - used_token_count) |
|
|
max_tokens - used_token_count) |
|
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
def decorate_answer(answer): |
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm |
|
|
|
|
|
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts |
|
|
|
|
|
|
|
|
|
|
|
finish_chat_ts = timer() |
|
|
|
|
|
|
|
|
refs = [] |
|
|
refs = [] |
|
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): |
|
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): |
|
|
answer, idx = retr.insert_citations(answer, |
|
|
|
|
|
[ck["content_ltks"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
[ck["vector"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
embd_mdl, |
|
|
|
|
|
tkweight=1 - dialog.vector_similarity_weight, |
|
|
|
|
|
vtweight=dialog.vector_similarity_weight) |
|
|
|
|
|
|
|
|
answer, idx = retriever.insert_citations(answer, |
|
|
|
|
|
[ck["content_ltks"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
[ck["vector"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
embd_mdl, |
|
|
|
|
|
tkweight=1 - dialog.vector_similarity_weight, |
|
|
|
|
|
vtweight=dialog.vector_similarity_weight) |
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
recall_docs = [ |
|
|
recall_docs = [ |
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" |
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" |
|
|
done_tm = timer() |
|
|
|
|
|
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % ( |
|
|
|
|
|
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000, |
|
|
|
|
|
(done_tm - retrieval_tm) * 1000) |
|
|
|
|
|
|
|
|
finish_chat_ts = timer() |
|
|
|
|
|
|
|
|
|
|
|
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 |
|
|
|
|
|
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 |
|
|
|
|
|
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000 |
|
|
|
|
|
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 |
|
|
|
|
|
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 |
|
|
|
|
|
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 |
|
|
|
|
|
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 |
|
|
|
|
|
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 |
|
|
|
|
|
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 |
|
|
|
|
|
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 |
|
|
|
|
|
|
|
|
|
|
|
prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" |
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
|
|
|
|
|
|
if stream: |
|
|
if stream: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): |
|
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): |
|
|
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" |
|
|
|
|
|
user_promt = """ |
|
|
|
|
|
表名:{}; |
|
|
|
|
|
数据库表字段说明如下: |
|
|
|
|
|
|
|
|
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." |
|
|
|
|
|
user_prompt = """ |
|
|
|
|
|
Table name: {}; |
|
|
|
|
|
Table of database fields are as follows: |
|
|
{} |
|
|
{} |
|
|
|
|
|
|
|
|
问题如下: |
|
|
|
|
|
|
|
|
Question are as follows: |
|
|
{} |
|
|
{} |
|
|
请写出SQL, 且只要SQL,不要有其他说明及文字。 |
|
|
|
|
|
|
|
|
Please write the SQL, only SQL, without any other explanations or text. |
|
|
""".format( |
|
|
""".format( |
|
|
index_name(tenant_id), |
|
|
index_name(tenant_id), |
|
|
"\n".join([f"{k}: {v}" for k, v in field_map.items()]), |
|
|
"\n".join([f"{k}: {v}" for k, v in field_map.items()]), |
|
|
|
|
|
|
|
|
tried_times = 0 |
|
|
tried_times = 0 |
|
|
|
|
|
|
|
|
def get_table(): |
|
|
def get_table(): |
|
|
nonlocal sys_prompt, user_promt, question, tried_times |
|
|
|
|
|
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { |
|
|
|
|
|
|
|
|
nonlocal sys_prompt, user_prompt, question, tried_times |
|
|
|
|
|
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], { |
|
|
"temperature": 0.06}) |
|
|
"temperature": 0.06}) |
|
|
logging.debug(f"{question} ==> {user_promt} get SQL: {sql}") |
|
|
|
|
|
|
|
|
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") |
|
|
sql = re.sub(r"[\r\n]+", " ", sql.lower()) |
|
|
sql = re.sub(r"[\r\n]+", " ", sql.lower()) |
|
|
sql = re.sub(r".*select ", "select ", sql.lower()) |
|
|
sql = re.sub(r".*select ", "select ", sql.lower()) |
|
|
sql = re.sub(r" +", " ", sql) |
|
|
sql = re.sub(r" +", " ", sql) |
|
|
|
|
|
|
|
|
if tbl is None: |
|
|
if tbl is None: |
|
|
return None |
|
|
return None |
|
|
if tbl.get("error") and tried_times <= 2: |
|
|
if tbl.get("error") and tried_times <= 2: |
|
|
user_promt = """ |
|
|
|
|
|
表名:{}; |
|
|
|
|
|
数据库表字段说明如下: |
|
|
|
|
|
|
|
|
user_prompt = """ |
|
|
|
|
|
Table name: {}; |
|
|
|
|
|
Table of database fields are as follows: |
|
|
{} |
|
|
{} |
|
|
|
|
|
|
|
|
问题如下: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Question are as follows: |
|
|
{} |
|
|
{} |
|
|
|
|
|
Please write the SQL, only SQL, without any other explanations or text. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
你上一次给出的错误SQL如下: |
|
|
|
|
|
|
|
|
The SQL error you provided last time is as follows: |
|
|
{} |
|
|
{} |
|
|
|
|
|
|
|
|
后台报错如下: |
|
|
|
|
|
|
|
|
Error issued by database as follows: |
|
|
{} |
|
|
{} |
|
|
|
|
|
|
|
|
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 |
|
|
|
|
|
|
|
|
Please correct the error and write SQL again, only SQL, without any other explanations or text. |
|
|
""".format( |
|
|
""".format( |
|
|
index_name(tenant_id), |
|
|
index_name(tenant_id), |
|
|
"\n".join([f"{k}: {v}" for k, v in field_map.items()]), |
|
|
"\n".join([f"{k}: {v}" for k, v in field_map.items()]), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docid_idx = set([ii for ii, c in enumerate( |
|
|
docid_idx = set([ii for ii, c in enumerate( |
|
|
tbl["columns"]) if c["name"] == "doc_id"]) |
|
|
tbl["columns"]) if c["name"] == "doc_id"]) |
|
|
docnm_idx = set([ii for ii, c in enumerate( |
|
|
|
|
|
|
|
|
doc_name_idx = set([ii for ii, c in enumerate( |
|
|
tbl["columns"]) if c["name"] == "docnm_kwd"]) |
|
|
tbl["columns"]) if c["name"] == "docnm_kwd"]) |
|
|
clmn_idx = [ii for ii in range( |
|
|
|
|
|
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] |
|
|
|
|
|
|
|
|
column_idx = [ii for ii in range( |
|
|
|
|
|
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] |
|
|
|
|
|
|
|
|
# compose markdown table |
|
|
|
|
|
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], |
|
|
|
|
|
tbl["columns"][i]["name"])) for i in |
|
|
|
|
|
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") |
|
|
|
|
|
|
|
|
# compose Markdown table |
|
|
|
|
|
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], |
|
|
|
|
|
tbl["columns"][i]["name"])) for i in |
|
|
|
|
|
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") |
|
|
|
|
|
|
|
|
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ |
|
|
|
|
|
|
|
|
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \ |
|
|
("|------|" if docid_idx and docid_idx else "") |
|
|
("|------|" if docid_idx and docid_idx else "") |
|
|
|
|
|
|
|
|
rows = ["|" + |
|
|
rows = ["|" + |
|
|
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + |
|
|
|
|
|
|
|
|
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + |
|
|
"|" for r in tbl["rows"]] |
|
|
"|" for r in tbl["rows"]] |
|
|
rows = [r for r in rows if re.sub(r"[ |]+", "", r)] |
|
|
rows = [r for r in rows if re.sub(r"[ |]+", "", r)] |
|
|
if quota: |
|
|
if quota: |
|
|
|
|
|
|
|
|
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) |
|
|
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) |
|
|
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) |
|
|
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) |
|
|
|
|
|
|
|
|
if not docid_idx or not docnm_idx: |
|
|
|
|
|
|
|
|
if not docid_idx or not doc_name_idx: |
|
|
logging.warning("SQL missing field: " + sql) |
|
|
logging.warning("SQL missing field: " + sql) |
|
|
return { |
|
|
return { |
|
|
"answer": "\n".join([clmns, line, rows]), |
|
|
|
|
|
|
|
|
"answer": "\n".join([columns, line, rows]), |
|
|
"reference": {"chunks": [], "doc_aggs": []}, |
|
|
"reference": {"chunks": [], "doc_aggs": []}, |
|
|
"prompt": sys_prompt |
|
|
"prompt": sys_prompt |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
docid_idx = list(docid_idx)[0] |
|
|
docid_idx = list(docid_idx)[0] |
|
|
docnm_idx = list(docnm_idx)[0] |
|
|
|
|
|
|
|
|
doc_name_idx = list(doc_name_idx)[0] |
|
|
doc_aggs = {} |
|
|
doc_aggs = {} |
|
|
for r in tbl["rows"]: |
|
|
for r in tbl["rows"]: |
|
|
if r[docid_idx] not in doc_aggs: |
|
|
if r[docid_idx] not in doc_aggs: |
|
|
doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0} |
|
|
|
|
|
|
|
|
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} |
|
|
doc_aggs[r[docid_idx]]["count"] += 1 |
|
|
doc_aggs[r[docid_idx]]["count"] += 1 |
|
|
return { |
|
|
return { |
|
|
"answer": "\n".join([clmns, line, rows]), |
|
|
|
|
|
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], |
|
|
|
|
|
|
|
|
"answer": "\n".join([columns, line, rows]), |
|
|
|
|
|
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], |
|
|
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in |
|
|
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in |
|
|
doc_aggs.items()]}, |
|
|
doc_aggs.items()]}, |
|
|
"prompt": sys_prompt |
|
|
"prompt": sys_prompt |
|
|
|
|
|
|
|
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) |
|
|
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) |
|
|
if isinstance(kwd, tuple): |
|
|
if isinstance(kwd, tuple): |
|
|
kwd = kwd[0] |
|
|
kwd = kwd[0] |
|
|
if kwd.find("**ERROR**") >=0: |
|
|
|
|
|
|
|
|
if kwd.find("**ERROR**") >= 0: |
|
|
return "" |
|
|
return "" |
|
|
return kwd |
|
|
return kwd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ask(question, kb_ids, tenant_id): |
|
|
def ask(question, kb_ids, tenant_id): |
|
|
kbs = KnowledgebaseService.get_by_ids(kb_ids) |
|
|
kbs = KnowledgebaseService.get_by_ids(kb_ids) |
|
|
embd_nms = list(set([kb.embd_id for kb in kbs])) |
|
|
|
|
|
|
|
|
embedding_list = list(set([kb.embd_id for kb in kbs])) |
|
|
|
|
|
|
|
|
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) |
|
|
|
|
|
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler |
|
|
|
|
|
|
|
|
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) |
|
|
|
|
|
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler |
|
|
|
|
|
|
|
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) |
|
|
|
|
|
|
|
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) |
|
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) |
|
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) |
|
|
max_tokens = chat_mdl.max_length |
|
|
max_tokens = chat_mdl.max_length |
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) |
|
|
|
|
|
|
|
|
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) |
|
|
knowledges = kb_prompt(kbinfos, max_tokens) |
|
|
knowledges = kb_prompt(kbinfos, max_tokens) |
|
|
prompt = """ |
|
|
prompt = """ |
|
|
Role: You're a smart assistant. Your name is Miss R. |
|
|
Role: You're a smart assistant. Your name is Miss R. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
def decorate_answer(answer): |
|
|
nonlocal knowledges, kbinfos, prompt |
|
|
nonlocal knowledges, kbinfos, prompt |
|
|
answer, idx = retr.insert_citations(answer, |
|
|
|
|
|
[ck["content_ltks"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
[ck["vector"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
embd_mdl, |
|
|
|
|
|
tkweight=0.7, |
|
|
|
|
|
vtweight=0.3) |
|
|
|
|
|
|
|
|
answer, idx = retriever.insert_citations(answer, |
|
|
|
|
|
[ck["content_ltks"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
[ck["vector"] |
|
|
|
|
|
for ck in kbinfos["chunks"]], |
|
|
|
|
|
embd_mdl, |
|
|
|
|
|
tkweight=0.7, |
|
|
|
|
|
vtweight=0.3) |
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) |
|
|
recall_docs = [ |
|
|
recall_docs = [ |
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] |
|
|
|
|
|
|
|
|
answer = ans |
|
|
answer = ans |
|
|
yield {"answer": answer, "reference": {}} |
|
|
yield {"answer": answer, "reference": {}} |
|
|
yield decorate_answer(answer) |
|
|
yield decorate_answer(answer) |
|
|
|
|
|
|