|
|
|
@@ -193,14 +193,14 @@ def chat(dialog, messages, **kwargs): |
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config
|
|
|
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
|
|
|
# try to use sql if field mapping is good to go
|
|
|
|
if field_map:
|
|
|
|
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
|
|
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
|
|
|
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
|
|
|
if ans: return ans
|
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config
|
|
|
|
for p in prompt_config["parameters"]:
|
|
|
|
if p["key"] == "knowledge":
|
|
|
|
continue
|
|
|
|
@@ -255,6 +255,7 @@ def chat(dialog, messages, **kwargs): |
|
|
|
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
|
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
|
|
|
kbinfos["doc_aggs"] = recall_docs
|
|
|
|
|
|
|
|
for c in kbinfos["chunks"]:
|
|
|
|
if c.get("vector"):
|
|
|
|
del c["vector"]
|
|
|
|
@@ -263,7 +264,7 @@ def chat(dialog, messages, **kwargs): |
|
|
|
return {"answer": answer, "reference": kbinfos}
|
|
|
|
|
|
|
|
|
|
|
|
def use_sql(question, field_map, tenant_id, chat_mdl):
|
|
|
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
|
|
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
|
|
|
user_promt = """
|
|
|
|
表名:{};
|
|
|
|
@@ -353,12 +354,16 @@ def use_sql(question, field_map, tenant_id, chat_mdl): |
|
|
|
# compose markdown table
|
|
|
|
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
|
|
|
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
|
|
|
|
|
|
|
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
|
|
|
("|------|" if docid_idx and docid_idx else "")
|
|
|
|
|
|
|
|
rows = ["|" +
|
|
|
|
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
|
|
|
"|" for r in tbl["rows"]]
|
|
|
|
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
|
|
|
if quota:
|
|
|
|
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
|
|
|
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
|
|
|
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
|
|
|
|
|
|
|
if not docid_idx or not docnm_idx:
|