You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tts_model.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import _thread as thread
  17. import base64
  18. import hashlib
  19. import hmac
  20. import json
  21. import queue
  22. import re
  23. import ssl
  24. import time
  25. from abc import ABC
  26. from datetime import datetime
  27. from time import mktime
  28. from typing import Annotated, Literal
  29. from urllib.parse import urlencode
  30. from wsgiref.handlers import format_date_time
  31. import httpx
  32. import ormsgpack
  33. import requests
  34. import websocket
  35. from pydantic import BaseModel, conint
  36. from rag.utils import num_tokens_from_string
  37. class ServeReferenceAudio(BaseModel):
  38. audio: bytes
  39. text: str
  40. class ServeTTSRequest(BaseModel):
  41. text: str
  42. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  43. # Audio format
  44. format: Literal["wav", "pcm", "mp3"] = "mp3"
  45. mp3_bitrate: Literal[64, 128, 192] = 128
  46. # References audios for in-context learning
  47. references: list[ServeReferenceAudio] = []
  48. # Reference id
  49. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  50. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  51. reference_id: str | None = None
  52. # Normalize text for en & zh, this increase stability for numbers
  53. normalize: bool = True
  54. # Balance mode will reduce latency to 300ms, but may decrease stability
  55. latency: Literal["normal", "balanced"] = "normal"
  56. class Base(ABC):
  57. def __init__(self, key, model_name, base_url, **kwargs):
  58. """
  59. Abstract base class constructor.
  60. Parameters are not stored; subclasses should handle their own initialization.
  61. """
  62. pass
  63. def tts(self, audio):
  64. pass
  65. def normalize_text(self, text):
  66. return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
  67. class FishAudioTTS(Base):
  68. _FACTORY_NAME = "Fish Audio"
  69. def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
  70. if not base_url:
  71. base_url = "https://api.fish.audio/v1/tts"
  72. key = json.loads(key)
  73. self.headers = {
  74. "api-key": key.get("fish_audio_ak"),
  75. "content-type": "application/msgpack",
  76. }
  77. self.ref_id = key.get("fish_audio_refid")
  78. self.base_url = base_url
  79. def tts(self, text):
  80. from http import HTTPStatus
  81. text = self.normalize_text(text)
  82. request = ServeTTSRequest(text=text, reference_id=self.ref_id)
  83. with httpx.Client() as client:
  84. try:
  85. with client.stream(
  86. method="POST",
  87. url=self.base_url,
  88. content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
  89. headers=self.headers,
  90. timeout=None,
  91. ) as response:
  92. if response.status_code == HTTPStatus.OK:
  93. for chunk in response.iter_bytes():
  94. yield chunk
  95. else:
  96. response.raise_for_status()
  97. yield num_tokens_from_string(text)
  98. except httpx.HTTPStatusError as e:
  99. raise RuntimeError(f"**ERROR**: {e}")
  100. class QwenTTS(Base):
  101. _FACTORY_NAME = "Tongyi-Qianwen"
  102. def __init__(self, key, model_name, base_url=""):
  103. import dashscope
  104. self.model_name = model_name
  105. dashscope.api_key = key
  106. def tts(self, text):
  107. from collections import deque
  108. from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
  109. from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
  110. class Callback(ResultCallback):
  111. def __init__(self) -> None:
  112. self.dque = deque()
  113. def _run(self):
  114. while True:
  115. if not self.dque:
  116. time.sleep(0)
  117. continue
  118. val = self.dque.popleft()
  119. if val:
  120. yield val
  121. else:
  122. break
  123. def on_open(self):
  124. pass
  125. def on_complete(self):
  126. self.dque.append(None)
  127. def on_error(self, response: SpeechSynthesisResponse):
  128. raise RuntimeError(str(response))
  129. def on_close(self):
  130. pass
  131. def on_event(self, result: SpeechSynthesisResult):
  132. if result.get_audio_frame() is not None:
  133. self.dque.append(result.get_audio_frame())
  134. text = self.normalize_text(text)
  135. callback = Callback()
  136. SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
  137. try:
  138. for data in callback._run():
  139. yield data
  140. yield num_tokens_from_string(text)
  141. except Exception as e:
  142. raise RuntimeError(f"**ERROR**: {e}")
  143. class OpenAITTS(Base):
  144. _FACTORY_NAME = "OpenAI"
  145. def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
  146. if not base_url:
  147. base_url = "https://api.openai.com/v1"
  148. self.api_key = key
  149. self.model_name = model_name
  150. self.base_url = base_url
  151. self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
  152. def tts(self, text, voice="alloy"):
  153. text = self.normalize_text(text)
  154. payload = {"model": self.model_name, "voice": voice, "input": text}
  155. response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
  156. if response.status_code != 200:
  157. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  158. for chunk in response.iter_content():
  159. if chunk:
  160. yield chunk
  161. class SparkTTS(Base):
  162. _FACTORY_NAME = "XunFei Spark"
  163. STATUS_FIRST_FRAME = 0
  164. STATUS_CONTINUE_FRAME = 1
  165. STATUS_LAST_FRAME = 2
  166. def __init__(self, key, model_name, base_url=""):
  167. key = json.loads(key)
  168. self.APPID = key.get("spark_app_id", "xxxxxxx")
  169. self.APISecret = key.get("spark_api_secret", "xxxxxxx")
  170. self.APIKey = key.get("spark_api_key", "xxxxxx")
  171. self.model_name = model_name
  172. self.CommonArgs = {"app_id": self.APPID}
  173. self.audio_queue = queue.Queue()
  174. # 用来存储音频数据
  175. # 生成url
  176. def create_url(self):
  177. url = "wss://tts-api.xfyun.cn/v2/tts"
  178. now = datetime.now()
  179. date = format_date_time(mktime(now.timetuple()))
  180. signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
  181. signature_origin += "date: " + date + "\n"
  182. signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
  183. signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
  184. signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
  185. authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
  186. authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
  187. v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
  188. url = url + "?" + urlencode(v)
  189. return url
  190. def tts(self, text):
  191. BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
  192. Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
  193. CommonArgs = {"app_id": self.APPID}
  194. audio_queue = self.audio_queue
  195. model_name = self.model_name
  196. class Callback:
  197. def __init__(self):
  198. self.audio_queue = audio_queue
  199. def on_message(self, ws, message):
  200. message = json.loads(message)
  201. code = message["code"]
  202. sid = message["sid"]
  203. audio = message["data"]["audio"]
  204. audio = base64.b64decode(audio)
  205. status = message["data"]["status"]
  206. if status == 2:
  207. ws.close()
  208. if code != 0:
  209. errMsg = message["message"]
  210. raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
  211. else:
  212. self.audio_queue.put(audio)
  213. def on_error(self, ws, error):
  214. raise Exception(error)
  215. def on_close(self, ws, close_status_code, close_msg):
  216. self.audio_queue.put(None) # 放入 None 作为结束标志
  217. def on_open(self, ws):
  218. def run(*args):
  219. d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
  220. ws.send(json.dumps(d))
  221. thread.start_new_thread(run, ())
  222. wsUrl = self.create_url()
  223. websocket.enableTrace(False)
  224. a = Callback()
  225. ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
  226. status_code = 0
  227. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  228. while True:
  229. audio_chunk = self.audio_queue.get()
  230. if audio_chunk is None:
  231. if status_code == 0:
  232. raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
  233. else:
  234. break
  235. status_code = 1
  236. yield audio_chunk
  237. class XinferenceTTS(Base):
  238. _FACTORY_NAME = "Xinference"
  239. def __init__(self, key, model_name, **kwargs):
  240. self.base_url = kwargs.get("base_url", None)
  241. self.model_name = model_name
  242. self.headers = {"accept": "application/json", "Content-Type": "application/json"}
  243. def tts(self, text, voice="中文女", stream=True):
  244. payload = {"model": self.model_name, "input": text, "voice": voice}
  245. response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
  246. if response.status_code != 200:
  247. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  248. for chunk in response.iter_content(chunk_size=1024):
  249. if chunk:
  250. yield chunk
  251. class OllamaTTS(Base):
  252. def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
  253. if not base_url:
  254. base_url = "https://api.ollama.ai/v1"
  255. self.model_name = model_name
  256. self.base_url = base_url
  257. self.headers = {"Content-Type": "application/json"}
  258. if key and key != "x":
  259. self.headers["Authorization"] = f"Bearer {key}"
  260. def tts(self, text, voice="standard-voice"):
  261. payload = {"model": self.model_name, "voice": voice, "input": text}
  262. response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
  263. if response.status_code != 200:
  264. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  265. for chunk in response.iter_content():
  266. if chunk:
  267. yield chunk
  268. class GPUStackTTS(Base):
  269. _FACTORY_NAME = "GPUStack"
  270. def __init__(self, key, model_name, **kwargs):
  271. self.base_url = kwargs.get("base_url", None)
  272. self.api_key = key
  273. self.model_name = model_name
  274. self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
  275. def tts(self, text, voice="Chinese Female", stream=True):
  276. payload = {"model": self.model_name, "input": text, "voice": voice}
  277. response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
  278. if response.status_code != 200:
  279. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  280. for chunk in response.iter_content(chunk_size=1024):
  281. if chunk:
  282. yield chunk
  283. class SILICONFLOWTTS(Base):
  284. _FACTORY_NAME = "SILICONFLOW"
  285. def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
  286. if not base_url:
  287. base_url = "https://api.siliconflow.cn/v1"
  288. self.api_key = key
  289. self.model_name = model_name
  290. self.base_url = base_url
  291. self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
  292. def tts(self, text, voice="anna"):
  293. text = self.normalize_text(text)
  294. payload = {
  295. "model": self.model_name,
  296. "input": text,
  297. "voice": f"{self.model_name}:{voice}",
  298. "response_format": "mp3",
  299. "sample_rate": 123,
  300. "stream": True,
  301. "speed": 1,
  302. "gain": 0,
  303. }
  304. response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
  305. if response.status_code != 200:
  306. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  307. for chunk in response.iter_content():
  308. if chunk:
  309. yield chunk
  310. class DeepInfraTTS(OpenAITTS):
  311. _FACTORY_NAME = "DeepInfra"
  312. def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
  313. if not base_url:
  314. base_url = "https://api.deepinfra.com/v1/openai"
  315. super().__init__(key, model_name, base_url, **kwargs)