| @@ -17,8 +17,12 @@ import websocket | |||
| class SparkLLMClient: | |||
| def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): | |||
| domain = 'spark-api.xf-yun.com' if not api_domain else api_domain | |||
| domain = 'spark-api.xf-yun.com' | |||
| endpoint = 'chat' | |||
| if api_domain: | |||
| domain = api_domain | |||
| if model_name == 'spark-v3': | |||
| endpoint = 'multimodal' | |||
| model_api_configs = { | |||
| 'spark': { | |||
| @@ -38,7 +42,7 @@ class SparkLLMClient: | |||
| api_version = model_api_configs[model_name]['version'] | |||
| self.chat_domain = model_api_configs[model_name]['chat_domain'] | |||
| self.api_base = f"wss://{domain}/{api_version}/chat" | |||
| self.api_base = f"wss://{domain}/{api_version}/{endpoint}" | |||
| self.app_id = app_id | |||
| self.ws_url = self.create_url( | |||
| urlparse(self.api_base).netloc, | |||