|
|
|
@@ -170,8 +170,40 @@ def label_question(question, kbs): |
|
|
|
return tags |
|
|
|
|
|
|
|
|
|
|
|
def chat_solo(dialog, messages, stream=True): |
|
|
|
if llm_id2llm_type(dialog.llm_id) == "image2text": |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) |
|
|
|
else: |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config |
|
|
|
tts_mdl = None |
|
|
|
if prompt_config.get("tts"): |
|
|
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) |
|
|
|
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} |
|
|
|
for m in messages if m["role"] != "system"] |
|
|
|
if stream: |
|
|
|
last_ans = "" |
|
|
|
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): |
|
|
|
answer = ans |
|
|
|
delta_ans = ans[len(last_ans):] |
|
|
|
if num_tokens_from_string(delta_ans) < 16: |
|
|
|
continue |
|
|
|
last_ans = answer |
|
|
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()} |
|
|
|
else: |
|
|
|
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) |
|
|
|
user_content = msg[-1].get("content", "[content not available]") |
|
|
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer)) |
|
|
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} |
|
|
|
|
|
|
|
|
|
|
|
def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." |
|
|
|
if not dialog.kb_ids: |
|
|
|
for ans in chat_solo(dialog, messages, stream): |
|
|
|
yield ans |
|
|
|
return |
|
|
|
|
|
|
|
chat_start_ts = timer() |
|
|
|
|