|
|
|
@@ -13,6 +13,7 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# |
|
|
|
import binascii |
|
|
|
import os |
|
|
|
import json |
|
|
|
import re |
|
|
|
@@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config |
|
|
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) |
|
|
|
tts_mdl = None |
|
|
|
if prompt_config.get("tts"): |
|
|
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) |
|
|
|
# try to use sql if field mapping is good to go |
|
|
|
if field_map: |
|
|
|
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) |
|
|
|
@@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) |
|
|
|
|
|
|
|
if not knowledges and prompt_config.get("empty_response"): |
|
|
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos} |
|
|
|
empty_res = prompt_config["empty_response"] |
|
|
|
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} |
|
|
|
return {"answer": prompt_config["empty_response"], "reference": kbinfos} |
|
|
|
|
|
|
|
kwargs["knowledge"] = "\n".join(knowledges) |
|
|
|
@@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
return {"answer": answer, "reference": refs, "prompt": prompt} |
|
|
|
|
|
|
|
if stream: |
|
|
|
last_ans = "" |
|
|
|
answer = "" |
|
|
|
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): |
|
|
|
answer = ans |
|
|
|
yield {"answer": answer, "reference": {}} |
|
|
|
delta_ans = ans[len(last_ans):] |
|
|
|
if num_tokens_from_string(delta_ans) < 12: |
|
|
|
continue |
|
|
|
last_ans = answer |
|
|
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} |
|
|
|
delta_ans = answer[len(last_ans):] |
|
|
|
if delta_ans: |
|
|
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} |
|
|
|
yield decorate_answer(answer) |
|
|
|
else: |
|
|
|
answer = chat_mdl.chat(prompt, msg[1:], gen_conf) |
|
|
|
chat_logger.info("User: {}|Assistant: {}".format( |
|
|
|
msg[-1]["content"], answer)) |
|
|
|
yield decorate_answer(answer) |
|
|
|
res = decorate_answer(answer) |
|
|
|
res["audio_binary"] = tts(tts_mdl, answer) |
|
|
|
yield res |
|
|
|
|
|
|
|
|
|
|
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): |
|
|
|
@@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question): |
|
|
|
""" |
|
|
|
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8}) |
|
|
|
return ans |
|
|
|
|
|
|
|
|
|
|
|
def tts(tts_mdl, text): |
|
|
|
return |
|
|
|
if not tts_mdl or not text: return |
|
|
|
bin = b"" |
|
|
|
for chunk in tts_mdl.tts(text): |
|
|
|
bin += chunk |
|
|
|
return binascii.hexlify(bin).decode("utf-8") |