### What problem does this PR solve? fix tts interface error ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.11.0
| @@ -196,12 +196,12 @@ def tts(): | |||
| tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) | |||
| def stream_audio(): | |||
| try: | |||
| for chunk in tts_mdl(text): | |||
| for chunk in tts_mdl.tts(text): | |||
| yield chunk | |||
| except Exception as e: | |||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||
| yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||
| "data": {"answer": "**ERROR**: "+str(e)}}, | |||
| ensure_ascii=False).encode('utf-8') | |||
| ensure_ascii=False)).encode('utf-8') | |||
| resp = Response(stream_audio(), mimetype="audio/mpeg") | |||
| resp.headers.add_header("Cache-Control", "no-cache") | |||
| @@ -194,7 +194,7 @@ class LLMBundle(object): | |||
| for lm in LLMService.query(llm_name=llm_name): | |||
| self.max_length = lm.max_tokens | |||
| break | |||
| def encode(self, texts: list, batch_size=32): | |||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | |||
| if not TenantLLMService.increase_usage( | |||
| @@ -235,6 +235,17 @@ class LLMBundle(object): | |||
| "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id)) | |||
| return txt | |||
| def tts(self, text): | |||
| for chunk in self.mdl.tts(text): | |||
| if isinstance(chunk,int): | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, chunk, self.llm_name): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/TTS".format(self.tenant_id)) | |||
| return | |||
| yield chunk | |||
| def chat(self, system, history, gen_conf): | |||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | |||
| if not TenantLLMService.increase_usage( | |||
| @@ -21,7 +21,7 @@ import ormsgpack | |||
| from pydantic import BaseModel, conint | |||
| from rag.utils import num_tokens_from_string | |||
| import json | |||
| import re | |||
| class ServeReferenceAudio(BaseModel): | |||
| audio: bytes | |||
| @@ -50,9 +50,11 @@ class Base(ABC): | |||
| def __init__(self, key, model_name, base_url): | |||
| pass | |||
| def transcription(self, audio): | |||
| def tts(self, audio): | |||
| pass | |||
| def normalize_text(text): | |||
| return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) | |||
| class FishAudioTTS(Base): | |||
| def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): | |||
| @@ -66,10 +68,11 @@ class FishAudioTTS(Base): | |||
| self.ref_id = key.get("fish_audio_refid") | |||
| self.base_url = base_url | |||
| def transcription(self, text): | |||
| def tts(self, text): | |||
| from http import HTTPStatus | |||
| request = request = ServeTTSRequest(text=text, reference_id=self.ref_id) | |||
| text = self.normalize_text(text) | |||
| request = ServeTTSRequest(text=text, reference_id=self.ref_id) | |||
| with httpx.Client() as client: | |||
| try: | |||