You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import _thread as thread
  17. import base64
  18. import hashlib
  19. import hmac
  20. import json
  21. import queue
  22. import re
  23. import ssl
  24. import time
  25. from abc import ABC
  26. from datetime import datetime
  27. from time import mktime
  28. from typing import Annotated, Literal
  29. from urllib.parse import urlencode
  30. from wsgiref.handlers import format_date_time
  31. import httpx
  32. import ormsgpack
  33. import requests
  34. import websocket
  35. from pydantic import BaseModel, conint
  36. from rag.utils import num_tokens_from_string
  37. class ServeReferenceAudio(BaseModel):
  38. audio: bytes
  39. text: str
  40. class ServeTTSRequest(BaseModel):
  41. text: str
  42. chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
  43. # Audio format
  44. format: Literal["wav", "pcm", "mp3"] = "mp3"
  45. mp3_bitrate: Literal[64, 128, 192] = 128
  46. # References audios for in-context learning
  47. references: list[ServeReferenceAudio] = []
  48. # Reference id
  49. # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
  50. # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
  51. reference_id: str | None = None
  52. # Normalize text for en & zh, this increase stability for numbers
  53. normalize: bool = True
  54. # Balance mode will reduce latency to 300ms, but may decrease stability
  55. latency: Literal["normal", "balanced"] = "normal"
  56. class Base(ABC):
  57. def __init__(self, key, model_name, base_url):
  58. pass
  59. def tts(self, audio):
  60. pass
  61. def normalize_text(self, text):
  62. return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
  63. class FishAudioTTS(Base):
  64. def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
  65. if not base_url:
  66. base_url = "https://api.fish.audio/v1/tts"
  67. key = json.loads(key)
  68. self.headers = {
  69. "api-key": key.get("fish_audio_ak"),
  70. "content-type": "application/msgpack",
  71. }
  72. self.ref_id = key.get("fish_audio_refid")
  73. self.base_url = base_url
  74. def tts(self, text):
  75. from http import HTTPStatus
  76. text = self.normalize_text(text)
  77. request = ServeTTSRequest(text=text, reference_id=self.ref_id)
  78. with httpx.Client() as client:
  79. try:
  80. with client.stream(
  81. method="POST",
  82. url=self.base_url,
  83. content=ormsgpack.packb(
  84. request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
  85. ),
  86. headers=self.headers,
  87. timeout=None,
  88. ) as response:
  89. if response.status_code == HTTPStatus.OK:
  90. for chunk in response.iter_bytes():
  91. yield chunk
  92. else:
  93. response.raise_for_status()
  94. yield num_tokens_from_string(text)
  95. except httpx.HTTPStatusError as e:
  96. raise RuntimeError(f"**ERROR**: {e}")
  97. class QwenTTS(Base):
  98. def __init__(self, key, model_name, base_url=""):
  99. import dashscope
  100. self.model_name = model_name
  101. dashscope.api_key = key
  102. def tts(self, text):
  103. from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
  104. from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
  105. from collections import deque
  106. class Callback(ResultCallback):
  107. def __init__(self) -> None:
  108. self.dque = deque()
  109. def _run(self):
  110. while True:
  111. if not self.dque:
  112. time.sleep(0)
  113. continue
  114. val = self.dque.popleft()
  115. if val:
  116. yield val
  117. else:
  118. break
  119. def on_open(self):
  120. pass
  121. def on_complete(self):
  122. self.dque.append(None)
  123. def on_error(self, response: SpeechSynthesisResponse):
  124. raise RuntimeError(str(response))
  125. def on_close(self):
  126. pass
  127. def on_event(self, result: SpeechSynthesisResult):
  128. if result.get_audio_frame() is not None:
  129. self.dque.append(result.get_audio_frame())
  130. text = self.normalize_text(text)
  131. callback = Callback()
  132. SpeechSynthesizer.call(model=self.model_name,
  133. text=text,
  134. callback=callback,
  135. format="mp3")
  136. try:
  137. for data in callback._run():
  138. yield data
  139. yield num_tokens_from_string(text)
  140. except Exception as e:
  141. raise RuntimeError(f"**ERROR**: {e}")
  142. class OpenAITTS(Base):
  143. def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
  144. if not base_url:
  145. base_url = "https://api.openai.com/v1"
  146. self.api_key = key
  147. self.model_name = model_name
  148. self.base_url = base_url
  149. self.headers = {
  150. "Authorization": f"Bearer {self.api_key}",
  151. "Content-Type": "application/json"
  152. }
  153. def tts(self, text, voice="alloy"):
  154. text = self.normalize_text(text)
  155. payload = {
  156. "model": self.model_name,
  157. "voice": voice,
  158. "input": text
  159. }
  160. response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
  161. if response.status_code != 200:
  162. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  163. for chunk in response.iter_content():
  164. if chunk:
  165. yield chunk
  166. class SparkTTS:
  167. STATUS_FIRST_FRAME = 0
  168. STATUS_CONTINUE_FRAME = 1
  169. STATUS_LAST_FRAME = 2
  170. def __init__(self, key, model_name, base_url=""):
  171. key = json.loads(key)
  172. self.APPID = key.get("spark_app_id", "xxxxxxx")
  173. self.APISecret = key.get("spark_api_secret", "xxxxxxx")
  174. self.APIKey = key.get("spark_api_key", "xxxxxx")
  175. self.model_name = model_name
  176. self.CommonArgs = {"app_id": self.APPID}
  177. self.audio_queue = queue.Queue()
  178. # 用来存储音频数据
  179. # 生成url
  180. def create_url(self):
  181. url = 'wss://tts-api.xfyun.cn/v2/tts'
  182. now = datetime.now()
  183. date = format_date_time(mktime(now.timetuple()))
  184. signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
  185. signature_origin += "date: " + date + "\n"
  186. signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
  187. signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
  188. digestmod=hashlib.sha256).digest()
  189. signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
  190. authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
  191. self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
  192. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  193. v = {
  194. "authorization": authorization,
  195. "date": date,
  196. "host": "ws-api.xfyun.cn"
  197. }
  198. url = url + '?' + urlencode(v)
  199. return url
  200. def tts(self, text):
  201. BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
  202. Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
  203. CommonArgs = {"app_id": self.APPID}
  204. audio_queue = self.audio_queue
  205. model_name = self.model_name
  206. class Callback:
  207. def __init__(self):
  208. self.audio_queue = audio_queue
  209. def on_message(self, ws, message):
  210. message = json.loads(message)
  211. code = message["code"]
  212. sid = message["sid"]
  213. audio = message["data"]["audio"]
  214. audio = base64.b64decode(audio)
  215. status = message["data"]["status"]
  216. if status == 2:
  217. ws.close()
  218. if code != 0:
  219. errMsg = message["message"]
  220. raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
  221. else:
  222. self.audio_queue.put(audio)
  223. def on_error(self, ws, error):
  224. raise Exception(error)
  225. def on_close(self, ws, close_status_code, close_msg):
  226. self.audio_queue.put(None) # 放入 None 作为结束标志
  227. def on_open(self, ws):
  228. def run(*args):
  229. d = {"common": CommonArgs,
  230. "business": BusinessArgs,
  231. "data": Data}
  232. ws.send(json.dumps(d))
  233. thread.start_new_thread(run, ())
  234. wsUrl = self.create_url()
  235. websocket.enableTrace(False)
  236. a = Callback()
  237. ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
  238. on_message=a.on_message)
  239. status_code = 0
  240. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  241. while True:
  242. audio_chunk = self.audio_queue.get()
  243. if audio_chunk is None:
  244. if status_code == 0:
  245. raise Exception(
  246. f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
  247. else:
  248. break
  249. status_code = 1
  250. yield audio_chunk
  251. class XinferenceTTS:
  252. def __init__(self, key, model_name, **kwargs):
  253. self.base_url = kwargs.get("base_url", None)
  254. self.model_name = model_name
  255. self.headers = {
  256. "accept": "application/json",
  257. "Content-Type": "application/json"
  258. }
  259. def tts(self, text, voice="中文女", stream=True):
  260. payload = {
  261. "model": self.model_name,
  262. "input": text,
  263. "voice": voice
  264. }
  265. response = requests.post(
  266. f"{self.base_url}/v1/audio/speech",
  267. headers=self.headers,
  268. json=payload,
  269. stream=stream
  270. )
  271. if response.status_code != 200:
  272. raise Exception(f"**Error**: {response.status_code}, {response.text}")
  273. for chunk in response.iter_content(chunk_size=1024):
  274. if chunk:
  275. yield chunk