|
|
|
@@ -196,6 +196,8 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] |
|
|
|
else: |
|
|
|
questions = questions[-1:] |
|
|
|
refineQ_tm = timer() |
|
|
|
keyword_tm = timer() |
|
|
|
|
|
|
|
rerank_mdl = None |
|
|
|
if dialog.rerank_id: |
|
|
|
@@ -208,6 +210,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
else: |
|
|
|
if prompt_config.get("keyword", False): |
|
|
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) |
|
|
|
keyword_tm = timer() |
|
|
|
|
|
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
|
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, |
|
|
|
@@ -267,7 +270,9 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: |
|
|
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" |
|
|
|
done_tm = timer() |
|
|
|
prompt += "\n\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000) |
|
|
|
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % ( |
|
|
|
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000, |
|
|
|
(done_tm - retrieval_tm) * 1000) |
|
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
|
|
|
|
|
if stream: |