Browse Source

add stream chat with TTS (#2228)

### What problem does this PR solve?



### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.11.0
Kevin Hu 1 year ago
parent
commit
abc32803cc
No account linked to committer's email address
2 changed files with 29 additions and 5 deletions
  1. 2
    2
      api/apps/conversation_app.py
  2. 27
    3
      api/db/services/dialog_service.py

+ 2
- 2
api/apps/conversation_app.py View File

tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
def stream_audio(): def stream_audio():
try: try:
for chunk in tts_mdl.tts(text):
yield chunk
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e: except Exception as e:
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e)}}, "data": {"answer": "**ERROR**: "+str(e)}},

+ 27
- 3
api/db/services/dialog_service.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import binascii
import os import os
import json import json
import re import re


prompt_config = dialog.prompt_config prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
# try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))


if not knowledges and prompt_config.get("empty_response"): if not knowledges and prompt_config.get("empty_response"):
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos} return {"answer": prompt_config["empty_response"], "reference": kbinfos}


kwargs["knowledge"] = "\n".join(knowledges) kwargs["knowledge"] = "\n".join(knowledges)
return {"answer": answer, "reference": refs, "prompt": prompt} return {"answer": answer, "reference": refs, "prompt": prompt}


if stream: if stream:
last_ans = ""
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans answer = ans
yield {"answer": answer, "reference": {}}
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 12:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(answer) yield decorate_answer(answer)
else: else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf) answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
chat_logger.info("User: {}|Assistant: {}".format( chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer)) msg[-1]["content"], answer))
yield decorate_answer(answer)
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
yield res




def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
""" """
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8}) ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
return ans return ans


def tts(tts_mdl, text):
return
if not tts_mdl or not text: return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")

Loading…
Cancel
Save