Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.5.1
| @@ -1,61 +0,0 @@ | |||
| """Wrapper around ZhipuAI APIs.""" | |||
| from __future__ import annotations | |||
| import logging | |||
| import posixpath | |||
| from pydantic import BaseModel, Extra | |||
| from zhipuai.model_api.api import InvokeType | |||
| from zhipuai.utils import jwt_token | |||
| from zhipuai.utils.http_client import post, stream | |||
| from zhipuai.utils.sse_client import SSEClient | |||
| logger = logging.getLogger(__name__) | |||
| class ZhipuModelAPI(BaseModel): | |||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||
| api_key: str | |||
| api_timeout_seconds = 60 | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| def invoke(self, **kwargs): | |||
| url = self._build_api_url(kwargs, InvokeType.SYNC) | |||
| response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||
| if not response['success']: | |||
| raise ValueError( | |||
| f"Error Code: {response['code']}, Message: {response['msg']} " | |||
| ) | |||
| return response | |||
| def sse_invoke(self, **kwargs): | |||
| url = self._build_api_url(kwargs, InvokeType.SSE) | |||
| data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||
| return SSEClient(data) | |||
| def _build_api_url(self, kwargs, *path): | |||
| if kwargs: | |||
| if "model" not in kwargs: | |||
| raise Exception("model param missed") | |||
| model = kwargs.pop("model") | |||
| else: | |||
| model = "-" | |||
| return posixpath.join(self.base_url, model, *path) | |||
| def _generate_token(self): | |||
| if not self.api_key: | |||
| raise Exception( | |||
| "api_key not provided, you could provide it." | |||
| ) | |||
| try: | |||
| return jwt_token.generate_token(self.api_key) | |||
| except Exception: | |||
| raise ValueError( | |||
| f"Your api_key is invalid, please check it." | |||
| ) | |||
| @@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, | |||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, | |||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage, | |||
| TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils import helper | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI | |||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion | |||
| class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| @@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| credentials_kwargs = self._to_credential_kwargs(credentials) | |||
| # invoke model | |||
| return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user) | |||
| return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) | |||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | |||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||
| @@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| :param tools: tools for tool calling | |||
| :return: | |||
| """ | |||
| prompt = self._convert_messages_to_prompt(prompt_messages) | |||
| prompt = self._convert_messages_to_prompt(prompt_messages, tools) | |||
| return self._get_num_tokens_by_gpt2(prompt) | |||
| @@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| model_parameters={ | |||
| "temperature": 0.5, | |||
| }, | |||
| tools=[], | |||
| stream=False | |||
| ) | |||
| except Exception as ex: | |||
| @@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| def _generate(self, model: str, credentials_kwargs: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||
| """ | |||
| @@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| if stop: | |||
| extra_model_kwargs['stop_sequences'] = stop | |||
| client = ZhipuModelAPI( | |||
| client = ZhipuAI( | |||
| api_key=credentials_kwargs['api_key'] | |||
| ) | |||
| @@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| # not support image message | |||
| continue | |||
| if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER: | |||
| if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ | |||
| copy_prompt_message.role == PromptMessageRole.USER: | |||
| new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | |||
| else: | |||
| if copy_prompt_message.role == PromptMessageRole.USER: | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| elif copy_prompt_message.role == PromptMessageRole.TOOL: | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| elif copy_prompt_message.role == PromptMessageRole.SYSTEM: | |||
| new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) | |||
| new_prompt_messages.append(new_prompt_message) | |||
| else: | |||
| new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) | |||
| new_prompt_messages.append(new_prompt_message) | |||
| @@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| if model == 'glm-4v': | |||
| params = { | |||
| 'model': model, | |||
| 'prompt': [{ | |||
| 'messages': [{ | |||
| 'role': prompt_message.role.value, | |||
| 'content': | |||
| [ | |||
| @@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| else: | |||
| params = { | |||
| 'model': model, | |||
| 'prompt': [{ | |||
| 'role': prompt_message.role.value, | |||
| 'content': prompt_message.content, | |||
| } for prompt_message in new_prompt_messages], | |||
| 'messages': [], | |||
| **model_parameters | |||
| } | |||
| # glm model | |||
| if not model.startswith('chatglm'): | |||
| for prompt_message in new_prompt_messages: | |||
| if prompt_message.role == PromptMessageRole.TOOL: | |||
| params['messages'].append({ | |||
| 'role': 'tool', | |||
| 'content': prompt_message.content, | |||
| 'tool_call_id': prompt_message.tool_call_id | |||
| }) | |||
| else: | |||
| params['messages'].append({ | |||
| 'role': prompt_message.role.value, | |||
| 'content': prompt_message.content | |||
| }) | |||
| else: | |||
| # chatglm model | |||
| for prompt_message in new_prompt_messages: | |||
| # merge system message to user message | |||
| if prompt_message.role == PromptMessageRole.SYSTEM or \ | |||
| prompt_message.role == PromptMessageRole.TOOL or \ | |||
| prompt_message.role == PromptMessageRole.USER: | |||
| if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': | |||
| params['messages'][-1]['content'] += "\n\n" + prompt_message.content | |||
| else: | |||
| params['messages'].append({ | |||
| 'role': 'user', | |||
| 'content': prompt_message.content | |||
| }) | |||
| else: | |||
| params['messages'].append({ | |||
| 'role': prompt_message.role.value, | |||
| 'content': prompt_message.content | |||
| }) | |||
| if tools and len(tools) > 0: | |||
| params['tools'] = [ | |||
| { | |||
| 'type': 'function', | |||
| 'function': helper.dump_model(tool) | |||
| } for tool in tools | |||
| ] | |||
| if stream: | |||
| response = client.sse_invoke(incremental=True, **params).events() | |||
| return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages) | |||
| response = client.chat.completions.create(stream=stream, **params) | |||
| return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) | |||
| response = client.invoke(**params) | |||
| return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages) | |||
| response = client.chat.completions.create(**params) | |||
| return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) | |||
| def _handle_generate_response(self, model: str, | |||
| credentials: dict, | |||
| response: Dict[str, Any], | |||
| tools: Optional[list[PromptMessageTool]], | |||
| response: Completion, | |||
| prompt_messages: list[PromptMessage]) -> LLMResult: | |||
| """ | |||
| Handle llm response | |||
| @@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| :param prompt_messages: prompt messages | |||
| :return: llm response | |||
| """ | |||
| data = response["data"] | |||
| text = '' | |||
| for res in data["choices"]: | |||
| text += res['content'] | |||
| assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] | |||
| for choice in response.choices: | |||
| if choice.message.tool_calls: | |||
| for tool_call in choice.message.tool_calls: | |||
| if tool_call.type == 'function': | |||
| assistant_tool_calls.append( | |||
| AssistantPromptMessage.ToolCall( | |||
| id=tool_call.id, | |||
| type=tool_call.type, | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=tool_call.function.name, | |||
| arguments=tool_call.function.arguments, | |||
| ) | |||
| ) | |||
| ) | |||
| text += choice.message.content or '' | |||
| token_usage = data.get("usage") | |||
| if token_usage is not None: | |||
| if 'prompt_tokens' not in token_usage: | |||
| token_usage['prompt_tokens'] = 0 | |||
| if 'completion_tokens' not in token_usage: | |||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||
| prompt_usage = response.usage.prompt_tokens | |||
| completion_usage = response.usage.completion_tokens | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) | |||
| usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage) | |||
| # transform response | |||
| result = LLMResult( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=text), | |||
| message=AssistantPromptMessage( | |||
| content=text, | |||
| tool_calls=assistant_tool_calls | |||
| ), | |||
| usage=usage, | |||
| ) | |||
| @@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| def _handle_generate_stream_response(self, model: str, | |||
| credentials: dict, | |||
| responses: list[Generator], | |||
| tools: Optional[list[PromptMessageTool]], | |||
| responses: Generator[ChatCompletionChunk, None, None], | |||
| prompt_messages: list[PromptMessage]) -> Generator: | |||
| """ | |||
| Handle llm stream response | |||
| @@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| :param prompt_messages: prompt messages | |||
| :return: llm response chunk generator result | |||
| """ | |||
| for index, event in enumerate(responses): | |||
| if event.event == "add": | |||
| full_assistant_content = '' | |||
| for chunk in responses: | |||
| if len(chunk.choices) == 0: | |||
| continue | |||
| delta = chunk.choices[0] | |||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): | |||
| continue | |||
| assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] | |||
| for tool_call in delta.delta.tool_calls or []: | |||
| if tool_call.type == 'function': | |||
| assistant_tool_calls.append( | |||
| AssistantPromptMessage.ToolCall( | |||
| id=tool_call.id, | |||
| type=tool_call.type, | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=tool_call.function.name, | |||
| arguments=tool_call.function.arguments, | |||
| ) | |||
| ) | |||
| ) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else '', | |||
| tool_calls=assistant_tool_calls | |||
| ) | |||
| full_assistant_content += delta.delta.content if delta.delta.content else '' | |||
| if delta.finish_reason is not None and chunk.usage is not None: | |||
| completion_tokens = chunk.usage.completion_tokens | |||
| prompt_tokens = chunk.usage.prompt_tokens | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||
| yield LLMResultChunk( | |||
| model=chunk.model, | |||
| prompt_messages=prompt_messages, | |||
| model=model, | |||
| system_fingerprint='', | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage(content=event.data) | |||
| index=delta.index, | |||
| message=assistant_prompt_message, | |||
| finish_reason=delta.finish_reason, | |||
| usage=usage | |||
| ) | |||
| ) | |||
| elif event.event == "error" or event.event == "interrupted": | |||
| raise ValueError( | |||
| f"{event.data}" | |||
| ) | |||
| elif event.event == "finish": | |||
| meta = json.loads(event.meta) | |||
| token_usage = meta['usage'] | |||
| if token_usage is not None: | |||
| if 'prompt_tokens' not in token_usage: | |||
| token_usage['prompt_tokens'] = 0 | |||
| if 'completion_tokens' not in token_usage: | |||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||
| usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) | |||
| else: | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| model=chunk.model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint='', | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage(content=event.data), | |||
| finish_reason='finish', | |||
| usage=usage | |||
| index=delta.index, | |||
| message=assistant_prompt_message, | |||
| ) | |||
| ) | |||
| @@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| raise ValueError(f"Got unknown type {message}") | |||
| return message_text | |||
| def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: | |||
| """ | |||
| Format a list of messages into a full prompt for the Anthropic model | |||
| def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: | |||
| """ | |||
| :param messages: List of PromptMessage to combine. | |||
| :return: Combined string with necessary human_prompt and ai_prompt tags. | |||
| """ | |||
| @@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| for message in messages | |||
| ) | |||
| if tools and len(tools) > 0: | |||
| text += "\n\nTools:" | |||
| for tool in tools: | |||
| text += f"\n{tool.json()}" | |||
| # trim off the trailing ' ' that might come from the "Assistant: " | |||
| return text.rstrip() | |||
| return text.rstrip() | |||
| @@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | |||
| from langchain.schema.language_model import _get_token_ids_default_method | |||
| @@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): | |||
| :return: embeddings result | |||
| """ | |||
| credentials_kwargs = self._to_credential_kwargs(credentials) | |||
| client = ZhipuModelAPI( | |||
| client = ZhipuAI( | |||
| api_key=credentials_kwargs['api_key'] | |||
| ) | |||
| @@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): | |||
| try: | |||
| # transform credentials to kwargs for model instance | |||
| credentials_kwargs = self._to_credential_kwargs(credentials) | |||
| client = ZhipuModelAPI( | |||
| client = ZhipuAI( | |||
| api_key=credentials_kwargs['api_key'] | |||
| ) | |||
| @@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]: | |||
| def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: | |||
| """Call out to ZhipuAI's embedding endpoint. | |||
| Args: | |||
| @@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): | |||
| Returns: | |||
| List of embeddings, one for each text. | |||
| """ | |||
| embeddings = [] | |||
| for text in texts: | |||
| response = client.invoke(model=model, prompt=text) | |||
| data = response["data"] | |||
| embeddings.append(data.get('embedding')) | |||
| embedding_used_tokens = 0 | |||
| embedding_used_tokens = data.get('usage') | |||
| for text in texts: | |||
| response = client.embeddings.create(model=model, input=text) | |||
| data = response.data[0] | |||
| embeddings.append(data.embedding) | |||
| embedding_used_tokens += response.usage.total_tokens | |||
| return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0 | |||
| return [list(map(float, e)) for e in embeddings], embedding_used_tokens | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Call out to ZhipuAI's embedding endpoint. | |||
| @@ -0,0 +1,17 @@ | |||
| from ._client import ZhipuAI | |||
| from .core._errors import ( | |||
| ZhipuAIError, | |||
| APIStatusError, | |||
| APIRequestFailedError, | |||
| APIAuthenticationError, | |||
| APIReachLimitError, | |||
| APIInternalError, | |||
| APIServerFlowExceedError, | |||
| APIResponseError, | |||
| APIResponseValidationError, | |||
| APITimeoutError, | |||
| ) | |||
| from .__version__ import __version__ | |||
| @@ -0,0 +1,2 @@ | |||
| __version__ = 'v2.0.1' | |||
| @@ -0,0 +1,71 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, Mapping | |||
| from typing_extensions import override | |||
| from .core import _jwt_token | |||
| from .core._errors import ZhipuAIError | |||
| from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES | |||
| from .core._base_type import NotGiven, NOT_GIVEN | |||
| from . import api_resource | |||
| import os | |||
| import httpx | |||
| from httpx import Timeout | |||
| class ZhipuAI(HttpClient): | |||
| chat: api_resource.chat | |||
| api_key: str | |||
| def __init__( | |||
| self, | |||
| *, | |||
| api_key: str | None = None, | |||
| base_url: str | httpx.URL | None = None, | |||
| timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, | |||
| max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, | |||
| http_client: httpx.Client | None = None, | |||
| custom_headers: Mapping[str, str] | None = None | |||
| ) -> None: | |||
| # if api_key is None: | |||
| # api_key = os.environ.get("ZHIPUAI_API_KEY") | |||
| if api_key is None: | |||
| raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") | |||
| self.api_key = api_key | |||
| if base_url is None: | |||
| base_url = os.environ.get("ZHIPUAI_BASE_URL") | |||
| if base_url is None: | |||
| base_url = f"https://open.bigmodel.cn/api/paas/v4" | |||
| from .__version__ import __version__ | |||
| super().__init__( | |||
| version=__version__, | |||
| base_url=base_url, | |||
| timeout=timeout, | |||
| custom_httpx_client=http_client, | |||
| custom_headers=custom_headers, | |||
| ) | |||
| self.chat = api_resource.chat.Chat(self) | |||
| self.images = api_resource.images.Images(self) | |||
| self.embeddings = api_resource.embeddings.Embeddings(self) | |||
| self.files = api_resource.files.Files(self) | |||
| self.fine_tuning = api_resource.fine_tuning.FineTuning(self) | |||
| @property | |||
| @override | |||
| def _auth_headers(self) -> dict[str, str]: | |||
| api_key = self.api_key | |||
| return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} | |||
| def __del__(self) -> None: | |||
| if (not hasattr(self, "_has_custom_http_client") | |||
| or not hasattr(self, "close") | |||
| or not hasattr(self, "_client")): | |||
| # if the '__init__' method raised an error, self would not have client attr | |||
| return | |||
| if self._has_custom_http_client: | |||
| return | |||
| self.close() | |||
| @@ -0,0 +1,5 @@ | |||
| from .chat import chat | |||
| from .images import Images | |||
| from .embeddings import Embeddings | |||
| from .files import Files | |||
| from .fine_tuning import fine_tuning | |||
| @@ -0,0 +1,87 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, List, Optional, TYPE_CHECKING | |||
| import httpx | |||
| from typing_extensions import Literal | |||
| from ...core._base_api import BaseAPI | |||
| from ...core._base_type import NotGiven, NOT_GIVEN, Headers | |||
| from ...core._http_client import make_user_request_input | |||
| from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class AsyncCompletions(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||
| seed: int | NotGiven = NOT_GIVEN, | |||
| messages: Union[str, List[str], List[int], List[List[int]], None], | |||
| stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> AsyncTaskStatus: | |||
| _cast_type = AsyncTaskStatus | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._post( | |||
| "/async/chat/completions", | |||
| body={ | |||
| "model": model, | |||
| "request_id": request_id, | |||
| "temperature": temperature, | |||
| "top_p": top_p, | |||
| "do_sample": do_sample, | |||
| "max_tokens": max_tokens, | |||
| "seed": seed, | |||
| "messages": messages, | |||
| "stop": stop, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "tools": tools, | |||
| "tool_choice": tool_choice, | |||
| }, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=_cast_type, | |||
| enable_stream=False, | |||
| ) | |||
| def retrieve_completion_result( | |||
| self, | |||
| id: str, | |||
| extra_headers: Headers | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Union[AsyncCompletion, AsyncTaskStatus]: | |||
| _cast_type = Union[AsyncCompletion,AsyncTaskStatus] | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._get( | |||
| path=f"/async-result/{id}", | |||
| cast_type=_cast_type, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, | |||
| timeout=timeout | |||
| ) | |||
| ) | |||
| @@ -0,0 +1,16 @@ | |||
| from typing import TYPE_CHECKING | |||
| from .completions import Completions | |||
| from .async_completions import AsyncCompletions | |||
| from ...core._base_api import BaseAPI | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class Chat(BaseAPI): | |||
| completions: Completions | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| self.completions = Completions(client) | |||
| self.asyncCompletions = AsyncCompletions(client) | |||
| @@ -0,0 +1,71 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, List, Optional, TYPE_CHECKING | |||
| import httpx | |||
| from typing_extensions import Literal | |||
| from ...core._base_api import BaseAPI | |||
| from ...core._base_type import NotGiven, NOT_GIVEN, Headers | |||
| from ...core._http_client import make_user_request_input | |||
| from ...core._sse_client import StreamResponse | |||
| from ...types.chat.chat_completion import Completion | |||
| from ...types.chat.chat_completion_chunk import ChatCompletionChunk | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class Completions(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||
| seed: int | NotGiven = NOT_GIVEN, | |||
| messages: Union[str, List[str], List[int], object, None], | |||
| stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Completion | StreamResponse[ChatCompletionChunk]: | |||
| _cast_type = Completion | |||
| _stream_cls = StreamResponse[ChatCompletionChunk] | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| _stream_cls = StreamResponse[object] | |||
| return self._post( | |||
| "/chat/completions", | |||
| body={ | |||
| "model": model, | |||
| "request_id": request_id, | |||
| "temperature": temperature, | |||
| "top_p": top_p, | |||
| "do_sample": do_sample, | |||
| "max_tokens": max_tokens, | |||
| "seed": seed, | |||
| "messages": messages, | |||
| "stop": stop, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "stream": stream, | |||
| "tools": tools, | |||
| "tool_choice": tool_choice, | |||
| }, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, | |||
| ), | |||
| cast_type=_cast_type, | |||
| enable_stream=stream or False, | |||
| stream_cls=_stream_cls, | |||
| ) | |||
| @@ -0,0 +1,49 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, List, Optional, TYPE_CHECKING | |||
| import httpx | |||
| from ..core._base_api import BaseAPI | |||
| from ..core._base_type import NotGiven, NOT_GIVEN, Headers | |||
| from ..core._http_client import make_user_request_input | |||
| from ..types.embeddings import EmbeddingsResponded | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class Embeddings(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| input: Union[str, List[str], List[int], List[List[int]]], | |||
| model: Union[str], | |||
| encoding_format: str | NotGiven = NOT_GIVEN, | |||
| user: str | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> EmbeddingsResponded: | |||
| _cast_type = EmbeddingsResponded | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._post( | |||
| "/embeddings", | |||
| body={ | |||
| "input": input, | |||
| "model": model, | |||
| "encoding_format": encoding_format, | |||
| "user": user, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| }, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=_cast_type, | |||
| enable_stream=False, | |||
| ) | |||
| @@ -0,0 +1,78 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| import httpx | |||
| from ..core._base_api import BaseAPI | |||
| from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes | |||
| from ..core._files import is_file_content | |||
| from ..core._http_client import ( | |||
| make_user_request_input, | |||
| ) | |||
| from ..types.file_object import FileObject, ListOfFileObject | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| __all__ = ["Files"] | |||
| class Files(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| file: FileTypes, | |||
| purpose: str, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FileObject: | |||
| if not is_file_content(file): | |||
| prefix = f"Expected file input `{file!r}`" | |||
| raise RuntimeError( | |||
| f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." | |||
| ) from None | |||
| files = [("file", file)] | |||
| extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} | |||
| return self._post( | |||
| "/files", | |||
| body={ | |||
| "purpose": purpose, | |||
| }, | |||
| files=files, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=FileObject, | |||
| ) | |||
| def list( | |||
| self, | |||
| *, | |||
| purpose: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| order: str | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ListOfFileObject: | |||
| return self._get( | |||
| "/files", | |||
| cast_type=ListOfFileObject, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, | |||
| timeout=timeout, | |||
| query={ | |||
| "purpose": purpose, | |||
| "limit": limit, | |||
| "after": after, | |||
| "order": order, | |||
| }, | |||
| ), | |||
| ) | |||
| @@ -0,0 +1,15 @@ | |||
| from typing import TYPE_CHECKING | |||
| from .jobs import Jobs | |||
| from ...core._base_api import BaseAPI | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class FineTuning(BaseAPI): | |||
| jobs: Jobs | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| self.jobs = Jobs(client) | |||
| @@ -0,0 +1,115 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional, TYPE_CHECKING | |||
| import httpx | |||
| from ...core._base_api import BaseAPI | |||
| from ...core._base_type import NOT_GIVEN, Headers, NotGiven | |||
| from ...core._http_client import ( | |||
| make_user_request_input, | |||
| ) | |||
| from ...types.fine_tuning import ( | |||
| FineTuningJob, | |||
| job_create_params, | |||
| ListOfFineTuningJob, | |||
| FineTuningJobEvent, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| __all__ = ["Jobs"] | |||
| class Jobs(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| training_file: str, | |||
| hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, | |||
| suffix: Optional[str] | NotGiven = NOT_GIVEN, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| validation_file: Optional[str] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| return self._post( | |||
| "/fine_tuning/jobs", | |||
| body={ | |||
| "model": model, | |||
| "training_file": training_file, | |||
| "hyperparameters": hyperparameters, | |||
| "suffix": suffix, | |||
| "validation_file": validation_file, | |||
| "request_id": request_id, | |||
| }, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| def retrieve( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| return self._get( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}", | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| def list( | |||
| self, | |||
| *, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ListOfFineTuningJob: | |||
| return self._get( | |||
| "/fine_tuning/jobs", | |||
| cast_type=ListOfFineTuningJob, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, | |||
| timeout=timeout, | |||
| query={ | |||
| "after": after, | |||
| "limit": limit, | |||
| }, | |||
| ), | |||
| ) | |||
| def list_events( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJobEvent: | |||
| return self._get( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}/events", | |||
| cast_type=FineTuningJobEvent, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, | |||
| timeout=timeout, | |||
| query={ | |||
| "after": after, | |||
| "limit": limit, | |||
| }, | |||
| ), | |||
| ) | |||
| @@ -0,0 +1,55 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, List, Optional, TYPE_CHECKING | |||
| import httpx | |||
| from ..core._base_api import BaseAPI | |||
| from ..core._base_type import NotGiven, NOT_GIVEN, Headers | |||
| from ..core._http_client import make_user_request_input | |||
| from ..types.image import ImagesResponded | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class Images(BaseAPI): | |||
| def __init__(self, client: "ZhipuAI") -> None: | |||
| super().__init__(client) | |||
| def generations( | |||
| self, | |||
| *, | |||
| prompt: str, | |||
| model: str | NotGiven = NOT_GIVEN, | |||
| n: Optional[int] | NotGiven = NOT_GIVEN, | |||
| quality: Optional[str] | NotGiven = NOT_GIVEN, | |||
| response_format: Optional[str] | NotGiven = NOT_GIVEN, | |||
| size: Optional[str] | NotGiven = NOT_GIVEN, | |||
| style: Optional[str] | NotGiven = NOT_GIVEN, | |||
| user: str | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ImagesResponded: | |||
| _cast_type = ImagesResponded | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._post( | |||
| "/images/generations", | |||
| body={ | |||
| "prompt": prompt, | |||
| "model": model, | |||
| "n": n, | |||
| "quality": quality, | |||
| "response_format": response_format, | |||
| "size": size, | |||
| "style": style, | |||
| "user": user, | |||
| }, | |||
| options=make_user_request_input( | |||
| extra_headers=extra_headers, timeout=timeout | |||
| ), | |||
| cast_type=_cast_type, | |||
| enable_stream=False, | |||
| ) | |||
| @@ -0,0 +1,17 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class BaseAPI: | |||
| _client: ZhipuAI | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| self._client = client | |||
| self._delete = client.delete | |||
| self._get = client.get | |||
| self._post = client.post | |||
| self._put = client.put | |||
| self._patch = client.patch | |||
| @@ -0,0 +1,115 @@ | |||
| from __future__ import annotations | |||
| from os import PathLike | |||
| from typing import ( | |||
| TYPE_CHECKING, | |||
| Type, | |||
| Union, | |||
| Mapping, | |||
| TypeVar, IO, Tuple, Sequence, Any, List, | |||
| ) | |||
| import pydantic | |||
| from typing_extensions import ( | |||
| Literal, | |||
| override, | |||
| ) | |||
| Query = Mapping[str, object] | |||
| Body = object | |||
| AnyMapping = Mapping[str, object] | |||
| PrimitiveData = Union[str, int, float, bool, None] | |||
| Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] | |||
| ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) | |||
| _T = TypeVar("_T") | |||
| if TYPE_CHECKING: | |||
| NoneType: Type[None] | |||
| else: | |||
| NoneType = type(None) | |||
| # Sentinel class used until PEP 0661 is accepted | |||
| class NotGiven(pydantic.BaseModel): | |||
| """ | |||
| A sentinel singleton class used to distinguish omitted keyword arguments | |||
| from those passed in with the value None (which may have different behavior). | |||
| For example: | |||
| ```py | |||
| def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... | |||
| get(timeout=1) # 1s timeout | |||
| get(timeout=None) # No timeout | |||
| get() # Default timeout behavior, which may not be statically known at the method definition. | |||
| ``` | |||
| """ | |||
| def __bool__(self) -> Literal[False]: | |||
| return False | |||
| @override | |||
| def __repr__(self) -> str: | |||
| return "NOT_GIVEN" | |||
| NotGivenOr = Union[_T, NotGiven] | |||
| NOT_GIVEN = NotGiven() | |||
| class Omit(pydantic.BaseModel): | |||
| """In certain situations you need to be able to represent a case where a default value has | |||
| to be explicitly removed and `None` is not an appropriate substitute, for example: | |||
| ```py | |||
| # as the default `Content-Type` header is `application/json` that will be sent | |||
| client.post('/upload/files', files={'file': b'my raw file content'}) | |||
| # you can't explicitly override the header as it has to be dynamically generated | |||
| # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' | |||
| client.post(..., headers={'Content-Type': 'multipart/form-data'}) | |||
| # instead you can remove the default `application/json` header by passing Omit | |||
| client.post(..., headers={'Content-Type': Omit()}) | |||
| ``` | |||
| """ | |||
| def __bool__(self) -> Literal[False]: | |||
| return False | |||
| Headers = Mapping[str, Union[str, Omit]] | |||
| ResponseT = TypeVar( | |||
| "ResponseT", | |||
| bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", | |||
| ) | |||
| # for user input files | |||
| if TYPE_CHECKING: | |||
| FileContent = Union[IO[bytes], bytes, PathLike[str]] | |||
| else: | |||
| FileContent = Union[IO[bytes], bytes, PathLike] | |||
| FileTypes = Union[ | |||
| FileContent, # file content | |||
| Tuple[str, FileContent], # (filename, file) | |||
| Tuple[str, FileContent, str], # (filename, file , content_type) | |||
| Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) | |||
| ] | |||
| RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] | |||
| # for httpx client supported files | |||
| HttpxFileContent = Union[bytes, IO[bytes]] | |||
| HttpxFileTypes = Union[ | |||
| FileContent, # file content | |||
| Tuple[str, HttpxFileContent], # (filename, file) | |||
| Tuple[str, HttpxFileContent, str], # (filename, file , content_type) | |||
| Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) | |||
| ] | |||
| HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] | |||
| @@ -0,0 +1,90 @@ | |||
| from __future__ import annotations | |||
| import httpx | |||
| __all__ = [ | |||
| "ZhipuAIError", | |||
| "APIStatusError", | |||
| "APIRequestFailedError", | |||
| "APIAuthenticationError", | |||
| "APIReachLimitError", | |||
| "APIInternalError", | |||
| "APIServerFlowExceedError", | |||
| "APIResponseError", | |||
| "APIResponseValidationError", | |||
| "APITimeoutError", | |||
| ] | |||
| class ZhipuAIError(Exception): | |||
| def __init__(self, message: str, ) -> None: | |||
| super().__init__(message) | |||
| class APIStatusError(Exception): | |||
| response: httpx.Response | |||
| status_code: int | |||
| def __init__(self, message: str, *, response: httpx.Response) -> None: | |||
| super().__init__(message) | |||
| self.response = response | |||
| self.status_code = response.status_code | |||
| class APIRequestFailedError(APIStatusError): | |||
| ... | |||
| class APIAuthenticationError(APIStatusError): | |||
| ... | |||
| class APIReachLimitError(APIStatusError): | |||
| ... | |||
| class APIInternalError(APIStatusError): | |||
| ... | |||
| class APIServerFlowExceedError(APIStatusError): | |||
| ... | |||
| class APIResponseError(Exception): | |||
| message: str | |||
| request: httpx.Request | |||
| json_data: object | |||
| def __init__(self, message: str, request: httpx.Request, json_data: object): | |||
| self.message = message | |||
| self.request = request | |||
| self.json_data = json_data | |||
| super().__init__(message) | |||
| class APIResponseValidationError(APIResponseError): | |||
| status_code: int | |||
| response: httpx.Response | |||
| def __init__( | |||
| self, | |||
| response: httpx.Response, | |||
| json_data: object | None, *, | |||
| message: str | None = None | |||
| ) -> None: | |||
| super().__init__( | |||
| message=message or "Data returned by API invalid for expected schema.", | |||
| request=response.request, | |||
| json_data=json_data | |||
| ) | |||
| self.response = response | |||
| self.status_code = response.status_code | |||
| class APITimeoutError(Exception): | |||
| request: httpx.Request | |||
| def __init__(self, request: httpx.Request): | |||
| self.request = request | |||
| super().__init__("Request Timeout") | |||
| @@ -0,0 +1,46 @@ | |||
| from __future__ import annotations | |||
| import io | |||
| import os | |||
| from pathlib import Path | |||
| from typing import Mapping, Sequence | |||
| from ._base_type import ( | |||
| FileTypes, | |||
| HttpxFileTypes, | |||
| HttpxRequestFiles, | |||
| RequestFiles, | |||
| ) | |||
| def is_file_content(obj: object) -> bool: | |||
| return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) | |||
| def _transform_file(file: FileTypes) -> HttpxFileTypes: | |||
| if is_file_content(file): | |||
| if isinstance(file, os.PathLike): | |||
| path = Path(file) | |||
| return path.name, path.read_bytes() | |||
| else: | |||
| return file | |||
| if isinstance(file, tuple): | |||
| if isinstance(file[1], os.PathLike): | |||
| return (file[0], Path(file[1]).read_bytes(), *file[2:]) | |||
| else: | |||
| return (file[0], file[1], *file[2:]) | |||
| else: | |||
| raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") | |||
| def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: | |||
| if files is None: | |||
| return None | |||
| if isinstance(files, Mapping): | |||
| files = {key: _transform_file(file) for key, file in files.items()} | |||
| elif isinstance(files, Sequence): | |||
| files = [(key, _transform_file(file)) for key, file in files] | |||
| else: | |||
| raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") | |||
| return files | |||
| @@ -0,0 +1,377 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from __future__ import annotations | |||
| import inspect | |||
| from typing import ( | |||
| Any, | |||
| Type, | |||
| Union, | |||
| cast, | |||
| Mapping, | |||
| ) | |||
| import httpx | |||
| import pydantic | |||
| from httpx import URL, Timeout | |||
| from . import _errors | |||
| from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data | |||
| from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError | |||
| from ._files import make_httpx_files | |||
| from ._request_opt import ClientRequestParam, UserRequestInput | |||
| from ._response import HttpResponse | |||
| from ._sse_client import StreamResponse | |||
| from ._utils import flatten | |||
| headers = { | |||
| "Accept": "application/json", | |||
| "Content-Type": "application/json; charset=UTF-8", | |||
| } | |||
| def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: | |||
| merged = {**map1, **map2} | |||
| return {key: val for key, val in merged.items() if val is not None} | |||
| from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT | |||
| ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) | |||
| ZHIPUAI_DEFAULT_MAX_RETRIES = 3 | |||
| ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) | |||
| class HttpClient: | |||
| _client: httpx.Client | |||
| _version: str | |||
| _base_url: URL | |||
| timeout: Union[float, Timeout, None] | |||
| _limits: httpx.Limits | |||
| _has_custom_http_client: bool | |||
| _default_stream_cls: type[StreamResponse[Any]] | None = None | |||
| def __init__( | |||
| self, | |||
| *, | |||
| version: str, | |||
| base_url: URL, | |||
| timeout: Union[float, Timeout, None], | |||
| custom_httpx_client: httpx.Client | None = None, | |||
| custom_headers: Mapping[str, str] | None = None, | |||
| ) -> None: | |||
| if timeout is None or isinstance(timeout, NotGiven): | |||
| if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: | |||
| timeout = custom_httpx_client.timeout | |||
| else: | |||
| timeout = ZHIPUAI_DEFAULT_TIMEOUT | |||
| self.timeout = cast(Timeout, timeout) | |||
| self._has_custom_http_client = bool(custom_httpx_client) | |||
| self._client = custom_httpx_client or httpx.Client( | |||
| base_url=base_url, | |||
| timeout=self.timeout, | |||
| limits=ZHIPUAI_DEFAULT_LIMITS, | |||
| ) | |||
| self._version = version | |||
| url = URL(url=base_url) | |||
| if not url.raw_path.endswith(b"/"): | |||
| url = url.copy_with(raw_path=url.raw_path + b"/") | |||
| self._base_url = url | |||
| self._custom_headers = custom_headers or {} | |||
| def _prepare_url(self, url: str) -> URL: | |||
| sub_url = URL(url) | |||
| if sub_url.is_relative_url: | |||
| request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") | |||
| return self._base_url.copy_with(raw_path=request_raw_url) | |||
| return sub_url | |||
| @property | |||
| def _default_headers(self): | |||
| return \ | |||
| { | |||
| "Accept": "application/json", | |||
| "Content-Type": "application/json; charset=UTF-8", | |||
| "ZhipuAI-SDK-Ver": self._version, | |||
| "source_type": "zhipu-sdk-python", | |||
| "x-request-sdk": "zhipu-sdk-python", | |||
| **self._auth_headers, | |||
| **self._custom_headers, | |||
| } | |||
| @property | |||
| def _auth_headers(self): | |||
| return {} | |||
| def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: | |||
| custom_headers = request_param.headers or {} | |||
| headers_dict = _merge_map(self._default_headers, custom_headers) | |||
| httpx_headers = httpx.Headers(headers_dict) | |||
| return httpx_headers | |||
| def _prepare_request( | |||
| self, | |||
| request_param: ClientRequestParam | |||
| ) -> httpx.Request: | |||
| kwargs: dict[str, Any] = {} | |||
| json_data = request_param.json_data | |||
| headers = self._prepare_headers(request_param) | |||
| url = self._prepare_url(request_param.url) | |||
| json_data = request_param.json_data | |||
| if headers.get("Content-Type") == "multipart/form-data": | |||
| headers.pop("Content-Type") | |||
| if json_data: | |||
| kwargs["data"] = self._make_multipartform(json_data) | |||
| return self._client.build_request( | |||
| headers=headers, | |||
| timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, | |||
| method=request_param.method, | |||
| url=url, | |||
| json=json_data, | |||
| files=request_param.files, | |||
| params=request_param.params, | |||
| **kwargs, | |||
| ) | |||
| def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: | |||
| items = [] | |||
| if isinstance(value, Mapping): | |||
| for k, v in value.items(): | |||
| items.extend(self._object_to_formfata(f"{key}[{k}]", v)) | |||
| return items | |||
| if isinstance(value, (list, tuple)): | |||
| for v in value: | |||
| items.extend(self._object_to_formfata(key + "[]", v)) | |||
| return items | |||
| def _primitive_value_to_str(val) -> str: | |||
| # copied from httpx | |||
| if val is True: | |||
| return "true" | |||
| elif val is False: | |||
| return "false" | |||
| elif val is None: | |||
| return "" | |||
| return str(val) | |||
| str_data = _primitive_value_to_str(value) | |||
| if not str_data: | |||
| return [] | |||
| return [(key, str_data)] | |||
| def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: | |||
| items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) | |||
| serialized: dict[str, object] = {} | |||
| for key, value in items: | |||
| if key in serialized: | |||
| raise ValueError(f"存在重复的键: {key};") | |||
| serialized[key] = value | |||
| return serialized | |||
| def _parse_response( | |||
| self, | |||
| *, | |||
| cast_type: Type[ResponseT], | |||
| response: httpx.Response, | |||
| enable_stream: bool, | |||
| request_param: ClientRequestParam, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| ) -> HttpResponse: | |||
| http_response = HttpResponse( | |||
| raw_response=response, | |||
| cast_type=cast_type, | |||
| client=self, | |||
| enable_stream=enable_stream, | |||
| stream_cls=stream_cls | |||
| ) | |||
| return http_response.parse() | |||
| def _process_response_data( | |||
| self, | |||
| *, | |||
| data: object, | |||
| cast_type: type[ResponseT], | |||
| response: httpx.Response, | |||
| ) -> ResponseT: | |||
| if data is None: | |||
| return cast(ResponseT, None) | |||
| try: | |||
| if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): | |||
| return cast(ResponseT, cast_type.validate(data)) | |||
| return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) | |||
| except pydantic.ValidationError as err: | |||
| raise APIResponseValidationError(response=response, json_data=data) from err | |||
| def is_closed(self) -> bool: | |||
| return self._client.is_closed | |||
| def close(self): | |||
| self._client.close() | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| self.close() | |||
| def request( | |||
| self, | |||
| *, | |||
| cast_type: Type[ResponseT], | |||
| params: ClientRequestParam, | |||
| enable_stream: bool = False, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| request = self._prepare_request(params) | |||
| try: | |||
| response = self._client.send( | |||
| request, | |||
| stream=enable_stream, | |||
| ) | |||
| response.raise_for_status() | |||
| except httpx.TimeoutException as err: | |||
| raise APITimeoutError(request=request) from err | |||
| except httpx.HTTPStatusError as err: | |||
| err.response.read() | |||
| # raise err | |||
| raise self._make_status_error(err.response) from None | |||
| except Exception as err: | |||
| raise err | |||
| return self._parse_response( | |||
| cast_type=cast_type, | |||
| request_param=params, | |||
| response=response, | |||
| enable_stream=enable_stream, | |||
| stream_cls=stream_cls, | |||
| ) | |||
| def get( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: Type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| enable_stream: bool = False, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = ClientRequestParam.construct(method="get", url=path, **options) | |||
| return self.request( | |||
| cast_type=cast_type, params=opts, | |||
| enable_stream=enable_stream | |||
| ) | |||
| def post( | |||
| self, | |||
| path: str, | |||
| *, | |||
| body: Body | None = None, | |||
| cast_type: Type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| enable_stream: bool = False, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, | |||
| **options) | |||
| return self.request( | |||
| cast_type=cast_type, params=opts, | |||
| enable_stream=enable_stream, | |||
| stream_cls=stream_cls | |||
| ) | |||
| def patch( | |||
| self, | |||
| path: str, | |||
| *, | |||
| body: Body | None = None, | |||
| cast_type: Type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| ) -> ResponseT: | |||
| opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) | |||
| return self.request( | |||
| cast_type=cast_type, params=opts, | |||
| ) | |||
| def put( | |||
| self, | |||
| path: str, | |||
| *, | |||
| body: Body | None = None, | |||
| cast_type: Type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), | |||
| **options) | |||
| return self.request( | |||
| cast_type=cast_type, params=opts, | |||
| ) | |||
| def delete( | |||
| self, | |||
| path: str, | |||
| *, | |||
| body: Body | None = None, | |||
| cast_type: Type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) | |||
| return self.request( | |||
| cast_type=cast_type, params=opts, | |||
| ) | |||
| def _make_status_error(self, response) -> APIStatusError: | |||
| response_text = response.text.strip() | |||
| status_code = response.status_code | |||
| error_msg = f"Error code: {status_code}, with error text {response_text}" | |||
| if status_code == 400: | |||
| return _errors.APIRequestFailedError(message=error_msg, response=response) | |||
| elif status_code == 401: | |||
| return _errors.APIAuthenticationError(message=error_msg, response=response) | |||
| elif status_code == 429: | |||
| return _errors.APIReachLimitError(message=error_msg, response=response) | |||
| elif status_code == 500: | |||
| return _errors.APIInternalError(message=error_msg, response=response) | |||
| elif status_code == 503: | |||
| return _errors.APIServerFlowExceedError(message=error_msg, response=response) | |||
| return APIStatusError(message=error_msg, response=response) | |||
| def make_user_request_input( | |||
| max_retries: int | None = None, | |||
| timeout: float | Timeout | None | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers = None, | |||
| query: Query | None = None, | |||
| ) -> UserRequestInput: | |||
| options: UserRequestInput = {} | |||
| if extra_headers is not None: | |||
| options["headers"] = extra_headers | |||
| if max_retries is not None: | |||
| options["max_retries"] = max_retries | |||
| if not isinstance(timeout, NotGiven): | |||
| options['timeout'] = timeout | |||
| if query is not None: | |||
| options["params"] = query | |||
| return options | |||
| @@ -0,0 +1,30 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import time | |||
| import cachetools.func | |||
| import jwt | |||
| API_TOKEN_TTL_SECONDS = 3 * 60 | |||
| CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 | |||
| @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) | |||
| def generate_token(apikey: str): | |||
| try: | |||
| api_key, secret = apikey.split(".") | |||
| except Exception as e: | |||
| raise Exception("invalid api_key", e) | |||
| payload = { | |||
| "api_key": api_key, | |||
| "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, | |||
| "timestamp": int(round(time.time() * 1000)), | |||
| } | |||
| ret = jwt.encode( | |||
| payload, | |||
| secret, | |||
| algorithm="HS256", | |||
| headers={"alg": "HS256", "sign_type": "SIGN"}, | |||
| ) | |||
| return ret | |||
| @@ -0,0 +1,54 @@ | |||
| from __future__ import annotations | |||
| from typing import Union, Any, cast | |||
| import pydantic.generics | |||
| from httpx import Timeout | |||
| from pydantic import ConfigDict | |||
| from typing_extensions import ( | |||
| Unpack, ClassVar, TypedDict | |||
| ) | |||
| from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query | |||
| from ._utils import remove_notgiven_indict | |||
| class UserRequestInput(TypedDict, total=False): | |||
| max_retries: int | |||
| timeout: float | Timeout | None | |||
| headers: Headers | |||
| params: Query | None | |||
| class ClientRequestParam(): | |||
| method: str | |||
| url: str | |||
| max_retries: Union[int, NotGiven] = NotGiven() | |||
| timeout: Union[float, NotGiven] = NotGiven() | |||
| headers: Union[Headers, NotGiven] = NotGiven() | |||
| json_data: Union[Body, None] = None | |||
| files: Union[HttpxRequestFiles, None] = None | |||
| params: Query = {} | |||
| model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) | |||
| def get_max_retries(self, max_retries) -> int: | |||
| if isinstance(self.max_retries, NotGiven): | |||
| return max_retries | |||
| return self.max_retries | |||
| @classmethod | |||
| def construct( # type: ignore | |||
| cls, | |||
| _fields_set: set[str] | None = None, | |||
| **values: Unpack[UserRequestInput], | |||
| ) -> ClientRequestParam : | |||
| kwargs: dict[str, Any] = { | |||
| key: remove_notgiven_indict(value) for key, value in values.items() | |||
| } | |||
| client = cls() | |||
| client.__dict__.update(kwargs) | |||
| return client | |||
| model_construct = construct | |||
| @@ -0,0 +1,121 @@ | |||
| from __future__ import annotations | |||
| import datetime | |||
| from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING | |||
| import httpx | |||
| import pydantic | |||
| from typing_extensions import ParamSpec, get_origin, get_args | |||
| from ._base_type import NoneType | |||
| from ._sse_client import StreamResponse | |||
| if TYPE_CHECKING: | |||
| from ._http_client import HttpClient | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class HttpResponse(Generic[R]): | |||
| _cast_type: type[R] | |||
| _client: "HttpClient" | |||
| _parsed: R | None | |||
| _enable_stream: bool | |||
| _stream_cls: type[StreamResponse[Any]] | |||
| http_response: httpx.Response | |||
| def __init__( | |||
| self, | |||
| *, | |||
| raw_response: httpx.Response, | |||
| cast_type: type[R], | |||
| client: "HttpClient", | |||
| enable_stream: bool = False, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| ) -> None: | |||
| self._cast_type = cast_type | |||
| self._client = client | |||
| self._parsed = None | |||
| self._stream_cls = stream_cls | |||
| self._enable_stream = enable_stream | |||
| self.http_response = raw_response | |||
| def parse(self) -> R: | |||
| self._parsed = self._parse() | |||
| return self._parsed | |||
| def _parse(self) -> R: | |||
| if self._enable_stream: | |||
| self._parsed = cast( | |||
| R, | |||
| self._stream_cls( | |||
| cast_type=cast(type, get_args(self._stream_cls)[0]), | |||
| response=self.http_response, | |||
| client=self._client | |||
| ) | |||
| ) | |||
| return self._parsed | |||
| cast_type = self._cast_type | |||
| if cast_type is NoneType: | |||
| return cast(R, None) | |||
| http_response = self.http_response | |||
| if cast_type == str: | |||
| return cast(R, http_response.text) | |||
| content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") | |||
| origin = get_origin(cast_type) or cast_type | |||
| if content_type != "application/json": | |||
| if issubclass(origin, pydantic.BaseModel): | |||
| data = http_response.json() | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=http_response, | |||
| ) | |||
| return http_response.text | |||
| data = http_response.json() | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=http_response, | |||
| ) | |||
| @property | |||
| def headers(self) -> httpx.Headers: | |||
| return self.http_response.headers | |||
| @property | |||
| def http_request(self) -> httpx.Request: | |||
| return self.http_response.request | |||
| @property | |||
| def status_code(self) -> int: | |||
| return self.http_response.status_code | |||
| @property | |||
| def url(self) -> httpx.URL: | |||
| return self.http_response.url | |||
| @property | |||
| def method(self) -> str: | |||
| return self.http_request.method | |||
| @property | |||
| def content(self) -> bytes: | |||
| return self.http_response.content | |||
| @property | |||
| def text(self) -> str: | |||
| return self.http_response.text | |||
| @property | |||
| def http_version(self) -> str: | |||
| return self.http_response.http_version | |||
| @property | |||
| def elapsed(self) -> datetime.timedelta: | |||
| return self.http_response.elapsed | |||
| @@ -0,0 +1,149 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from __future__ import annotations | |||
| import json | |||
| from typing import Generic, Iterator, TYPE_CHECKING, Mapping | |||
| import httpx | |||
| from ._base_type import ResponseT | |||
| from ._errors import APIResponseError | |||
| _FIELD_SEPARATOR = ":" | |||
| if TYPE_CHECKING: | |||
| from ._http_client import HttpClient | |||
| class StreamResponse(Generic[ResponseT]): | |||
| response: httpx.Response | |||
| _cast_type: type[ResponseT] | |||
| def __init__( | |||
| self, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| response: httpx.Response, | |||
| client: HttpClient, | |||
| ) -> None: | |||
| self.response = response | |||
| self._cast_type = cast_type | |||
| self._data_process_func = client._process_response_data | |||
| self._stream_chunks = self.__stream__() | |||
| def __next__(self) -> ResponseT: | |||
| return self._stream_chunks.__next__() | |||
| def __iter__(self) -> Iterator[ResponseT]: | |||
| for item in self._stream_chunks: | |||
| yield item | |||
| def __stream__(self) -> Iterator[ResponseT]: | |||
| sse_line_parser = SSELineParser() | |||
| iterator = sse_line_parser.iter_lines(self.response.iter_lines()) | |||
| for sse in iterator: | |||
| if sse.data.startswith("[DONE]"): | |||
| break | |||
| if sse.event is None: | |||
| data = sse.json_data() | |||
| if isinstance(data, Mapping) and data.get("error"): | |||
| raise APIResponseError( | |||
| message="An error occurred during streaming", | |||
| request=self.response.request, | |||
| json_data=data["error"], | |||
| ) | |||
| yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) | |||
| for sse in iterator: | |||
| pass | |||
| class Event(object): | |||
| def __init__( | |||
| self, | |||
| event: str | None = None, | |||
| data: str | None = None, | |||
| id: str | None = None, | |||
| retry: int | None = None | |||
| ): | |||
| self._event = event | |||
| self._data = data | |||
| self._id = id | |||
| self._retry = retry | |||
| def __repr__(self): | |||
| data_len = len(self._data) if self._data else 0 | |||
| return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" | |||
| @property | |||
| def event(self): return self._event | |||
| @property | |||
| def data(self): return self._data | |||
| def json_data(self): return json.loads(self._data) | |||
| @property | |||
| def id(self): return self._id | |||
| @property | |||
| def retry(self): return self._retry | |||
| class SSELineParser: | |||
| _data: list[str] | |||
| _event: str | None | |||
| _retry: int | None | |||
| _id: str | None | |||
| def __init__(self): | |||
| self._event = None | |||
| self._data = [] | |||
| self._id = None | |||
| self._retry = None | |||
| def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: | |||
| for line in lines: | |||
| line = line.rstrip('\n') | |||
| if not line: | |||
| if self._event is None and \ | |||
| not self._data and \ | |||
| self._id is None and \ | |||
| self._retry is None: | |||
| continue | |||
| sse_event = Event( | |||
| event=self._event, | |||
| data='\n'.join(self._data), | |||
| id=self._id, | |||
| retry=self._retry | |||
| ) | |||
| self._event = None | |||
| self._data = [] | |||
| self._id = None | |||
| self._retry = None | |||
| yield sse_event | |||
| self.decode_line(line) | |||
| def decode_line(self, line: str): | |||
| if line.startswith(":") or not line: | |||
| return | |||
| field, _p, value = line.partition(":") | |||
| if value.startswith(' '): | |||
| value = value[1:] | |||
| if field == "data": | |||
| self._data.append(value) | |||
| elif field == "event": | |||
| self._event = value | |||
| elif field == "retry": | |||
| try: | |||
| self._retry = int(value) | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return | |||
| @@ -0,0 +1,18 @@ | |||
| from __future__ import annotations | |||
| from typing import Mapping, Iterable, TypeVar | |||
| from ._base_type import NotGiven | |||
| def remove_notgiven_indict(obj): | |||
| if obj is None or (not isinstance(obj, Mapping)): | |||
| return obj | |||
| return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} | |||
| _T = TypeVar("_T") | |||
| def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: | |||
| return [item for sublist in t for item in sublist] | |||
| @@ -0,0 +1,23 @@ | |||
| from typing import List, Optional | |||
| from pydantic import BaseModel | |||
| from .chat_completion import CompletionChoice, CompletionUsage | |||
| __all__ = ["AsyncTaskStatus"] | |||
| class AsyncTaskStatus(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| model: Optional[str] = None | |||
| task_status: Optional[str] = None | |||
| class AsyncCompletion(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| model: Optional[str] = None | |||
| task_status: str | |||
| choices: List[CompletionChoice] | |||
| usage: CompletionUsage | |||
| @@ -0,0 +1,45 @@ | |||
| from typing import List, Optional | |||
| from pydantic import BaseModel | |||
| __all__ = ["Completion", "CompletionUsage"] | |||
| class Function(BaseModel): | |||
| arguments: str | |||
| name: str | |||
| class CompletionMessageToolCall(BaseModel): | |||
| id: str | |||
| function: Function | |||
| type: str | |||
| class CompletionMessage(BaseModel): | |||
| content: Optional[str] = None | |||
| role: str | |||
| tool_calls: Optional[List[CompletionMessageToolCall]] = None | |||
| class CompletionUsage(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| total_tokens: int | |||
| class CompletionChoice(BaseModel): | |||
| index: int | |||
| finish_reason: str | |||
| message: CompletionMessage | |||
| class Completion(BaseModel): | |||
| model: Optional[str] = None | |||
| created: Optional[int] = None | |||
| choices: List[CompletionChoice] | |||
| request_id: Optional[str] = None | |||
| id: Optional[str] = None | |||
| usage: CompletionUsage | |||
| @@ -0,0 +1,55 @@ | |||
| from typing import List, Optional | |||
| from pydantic import BaseModel | |||
| __all__ = [ | |||
| "ChatCompletionChunk", | |||
| "Choice", | |||
| "ChoiceDelta", | |||
| "ChoiceDeltaFunctionCall", | |||
| "ChoiceDeltaToolCall", | |||
| "ChoiceDeltaToolCallFunction", | |||
| ] | |||
| class ChoiceDeltaFunctionCall(BaseModel): | |||
| arguments: Optional[str] = None | |||
| name: Optional[str] = None | |||
| class ChoiceDeltaToolCallFunction(BaseModel): | |||
| arguments: Optional[str] = None | |||
| name: Optional[str] = None | |||
| class ChoiceDeltaToolCall(BaseModel): | |||
| index: int | |||
| id: Optional[str] = None | |||
| function: Optional[ChoiceDeltaToolCallFunction] = None | |||
| type: Optional[str] = None | |||
| class ChoiceDelta(BaseModel): | |||
| content: Optional[str] = None | |||
| role: Optional[str] = None | |||
| tool_calls: Optional[List[ChoiceDeltaToolCall]] = None | |||
| class Choice(BaseModel): | |||
| delta: ChoiceDelta | |||
| finish_reason: Optional[str] = None | |||
| index: int | |||
| class CompletionUsage(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| total_tokens: int | |||
| class ChatCompletionChunk(BaseModel): | |||
| id: Optional[str] = None | |||
| choices: List[Choice] | |||
| created: Optional[int] = None | |||
| model: Optional[str] = None | |||
| usage: Optional[CompletionUsage] = None | |||
| @@ -0,0 +1,8 @@ | |||
| from typing import Optional | |||
| from typing_extensions import TypedDict | |||
| class Reference(TypedDict, total=False): | |||
| enable: Optional[bool] | |||
| search_query: Optional[str] | |||
| @@ -0,0 +1,20 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional, List | |||
| from pydantic import BaseModel | |||
| from .chat.chat_completion import CompletionUsage | |||
| __all__ = ["Embedding", "EmbeddingsResponded"] | |||
| class Embedding(BaseModel): | |||
| object: str | |||
| index: Optional[int] = None | |||
| embedding: List[float] | |||
| class EmbeddingsResponded(BaseModel): | |||
| object: str | |||
| data: List[Embedding] | |||
| model: str | |||
| usage: CompletionUsage | |||
| @@ -0,0 +1,24 @@ | |||
| from typing import Optional, List | |||
| from pydantic import BaseModel | |||
| __all__ = ["FileObject"] | |||
| class FileObject(BaseModel): | |||
| id: Optional[str] = None | |||
| bytes: Optional[int] = None | |||
| created_at: Optional[int] = None | |||
| filename: Optional[str] = None | |||
| object: Optional[str] = None | |||
| purpose: Optional[str] = None | |||
| status: Optional[str] = None | |||
| status_details: Optional[str] = None | |||
| class ListOfFileObject(BaseModel): | |||
| object: Optional[str] = None | |||
| data: List[FileObject] | |||
| has_more: Optional[bool] = None | |||
| @@ -0,0 +1,5 @@ | |||
| from __future__ import annotations | |||
| from .fine_tuning_job import FineTuningJob as FineTuningJob | |||
| from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob | |||
| from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent | |||
| @@ -0,0 +1,52 @@ | |||
| from typing import List, Union, Optional | |||
| from typing_extensions import Literal | |||
| from pydantic import BaseModel | |||
| __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] | |||
| class Error(BaseModel): | |||
| code: str | |||
| message: str | |||
| param: Optional[str] = None | |||
| class Hyperparameters(BaseModel): | |||
| n_epochs: Union[str, int, None] = None | |||
| class FineTuningJob(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| created_at: Optional[int] = None | |||
| error: Optional[Error] = None | |||
| fine_tuned_model: Optional[str] = None | |||
| finished_at: Optional[int] = None | |||
| hyperparameters: Optional[Hyperparameters] = None | |||
| model: Optional[str] = None | |||
| object: Optional[str] = None | |||
| result_files: List[str] | |||
| status: str | |||
| trained_tokens: Optional[int] = None | |||
| training_file: str | |||
| validation_file: Optional[str] = None | |||
| class ListOfFineTuningJob(BaseModel): | |||
| object: Optional[str] = None | |||
| data: List[FineTuningJob] | |||
| has_more: Optional[bool] = None | |||
| @@ -0,0 +1,36 @@ | |||
| from typing import List, Union, Optional | |||
| from typing_extensions import Literal | |||
| from pydantic import BaseModel | |||
| __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] | |||
| class Metric(BaseModel): | |||
| epoch: Optional[Union[str, int, float]] = None | |||
| current_steps: Optional[int] = None | |||
| total_steps: Optional[int] = None | |||
| elapsed_time: Optional[str] = None | |||
| remaining_time: Optional[str] = None | |||
| trained_tokens: Optional[int] = None | |||
| loss: Optional[Union[str, int, float]] = None | |||
| eval_loss: Optional[Union[str, int, float]] = None | |||
| acc: Optional[Union[str, int, float]] = None | |||
| eval_acc: Optional[Union[str, int, float]] = None | |||
| learning_rate: Optional[Union[str, int, float]] = None | |||
| class JobEvent(BaseModel): | |||
| object: Optional[str] = None | |||
| id: Optional[str] = None | |||
| type: Optional[str] = None | |||
| created_at: Optional[int] = None | |||
| level: Optional[str] = None | |||
| message: Optional[str] = None | |||
| data: Optional[Metric] = None | |||
| class FineTuningJobEvent(BaseModel): | |||
| object: Optional[str] = None | |||
| data: List[JobEvent] | |||
| has_more: Optional[bool] = None | |||
| @@ -0,0 +1,15 @@ | |||
| from __future__ import annotations | |||
| from typing import Union | |||
| from typing_extensions import Literal, TypedDict | |||
| __all__ = ["Hyperparameters"] | |||
| class Hyperparameters(TypedDict, total=False): | |||
| batch_size: Union[Literal["auto"], int] | |||
| learning_rate_multiplier: Union[Literal["auto"], float] | |||
| n_epochs: Union[Literal["auto"], int] | |||
| @@ -0,0 +1,18 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional, List | |||
| from pydantic import BaseModel | |||
| __all__ = ["GeneratedImage", "ImagesResponded"] | |||
| class GeneratedImage(BaseModel): | |||
| b64_json: Optional[str] = None | |||
| url: Optional[str] = None | |||
| revised_prompt: Optional[str] = None | |||
| class ImagesResponded(BaseModel): | |||
| created: int | |||
| data: List[GeneratedImage] | |||
| @@ -3,7 +3,8 @@ from typing import Generator | |||
| import pytest | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, | |||
| UserPromptMessage, PromptMessageTool) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel | |||
| @@ -102,3 +103,48 @@ def test_get_num_tokens(): | |||
| ) | |||
| assert num_tokens == 14 | |||
| def test_get_tools_num_tokens(): | |||
| model = ZhipuAILargeLanguageModel() | |||
| num_tokens = model.get_num_tokens( | |||
| model='tools', | |||
| credentials={ | |||
| 'api_key': os.environ.get('ZHIPUAI_API_KEY') | |||
| }, | |||
| tools=[ | |||
| PromptMessageTool( | |||
| name='get_current_weather', | |||
| description='Get the current weather in a given location', | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": { | |||
| "location": { | |||
| "type": "string", | |||
| "description": "The city and state e.g. San Francisco, CA" | |||
| }, | |||
| "unit": { | |||
| "type": "string", | |||
| "enum": [ | |||
| "c", | |||
| "f" | |||
| ] | |||
| } | |||
| }, | |||
| "required": [ | |||
| "location" | |||
| ] | |||
| } | |||
| ) | |||
| ], | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ] | |||
| ) | |||
| assert num_tokens == 108 | |||
| @@ -42,7 +42,7 @@ def test_invoke_model(): | |||
| assert isinstance(result, TextEmbeddingResult) | |||
| assert len(result.embeddings) == 2 | |||
| assert result.usage.total_tokens == 2 | |||
| assert result.usage.total_tokens > 0 | |||
| def test_get_num_tokens(): | |||
| @@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({ | |||
| <Thought | |||
| thought={item} | |||
| allToolIcons={allToolIcons || {}} | |||
| isFinished={!!item.observation} | |||
| isFinished={!!item.observation || !isResponsing} | |||
| /> | |||
| )} | |||
| @@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets' | |||
| import { useProviderContext } from '@/context/provider-context' | |||
| import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app' | |||
| import { PromptMode } from '@/models/debug' | |||
| import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' | |||
| import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config' | |||
| import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset' | |||
| import I18n from '@/context/i18n' | |||
| import { useModalContext } from '@/context/modal-context' | |||
| @@ -163,8 +163,7 @@ const Configuration: FC = () => { | |||
| doSetModelConfig(newModelConfig) | |||
| } | |||
| const isOpenAI = modelConfig.provider === 'openai' | |||
| const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat | |||
| const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id) | |||
| const [collectionList, setCollectionList] = useState<Collection[]>([]) | |||
| useEffect(() => { | |||
| @@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = { | |||
| tools: [], | |||
| } | |||
| export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4'] | |||
| export const DEFAULT_AGENT_PROMPT = { | |||
| chat: `Respond to the human as helpfully and accurately as possible. | |||