### What problem does this PR solve? SparkTTS ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>tags/v0.12.0
| @@ -161,7 +161,10 @@ def add_llm(): | |||
| elif factory =="XunFei Spark": | |||
| llm_name = req["llm_name"] | |||
| api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx") | |||
| if req["model_type"] == "chat": | |||
| api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx") | |||
| elif req["model_type"] == "tts": | |||
| api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"]) | |||
| elif factory == "BaiduYiyan": | |||
| llm_name = req["llm_name"] | |||
| @@ -139,5 +139,6 @@ Seq2txtModel = { | |||
| TTSModel = { | |||
| "Fish Audio": FishAudioTTS, | |||
| "Tongyi-Qianwen": QwenTTS, | |||
| "OpenAI":OpenAITTS | |||
| "OpenAI":OpenAITTS, | |||
| "XunFei Spark":SparkTTS | |||
| } | |||
| @@ -14,16 +14,30 @@ | |||
| # limitations under the License. | |||
| # | |||
| import requests | |||
| from typing import Annotated, Literal | |||
| import _thread as thread | |||
| import base64 | |||
| import datetime | |||
| import hashlib | |||
| import hmac | |||
| import json | |||
| import queue | |||
| import re | |||
| import ssl | |||
| import time | |||
| from abc import ABC | |||
| from datetime import datetime | |||
| from time import mktime | |||
| from typing import Annotated, Literal | |||
| from urllib.parse import urlencode | |||
| from wsgiref.handlers import format_date_time | |||
| import httpx | |||
| import ormsgpack | |||
| import requests | |||
| import websocket | |||
| from pydantic import BaseModel, conint | |||
| from rag.utils import num_tokens_from_string | |||
| import json | |||
| import re | |||
| import time | |||
| class ServeReferenceAudio(BaseModel): | |||
| @@ -161,7 +175,7 @@ class QwenTTS(Base): | |||
| class OpenAITTS(Base): | |||
| def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): | |||
| if not base_url: base_url="https://api.openai.com/v1" | |||
| if not base_url: base_url = "https://api.openai.com/v1" | |||
| self.api_key = key | |||
| self.model_name = model_name | |||
| self.base_url = base_url | |||
| @@ -185,3 +199,101 @@ class OpenAITTS(Base): | |||
| for chunk in response.iter_content(): | |||
| if chunk: | |||
| yield chunk | |||
| class SparkTTS: | |||
| STATUS_FIRST_FRAME = 0 | |||
| STATUS_CONTINUE_FRAME = 1 | |||
| STATUS_LAST_FRAME = 2 | |||
| def __init__(self, key, model_name, base_url=""): | |||
| key = json.loads(key) | |||
| self.APPID = key.get("spark_app_id", "xxxxxxx") | |||
| self.APISecret = key.get("spark_api_secret", "xxxxxxx") | |||
| self.APIKey = key.get("spark_api_key", "xxxxxx") | |||
| self.model_name = model_name | |||
| self.CommonArgs = {"app_id": self.APPID} | |||
| self.audio_queue = queue.Queue() | |||
| # 用来存储音频数据 | |||
| # 生成url | |||
| def create_url(self): | |||
| url = 'wss://tts-api.xfyun.cn/v2/tts' | |||
| now = datetime.now() | |||
| date = format_date_time(mktime(now.timetuple())) | |||
| signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" | |||
| signature_origin += "date: " + date + "\n" | |||
| signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" | |||
| signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), | |||
| digestmod=hashlib.sha256).digest() | |||
| signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') | |||
| authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( | |||
| self.APIKey, "hmac-sha256", "host date request-line", signature_sha) | |||
| authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') | |||
| v = { | |||
| "authorization": authorization, | |||
| "date": date, | |||
| "host": "ws-api.xfyun.cn" | |||
| } | |||
| url = url + '?' + urlencode(v) | |||
| return url | |||
| def tts(self, text): | |||
| BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} | |||
| Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} | |||
| CommonArgs = {"app_id": self.APPID} | |||
| audio_queue = self.audio_queue | |||
| model_name = self.model_name | |||
| class Callback: | |||
| def __init__(self): | |||
| self.audio_queue = audio_queue | |||
| def on_message(self, ws, message): | |||
| message = json.loads(message) | |||
| code = message["code"] | |||
| sid = message["sid"] | |||
| audio = message["data"]["audio"] | |||
| audio = base64.b64decode(audio) | |||
| status = message["data"]["status"] | |||
| if status == 2: | |||
| ws.close() | |||
| if code != 0: | |||
| errMsg = message["message"] | |||
| raise Exception(f"sid:{sid} call error:{errMsg} code:{code}") | |||
| else: | |||
| self.audio_queue.put(audio) | |||
| def on_error(self, ws, error): | |||
| raise Exception(error) | |||
| def on_close(self, ws, close_status_code, close_msg): | |||
| self.audio_queue.put(None) # 放入 None 作为结束标志 | |||
| def on_open(self, ws): | |||
| def run(*args): | |||
| d = {"common": CommonArgs, | |||
| "business": BusinessArgs, | |||
| "data": Data} | |||
| ws.send(json.dumps(d)) | |||
| thread.start_new_thread(run, ()) | |||
| wsUrl = self.create_url() | |||
| websocket.enableTrace(False) | |||
| a = Callback() | |||
| ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, | |||
| on_message=a.on_message) | |||
| status_code = 0 | |||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | |||
| while True: | |||
| audio_chunk = self.audio_queue.get() | |||
| if audio_chunk is None: | |||
| if status_code == 0: | |||
| raise Exception( | |||
| f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") | |||
| else: | |||
| break | |||
| status_code = 1 | |||
| yield audio_chunk | |||
| @@ -94,6 +94,8 @@ vertexai==1.64.0 | |||
| volcengine==1.0.146 | |||
| voyageai==0.2.3 | |||
| webdriver_manager==4.0.1 | |||
| websocket==0.2.1 | |||
| websocket-client==1.8.0 | |||
| Werkzeug==3.0.3 | |||
| wikipedia==1.4.0 | |||
| word2number==1.1 | |||
| @@ -551,6 +551,12 @@ The above is the content you need to summarize.`, | |||
| SparkModelNameMessage: 'Please select Spark model', | |||
| addSparkAPIPassword: 'Spark APIPassword', | |||
| SparkAPIPasswordMessage: 'please input your APIPassword', | |||
| addSparkAPPID: 'Spark APPID', | |||
| SparkAPPIDMessage: 'please input your APPID', | |||
| addSparkAPISecret: 'Spark APISecret', | |||
| SparkAPISecretMessage: 'please input your APISecret', | |||
| addSparkAPIKey: 'Spark APIKey', | |||
| SparkAPIKeyMessage: 'please input your APIKey', | |||
| yiyanModelNameMessage: 'Please input model name', | |||
| addyiyanAK: 'yiyan API KEY', | |||
| yiyanAKMessage: 'Please input your API KEY', | |||
| @@ -512,6 +512,12 @@ export default { | |||
| SparkModelNameMessage: '請選擇星火模型!', | |||
| addSparkAPIPassword: '星火 APIPassword', | |||
| SparkAPIPasswordMessage: '請輸入 APIPassword', | |||
| addSparkAPPID: '星火 APPID', | |||
| SparkAPPIDMessage: '請輸入 APPID', | |||
| addSparkAPISecret: '星火 APISecret', | |||
| SparkAPISecretMessage: '請輸入 APISecret', | |||
| addSparkAPIKey: '星火 APIKey', | |||
| SparkAPIKeyMessage: '請輸入 APIKey', | |||
| yiyanModelNameMessage: '輸入模型名稱', | |||
| addyiyanAK: '一言 API KEY', | |||
| yiyanAKMessage: '請輸入 API KEY', | |||
| @@ -529,6 +529,12 @@ export default { | |||
| SparkModelNameMessage: '请选择星火模型!', | |||
| addSparkAPIPassword: '星火 APIPassword', | |||
| SparkAPIPasswordMessage: '请输入 APIPassword', | |||
| addSparkAPPID: '星火 APPID', | |||
| SparkAPPIDMessage: '请输入 APPID', | |||
| addSparkAPISecret: '星火 APISecret', | |||
| SparkAPISecretMessage: '请输入 APISecret', | |||
| addSparkAPIKey: '星火 APIKey', | |||
| SparkAPIKeyMessage: '请输入 APIKey', | |||
| yiyanModelNameMessage: '请输入模型名称', | |||
| addyiyanAK: '一言 API KEY', | |||
| yiyanAKMessage: '请输入 API KEY', | |||
| @@ -7,6 +7,9 @@ import omit from 'lodash/omit'; | |||
| type FieldType = IAddLlmRequestBody & { | |||
| vision: boolean; | |||
| spark_api_password: string; | |||
| spark_app_id: string; | |||
| spark_api_secret: string; | |||
| spark_api_key: string; | |||
| }; | |||
| const { Option } = Select; | |||
| @@ -63,28 +66,67 @@ const SparkModal = ({ | |||
| > | |||
| <Select placeholder={t('modelTypeMessage')}> | |||
| <Option value="chat">chat</Option> | |||
| <Option value="tts">tts</Option> | |||
| </Select> | |||
| </Form.Item> | |||
| <Form.Item<FieldType> | |||
| label={t('modelName')} | |||
| name="llm_name" | |||
| initialValue={'Spark-Max'} | |||
| rules={[{ required: true, message: t('SparkModelNameMessage') }]} | |||
| > | |||
| <Select placeholder={t('modelTypeMessage')}> | |||
| <Option value="Spark-Max">Spark-Max</Option> | |||
| <Option value="Spark-Lite">Spark-Lite</Option> | |||
| <Option value="Spark-Pro">Spark-Pro</Option> | |||
| <Option value="Spark-Pro-128K">Spark-Pro-128K</Option> | |||
| <Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option> | |||
| </Select> | |||
| <Input placeholder={t('modelNameMessage')} /> | |||
| </Form.Item> | |||
| <Form.Item<FieldType> | |||
| label={t('addSparkAPIPassword')} | |||
| name="spark_api_password" | |||
| rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} | |||
| > | |||
| <Input placeholder={t('SparkAPIPasswordMessage')} /> | |||
| <Form.Item noStyle dependencies={['model_type']}> | |||
| {({ getFieldValue }) => | |||
| getFieldValue('model_type') === 'chat' && ( | |||
| <Form.Item<FieldType> | |||
| label={t('addSparkAPIPassword')} | |||
| name="spark_api_password" | |||
| rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]} | |||
| > | |||
| <Input placeholder={t('SparkAPIPasswordMessage')} /> | |||
| </Form.Item> | |||
| ) | |||
| } | |||
| </Form.Item> | |||
| <Form.Item noStyle dependencies={['model_type']}> | |||
| {({ getFieldValue }) => | |||
| getFieldValue('model_type') === 'tts' && ( | |||
| <Form.Item<FieldType> | |||
| label={t('addSparkAPPID')} | |||
| name="spark_app_id" | |||
| rules={[{ required: true, message: t('SparkAPPIDMessage') }]} | |||
| > | |||
| <Input placeholder={t('SparkAPPIDMessage')} /> | |||
| </Form.Item> | |||
| ) | |||
| } | |||
| </Form.Item> | |||
| <Form.Item noStyle dependencies={['model_type']}> | |||
| {({ getFieldValue }) => | |||
| getFieldValue('model_type') === 'tts' && ( | |||
| <Form.Item<FieldType> | |||
| label={t('addSparkAPISecret')} | |||
| name="spark_api_secret" | |||
| rules={[{ required: true, message: t('SparkAPISecretMessage') }]} | |||
| > | |||
| <Input placeholder={t('SparkAPISecretMessage')} /> | |||
| </Form.Item> | |||
| ) | |||
| } | |||
| </Form.Item> | |||
| <Form.Item noStyle dependencies={['model_type']}> | |||
| {({ getFieldValue }) => | |||
| getFieldValue('model_type') === 'tts' && ( | |||
| <Form.Item<FieldType> | |||
| label={t('addSparkAPIKey')} | |||
| name="spark_api_key" | |||
| rules={[{ required: true, message: t('SparkAPIKeyMessage') }]} | |||
| > | |||
| <Input placeholder={t('SparkAPIKeyMessage')} /> | |||
| </Form.Item> | |||
| ) | |||
| } | |||
| </Form.Item> | |||
| </Form> | |||
| </Modal> | |||