|
|
|
@@ -179,6 +179,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
for m in messages if m["role"] != "system"]) |
|
|
|
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) |
|
|
|
assert len(msg) >= 2, f"message_fit_in has bug: {msg}" |
|
|
|
prompt = msg[0]["content"] |
|
|
|
|
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
gen_conf["max_tokens"] = min( |
|
|
|
@@ -186,7 +187,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
max_tokens - used_token_count) |
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt |
|
|
|
refs = [] |
|
|
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): |
|
|
|
answer, idx = retr.insert_citations(answer, |
|
|
|
@@ -210,17 +211,16 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
|
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
return {"answer": answer, "reference": refs} |
|
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
|
|
|
|
|
if stream: |
|
|
|
answer = "" |
|
|
|
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf): |
|
|
|
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): |
|
|
|
answer = ans |
|
|
|
yield {"answer": answer, "reference": {}} |
|
|
|
yield {"answer": answer, "reference": {}, "prompt": prompt} |
|
|
|
yield decorate_answer(answer) |
|
|
|
else: |
|
|
|
answer = chat_mdl.chat( |
|
|
|
msg[0]["content"], msg[1:], gen_conf) |
|
|
|
answer = chat_mdl.chat(prompt, msg[1:], gen_conf) |
|
|
|
chat_logger.info("User: {}|Assistant: {}".format( |
|
|
|
msg[-1]["content"], answer)) |
|
|
|
yield decorate_answer(answer) |
|
|
|
@@ -334,7 +334,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): |
|
|
|
chat_logger.warning("SQL missing field: " + sql) |
|
|
|
return { |
|
|
|
"answer": "\n".join([clmns, line, rows]), |
|
|
|
"reference": {"chunks": [], "doc_aggs": []} |
|
|
|
"reference": {"chunks": [], "doc_aggs": []}, |
|
|
|
"prompt": sys_prompt |
|
|
|
} |
|
|
|
|
|
|
|
docid_idx = list(docid_idx)[0] |
|
|
|
@@ -348,7 +349,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): |
|
|
|
"answer": "\n".join([clmns, line, rows]), |
|
|
|
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], |
|
|
|
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in |
|
|
|
doc_aggs.items()]} |
|
|
|
doc_aggs.items()]}, |
|
|
|
"prompt": sys_prompt |
|
|
|
} |
|
|
|
|
|
|
|
|