| import decimal | import decimal | ||||
| from functools import wraps | |||||
| from typing import List, Optional, Any | from typing import List, Optional, Any | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| def _init_client(self) -> Any: | def _init_client(self) -> Any: | ||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | ||||
| return ChatSpark( | return ChatSpark( | ||||
| model_name=self.name, | |||||
| streaming=self.streaming, | streaming=self.streaming, | ||||
| callbacks=self.callbacks, | callbacks=self.callbacks, | ||||
| **self.credentials, | **self.credentials, |
| return [ | return [ | ||||
| { | { | ||||
| 'id': 'spark', | 'id': 'spark', | ||||
| 'name': '星火认知大模型', | |||||
| 'name': 'Spark V1.5', | |||||
| }, | |||||
| { | |||||
| 'id': 'spark-v2', | |||||
| 'name': 'Spark V2.0', | |||||
| } | } | ||||
| ] | ] | ||||
| else: | else: |
| .. code-block:: python | .. code-block:: python | ||||
| client = SparkLLMClient( | client = SparkLLMClient( | ||||
| model_name="<model_name>", | |||||
| app_id="<app_id>", | app_id="<app_id>", | ||||
| api_key="<api_key>", | api_key="<api_key>", | ||||
| api_secret="<api_secret>" | api_secret="<api_secret>" | ||||
| """ | """ | ||||
| client: Any = None #: :meta private: | client: Any = None #: :meta private: | ||||
| model_name: str = "spark" | |||||
| """The Spark model name.""" | |||||
| max_tokens: int = 256 | max_tokens: int = 256 | ||||
| """Denotes the number of tokens to predict per generation.""" | """Denotes the number of tokens to predict per generation.""" | ||||
| ) | ) | ||||
| values["client"] = SparkLLMClient( | values["client"] = SparkLLMClient( | ||||
| model_name=values["model_name"], | |||||
| app_id=values["app_id"], | app_id=values["app_id"], | ||||
| api_key=values["api_key"], | api_key=values["api_key"], | ||||
| api_secret=values["api_secret"], | api_secret=values["api_secret"], |
| class SparkLLMClient: | class SparkLLMClient: | ||||
| def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): | |||||
| def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): | |||||
| self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat') | |||||
| domain = 'spark-api.xf-yun.com' if not api_domain else api_domain | |||||
| api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1' | |||||
| self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general' | |||||
| self.api_base = f"wss://{domain}/{api_version}/chat" | |||||
| self.app_id = app_id | self.app_id = app_id | ||||
| self.ws_url = self.create_url( | self.ws_url = self.create_url( | ||||
| urlparse(self.api_base).netloc, | urlparse(self.api_base).netloc, | ||||
| ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | ||||
| def on_error(self, ws, error): | def on_error(self, ws, error): | ||||
| self.queue.put({'error': error}) | |||||
| self.queue.put({ | |||||
| 'status_code': error.status_code, | |||||
| 'error': error.resp_body.decode('utf-8') | |||||
| }) | |||||
| ws.close() | ws.close() | ||||
| def on_close(self, ws, close_status_code, close_reason): | def on_close(self, ws, close_status_code, close_reason): | ||||
| }, | }, | ||||
| "parameter": { | "parameter": { | ||||
| "chat": { | "chat": { | ||||
| "domain": "general" | |||||
| "domain": self.chat_domain | |||||
| } | } | ||||
| }, | }, | ||||
| "payload": { | "payload": { | ||||
| while True: | while True: | ||||
| content = self.queue.get() | content = self.queue.get() | ||||
| if 'error' in content: | if 'error' in content: | ||||
| raise SparkError(content['error']) | |||||
| if content['status_code'] == 401: | |||||
| raise SparkError('[Spark] The credentials you provided are incorrect. ' | |||||
| 'Please double-check and fill them in again.') | |||||
| elif content['status_code'] == 403: | |||||
| raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " | |||||
| "Please try again after obtaining the necessary permissions.") | |||||
| else: | |||||
| raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") | |||||
| if 'data' not in content: | if 'data' not in content: | ||||
| break | break |
| for model in model_list: | for model in model_list: | ||||
| valid_model_dict = { | valid_model_dict = { | ||||
| "model_name": model['id'], | "model_name": model['id'], | ||||
| "model_display_name": model['name'], | |||||
| "model_type": model_type, | "model_type": model_type, | ||||
| "model_provider": { | "model_provider": { | ||||
| "provider_name": provider.provider_name, | "provider_name": provider.provider_name, |