瀏覽代碼

add prompt to message (#2099)

### What problem does this PR solve?

#2098

### Type of change
 
- [x] New Feature (non-breaking change which adds functionality)
tags/v0.11.0
Kevin Hu 1 年之前
父節點
當前提交
6d3e3e4e3c
No account linked to committer's email address
共有 2 個文件被更改,包括 12 次插入9 次删除
  1. 2
    1
      api/apps/conversation_app.py
  2. 10
    8
      api/db/services/dialog_service.py

+ 2
- 1
api/apps/conversation_app.py 查看文件

if not conv.reference: if not conv.reference:
conv.reference.append(ans["reference"]) conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"] else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}


def stream(): def stream():
nonlocal dia, msg, req, conv nonlocal dia, msg, req, conv

+ 10
- 8
api/db/services/dialog_service.py 查看文件

for m in messages if m["role"] != "system"]) for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}" assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
prompt = msg[0]["content"]


if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min( gen_conf["max_tokens"] = min(
max_tokens - used_token_count) max_tokens - used_token_count)


def decorate_answer(answer): def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
refs = [] refs = []
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retr.insert_citations(answer, answer, idx = retr.insert_citations(answer,


if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
return {"answer": answer, "reference": refs, "prompt": prompt}


if stream: if stream:
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf):
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans answer = ans
yield {"answer": answer, "reference": {}}
yield {"answer": answer, "reference": {}, "prompt": prompt}
yield decorate_answer(answer) yield decorate_answer(answer)
else: else:
answer = chat_mdl.chat(
msg[0]["content"], msg[1:], gen_conf)
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
chat_logger.info("User: {}|Assistant: {}".format( chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer)) msg[-1]["content"], answer))
yield decorate_answer(answer) yield decorate_answer(answer)
chat_logger.warning("SQL missing field: " + sql) chat_logger.warning("SQL missing field: " + sql)
return { return {
"answer": "\n".join([clmns, line, rows]), "answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []}
"reference": {"chunks": [], "doc_aggs": []},
"prompt": sys_prompt
} }


docid_idx = list(docid_idx)[0] docid_idx = list(docid_idx)[0]
"answer": "\n".join([clmns, line, rows]), "answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]}
doc_aggs.items()]},
"prompt": sys_prompt
} }





Loading…
取消
儲存