You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tts_model.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import _thread as thread
  17. import base64
  18. import datetime
  19. import hashlib
  20. import hmac
  21. import json
  22. import queue
  23. import re
  24. import ssl
  25. import time
  26. from abc import ABC
  27. from datetime import datetime
  28. from time import mktime
  29. from typing import Annotated, Literal
  30. from urllib.parse import urlencode
  31. from wsgiref.handlers import format_date_time
  32. import httpx
  33. import ormsgpack
  34. import requests
  35. import websocket
  36. from pydantic import BaseModel, conint
  37. from rag.utils import num_tokens_from_string
  38. class ServeReferenceAudio(BaseModel):
  39. audio: bytes
  40. text: str
  41. class ServeTTSRequest(BaseModel):
  42. text: str
  43. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  44. # Audio format
  45. format: Literal["wav", "pcm", "mp3"] = "mp3"
  46. mp3_bitrate: Literal[64, 128, 192] = 128
  47. # References audios for in-context learning
  48. references: list[ServeReferenceAudio] = []
  49. # Reference id
  50. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  51. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  52. reference_id: str | None = None
  53. # Normalize text for en & zh, this increase stability for numbers
  54. normalize: bool = True
  55. # Balance mode will reduce latency to 300ms, but may decrease stability
  56. latency: Literal["normal", "balanced"] = "normal"
  57. class Base(ABC):
  58. def __init__(self, key, model_name, base_url):
  59. pass
  60. def tts(self, audio):
  61. pass
  62. def normalize_text(self, text):
  63. return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
  64. class FishAudioTTS(Base):
  65. def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
  66. if not base_url:
  67. base_url = "https://api.fish.audio/v1/tts"
  68. key = json.loads(key)
  69. self.headers = {
  70. "api-key": key.get("fish_audio_ak"),
  71. "content-type": "application/msgpack",
  72. }
  73. self.ref_id = key.get("fish_audio_refid")
  74. self.base_url = base_url
  75. def tts(self, text):
  76. from http import HTTPStatus
  77. text = self.normalize_text(text)
  78. request = ServeTTSRequest(text=text, reference_id=self.ref_id)
  79. with httpx.Client() as client:
  80. try:
  81. with client.stream(
  82. method="POST",
  83. url=self.base_url,
  84. content=ormsgpack.packb(
  85. request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
  86. ),
  87. headers=self.headers,
  88. timeout=None,
  89. ) as response:
  90. if response.status_code == HTTPStatus.OK:
  91. for chunk in response.iter_bytes():
  92. yield chunk
  93. else:
  94. response.raise_for_status()
  95. yield num_tokens_from_string(text)
  96. except httpx.HTTPStatusError as e:
  97. raise RuntimeError(f"**ERROR**: {e}")
  98. class QwenTTS(Base):
  99. def __init__(self, key, model_name, base_url=""):
  100. import dashscope
  101. self.model_name = model_name
  102. dashscope.api_key = key
  103. def tts(self, text):
  104. from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
  105. from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
  106. from collections import deque
  107. class Callback(ResultCallback):
  108. def __init__(self) -> None:
  109. self.dque = deque()
  110. def _run(self):
  111. while True:
  112. if not self.dque:
  113. time.sleep(0)
  114. continue
  115. val = self.dque.popleft()
  116. if val:
  117. yield val
  118. else:
  119. break
  120. def on_open(self):
  121. pass
  122. def on_complete(self):
  123. self.dque.append(None)
  124. def on_error(self, response: SpeechSynthesisResponse):
  125. raise RuntimeError(str(response))
  126. def on_close(self):
  127. pass
  128. def on_event(self, result: SpeechSynthesisResult):
  129. if result.get_audio_frame() is not None:
  130. self.dque.append(result.get_audio_frame())
  131. text = self.normalize_text(text)
  132. callback = Callback()
  133. SpeechSynthesizer.call(model=self.model_name,
  134. text=text,
  135. callback=callback,
  136. format="mp3")
  137. try:
  138. for data in callback._run():
  139. yield data
  140. yield num_tokens_from_string(text)
  141. except Exception as e:
  142. raise RuntimeError(f"**ERROR**: {e}")
  143. class OpenAITTS(Base):
  144. def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
  145. if not base_url: base_url = "https://api.openai.com/v1"
  146. self.api_key = key
  147. self.model_name = model_name
  148. self.base_url = base_url
  149. self.headers = {
  150. "Authorization": f"Bearer {self.api_key}",
  151. "Content-Type": "application/json"
  152. }
  153. def tts(self, text, voice="alloy"):
  154. text = self.normalize_text(text)
  155. payload = {
  156. "model": self.model_name,
  157. "voice": voice,
  158. "input": text
  159. }
  160. response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
  161. if response.status_code != 200:
  162. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  163. for chunk in response.iter_content():
  164. if chunk:
  165. yield chunk
  166. class SparkTTS:
  167. STATUS_FIRST_FRAME = 0
  168. STATUS_CONTINUE_FRAME = 1
  169. STATUS_LAST_FRAME = 2
  170. def __init__(self, key, model_name, base_url=""):
  171. key = json.loads(key)
  172. self.APPID = key.get("spark_app_id", "xxxxxxx")
  173. self.APISecret = key.get("spark_api_secret", "xxxxxxx")
  174. self.APIKey = key.get("spark_api_key", "xxxxxx")
  175. self.model_name = model_name
  176. self.CommonArgs = {"app_id": self.APPID}
  177. self.audio_queue = queue.Queue()
  178. # 用来存储音频数据
  179. # 生成url
  180. def create_url(self):
  181. url = 'wss://tts-api.xfyun.cn/v2/tts'
  182. now = datetime.now()
  183. date = format_date_time(mktime(now.timetuple()))
  184. signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
  185. signature_origin += "date: " + date + "\n"
  186. signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
  187. signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
  188. digestmod=hashlib.sha256).digest()
  189. signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
  190. authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
  191. self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
  192. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  193. v = {
  194. "authorization": authorization,
  195. "date": date,
  196. "host": "ws-api.xfyun.cn"
  197. }
  198. url = url + '?' + urlencode(v)
  199. return url
  200. def tts(self, text):
  201. BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
  202. Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
  203. CommonArgs = {"app_id": self.APPID}
  204. audio_queue = self.audio_queue
  205. model_name = self.model_name
  206. class Callback:
  207. def __init__(self):
  208. self.audio_queue = audio_queue
  209. def on_message(self, ws, message):
  210. message = json.loads(message)
  211. code = message["code"]
  212. sid = message["sid"]
  213. audio = message["data"]["audio"]
  214. audio = base64.b64decode(audio)
  215. status = message["data"]["status"]
  216. if status == 2:
  217. ws.close()
  218. if code != 0:
  219. errMsg = message["message"]
  220. raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
  221. else:
  222. self.audio_queue.put(audio)
  223. def on_error(self, ws, error):
  224. raise Exception(error)
  225. def on_close(self, ws, close_status_code, close_msg):
  226. self.audio_queue.put(None) # 放入 None 作为结束标志
  227. def on_open(self, ws):
  228. def run(*args):
  229. d = {"common": CommonArgs,
  230. "business": BusinessArgs,
  231. "data": Data}
  232. ws.send(json.dumps(d))
  233. thread.start_new_thread(run, ())
  234. wsUrl = self.create_url()
  235. websocket.enableTrace(False)
  236. a = Callback()
  237. ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
  238. on_message=a.on_message)
  239. status_code = 0
  240. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  241. while True:
  242. audio_chunk = self.audio_queue.get()
  243. if audio_chunk is None:
  244. if status_code == 0:
  245. raise Exception(
  246. f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
  247. else:
  248. break
  249. status_code = 1
  250. yield audio_chunk
  251. class XinferenceTTS:
  252. def __init__(self, key, model_name, **kwargs):
  253. self.base_url = kwargs.get("base_url", None)
  254. self.model_name = model_name
  255. self.headers = {
  256. "accept": "application/json",
  257. "Content-Type": "application/json"
  258. }
  259. def tts(self, text, voice="中文女", stream=True):
  260. payload = {
  261. "model": self.model_name,
  262. "input": text,
  263. "voice": voice
  264. }
  265. response = requests.post(
  266. f"{self.base_url}/v1/audio/speech",
  267. headers=self.headers,
  268. json=payload,
  269. stream=stream
  270. )
  271. if response.status_code != 200:
  272. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  273. for chunk in response.iter_content(chunk_size=1024):
  274. if chunk:
  275. yield chunk