瀏覽代碼

feat: add spark v2 support (#885)

tags/0.3.15
takatost 2 年之前
父節點
當前提交
f42e7d1a61
沒有連結到貢獻者的電子郵件帳戶。

+ 1
- 1
api/core/model_providers/models/llm/spark_model.py 查看文件

import decimal import decimal
from functools import wraps
from typing import List, Optional, Any from typing import List, Optional, Any


from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatSpark( return ChatSpark(
model_name=self.name,
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
**self.credentials, **self.credentials,

+ 5
- 1
api/core/model_providers/providers/spark_provider.py 查看文件

return [ return [
{ {
'id': 'spark', 'id': 'spark',
'name': '星火认知大模型',
'name': 'Spark V1.5',
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
} }
] ]
else: else:

+ 5
- 0
api/core/third_party/langchain/llms/spark.py 查看文件

.. code-block:: python .. code-block:: python


client = SparkLLMClient( client = SparkLLMClient(
model_name="<model_name>",
app_id="<app_id>", app_id="<app_id>",
api_key="<api_key>", api_key="<api_key>",
api_secret="<api_secret>" api_secret="<api_secret>"
""" """
client: Any = None #: :meta private: client: Any = None #: :meta private:


model_name: str = "spark"
"""The Spark model name."""

max_tokens: int = 256 max_tokens: int = 256
"""Denotes the number of tokens to predict per generation.""" """Denotes the number of tokens to predict per generation."""


) )


values["client"] = SparkLLMClient( values["client"] = SparkLLMClient(
model_name=values["model_name"],
app_id=values["app_id"], app_id=values["app_id"],
api_key=values["api_key"], api_key=values["api_key"],
api_secret=values["api_secret"], api_secret=values["api_secret"],

+ 19
- 5
api/core/third_party/spark/spark_llm.py 查看文件





class SparkLLMClient: class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):


self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'

self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
self.api_base = f"wss://{domain}/{api_version}/chat"
self.app_id = app_id self.app_id = app_id
self.ws_url = self.create_url( self.ws_url = self.create_url(
urlparse(self.api_base).netloc, urlparse(self.api_base).netloc,
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})


def on_error(self, ws, error): def on_error(self, ws, error):
self.queue.put({'error': error})
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close() ws.close()


def on_close(self, ws, close_status_code, close_reason): def on_close(self, ws, close_status_code, close_reason):
}, },
"parameter": { "parameter": {
"chat": { "chat": {
"domain": "general"
"domain": self.chat_domain
} }
}, },
"payload": { "payload": {
while True: while True:
content = self.queue.get() content = self.queue.get()
if 'error' in content: if 'error' in content:
raise SparkError(content['error'])
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")


if 'data' not in content: if 'data' not in content:
break break

+ 1
- 0
api/services/provider_service.py 查看文件

for model in model_list: for model in model_list:
valid_model_dict = { valid_model_dict = {
"model_name": model['id'], "model_name": model['id'],
"model_display_name": model['name'],
"model_type": model_type, "model_type": model_type,
"model_provider": { "model_provider": {
"provider_name": provider.provider_name, "provider_name": provider.provider_name,

Loading…
取消
儲存