### What problem does this PR solve? SparkTTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>tags/v0.12.0
| elif factory =="XunFei Spark": | elif factory =="XunFei Spark": | ||||
| llm_name = req["llm_name"] | llm_name = req["llm_name"] | ||||
| api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") | |||||
| if req["model_type"] == "chat": | |||||
| api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx") | |||||
| elif req["model_type"] == "tts": | |||||
| api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"]) | |||||
| elif factory == "BaiduYiyan": | elif factory == "BaiduYiyan": | ||||
| llm_name = req["llm_name"] | llm_name = req["llm_name"] |
| TTSModel = { | TTSModel = { | ||||
| "Fish Audio": FishAudioTTS, | "Fish Audio": FishAudioTTS, | ||||
| "Tongyi-Qianwen": QwenTTS, | "Tongyi-Qianwen": QwenTTS, | ||||
| "OpenAI":OpenAITTS | |||||
| "OpenAI":OpenAITTS, | |||||
| "XunFei Spark":SparkTTS | |||||
| } | } |
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import requests | |||||
| from typing import Annotated, Literal | |||||
| import _thread as thread | |||||
| import base64 | |||||
| import datetime | |||||
| import hashlib | |||||
| import hmac | |||||
| import json | |||||
| import queue | |||||
| import re | |||||
| import ssl | |||||
| import time | |||||
| from abc import ABC | from abc import ABC | ||||
| from datetime import datetime | |||||
| from time import mktime | |||||
| from typing import Annotated, Literal | |||||
| from urllib.parse import urlencode | |||||
| from wsgiref.handlers import format_date_time | |||||
| import httpx | import httpx | ||||
| import ormsgpack | import ormsgpack | ||||
| import requests | |||||
| import websocket | |||||
| from pydantic import BaseModel, conint | from pydantic import BaseModel, conint | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| import json | |||||
| import re | |||||
| import time | |||||
| class ServeReferenceAudio(BaseModel): | class ServeReferenceAudio(BaseModel): | ||||
| class OpenAITTS(Base): | class OpenAITTS(Base): | ||||
| def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): | def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): | ||||
| if not base_url: base_url="https://api.openai.com/v1" | |||||
| if not base_url: base_url = "https://api.openai.com/v1" | |||||
| self.api_key = key | self.api_key = key | ||||
| self.model_name = model_name | self.model_name = model_name | ||||
| self.base_url = base_url | self.base_url = base_url | ||||
| for chunk in response.iter_content(): | for chunk in response.iter_content(): | ||||
| if chunk: | if chunk: | ||||
| yield chunk | yield chunk | ||||
| class SparkTTS: | |||||
| STATUS_FIRST_FRAME = 0 | |||||
| STATUS_CONTINUE_FRAME = 1 | |||||
| STATUS_LAST_FRAME = 2 | |||||
| def __init__(self, key, model_name, base_url=""): | |||||
| key = json.loads(key) | |||||
| self.APPID = key.get("spark_app_id", "xxxxxxx") | |||||
| self.APISecret = key.get("spark_api_secret", "xxxxxxx") | |||||
| self.APIKey = key.get("spark_api_key", "xxxxxx") | |||||
| self.model_name = model_name | |||||
| self.CommonArgs = {"app_id": self.APPID} | |||||
| self.audio_queue = queue.Queue() | |||||
| # 用来存储音频数据 | |||||
| # 生成url | |||||
| def create_url(self): | |||||
| url = 'wss://tts-api.xfyun.cn/v2/tts' | |||||
| now = datetime.now() | |||||
| date = format_date_time(mktime(now.timetuple())) | |||||
| signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" | |||||
| signature_origin += "date: " + date + "\n" | |||||
| signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" | |||||
| signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), | |||||
| digestmod=hashlib.sha256).digest() | |||||
| signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') | |||||
| authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( | |||||
| self.APIKey, "hmac-sha256", "host date request-line", signature_sha) | |||||
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | |||||
| v = { | |||||
| "authorization": authorization, | |||||
| "date": date, | |||||
| "host": "ws-api.xfyun.cn" | |||||
| } | |||||
| url = url + '?' + urlencode(v) | |||||
| return url | |||||
| def tts(self, text): | |||||
| BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} | |||||
| Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} | |||||
| CommonArgs = {"app_id": self.APPID} | |||||
| audio_queue = self.audio_queue | |||||
| model_name = self.model_name | |||||
| class Callback: | |||||
| def __init__(self): | |||||
| self.audio_queue = audio_queue | |||||
| def on_message(self, ws, message): | |||||
| message = json.loads(message) | |||||
| code = message["code"] | |||||
| sid = message["sid"] | |||||
| audio = message["data"]["audio"] | |||||
| audio = base64.b64decode(audio) | |||||
| status = message["data"]["status"] | |||||
| if status == 2: | |||||
| ws.close() | |||||
| if code != 0: | |||||
| errMsg = message["message"] | |||||
| raise Exception(f"sid:{sid} call error:{errMsg} code:{code}") | |||||
| else: | |||||
| self.audio_queue.put(audio) | |||||
| def on_error(self, ws, error): | |||||
| raise Exception(error) | |||||
| def on_close(self, ws, close_status_code, close_msg): | |||||
| self.audio_queue.put(None) # 放入 None 作为结束标志 | |||||
| def on_open(self, ws): | |||||
| def run(*args): | |||||
| d = {"common": CommonArgs, | |||||
| "business": BusinessArgs, | |||||
| "data": Data} | |||||
| ws.send(json.dumps(d)) | |||||
| thread.start_new_thread(run, ()) | |||||
| wsUrl = self.create_url() | |||||
| websocket.enableTrace(False) | |||||
| a = Callback() | |||||
| ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, | |||||
| on_message=a.on_message) | |||||
| status_code = 0 | |||||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | |||||
| while True: | |||||
| audio_chunk = self.audio_queue.get() | |||||
| if audio_chunk is None: | |||||
| if status_code == 0: | |||||
| raise Exception( | |||||
| f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") | |||||
| else: | |||||
| break | |||||
| status_code = 1 | |||||
| yield audio_chunk |
| volcengine==1.0.146 | volcengine==1.0.146 | ||||
| voyageai==0.2.3 | voyageai==0.2.3 | ||||
| webdriver_manager==4.0.1 | webdriver_manager==4.0.1 | ||||
| websocket==0.2.1 | |||||
| websocket-client==1.8.0 | |||||
| Werkzeug==3.0.3 | Werkzeug==3.0.3 | ||||
| wikipedia==1.4.0 | wikipedia==1.4.0 | ||||
| word2number==1.1 | word2number==1.1 |
| SparkModelNameMessage: 'Please select Spark model', | SparkModelNameMessage: 'Please select Spark model', | ||||
| addSparkAPIPassword: 'Spark APIPassword', | addSparkAPIPassword: 'Spark APIPassword', | ||||
| SparkAPIPasswordMessage: 'please input your APIPassword', | SparkAPIPasswordMessage: 'please input your APIPassword', | ||||
| addSparkAPPID: 'Spark APPID', | |||||
| SparkAPPIDMessage: 'please input your APPID', | |||||
| addSparkAPISecret: 'Spark APISecret', | |||||
| SparkAPISecretMessage: 'please input your APISecret', | |||||
| addSparkAPIKey: 'Spark APIKey', | |||||
| SparkAPIKeyMessage: 'please input your APIKey', | |||||
| yiyanModelNameMessage: 'Please input model name', | yiyanModelNameMessage: 'Please input model name', | ||||
| addyiyanAK: 'yiyan API KEY', | addyiyanAK: 'yiyan API KEY', | ||||
| yiyanAKMessage: 'Please input your API KEY', | yiyanAKMessage: 'Please input your API KEY', |
| SparkModelNameMessage: '請選擇星火模型!', | SparkModelNameMessage: '請選擇星火模型!', | ||||
| addSparkAPIPassword: '星火 APIPassword', | addSparkAPIPassword: '星火 APIPassword', | ||||
| SparkAPIPasswordMessage: '請輸入 APIPassword', | SparkAPIPasswordMessage: '請輸入 APIPassword', | ||||
| addSparkAPPID: '星火 APPID', | |||||
| SparkAPPIDMessage: '請輸入 APPID', | |||||
| addSparkAPISecret: '星火 APISecret', | |||||
| SparkAPISecretMessage: '請輸入 APISecret', | |||||
| addSparkAPIKey: '星火 APIKey', | |||||
| SparkAPIKeyMessage: '請輸入 APIKey', | |||||
| yiyanModelNameMessage: '輸入模型名稱', | yiyanModelNameMessage: '輸入模型名稱', | ||||
| addyiyanAK: '一言 API KEY', | addyiyanAK: '一言 API KEY', | ||||
| yiyanAKMessage: '請輸入 API KEY', | yiyanAKMessage: '請輸入 API KEY', |
| SparkModelNameMessage: '请选择星火模型!', | SparkModelNameMessage: '请选择星火模型!', | ||||
| addSparkAPIPassword: '星火 APIPassword', | addSparkAPIPassword: '星火 APIPassword', | ||||
| SparkAPIPasswordMessage: '请输入 APIPassword', | SparkAPIPasswordMessage: '请输入 APIPassword', | ||||
| addSparkAPPID: '星火 APPID', | |||||
| SparkAPPIDMessage: '请输入 APPID', | |||||
| addSparkAPISecret: '星火 APISecret', | |||||
| SparkAPISecretMessage: '请输入 APISecret', | |||||
| addSparkAPIKey: '星火 APIKey', | |||||
| SparkAPIKeyMessage: '请输入 APIKey', | |||||
| yiyanModelNameMessage: '请输入模型名称', | yiyanModelNameMessage: '请输入模型名称', | ||||
| addyiyanAK: '一言 API KEY', | addyiyanAK: '一言 API KEY', | ||||
| yiyanAKMessage: '请输入 API KEY', | yiyanAKMessage: '请输入 API KEY', |
| type FieldType = IAddLlmRequestBody & { | type FieldType = IAddLlmRequestBody & { | ||||
| vision: boolean; | vision: boolean; | ||||
| spark_api_password: string; | spark_api_password: string; | ||||
| spark_app_id: string; | |||||
| spark_api_secret: string; | |||||
| spark_api_key: string; | |||||
| }; | }; | ||||
| const { Option } = Select; | const { Option } = Select; | ||||
| > | > | ||||
| <Select placeholder={t('modelTypeMessage')}> | <Select placeholder={t('modelTypeMessage')}> | ||||
| <Option value="chat">chat</Option> | <Option value="chat">chat</Option> | ||||
| <Option value="tts">tts</Option> | |||||
| </Select> | </Select> | ||||
| </Form.Item> | </Form.Item> | ||||
| <Form.Item<FieldType> | <Form.Item<FieldType> | ||||
| label={t('modelName')} | label={t('modelName')} | ||||
| name="llm_name" | name="llm_name" | ||||
| initialValue={'Spark-Max'} | |||||
| rules={[{ required: true, message: t('SparkModelNameMessage') }]} | rules={[{ required: true, message: t('SparkModelNameMessage') }]} | ||||
| > | > | ||||
| <Select placeholder={t('modelTypeMessage')}> | |||||
| <Option value="Spark-Max">Spark-Max</Option> | |||||
| <Option value="Spark-Lite">Spark-Lite</Option> | |||||
| <Option value="Spark-Pro">Spark-Pro</Option> | |||||
| <Option value="Spark-Pro-128K">Spark-Pro-128K</Option> | |||||
| <Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option> | |||||
| </Select> | |||||
| <Input placeholder={t('modelNameMessage')} /> | |||||
| </Form.Item> | </Form.Item> | ||||
| <Form.Item<FieldType> | |||||
| label={t('addSparkAPIPassword')} | |||||
| name="spark_api_password" | |||||
| rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} | |||||
| > | |||||
| <Input placeholder={t('SparkAPIPasswordMessage')} /> | |||||
| <Form.Item noStyle dependencies={['model_type']}> | |||||
| {({ getFieldValue }) => | |||||
| getFieldValue('model_type') === 'chat' && ( | |||||
| <Form.Item<FieldType> | |||||
| label={t('addSparkAPIPassword')} | |||||
| name="spark_api_password" | |||||
| rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} | |||||
| > | |||||
| <Input placeholder={t('SparkAPIPasswordMessage')} /> | |||||
| </Form.Item> | |||||
| ) | |||||
| } | |||||
| </Form.Item> | |||||
| <Form.Item noStyle dependencies={['model_type']}> | |||||
| {({ getFieldValue }) => | |||||
| getFieldValue('model_type') === 'tts' && ( | |||||
| <Form.Item<FieldType> | |||||
| label={t('addSparkAPPID')} | |||||
| name="spark_app_id" | |||||
| rules={[{ required: true, message: t('SparkAPPIDMessage') }]} | |||||
| > | |||||
| <Input placeholder={t('SparkAPPIDMessage')} /> | |||||
| </Form.Item> | |||||
| ) | |||||
| } | |||||
| </Form.Item> | |||||
| <Form.Item noStyle dependencies={['model_type']}> | |||||
| {({ getFieldValue }) => | |||||
| getFieldValue('model_type') === 'tts' && ( | |||||
| <Form.Item<FieldType> | |||||
| label={t('addSparkAPISecret')} | |||||
| name="spark_api_secret" | |||||
| rules={[{ required: true, message: t('SparkAPISecretMessage') }]} | |||||
| > | |||||
| <Input placeholder={t('SparkAPISecretMessage')} /> | |||||
| </Form.Item> | |||||
| ) | |||||
| } | |||||
| </Form.Item> | |||||
| <Form.Item noStyle dependencies={['model_type']}> | |||||
| {({ getFieldValue }) => | |||||
| getFieldValue('model_type') === 'tts' && ( | |||||
| <Form.Item<FieldType> | |||||
| label={t('addSparkAPIKey')} | |||||
| name="spark_api_key" | |||||
| rules={[{ required: true, message: t('SparkAPIKeyMessage') }]} | |||||
| > | |||||
| <Input placeholder={t('SparkAPIKeyMessage')} /> | |||||
| </Form.Item> | |||||
| ) | |||||
| } | |||||
| </Form.Item> | </Form.Item> | ||||
| </Form> | </Form> | ||||
| </Modal> | </Modal> |