Browse Source

feat:use xinference tts stream mode (#8616)

tags/0.9.0
呆萌闷油瓶 1 year ago
parent
commit
c8b9bdebfe
No account linked to committer's email address

+ 1
- 2
api/core/model_runtime/model_providers/xinference/llm/llm.py View File

from openai.types.completion import Completion from openai.types.completion import Completion
from xinference_client.client.restful.restful_client import ( from xinference_client.client.restful.restful_client import (
Client, Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle, RESTfulChatModelHandle,
RESTfulGenerateModelHandle, RESTfulGenerateModelHandle,
) )
if tools and len(tools) > 0: if tools and len(tools) > 0:
generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
vision = credentials.get("support_vision", False) vision = credentials.get("support_vision", False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
if isinstance(xinference_model, RESTfulChatModelHandle):
resp = client.chat.completions.create( resp = client.chat.completions.create(
model=credentials["model_uid"], model=credentials["model_uid"],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],

+ 6
- 6
api/core/model_runtime/model_providers/xinference/tts/tts.py View File

executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [ futures = [
executor.submit( executor.submit(
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=True
) )
for i in range(len(sentences)) for i in range(len(sentences))
] ]


for future in futures: for future in futures:
response = future.result() response = future.result()
for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
else: else:
response = handle.speech( response = handle.speech(
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=True
) )


for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

+ 4
- 4
api/poetry.lock View File



[[package]] [[package]]
name = "xinference-client" name = "xinference-client"
version = "0.13.3"
version = "0.15.2"
description = "Client for Xinference" description = "Client for Xinference"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"},
{file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"},
{file = "xinference-client-0.15.2.tar.gz", hash = "sha256:5c2259bb133148d1cc9bd2b8ec6eb8b5bbeba7f11d6252959f4e6cd79baa53ed"},
{file = "xinference_client-0.15.2-py3-none-any.whl", hash = "sha256:b6275adab695e75e75a33e21e0ad212488fc2d5a4d0f693d544c0e78469abbe3"},
] ]


[package.dependencies] [package.dependencies]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "18924ae12a00bde4438a46168bc167ed69613ab1ab0c387f193cd47ac24379b2"
content-hash = "85aa4be7defee8fe6622cf95ba03e81895121502ebf6d666d6ce376ff019fac7"

+ 1
- 1
api/pyproject.toml View File

unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
websocket-client = "~1.7.0" websocket-client = "~1.7.0"
werkzeug = "~3.0.1" werkzeug = "~3.0.1"
xinference-client = "0.13.3"
xinference-client = "0.15.2"
yarl = "~1.9.4" yarl = "~1.9.4"
zhipuai = "1.0.7" zhipuai = "1.0.7"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.

+ 1
- 4
api/tests/integration_tests/model_runtime/__mock/xinference.py View File

from requests.sessions import Session from requests.sessions import Session
from xinference_client.client.restful.restful_client import ( from xinference_client.client.restful.restful_client import (
Client, Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle, RESTfulChatModelHandle,
RESTfulEmbeddingModelHandle, RESTfulEmbeddingModelHandle,
RESTfulGenerateModelHandle, RESTfulGenerateModelHandle,




class MockXinferenceClass: class MockXinferenceClass:
def get_chat_model(
self: Client, model_uid: str
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found") raise RuntimeError("404 Not Found")



Loading…
Cancel
Save