### What problem does this PR solve? fix tts interface error ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.11.0
| tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) | tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) | ||||
| def stream_audio(): | def stream_audio(): | ||||
| try: | try: | ||||
| for chunk in tts_mdl(text): | |||||
| for chunk in tts_mdl.tts(text): | |||||
| yield chunk | yield chunk | ||||
| except Exception as e: | except Exception as e: | ||||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||||
| yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||||
| "data": {"answer": "**ERROR**: "+str(e)}}, | "data": {"answer": "**ERROR**: "+str(e)}}, | ||||
| ensure_ascii=False).encode('utf-8') | |||||
| ensure_ascii=False)).encode('utf-8') | |||||
| resp = Response(stream_audio(), mimetype="audio/mpeg") | resp = Response(stream_audio(), mimetype="audio/mpeg") | ||||
| resp.headers.add_header("Cache-Control", "no-cache") | resp.headers.add_header("Cache-Control", "no-cache") | 
| for lm in LLMService.query(llm_name=llm_name): | for lm in LLMService.query(llm_name=llm_name): | ||||
| self.max_length = lm.max_tokens | self.max_length = lm.max_tokens | ||||
| break | break | ||||
| def encode(self, texts: list, batch_size=32): | def encode(self, texts: list, batch_size=32): | ||||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | emd, used_tokens = self.mdl.encode(texts, batch_size) | ||||
| if not TenantLLMService.increase_usage( | if not TenantLLMService.increase_usage( | ||||
| "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id)) | "Can't update token usage for {}/SEQUENCE2TXT".format(self.tenant_id)) | ||||
| return txt | return txt | ||||
| def tts(self, text): | |||||
| for chunk in self.mdl.tts(text): | |||||
| if isinstance(chunk,int): | |||||
| if not TenantLLMService.increase_usage( | |||||
| self.tenant_id, self.llm_type, chunk, self.llm_name): | |||||
| database_logger.error( | |||||
| "Can't update token usage for {}/TTS".format(self.tenant_id)) | |||||
| return | |||||
| yield chunk | |||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | txt, used_tokens = self.mdl.chat(system, history, gen_conf) | ||||
| if not TenantLLMService.increase_usage( | if not TenantLLMService.increase_usage( | 
| from pydantic import BaseModel, conint | from pydantic import BaseModel, conint | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| import json | import json | ||||
| import re | |||||
| class ServeReferenceAudio(BaseModel): | class ServeReferenceAudio(BaseModel): | ||||
| audio: bytes | audio: bytes | ||||
| def __init__(self, key, model_name, base_url): | def __init__(self, key, model_name, base_url): | ||||
| pass | pass | ||||
| def transcription(self, audio): | |||||
| def tts(self, audio): | |||||
| pass | pass | ||||
| def normalize_text(text): | |||||
| return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) | |||||
| class FishAudioTTS(Base): | class FishAudioTTS(Base): | ||||
| def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): | def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): | ||||
| self.ref_id = key.get("fish_audio_refid") | self.ref_id = key.get("fish_audio_refid") | ||||
| self.base_url = base_url | self.base_url = base_url | ||||
| def transcription(self, text): | |||||
| def tts(self, text): | |||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| request = request = ServeTTSRequest(text=text, reference_id=self.ref_id) | |||||
| text = self.normalize_text(text) | |||||
| request = ServeTTSRequest(text=text, reference_id=self.ref_id) | |||||
| with httpx.Client() as client: | with httpx.Client() as client: | ||||
| try: | try: |