瀏覽代碼

SparkTTS (#2535)

### What problem does this PR solve?

SparkTTS

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
tags/v0.12.0
liuhua 1 年之前
父節點
當前提交
d9c2a128a5
沒有連結到貢獻者的電子郵件帳戶。

+ 4
- 1
api/apps/llm_app.py 查看文件



elif factory =="XunFei Spark": elif factory =="XunFei Spark":
llm_name = req["llm_name"] llm_name = req["llm_name"]
api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
if req["model_type"] == "chat":
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
elif req["model_type"] == "tts":
api_key = apikey_json(["spark_app_id", "spark_api_secret","spark_api_key"])


elif factory == "BaiduYiyan": elif factory == "BaiduYiyan":
llm_name = req["llm_name"] llm_name = req["llm_name"]

+ 2
- 1
rag/llm/__init__.py 查看文件

TTSModel = { TTSModel = {
"Fish Audio": FishAudioTTS, "Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS, "Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS
"OpenAI":OpenAITTS,
"XunFei Spark":SparkTTS
} }

+ 118
- 6
rag/llm/tts_model.py 查看文件

# limitations under the License. # limitations under the License.
# #


import requests
from typing import Annotated, Literal
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import queue
import re
import ssl
import time
from abc import ABC from abc import ABC
from datetime import datetime
from time import mktime
from typing import Annotated, Literal
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import httpx import httpx
import ormsgpack import ormsgpack
import requests
import websocket
from pydantic import BaseModel, conint from pydantic import BaseModel, conint

from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
import json
import re
import time




class ServeReferenceAudio(BaseModel): class ServeReferenceAudio(BaseModel):


class OpenAITTS(Base): class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url: base_url = "https://api.openai.com/v1"
self.api_key = key self.api_key = key
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
for chunk in response.iter_content(): for chunk in response.iter_content():
if chunk: if chunk:
yield chunk yield chunk


class SparkTTS:
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2

def __init__(self, key, model_name, base_url=""):
key = json.loads(key)
self.APPID = key.get("spark_app_id", "xxxxxxx")
self.APISecret = key.get("spark_api_secret", "xxxxxxx")
self.APIKey = key.get("spark_api_key", "xxxxxx")
self.model_name = model_name
self.CommonArgs = {"app_id": self.APPID}
self.audio_queue = queue.Queue()

# 用来存储音频数据

# 生成url
def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts'
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url

def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name

class Callback:
def __init__(self):
self.audio_queue = audio_queue

def on_message(self, ws, message):
message = json.loads(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
status = message["data"]["status"]
if status == 2:
ws.close()
if code != 0:
errMsg = message["message"]
raise Exception(f"sid:{sid} call error:{errMsg} code:{code}")
else:
self.audio_queue.put(audio)

def on_error(self, ws, error):
raise Exception(error)

def on_close(self, ws, close_status_code, close_msg):
self.audio_queue.put(None) # 放入 None 作为结束标志

def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs,
"business": BusinessArgs,
"data": Data}
ws.send(json.dumps(d))

thread.start_new_thread(run, ())

wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk

+ 2
- 0
requirements.txt 查看文件

volcengine==1.0.146 volcengine==1.0.146
voyageai==0.2.3 voyageai==0.2.3
webdriver_manager==4.0.1 webdriver_manager==4.0.1
websocket==0.2.1
websocket-client==1.8.0
Werkzeug==3.0.3 Werkzeug==3.0.3
wikipedia==1.4.0 wikipedia==1.4.0
word2number==1.1 word2number==1.1

+ 6
- 0
web/src/locales/en.ts 查看文件

SparkModelNameMessage: 'Please select Spark model', SparkModelNameMessage: 'Please select Spark model',
addSparkAPIPassword: 'Spark APIPassword', addSparkAPIPassword: 'Spark APIPassword',
SparkAPIPasswordMessage: 'please input your APIPassword', SparkAPIPasswordMessage: 'please input your APIPassword',
addSparkAPPID: 'Spark APPID',
SparkAPPIDMessage: 'please input your APPID',
addSparkAPISecret: 'Spark APISecret',
SparkAPISecretMessage: 'please input your APISecret',
addSparkAPIKey: 'Spark APIKey',
SparkAPIKeyMessage: 'please input your APIKey',
yiyanModelNameMessage: 'Please input model name', yiyanModelNameMessage: 'Please input model name',
addyiyanAK: 'yiyan API KEY', addyiyanAK: 'yiyan API KEY',
yiyanAKMessage: 'Please input your API KEY', yiyanAKMessage: 'Please input your API KEY',

+ 6
- 0
web/src/locales/zh-traditional.ts 查看文件

SparkModelNameMessage: '請選擇星火模型!', SparkModelNameMessage: '請選擇星火模型!',
addSparkAPIPassword: '星火 APIPassword', addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '請輸入 APIPassword', SparkAPIPasswordMessage: '請輸入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '請輸入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '請輸入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '請輸入 APIKey',
yiyanModelNameMessage: '輸入模型名稱', yiyanModelNameMessage: '輸入模型名稱',
addyiyanAK: '一言 API KEY', addyiyanAK: '一言 API KEY',
yiyanAKMessage: '請輸入 API KEY', yiyanAKMessage: '請輸入 API KEY',

+ 6
- 0
web/src/locales/zh.ts 查看文件

SparkModelNameMessage: '请选择星火模型!', SparkModelNameMessage: '请选择星火模型!',
addSparkAPIPassword: '星火 APIPassword', addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '请输入 APIPassword', SparkAPIPasswordMessage: '请输入 APIPassword',
addSparkAPPID: '星火 APPID',
SparkAPPIDMessage: '请输入 APPID',
addSparkAPISecret: '星火 APISecret',
SparkAPISecretMessage: '请输入 APISecret',
addSparkAPIKey: '星火 APIKey',
SparkAPIKeyMessage: '请输入 APIKey',
yiyanModelNameMessage: '请输入模型名称', yiyanModelNameMessage: '请输入模型名称',
addyiyanAK: '一言 API KEY', addyiyanAK: '一言 API KEY',
yiyanAKMessage: '请输入 API KEY', yiyanAKMessage: '请输入 API KEY',

+ 56
- 14
web/src/pages/user-setting/setting-model/spark-modal/index.tsx 查看文件

type FieldType = IAddLlmRequestBody & { type FieldType = IAddLlmRequestBody & {
vision: boolean; vision: boolean;
spark_api_password: string; spark_api_password: string;
spark_app_id: string;
spark_api_secret: string;
spark_api_key: string;
}; };


const { Option } = Select; const { Option } = Select;
> >
<Select placeholder={t('modelTypeMessage')}> <Select placeholder={t('modelTypeMessage')}>
<Option value="chat">chat</Option> <Option value="chat">chat</Option>
<Option value="tts">tts</Option>
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item<FieldType> <Form.Item<FieldType>
label={t('modelName')} label={t('modelName')}
name="llm_name" name="llm_name"
initialValue={'Spark-Max'}
rules={[{ required: true, message: t('SparkModelNameMessage') }]} rules={[{ required: true, message: t('SparkModelNameMessage') }]}
> >
<Select placeholder={t('modelTypeMessage')}>
<Option value="Spark-Max">Spark-Max</Option>
<Option value="Spark-Lite">Spark-Lite</Option>
<Option value="Spark-Pro">Spark-Pro</Option>
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
</Select>
<Input placeholder={t('modelNameMessage')} />
</Form.Item> </Form.Item>
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
>
<Input placeholder={t('SparkAPIPasswordMessage')} />
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
rules={[{ required: true, message: t('SparkAPIPasswordMessage') }]}
>
<Input placeholder={t('SparkAPIPasswordMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPPID')}
name="spark_app_id"
rules={[{ required: true, message: t('SparkAPPIDMessage') }]}
>
<Input placeholder={t('SparkAPPIDMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPISecret')}
name="spark_api_secret"
rules={[{ required: true, message: t('SparkAPISecretMessage') }]}
>
<Input placeholder={t('SparkAPISecretMessage')} />
</Form.Item>
)
}
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'tts' && (
<Form.Item<FieldType>
label={t('addSparkAPIKey')}
name="spark_api_key"
rules={[{ required: true, message: t('SparkAPIKeyMessage') }]}
>
<Input placeholder={t('SparkAPIKeyMessage')} />
</Form.Item>
)
}
</Form.Item> </Form.Item>
</Form> </Form>
</Modal> </Modal>

Loading…
取消
儲存