Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.5.1
| """Wrapper around ZhipuAI APIs.""" | |||||
| from __future__ import annotations | |||||
| import logging | |||||
| import posixpath | |||||
| from pydantic import BaseModel, Extra | |||||
| from zhipuai.model_api.api import InvokeType | |||||
| from zhipuai.utils import jwt_token | |||||
| from zhipuai.utils.http_client import post, stream | |||||
| from zhipuai.utils.sse_client import SSEClient | |||||
| logger = logging.getLogger(__name__) | |||||
| class ZhipuModelAPI(BaseModel): | |||||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||||
| api_key: str | |||||
| api_timeout_seconds = 60 | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| def invoke(self, **kwargs): | |||||
| url = self._build_api_url(kwargs, InvokeType.SYNC) | |||||
| response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||||
| if not response['success']: | |||||
| raise ValueError( | |||||
| f"Error Code: {response['code']}, Message: {response['msg']} " | |||||
| ) | |||||
| return response | |||||
| def sse_invoke(self, **kwargs): | |||||
| url = self._build_api_url(kwargs, InvokeType.SSE) | |||||
| data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||||
| return SSEClient(data) | |||||
| def _build_api_url(self, kwargs, *path): | |||||
| if kwargs: | |||||
| if "model" not in kwargs: | |||||
| raise Exception("model param missed") | |||||
| model = kwargs.pop("model") | |||||
| else: | |||||
| model = "-" | |||||
| return posixpath.join(self.base_url, model, *path) | |||||
| def _generate_token(self): | |||||
| if not self.api_key: | |||||
| raise Exception( | |||||
| "api_key not provided, you could provide it." | |||||
| ) | |||||
| try: | |||||
| return jwt_token.generate_token(self.api_key) | |||||
| except Exception: | |||||
| raise ValueError( | |||||
| f"Your api_key is invalid, please check it." | |||||
| ) |
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | ||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, | from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, | ||||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, | |||||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage, | |||||
| TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType) | TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.utils import helper | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI | |||||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | ||||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk | |||||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion | |||||
| class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | ||||
| credentials_kwargs = self._to_credential_kwargs(credentials) | credentials_kwargs = self._to_credential_kwargs(credentials) | ||||
| # invoke model | # invoke model | ||||
| return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user) | |||||
| return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) | |||||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | ||||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | tools: Optional[list[PromptMessageTool]] = None) -> int: | ||||
| :param tools: tools for tool calling | :param tools: tools for tool calling | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| prompt = self._convert_messages_to_prompt(prompt_messages) | |||||
| prompt = self._convert_messages_to_prompt(prompt_messages, tools) | |||||
| return self._get_num_tokens_by_gpt2(prompt) | return self._get_num_tokens_by_gpt2(prompt) | ||||
| model_parameters={ | model_parameters={ | ||||
| "temperature": 0.5, | "temperature": 0.5, | ||||
| }, | }, | ||||
| tools=[], | |||||
| stream=False | stream=False | ||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| def _generate(self, model: str, credentials_kwargs: dict, | def _generate(self, model: str, credentials_kwargs: dict, | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | |||||
| stop: Optional[List[str]] = None, stream: bool = True, | stop: Optional[List[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None) -> Union[LLMResult, Generator]: | user: Optional[str] = None) -> Union[LLMResult, Generator]: | ||||
| """ | """ | ||||
| if stop: | if stop: | ||||
| extra_model_kwargs['stop_sequences'] = stop | extra_model_kwargs['stop_sequences'] = stop | ||||
| client = ZhipuModelAPI( | |||||
| client = ZhipuAI( | |||||
| api_key=credentials_kwargs['api_key'] | api_key=credentials_kwargs['api_key'] | ||||
| ) | ) | ||||
| # not support image message | # not support image message | ||||
| continue | continue | ||||
| if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER: | |||||
| if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ | |||||
| copy_prompt_message.role == PromptMessageRole.USER: | |||||
| new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | ||||
| else: | else: | ||||
| if copy_prompt_message.role == PromptMessageRole.USER: | if copy_prompt_message.role == PromptMessageRole.USER: | ||||
| new_prompt_messages.append(copy_prompt_message) | new_prompt_messages.append(copy_prompt_message) | ||||
| elif copy_prompt_message.role == PromptMessageRole.TOOL: | |||||
| new_prompt_messages.append(copy_prompt_message) | |||||
| elif copy_prompt_message.role == PromptMessageRole.SYSTEM: | |||||
| new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) | |||||
| new_prompt_messages.append(new_prompt_message) | |||||
| else: | else: | ||||
| new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) | new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) | ||||
| new_prompt_messages.append(new_prompt_message) | new_prompt_messages.append(new_prompt_message) | ||||
| if model == 'glm-4v': | if model == 'glm-4v': | ||||
| params = { | params = { | ||||
| 'model': model, | 'model': model, | ||||
| 'prompt': [{ | |||||
| 'messages': [{ | |||||
| 'role': prompt_message.role.value, | 'role': prompt_message.role.value, | ||||
| 'content': | 'content': | ||||
| [ | [ | ||||
| else: | else: | ||||
| params = { | params = { | ||||
| 'model': model, | 'model': model, | ||||
| 'prompt': [{ | |||||
| 'role': prompt_message.role.value, | |||||
| 'content': prompt_message.content, | |||||
| } for prompt_message in new_prompt_messages], | |||||
| 'messages': [], | |||||
| **model_parameters | **model_parameters | ||||
| } | } | ||||
| # glm model | |||||
| if not model.startswith('chatglm'): | |||||
| for prompt_message in new_prompt_messages: | |||||
| if prompt_message.role == PromptMessageRole.TOOL: | |||||
| params['messages'].append({ | |||||
| 'role': 'tool', | |||||
| 'content': prompt_message.content, | |||||
| 'tool_call_id': prompt_message.tool_call_id | |||||
| }) | |||||
| else: | |||||
| params['messages'].append({ | |||||
| 'role': prompt_message.role.value, | |||||
| 'content': prompt_message.content | |||||
| }) | |||||
| else: | |||||
| # chatglm model | |||||
| for prompt_message in new_prompt_messages: | |||||
| # merge system message to user message | |||||
| if prompt_message.role == PromptMessageRole.SYSTEM or \ | |||||
| prompt_message.role == PromptMessageRole.TOOL or \ | |||||
| prompt_message.role == PromptMessageRole.USER: | |||||
| if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': | |||||
| params['messages'][-1]['content'] += "\n\n" + prompt_message.content | |||||
| else: | |||||
| params['messages'].append({ | |||||
| 'role': 'user', | |||||
| 'content': prompt_message.content | |||||
| }) | |||||
| else: | |||||
| params['messages'].append({ | |||||
| 'role': prompt_message.role.value, | |||||
| 'content': prompt_message.content | |||||
| }) | |||||
| if tools and len(tools) > 0: | |||||
| params['tools'] = [ | |||||
| { | |||||
| 'type': 'function', | |||||
| 'function': helper.dump_model(tool) | |||||
| } for tool in tools | |||||
| ] | |||||
| if stream: | if stream: | ||||
| response = client.sse_invoke(incremental=True, **params).events() | |||||
| return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages) | |||||
| response = client.chat.completions.create(stream=stream, **params) | |||||
| return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) | |||||
| response = client.invoke(**params) | |||||
| return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages) | |||||
| response = client.chat.completions.create(**params) | |||||
| return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) | |||||
| def _handle_generate_response(self, model: str, | def _handle_generate_response(self, model: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| response: Dict[str, Any], | |||||
| tools: Optional[list[PromptMessageTool]], | |||||
| response: Completion, | |||||
| prompt_messages: list[PromptMessage]) -> LLMResult: | prompt_messages: list[PromptMessage]) -> LLMResult: | ||||
| """ | """ | ||||
| Handle llm response | Handle llm response | ||||
| :param prompt_messages: prompt messages | :param prompt_messages: prompt messages | ||||
| :return: llm response | :return: llm response | ||||
| """ | """ | ||||
| data = response["data"] | |||||
| text = '' | text = '' | ||||
| for res in data["choices"]: | |||||
| text += res['content'] | |||||
| assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] | |||||
| for choice in response.choices: | |||||
| if choice.message.tool_calls: | |||||
| for tool_call in choice.message.tool_calls: | |||||
| if tool_call.type == 'function': | |||||
| assistant_tool_calls.append( | |||||
| AssistantPromptMessage.ToolCall( | |||||
| id=tool_call.id, | |||||
| type=tool_call.type, | |||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||||
| name=tool_call.function.name, | |||||
| arguments=tool_call.function.arguments, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| text += choice.message.content or '' | |||||
| token_usage = data.get("usage") | |||||
| if token_usage is not None: | |||||
| if 'prompt_tokens' not in token_usage: | |||||
| token_usage['prompt_tokens'] = 0 | |||||
| if 'completion_tokens' not in token_usage: | |||||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||||
| prompt_usage = response.usage.prompt_tokens | |||||
| completion_usage = response.usage.completion_tokens | |||||
| # transform usage | # transform usage | ||||
| usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) | |||||
| usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage) | |||||
| # transform response | # transform response | ||||
| result = LLMResult( | result = LLMResult( | ||||
| model=model, | model=model, | ||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| message=AssistantPromptMessage(content=text), | |||||
| message=AssistantPromptMessage( | |||||
| content=text, | |||||
| tool_calls=assistant_tool_calls | |||||
| ), | |||||
| usage=usage, | usage=usage, | ||||
| ) | ) | ||||
| def _handle_generate_stream_response(self, model: str, | def _handle_generate_stream_response(self, model: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| responses: list[Generator], | |||||
| tools: Optional[list[PromptMessageTool]], | |||||
| responses: Generator[ChatCompletionChunk, None, None], | |||||
| prompt_messages: list[PromptMessage]) -> Generator: | prompt_messages: list[PromptMessage]) -> Generator: | ||||
| """ | """ | ||||
| Handle llm stream response | Handle llm stream response | ||||
| :param prompt_messages: prompt messages | :param prompt_messages: prompt messages | ||||
| :return: llm response chunk generator result | :return: llm response chunk generator result | ||||
| """ | """ | ||||
| for index, event in enumerate(responses): | |||||
| if event.event == "add": | |||||
| full_assistant_content = '' | |||||
| for chunk in responses: | |||||
| if len(chunk.choices) == 0: | |||||
| continue | |||||
| delta = chunk.choices[0] | |||||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): | |||||
| continue | |||||
| assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] | |||||
| for tool_call in delta.delta.tool_calls or []: | |||||
| if tool_call.type == 'function': | |||||
| assistant_tool_calls.append( | |||||
| AssistantPromptMessage.ToolCall( | |||||
| id=tool_call.id, | |||||
| type=tool_call.type, | |||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||||
| name=tool_call.function.name, | |||||
| arguments=tool_call.function.arguments, | |||||
| ) | |||||
| ) | |||||
| ) | |||||
| # transform assistant message to prompt message | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=delta.delta.content if delta.delta.content else '', | |||||
| tool_calls=assistant_tool_calls | |||||
| ) | |||||
| full_assistant_content += delta.delta.content if delta.delta.content else '' | |||||
| if delta.finish_reason is not None and chunk.usage is not None: | |||||
| completion_tokens = chunk.usage.completion_tokens | |||||
| prompt_tokens = chunk.usage.prompt_tokens | |||||
| # transform usage | |||||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||||
| yield LLMResultChunk( | yield LLMResultChunk( | ||||
| model=chunk.model, | |||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| model=model, | |||||
| system_fingerprint='', | |||||
| delta=LLMResultChunkDelta( | delta=LLMResultChunkDelta( | ||||
| index=index, | |||||
| message=AssistantPromptMessage(content=event.data) | |||||
| index=delta.index, | |||||
| message=assistant_prompt_message, | |||||
| finish_reason=delta.finish_reason, | |||||
| usage=usage | |||||
| ) | ) | ||||
| ) | ) | ||||
| elif event.event == "error" or event.event == "interrupted": | |||||
| raise ValueError( | |||||
| f"{event.data}" | |||||
| ) | |||||
| elif event.event == "finish": | |||||
| meta = json.loads(event.meta) | |||||
| token_usage = meta['usage'] | |||||
| if token_usage is not None: | |||||
| if 'prompt_tokens' not in token_usage: | |||||
| token_usage['prompt_tokens'] = 0 | |||||
| if 'completion_tokens' not in token_usage: | |||||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||||
| usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) | |||||
| else: | |||||
| yield LLMResultChunk( | yield LLMResultChunk( | ||||
| model=model, | |||||
| model=chunk.model, | |||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| system_fingerprint='', | |||||
| delta=LLMResultChunkDelta( | delta=LLMResultChunkDelta( | ||||
| index=index, | |||||
| message=AssistantPromptMessage(content=event.data), | |||||
| finish_reason='finish', | |||||
| usage=usage | |||||
| index=delta.index, | |||||
| message=assistant_prompt_message, | |||||
| ) | ) | ||||
| ) | ) | ||||
| raise ValueError(f"Got unknown type {message}") | raise ValueError(f"Got unknown type {message}") | ||||
| return message_text | return message_text | ||||
| def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: | |||||
| """ | |||||
| Format a list of messages into a full prompt for the Anthropic model | |||||
| def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: | |||||
| """ | |||||
| :param messages: List of PromptMessage to combine. | :param messages: List of PromptMessage to combine. | ||||
| :return: Combined string with necessary human_prompt and ai_prompt tags. | :return: Combined string with necessary human_prompt and ai_prompt tags. | ||||
| """ | """ | ||||
| for message in messages | for message in messages | ||||
| ) | ) | ||||
| if tools and len(tools) > 0: | |||||
| text += "\n\nTools:" | |||||
| for tool in tools: | |||||
| text += f"\n{tool.json()}" | |||||
| # trim off the trailing ' ' that might come from the "Assistant: " | # trim off the trailing ' ' that might come from the "Assistant: " | ||||
| return text.rstrip() | |||||
| return text.rstrip() |
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||||
| from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI | |||||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | ||||
| from langchain.schema.language_model import _get_token_ids_default_method | from langchain.schema.language_model import _get_token_ids_default_method | ||||
| :return: embeddings result | :return: embeddings result | ||||
| """ | """ | ||||
| credentials_kwargs = self._to_credential_kwargs(credentials) | credentials_kwargs = self._to_credential_kwargs(credentials) | ||||
| client = ZhipuModelAPI( | |||||
| client = ZhipuAI( | |||||
| api_key=credentials_kwargs['api_key'] | api_key=credentials_kwargs['api_key'] | ||||
| ) | ) | ||||
| try: | try: | ||||
| # transform credentials to kwargs for model instance | # transform credentials to kwargs for model instance | ||||
| credentials_kwargs = self._to_credential_kwargs(credentials) | credentials_kwargs = self._to_credential_kwargs(credentials) | ||||
| client = ZhipuModelAPI( | |||||
| client = ZhipuAI( | |||||
| api_key=credentials_kwargs['api_key'] | api_key=credentials_kwargs['api_key'] | ||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| raise CredentialsValidateFailedError(str(ex)) | raise CredentialsValidateFailedError(str(ex)) | ||||
| def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]: | |||||
| def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: | |||||
| """Call out to ZhipuAI's embedding endpoint. | """Call out to ZhipuAI's embedding endpoint. | ||||
| Args: | Args: | ||||
| Returns: | Returns: | ||||
| List of embeddings, one for each text. | List of embeddings, one for each text. | ||||
| """ | """ | ||||
| embeddings = [] | embeddings = [] | ||||
| for text in texts: | |||||
| response = client.invoke(model=model, prompt=text) | |||||
| data = response["data"] | |||||
| embeddings.append(data.get('embedding')) | |||||
| embedding_used_tokens = 0 | |||||
| embedding_used_tokens = data.get('usage') | |||||
| for text in texts: | |||||
| response = client.embeddings.create(model=model, input=text) | |||||
| data = response.data[0] | |||||
| embeddings.append(data.embedding) | |||||
| embedding_used_tokens += response.usage.total_tokens | |||||
| return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0 | |||||
| return [list(map(float, e)) for e in embeddings], embedding_used_tokens | |||||
| def embed_query(self, text: str) -> List[float]: | def embed_query(self, text: str) -> List[float]: | ||||
| """Call out to ZhipuAI's embedding endpoint. | """Call out to ZhipuAI's embedding endpoint. |
| from ._client import ZhipuAI | |||||
| from .core._errors import ( | |||||
| ZhipuAIError, | |||||
| APIStatusError, | |||||
| APIRequestFailedError, | |||||
| APIAuthenticationError, | |||||
| APIReachLimitError, | |||||
| APIInternalError, | |||||
| APIServerFlowExceedError, | |||||
| APIResponseError, | |||||
| APIResponseValidationError, | |||||
| APITimeoutError, | |||||
| ) | |||||
| from .__version__ import __version__ |
| __version__ = 'v2.0.1' |
| from __future__ import annotations | |||||
| from typing import Union, Mapping | |||||
| from typing_extensions import override | |||||
| from .core import _jwt_token | |||||
| from .core._errors import ZhipuAIError | |||||
| from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES | |||||
| from .core._base_type import NotGiven, NOT_GIVEN | |||||
| from . import api_resource | |||||
| import os | |||||
| import httpx | |||||
| from httpx import Timeout | |||||
| class ZhipuAI(HttpClient): | |||||
| chat: api_resource.chat | |||||
| api_key: str | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| api_key: str | None = None, | |||||
| base_url: str | httpx.URL | None = None, | |||||
| timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, | |||||
| max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, | |||||
| http_client: httpx.Client | None = None, | |||||
| custom_headers: Mapping[str, str] | None = None | |||||
| ) -> None: | |||||
| # if api_key is None: | |||||
| # api_key = os.environ.get("ZHIPUAI_API_KEY") | |||||
| if api_key is None: | |||||
| raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") | |||||
| self.api_key = api_key | |||||
| if base_url is None: | |||||
| base_url = os.environ.get("ZHIPUAI_BASE_URL") | |||||
| if base_url is None: | |||||
| base_url = f"https://open.bigmodel.cn/api/paas/v4" | |||||
| from .__version__ import __version__ | |||||
| super().__init__( | |||||
| version=__version__, | |||||
| base_url=base_url, | |||||
| timeout=timeout, | |||||
| custom_httpx_client=http_client, | |||||
| custom_headers=custom_headers, | |||||
| ) | |||||
| self.chat = api_resource.chat.Chat(self) | |||||
| self.images = api_resource.images.Images(self) | |||||
| self.embeddings = api_resource.embeddings.Embeddings(self) | |||||
| self.files = api_resource.files.Files(self) | |||||
| self.fine_tuning = api_resource.fine_tuning.FineTuning(self) | |||||
| @property | |||||
| @override | |||||
| def _auth_headers(self) -> dict[str, str]: | |||||
| api_key = self.api_key | |||||
| return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} | |||||
| def __del__(self) -> None: | |||||
| if (not hasattr(self, "_has_custom_http_client") | |||||
| or not hasattr(self, "close") | |||||
| or not hasattr(self, "_client")): | |||||
| # if the '__init__' method raised an error, self would not have client attr | |||||
| return | |||||
| if self._has_custom_http_client: | |||||
| return | |||||
| self.close() |
| from .chat import chat | |||||
| from .images import Images | |||||
| from .embeddings import Embeddings | |||||
| from .files import Files | |||||
| from .fine_tuning import fine_tuning |
| from __future__ import annotations | |||||
| from typing import Union, List, Optional, TYPE_CHECKING | |||||
| import httpx | |||||
| from typing_extensions import Literal | |||||
| from ...core._base_api import BaseAPI | |||||
| from ...core._base_type import NotGiven, NOT_GIVEN, Headers | |||||
| from ...core._http_client import make_user_request_input | |||||
| from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion | |||||
| if TYPE_CHECKING: | |||||
| from ..._client import ZhipuAI | |||||
| class AsyncCompletions(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def create( | |||||
| self, | |||||
| *, | |||||
| model: str, | |||||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||||
| seed: int | NotGiven = NOT_GIVEN, | |||||
| messages: Union[str, List[str], List[int], List[List[int]], None], | |||||
| stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, | |||||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| disable_strict_validation: Optional[bool] | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> AsyncTaskStatus: | |||||
| _cast_type = AsyncTaskStatus | |||||
| if disable_strict_validation: | |||||
| _cast_type = object | |||||
| return self._post( | |||||
| "/async/chat/completions", | |||||
| body={ | |||||
| "model": model, | |||||
| "request_id": request_id, | |||||
| "temperature": temperature, | |||||
| "top_p": top_p, | |||||
| "do_sample": do_sample, | |||||
| "max_tokens": max_tokens, | |||||
| "seed": seed, | |||||
| "messages": messages, | |||||
| "stop": stop, | |||||
| "sensitive_word_check": sensitive_word_check, | |||||
| "tools": tools, | |||||
| "tool_choice": tool_choice, | |||||
| }, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=_cast_type, | |||||
| enable_stream=False, | |||||
| ) | |||||
| def retrieve_completion_result( | |||||
| self, | |||||
| id: str, | |||||
| extra_headers: Headers | None = None, | |||||
| disable_strict_validation: Optional[bool] | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> Union[AsyncCompletion, AsyncTaskStatus]: | |||||
| _cast_type = Union[AsyncCompletion,AsyncTaskStatus] | |||||
| if disable_strict_validation: | |||||
| _cast_type = object | |||||
| return self._get( | |||||
| path=f"/async-result/{id}", | |||||
| cast_type=_cast_type, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, | |||||
| timeout=timeout | |||||
| ) | |||||
| ) | |||||
| from typing import TYPE_CHECKING | |||||
| from .completions import Completions | |||||
| from .async_completions import AsyncCompletions | |||||
| from ...core._base_api import BaseAPI | |||||
| if TYPE_CHECKING: | |||||
| from ..._client import ZhipuAI | |||||
| class Chat(BaseAPI): | |||||
| completions: Completions | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| self.completions = Completions(client) | |||||
| self.asyncCompletions = AsyncCompletions(client) |
| from __future__ import annotations | |||||
| from typing import Union, List, Optional, TYPE_CHECKING | |||||
| import httpx | |||||
| from typing_extensions import Literal | |||||
| from ...core._base_api import BaseAPI | |||||
| from ...core._base_type import NotGiven, NOT_GIVEN, Headers | |||||
| from ...core._http_client import make_user_request_input | |||||
| from ...core._sse_client import StreamResponse | |||||
| from ...types.chat.chat_completion import Completion | |||||
| from ...types.chat.chat_completion_chunk import ChatCompletionChunk | |||||
| if TYPE_CHECKING: | |||||
| from ..._client import ZhipuAI | |||||
| class Completions(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def create( | |||||
| self, | |||||
| *, | |||||
| model: str, | |||||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||||
| stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||||
| seed: int | NotGiven = NOT_GIVEN, | |||||
| messages: Union[str, List[str], List[int], object, None], | |||||
| stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, | |||||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| disable_strict_validation: Optional[bool] | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> Completion | StreamResponse[ChatCompletionChunk]: | |||||
| _cast_type = Completion | |||||
| _stream_cls = StreamResponse[ChatCompletionChunk] | |||||
| if disable_strict_validation: | |||||
| _cast_type = object | |||||
| _stream_cls = StreamResponse[object] | |||||
| return self._post( | |||||
| "/chat/completions", | |||||
| body={ | |||||
| "model": model, | |||||
| "request_id": request_id, | |||||
| "temperature": temperature, | |||||
| "top_p": top_p, | |||||
| "do_sample": do_sample, | |||||
| "max_tokens": max_tokens, | |||||
| "seed": seed, | |||||
| "messages": messages, | |||||
| "stop": stop, | |||||
| "sensitive_word_check": sensitive_word_check, | |||||
| "stream": stream, | |||||
| "tools": tools, | |||||
| "tool_choice": tool_choice, | |||||
| }, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, | |||||
| ), | |||||
| cast_type=_cast_type, | |||||
| enable_stream=stream or False, | |||||
| stream_cls=_stream_cls, | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import Union, List, Optional, TYPE_CHECKING | |||||
| import httpx | |||||
| from ..core._base_api import BaseAPI | |||||
| from ..core._base_type import NotGiven, NOT_GIVEN, Headers | |||||
| from ..core._http_client import make_user_request_input | |||||
| from ..types.embeddings import EmbeddingsResponded | |||||
| if TYPE_CHECKING: | |||||
| from .._client import ZhipuAI | |||||
| class Embeddings(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def create( | |||||
| self, | |||||
| *, | |||||
| input: Union[str, List[str], List[int], List[List[int]]], | |||||
| model: Union[str], | |||||
| encoding_format: str | NotGiven = NOT_GIVEN, | |||||
| user: str | NotGiven = NOT_GIVEN, | |||||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| disable_strict_validation: Optional[bool] | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> EmbeddingsResponded: | |||||
| _cast_type = EmbeddingsResponded | |||||
| if disable_strict_validation: | |||||
| _cast_type = object | |||||
| return self._post( | |||||
| "/embeddings", | |||||
| body={ | |||||
| "input": input, | |||||
| "model": model, | |||||
| "encoding_format": encoding_format, | |||||
| "user": user, | |||||
| "sensitive_word_check": sensitive_word_check, | |||||
| }, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=_cast_type, | |||||
| enable_stream=False, | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import TYPE_CHECKING | |||||
| import httpx | |||||
| from ..core._base_api import BaseAPI | |||||
| from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes | |||||
| from ..core._files import is_file_content | |||||
| from ..core._http_client import ( | |||||
| make_user_request_input, | |||||
| ) | |||||
| from ..types.file_object import FileObject, ListOfFileObject | |||||
| if TYPE_CHECKING: | |||||
| from .._client import ZhipuAI | |||||
| __all__ = ["Files"] | |||||
| class Files(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def create( | |||||
| self, | |||||
| *, | |||||
| file: FileTypes, | |||||
| purpose: str, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> FileObject: | |||||
| if not is_file_content(file): | |||||
| prefix = f"Expected file input `{file!r}`" | |||||
| raise RuntimeError( | |||||
| f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." | |||||
| ) from None | |||||
| files = [("file", file)] | |||||
| extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} | |||||
| return self._post( | |||||
| "/files", | |||||
| body={ | |||||
| "purpose": purpose, | |||||
| }, | |||||
| files=files, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=FileObject, | |||||
| ) | |||||
| def list( | |||||
| self, | |||||
| *, | |||||
| purpose: str | NotGiven = NOT_GIVEN, | |||||
| limit: int | NotGiven = NOT_GIVEN, | |||||
| after: str | NotGiven = NOT_GIVEN, | |||||
| order: str | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> ListOfFileObject: | |||||
| return self._get( | |||||
| "/files", | |||||
| cast_type=ListOfFileObject, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, | |||||
| timeout=timeout, | |||||
| query={ | |||||
| "purpose": purpose, | |||||
| "limit": limit, | |||||
| "after": after, | |||||
| "order": order, | |||||
| }, | |||||
| ), | |||||
| ) |
| from typing import TYPE_CHECKING | |||||
| from .jobs import Jobs | |||||
| from ...core._base_api import BaseAPI | |||||
| if TYPE_CHECKING: | |||||
| from ..._client import ZhipuAI | |||||
| class FineTuning(BaseAPI): | |||||
| jobs: Jobs | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| self.jobs = Jobs(client) | |||||
| from __future__ import annotations | |||||
| from typing import Optional, TYPE_CHECKING | |||||
| import httpx | |||||
| from ...core._base_api import BaseAPI | |||||
| from ...core._base_type import NOT_GIVEN, Headers, NotGiven | |||||
| from ...core._http_client import ( | |||||
| make_user_request_input, | |||||
| ) | |||||
| from ...types.fine_tuning import ( | |||||
| FineTuningJob, | |||||
| job_create_params, | |||||
| ListOfFineTuningJob, | |||||
| FineTuningJobEvent, | |||||
| ) | |||||
| if TYPE_CHECKING: | |||||
| from ..._client import ZhipuAI | |||||
| __all__ = ["Jobs"] | |||||
| class Jobs(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def create( | |||||
| self, | |||||
| *, | |||||
| model: str, | |||||
| training_file: str, | |||||
| hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, | |||||
| suffix: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| validation_file: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> FineTuningJob: | |||||
| return self._post( | |||||
| "/fine_tuning/jobs", | |||||
| body={ | |||||
| "model": model, | |||||
| "training_file": training_file, | |||||
| "hyperparameters": hyperparameters, | |||||
| "suffix": suffix, | |||||
| "validation_file": validation_file, | |||||
| "request_id": request_id, | |||||
| }, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=FineTuningJob, | |||||
| ) | |||||
| def retrieve( | |||||
| self, | |||||
| fine_tuning_job_id: str, | |||||
| *, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> FineTuningJob: | |||||
| return self._get( | |||||
| f"/fine_tuning/jobs/{fine_tuning_job_id}", | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=FineTuningJob, | |||||
| ) | |||||
| def list( | |||||
| self, | |||||
| *, | |||||
| after: str | NotGiven = NOT_GIVEN, | |||||
| limit: int | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> ListOfFineTuningJob: | |||||
| return self._get( | |||||
| "/fine_tuning/jobs", | |||||
| cast_type=ListOfFineTuningJob, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, | |||||
| timeout=timeout, | |||||
| query={ | |||||
| "after": after, | |||||
| "limit": limit, | |||||
| }, | |||||
| ), | |||||
| ) | |||||
| def list_events( | |||||
| self, | |||||
| fine_tuning_job_id: str, | |||||
| *, | |||||
| after: str | NotGiven = NOT_GIVEN, | |||||
| limit: int | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> FineTuningJobEvent: | |||||
| return self._get( | |||||
| f"/fine_tuning/jobs/{fine_tuning_job_id}/events", | |||||
| cast_type=FineTuningJobEvent, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, | |||||
| timeout=timeout, | |||||
| query={ | |||||
| "after": after, | |||||
| "limit": limit, | |||||
| }, | |||||
| ), | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import Union, List, Optional, TYPE_CHECKING | |||||
| import httpx | |||||
| from ..core._base_api import BaseAPI | |||||
| from ..core._base_type import NotGiven, NOT_GIVEN, Headers | |||||
| from ..core._http_client import make_user_request_input | |||||
| from ..types.image import ImagesResponded | |||||
| if TYPE_CHECKING: | |||||
| from .._client import ZhipuAI | |||||
| class Images(BaseAPI): | |||||
| def __init__(self, client: "ZhipuAI") -> None: | |||||
| super().__init__(client) | |||||
| def generations( | |||||
| self, | |||||
| *, | |||||
| prompt: str, | |||||
| model: str | NotGiven = NOT_GIVEN, | |||||
| n: Optional[int] | NotGiven = NOT_GIVEN, | |||||
| quality: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| response_format: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| size: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| style: Optional[str] | NotGiven = NOT_GIVEN, | |||||
| user: str | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers | None = None, | |||||
| disable_strict_validation: Optional[bool] | None = None, | |||||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||||
| ) -> ImagesResponded: | |||||
| _cast_type = ImagesResponded | |||||
| if disable_strict_validation: | |||||
| _cast_type = object | |||||
| return self._post( | |||||
| "/images/generations", | |||||
| body={ | |||||
| "prompt": prompt, | |||||
| "model": model, | |||||
| "n": n, | |||||
| "quality": quality, | |||||
| "response_format": response_format, | |||||
| "size": size, | |||||
| "style": style, | |||||
| "user": user, | |||||
| }, | |||||
| options=make_user_request_input( | |||||
| extra_headers=extra_headers, timeout=timeout | |||||
| ), | |||||
| cast_type=_cast_type, | |||||
| enable_stream=False, | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import TYPE_CHECKING | |||||
| if TYPE_CHECKING: | |||||
| from .._client import ZhipuAI | |||||
| class BaseAPI: | |||||
| _client: ZhipuAI | |||||
| def __init__(self, client: ZhipuAI) -> None: | |||||
| self._client = client | |||||
| self._delete = client.delete | |||||
| self._get = client.get | |||||
| self._post = client.post | |||||
| self._put = client.put | |||||
| self._patch = client.patch |
| from __future__ import annotations | |||||
| from os import PathLike | |||||
| from typing import ( | |||||
| TYPE_CHECKING, | |||||
| Type, | |||||
| Union, | |||||
| Mapping, | |||||
| TypeVar, IO, Tuple, Sequence, Any, List, | |||||
| ) | |||||
| import pydantic | |||||
| from typing_extensions import ( | |||||
| Literal, | |||||
| override, | |||||
| ) | |||||
| Query = Mapping[str, object] | |||||
| Body = object | |||||
| AnyMapping = Mapping[str, object] | |||||
| PrimitiveData = Union[str, int, float, bool, None] | |||||
| Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] | |||||
| ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) | |||||
| _T = TypeVar("_T") | |||||
| if TYPE_CHECKING: | |||||
| NoneType: Type[None] | |||||
| else: | |||||
| NoneType = type(None) | |||||
| # Sentinel class used until PEP 0661 is accepted | |||||
| class NotGiven(pydantic.BaseModel): | |||||
| """ | |||||
| A sentinel singleton class used to distinguish omitted keyword arguments | |||||
| from those passed in with the value None (which may have different behavior). | |||||
| For example: | |||||
| ```py | |||||
| def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... | |||||
| get(timeout=1) # 1s timeout | |||||
| get(timeout=None) # No timeout | |||||
| get() # Default timeout behavior, which may not be statically known at the method definition. | |||||
| ``` | |||||
| """ | |||||
| def __bool__(self) -> Literal[False]: | |||||
| return False | |||||
| @override | |||||
| def __repr__(self) -> str: | |||||
| return "NOT_GIVEN" | |||||
| NotGivenOr = Union[_T, NotGiven] | |||||
| NOT_GIVEN = NotGiven() | |||||
| class Omit(pydantic.BaseModel): | |||||
| """In certain situations you need to be able to represent a case where a default value has | |||||
| to be explicitly removed and `None` is not an appropriate substitute, for example: | |||||
| ```py | |||||
| # as the default `Content-Type` header is `application/json` that will be sent | |||||
| client.post('/upload/files', files={'file': b'my raw file content'}) | |||||
| # you can't explicitly override the header as it has to be dynamically generated | |||||
| # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' | |||||
| client.post(..., headers={'Content-Type': 'multipart/form-data'}) | |||||
| # instead you can remove the default `application/json` header by passing Omit | |||||
| client.post(..., headers={'Content-Type': Omit()}) | |||||
| ``` | |||||
| """ | |||||
| def __bool__(self) -> Literal[False]: | |||||
| return False | |||||
| Headers = Mapping[str, Union[str, Omit]] | |||||
| ResponseT = TypeVar( | |||||
| "ResponseT", | |||||
| bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", | |||||
| ) | |||||
| # for user input files | |||||
| if TYPE_CHECKING: | |||||
| FileContent = Union[IO[bytes], bytes, PathLike[str]] | |||||
| else: | |||||
| FileContent = Union[IO[bytes], bytes, PathLike] | |||||
| FileTypes = Union[ | |||||
| FileContent, # file content | |||||
| Tuple[str, FileContent], # (filename, file) | |||||
| Tuple[str, FileContent, str], # (filename, file , content_type) | |||||
| Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) | |||||
| ] | |||||
| RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] | |||||
| # for httpx client supported files | |||||
| HttpxFileContent = Union[bytes, IO[bytes]] | |||||
| HttpxFileTypes = Union[ | |||||
| FileContent, # file content | |||||
| Tuple[str, HttpxFileContent], # (filename, file) | |||||
| Tuple[str, HttpxFileContent, str], # (filename, file , content_type) | |||||
| Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) | |||||
| ] | |||||
| HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] |
| from __future__ import annotations | |||||
| import httpx | |||||
| __all__ = [ | |||||
| "ZhipuAIError", | |||||
| "APIStatusError", | |||||
| "APIRequestFailedError", | |||||
| "APIAuthenticationError", | |||||
| "APIReachLimitError", | |||||
| "APIInternalError", | |||||
| "APIServerFlowExceedError", | |||||
| "APIResponseError", | |||||
| "APIResponseValidationError", | |||||
| "APITimeoutError", | |||||
| ] | |||||
| class ZhipuAIError(Exception): | |||||
| def __init__(self, message: str, ) -> None: | |||||
| super().__init__(message) | |||||
| class APIStatusError(Exception): | |||||
| response: httpx.Response | |||||
| status_code: int | |||||
| def __init__(self, message: str, *, response: httpx.Response) -> None: | |||||
| super().__init__(message) | |||||
| self.response = response | |||||
| self.status_code = response.status_code | |||||
| class APIRequestFailedError(APIStatusError): | |||||
| ... | |||||
| class APIAuthenticationError(APIStatusError): | |||||
| ... | |||||
| class APIReachLimitError(APIStatusError): | |||||
| ... | |||||
| class APIInternalError(APIStatusError): | |||||
| ... | |||||
| class APIServerFlowExceedError(APIStatusError): | |||||
| ... | |||||
| class APIResponseError(Exception): | |||||
| message: str | |||||
| request: httpx.Request | |||||
| json_data: object | |||||
| def __init__(self, message: str, request: httpx.Request, json_data: object): | |||||
| self.message = message | |||||
| self.request = request | |||||
| self.json_data = json_data | |||||
| super().__init__(message) | |||||
| class APIResponseValidationError(APIResponseError): | |||||
| status_code: int | |||||
| response: httpx.Response | |||||
| def __init__( | |||||
| self, | |||||
| response: httpx.Response, | |||||
| json_data: object | None, *, | |||||
| message: str | None = None | |||||
| ) -> None: | |||||
| super().__init__( | |||||
| message=message or "Data returned by API invalid for expected schema.", | |||||
| request=response.request, | |||||
| json_data=json_data | |||||
| ) | |||||
| self.response = response | |||||
| self.status_code = response.status_code | |||||
| class APITimeoutError(Exception): | |||||
| request: httpx.Request | |||||
| def __init__(self, request: httpx.Request): | |||||
| self.request = request | |||||
| super().__init__("Request Timeout") |
| from __future__ import annotations | |||||
| import io | |||||
| import os | |||||
| from pathlib import Path | |||||
| from typing import Mapping, Sequence | |||||
| from ._base_type import ( | |||||
| FileTypes, | |||||
| HttpxFileTypes, | |||||
| HttpxRequestFiles, | |||||
| RequestFiles, | |||||
| ) | |||||
| def is_file_content(obj: object) -> bool: | |||||
| return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) | |||||
| def _transform_file(file: FileTypes) -> HttpxFileTypes: | |||||
| if is_file_content(file): | |||||
| if isinstance(file, os.PathLike): | |||||
| path = Path(file) | |||||
| return path.name, path.read_bytes() | |||||
| else: | |||||
| return file | |||||
| if isinstance(file, tuple): | |||||
| if isinstance(file[1], os.PathLike): | |||||
| return (file[0], Path(file[1]).read_bytes(), *file[2:]) | |||||
| else: | |||||
| return (file[0], file[1], *file[2:]) | |||||
| else: | |||||
| raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") | |||||
| def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: | |||||
| if files is None: | |||||
| return None | |||||
| if isinstance(files, Mapping): | |||||
| files = {key: _transform_file(file) for key, file in files.items()} | |||||
| elif isinstance(files, Sequence): | |||||
| files = [(key, _transform_file(file)) for key, file in files] | |||||
| else: | |||||
| raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") | |||||
| return files |
| # -*- coding:utf-8 -*- | |||||
| from __future__ import annotations | |||||
| import inspect | |||||
| from typing import ( | |||||
| Any, | |||||
| Type, | |||||
| Union, | |||||
| cast, | |||||
| Mapping, | |||||
| ) | |||||
| import httpx | |||||
| import pydantic | |||||
| from httpx import URL, Timeout | |||||
| from . import _errors | |||||
| from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data | |||||
| from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError | |||||
| from ._files import make_httpx_files | |||||
| from ._request_opt import ClientRequestParam, UserRequestInput | |||||
| from ._response import HttpResponse | |||||
| from ._sse_client import StreamResponse | |||||
| from ._utils import flatten | |||||
| headers = { | |||||
| "Accept": "application/json", | |||||
| "Content-Type": "application/json; charset=UTF-8", | |||||
| } | |||||
| def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: | |||||
| merged = {**map1, **map2} | |||||
| return {key: val for key, val in merged.items() if val is not None} | |||||
| from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT | |||||
| ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) | |||||
| ZHIPUAI_DEFAULT_MAX_RETRIES = 3 | |||||
| ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) | |||||
| class HttpClient: | |||||
| _client: httpx.Client | |||||
| _version: str | |||||
| _base_url: URL | |||||
| timeout: Union[float, Timeout, None] | |||||
| _limits: httpx.Limits | |||||
| _has_custom_http_client: bool | |||||
| _default_stream_cls: type[StreamResponse[Any]] | None = None | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| version: str, | |||||
| base_url: URL, | |||||
| timeout: Union[float, Timeout, None], | |||||
| custom_httpx_client: httpx.Client | None = None, | |||||
| custom_headers: Mapping[str, str] | None = None, | |||||
| ) -> None: | |||||
| if timeout is None or isinstance(timeout, NotGiven): | |||||
| if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: | |||||
| timeout = custom_httpx_client.timeout | |||||
| else: | |||||
| timeout = ZHIPUAI_DEFAULT_TIMEOUT | |||||
| self.timeout = cast(Timeout, timeout) | |||||
| self._has_custom_http_client = bool(custom_httpx_client) | |||||
| self._client = custom_httpx_client or httpx.Client( | |||||
| base_url=base_url, | |||||
| timeout=self.timeout, | |||||
| limits=ZHIPUAI_DEFAULT_LIMITS, | |||||
| ) | |||||
| self._version = version | |||||
| url = URL(url=base_url) | |||||
| if not url.raw_path.endswith(b"/"): | |||||
| url = url.copy_with(raw_path=url.raw_path + b"/") | |||||
| self._base_url = url | |||||
| self._custom_headers = custom_headers or {} | |||||
| def _prepare_url(self, url: str) -> URL: | |||||
| sub_url = URL(url) | |||||
| if sub_url.is_relative_url: | |||||
| request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") | |||||
| return self._base_url.copy_with(raw_path=request_raw_url) | |||||
| return sub_url | |||||
| @property | |||||
| def _default_headers(self): | |||||
| return \ | |||||
| { | |||||
| "Accept": "application/json", | |||||
| "Content-Type": "application/json; charset=UTF-8", | |||||
| "ZhipuAI-SDK-Ver": self._version, | |||||
| "source_type": "zhipu-sdk-python", | |||||
| "x-request-sdk": "zhipu-sdk-python", | |||||
| **self._auth_headers, | |||||
| **self._custom_headers, | |||||
| } | |||||
| @property | |||||
| def _auth_headers(self): | |||||
| return {} | |||||
| def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: | |||||
| custom_headers = request_param.headers or {} | |||||
| headers_dict = _merge_map(self._default_headers, custom_headers) | |||||
| httpx_headers = httpx.Headers(headers_dict) | |||||
| return httpx_headers | |||||
| def _prepare_request( | |||||
| self, | |||||
| request_param: ClientRequestParam | |||||
| ) -> httpx.Request: | |||||
| kwargs: dict[str, Any] = {} | |||||
| json_data = request_param.json_data | |||||
| headers = self._prepare_headers(request_param) | |||||
| url = self._prepare_url(request_param.url) | |||||
| json_data = request_param.json_data | |||||
| if headers.get("Content-Type") == "multipart/form-data": | |||||
| headers.pop("Content-Type") | |||||
| if json_data: | |||||
| kwargs["data"] = self._make_multipartform(json_data) | |||||
| return self._client.build_request( | |||||
| headers=headers, | |||||
| timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, | |||||
| method=request_param.method, | |||||
| url=url, | |||||
| json=json_data, | |||||
| files=request_param.files, | |||||
| params=request_param.params, | |||||
| **kwargs, | |||||
| ) | |||||
| def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: | |||||
| items = [] | |||||
| if isinstance(value, Mapping): | |||||
| for k, v in value.items(): | |||||
| items.extend(self._object_to_formfata(f"{key}[{k}]", v)) | |||||
| return items | |||||
| if isinstance(value, (list, tuple)): | |||||
| for v in value: | |||||
| items.extend(self._object_to_formfata(key + "[]", v)) | |||||
| return items | |||||
| def _primitive_value_to_str(val) -> str: | |||||
| # copied from httpx | |||||
| if val is True: | |||||
| return "true" | |||||
| elif val is False: | |||||
| return "false" | |||||
| elif val is None: | |||||
| return "" | |||||
| return str(val) | |||||
| str_data = _primitive_value_to_str(value) | |||||
| if not str_data: | |||||
| return [] | |||||
| return [(key, str_data)] | |||||
| def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: | |||||
| items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) | |||||
| serialized: dict[str, object] = {} | |||||
| for key, value in items: | |||||
| if key in serialized: | |||||
| raise ValueError(f"存在重复的键: {key};") | |||||
| serialized[key] = value | |||||
| return serialized | |||||
| def _parse_response( | |||||
| self, | |||||
| *, | |||||
| cast_type: Type[ResponseT], | |||||
| response: httpx.Response, | |||||
| enable_stream: bool, | |||||
| request_param: ClientRequestParam, | |||||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||||
| ) -> HttpResponse: | |||||
| http_response = HttpResponse( | |||||
| raw_response=response, | |||||
| cast_type=cast_type, | |||||
| client=self, | |||||
| enable_stream=enable_stream, | |||||
| stream_cls=stream_cls | |||||
| ) | |||||
| return http_response.parse() | |||||
| def _process_response_data( | |||||
| self, | |||||
| *, | |||||
| data: object, | |||||
| cast_type: type[ResponseT], | |||||
| response: httpx.Response, | |||||
| ) -> ResponseT: | |||||
| if data is None: | |||||
| return cast(ResponseT, None) | |||||
| try: | |||||
| if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): | |||||
| return cast(ResponseT, cast_type.validate(data)) | |||||
| return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) | |||||
| except pydantic.ValidationError as err: | |||||
| raise APIResponseValidationError(response=response, json_data=data) from err | |||||
| def is_closed(self) -> bool: | |||||
| return self._client.is_closed | |||||
| def close(self): | |||||
| self._client.close() | |||||
| def __enter__(self): | |||||
| return self | |||||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||||
| self.close() | |||||
| def request( | |||||
| self, | |||||
| *, | |||||
| cast_type: Type[ResponseT], | |||||
| params: ClientRequestParam, | |||||
| enable_stream: bool = False, | |||||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||||
| ) -> ResponseT | StreamResponse: | |||||
| request = self._prepare_request(params) | |||||
| try: | |||||
| response = self._client.send( | |||||
| request, | |||||
| stream=enable_stream, | |||||
| ) | |||||
| response.raise_for_status() | |||||
| except httpx.TimeoutException as err: | |||||
| raise APITimeoutError(request=request) from err | |||||
| except httpx.HTTPStatusError as err: | |||||
| err.response.read() | |||||
| # raise err | |||||
| raise self._make_status_error(err.response) from None | |||||
| except Exception as err: | |||||
| raise err | |||||
| return self._parse_response( | |||||
| cast_type=cast_type, | |||||
| request_param=params, | |||||
| response=response, | |||||
| enable_stream=enable_stream, | |||||
| stream_cls=stream_cls, | |||||
| ) | |||||
| def get( | |||||
| self, | |||||
| path: str, | |||||
| *, | |||||
| cast_type: Type[ResponseT], | |||||
| options: UserRequestInput = {}, | |||||
| enable_stream: bool = False, | |||||
| ) -> ResponseT | StreamResponse: | |||||
| opts = ClientRequestParam.construct(method="get", url=path, **options) | |||||
| return self.request( | |||||
| cast_type=cast_type, params=opts, | |||||
| enable_stream=enable_stream | |||||
| ) | |||||
| def post( | |||||
| self, | |||||
| path: str, | |||||
| *, | |||||
| body: Body | None = None, | |||||
| cast_type: Type[ResponseT], | |||||
| options: UserRequestInput = {}, | |||||
| files: RequestFiles | None = None, | |||||
| enable_stream: bool = False, | |||||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||||
| ) -> ResponseT | StreamResponse: | |||||
| opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, | |||||
| **options) | |||||
| return self.request( | |||||
| cast_type=cast_type, params=opts, | |||||
| enable_stream=enable_stream, | |||||
| stream_cls=stream_cls | |||||
| ) | |||||
| def patch( | |||||
| self, | |||||
| path: str, | |||||
| *, | |||||
| body: Body | None = None, | |||||
| cast_type: Type[ResponseT], | |||||
| options: UserRequestInput = {}, | |||||
| ) -> ResponseT: | |||||
| opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) | |||||
| return self.request( | |||||
| cast_type=cast_type, params=opts, | |||||
| ) | |||||
| def put( | |||||
| self, | |||||
| path: str, | |||||
| *, | |||||
| body: Body | None = None, | |||||
| cast_type: Type[ResponseT], | |||||
| options: UserRequestInput = {}, | |||||
| files: RequestFiles | None = None, | |||||
| ) -> ResponseT | StreamResponse: | |||||
| opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), | |||||
| **options) | |||||
| return self.request( | |||||
| cast_type=cast_type, params=opts, | |||||
| ) | |||||
| def delete( | |||||
| self, | |||||
| path: str, | |||||
| *, | |||||
| body: Body | None = None, | |||||
| cast_type: Type[ResponseT], | |||||
| options: UserRequestInput = {}, | |||||
| ) -> ResponseT | StreamResponse: | |||||
| opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) | |||||
| return self.request( | |||||
| cast_type=cast_type, params=opts, | |||||
| ) | |||||
| def _make_status_error(self, response) -> APIStatusError: | |||||
| response_text = response.text.strip() | |||||
| status_code = response.status_code | |||||
| error_msg = f"Error code: {status_code}, with error text {response_text}" | |||||
| if status_code == 400: | |||||
| return _errors.APIRequestFailedError(message=error_msg, response=response) | |||||
| elif status_code == 401: | |||||
| return _errors.APIAuthenticationError(message=error_msg, response=response) | |||||
| elif status_code == 429: | |||||
| return _errors.APIReachLimitError(message=error_msg, response=response) | |||||
| elif status_code == 500: | |||||
| return _errors.APIInternalError(message=error_msg, response=response) | |||||
| elif status_code == 503: | |||||
| return _errors.APIServerFlowExceedError(message=error_msg, response=response) | |||||
| return APIStatusError(message=error_msg, response=response) | |||||
| def make_user_request_input( | |||||
| max_retries: int | None = None, | |||||
| timeout: float | Timeout | None | NotGiven = NOT_GIVEN, | |||||
| extra_headers: Headers = None, | |||||
| query: Query | None = None, | |||||
| ) -> UserRequestInput: | |||||
| options: UserRequestInput = {} | |||||
| if extra_headers is not None: | |||||
| options["headers"] = extra_headers | |||||
| if max_retries is not None: | |||||
| options["max_retries"] = max_retries | |||||
| if not isinstance(timeout, NotGiven): | |||||
| options['timeout'] = timeout | |||||
| if query is not None: | |||||
| options["params"] = query | |||||
| return options |
| # -*- coding:utf-8 -*- | |||||
| import time | |||||
| import cachetools.func | |||||
| import jwt | |||||
| API_TOKEN_TTL_SECONDS = 3 * 60 | |||||
| CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 | |||||
| @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) | |||||
| def generate_token(apikey: str): | |||||
| try: | |||||
| api_key, secret = apikey.split(".") | |||||
| except Exception as e: | |||||
| raise Exception("invalid api_key", e) | |||||
| payload = { | |||||
| "api_key": api_key, | |||||
| "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, | |||||
| "timestamp": int(round(time.time() * 1000)), | |||||
| } | |||||
| ret = jwt.encode( | |||||
| payload, | |||||
| secret, | |||||
| algorithm="HS256", | |||||
| headers={"alg": "HS256", "sign_type": "SIGN"}, | |||||
| ) | |||||
| return ret |
| from __future__ import annotations | |||||
| from typing import Union, Any, cast | |||||
| import pydantic.generics | |||||
| from httpx import Timeout | |||||
| from pydantic import ConfigDict | |||||
| from typing_extensions import ( | |||||
| Unpack, ClassVar, TypedDict | |||||
| ) | |||||
| from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query | |||||
| from ._utils import remove_notgiven_indict | |||||
| class UserRequestInput(TypedDict, total=False): | |||||
| max_retries: int | |||||
| timeout: float | Timeout | None | |||||
| headers: Headers | |||||
| params: Query | None | |||||
| class ClientRequestParam(): | |||||
| method: str | |||||
| url: str | |||||
| max_retries: Union[int, NotGiven] = NotGiven() | |||||
| timeout: Union[float, NotGiven] = NotGiven() | |||||
| headers: Union[Headers, NotGiven] = NotGiven() | |||||
| json_data: Union[Body, None] = None | |||||
| files: Union[HttpxRequestFiles, None] = None | |||||
| params: Query = {} | |||||
| model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) | |||||
| def get_max_retries(self, max_retries) -> int: | |||||
| if isinstance(self.max_retries, NotGiven): | |||||
| return max_retries | |||||
| return self.max_retries | |||||
| @classmethod | |||||
| def construct( # type: ignore | |||||
| cls, | |||||
| _fields_set: set[str] | None = None, | |||||
| **values: Unpack[UserRequestInput], | |||||
| ) -> ClientRequestParam : | |||||
| kwargs: dict[str, Any] = { | |||||
| key: remove_notgiven_indict(value) for key, value in values.items() | |||||
| } | |||||
| client = cls() | |||||
| client.__dict__.update(kwargs) | |||||
| return client | |||||
| model_construct = construct | |||||
| from __future__ import annotations | |||||
| import datetime | |||||
| from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING | |||||
| import httpx | |||||
| import pydantic | |||||
| from typing_extensions import ParamSpec, get_origin, get_args | |||||
| from ._base_type import NoneType | |||||
| from ._sse_client import StreamResponse | |||||
| if TYPE_CHECKING: | |||||
| from ._http_client import HttpClient | |||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| class HttpResponse(Generic[R]): | |||||
| _cast_type: type[R] | |||||
| _client: "HttpClient" | |||||
| _parsed: R | None | |||||
| _enable_stream: bool | |||||
| _stream_cls: type[StreamResponse[Any]] | |||||
| http_response: httpx.Response | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| raw_response: httpx.Response, | |||||
| cast_type: type[R], | |||||
| client: "HttpClient", | |||||
| enable_stream: bool = False, | |||||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||||
| ) -> None: | |||||
| self._cast_type = cast_type | |||||
| self._client = client | |||||
| self._parsed = None | |||||
| self._stream_cls = stream_cls | |||||
| self._enable_stream = enable_stream | |||||
| self.http_response = raw_response | |||||
| def parse(self) -> R: | |||||
| self._parsed = self._parse() | |||||
| return self._parsed | |||||
| def _parse(self) -> R: | |||||
| if self._enable_stream: | |||||
| self._parsed = cast( | |||||
| R, | |||||
| self._stream_cls( | |||||
| cast_type=cast(type, get_args(self._stream_cls)[0]), | |||||
| response=self.http_response, | |||||
| client=self._client | |||||
| ) | |||||
| ) | |||||
| return self._parsed | |||||
| cast_type = self._cast_type | |||||
| if cast_type is NoneType: | |||||
| return cast(R, None) | |||||
| http_response = self.http_response | |||||
| if cast_type == str: | |||||
| return cast(R, http_response.text) | |||||
| content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") | |||||
| origin = get_origin(cast_type) or cast_type | |||||
| if content_type != "application/json": | |||||
| if issubclass(origin, pydantic.BaseModel): | |||||
| data = http_response.json() | |||||
| return self._client._process_response_data( | |||||
| data=data, | |||||
| cast_type=cast_type, # type: ignore | |||||
| response=http_response, | |||||
| ) | |||||
| return http_response.text | |||||
| data = http_response.json() | |||||
| return self._client._process_response_data( | |||||
| data=data, | |||||
| cast_type=cast_type, # type: ignore | |||||
| response=http_response, | |||||
| ) | |||||
| @property | |||||
| def headers(self) -> httpx.Headers: | |||||
| return self.http_response.headers | |||||
| @property | |||||
| def http_request(self) -> httpx.Request: | |||||
| return self.http_response.request | |||||
| @property | |||||
| def status_code(self) -> int: | |||||
| return self.http_response.status_code | |||||
| @property | |||||
| def url(self) -> httpx.URL: | |||||
| return self.http_response.url | |||||
| @property | |||||
| def method(self) -> str: | |||||
| return self.http_request.method | |||||
| @property | |||||
| def content(self) -> bytes: | |||||
| return self.http_response.content | |||||
| @property | |||||
| def text(self) -> str: | |||||
| return self.http_response.text | |||||
| @property | |||||
| def http_version(self) -> str: | |||||
| return self.http_response.http_version | |||||
| @property | |||||
| def elapsed(self) -> datetime.timedelta: | |||||
| return self.http_response.elapsed |
| # -*- coding:utf-8 -*- | |||||
| from __future__ import annotations | |||||
| import json | |||||
| from typing import Generic, Iterator, TYPE_CHECKING, Mapping | |||||
| import httpx | |||||
| from ._base_type import ResponseT | |||||
| from ._errors import APIResponseError | |||||
| _FIELD_SEPARATOR = ":" | |||||
| if TYPE_CHECKING: | |||||
| from ._http_client import HttpClient | |||||
| class StreamResponse(Generic[ResponseT]): | |||||
| response: httpx.Response | |||||
| _cast_type: type[ResponseT] | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| cast_type: type[ResponseT], | |||||
| response: httpx.Response, | |||||
| client: HttpClient, | |||||
| ) -> None: | |||||
| self.response = response | |||||
| self._cast_type = cast_type | |||||
| self._data_process_func = client._process_response_data | |||||
| self._stream_chunks = self.__stream__() | |||||
| def __next__(self) -> ResponseT: | |||||
| return self._stream_chunks.__next__() | |||||
| def __iter__(self) -> Iterator[ResponseT]: | |||||
| for item in self._stream_chunks: | |||||
| yield item | |||||
| def __stream__(self) -> Iterator[ResponseT]: | |||||
| sse_line_parser = SSELineParser() | |||||
| iterator = sse_line_parser.iter_lines(self.response.iter_lines()) | |||||
| for sse in iterator: | |||||
| if sse.data.startswith("[DONE]"): | |||||
| break | |||||
| if sse.event is None: | |||||
| data = sse.json_data() | |||||
| if isinstance(data, Mapping) and data.get("error"): | |||||
| raise APIResponseError( | |||||
| message="An error occurred during streaming", | |||||
| request=self.response.request, | |||||
| json_data=data["error"], | |||||
| ) | |||||
| yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) | |||||
| for sse in iterator: | |||||
| pass | |||||
| class Event(object): | |||||
| def __init__( | |||||
| self, | |||||
| event: str | None = None, | |||||
| data: str | None = None, | |||||
| id: str | None = None, | |||||
| retry: int | None = None | |||||
| ): | |||||
| self._event = event | |||||
| self._data = data | |||||
| self._id = id | |||||
| self._retry = retry | |||||
| def __repr__(self): | |||||
| data_len = len(self._data) if self._data else 0 | |||||
| return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" | |||||
| @property | |||||
| def event(self): return self._event | |||||
| @property | |||||
| def data(self): return self._data | |||||
| def json_data(self): return json.loads(self._data) | |||||
| @property | |||||
| def id(self): return self._id | |||||
| @property | |||||
| def retry(self): return self._retry | |||||
| class SSELineParser: | |||||
| _data: list[str] | |||||
| _event: str | None | |||||
| _retry: int | None | |||||
| _id: str | None | |||||
| def __init__(self): | |||||
| self._event = None | |||||
| self._data = [] | |||||
| self._id = None | |||||
| self._retry = None | |||||
| def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: | |||||
| for line in lines: | |||||
| line = line.rstrip('\n') | |||||
| if not line: | |||||
| if self._event is None and \ | |||||
| not self._data and \ | |||||
| self._id is None and \ | |||||
| self._retry is None: | |||||
| continue | |||||
| sse_event = Event( | |||||
| event=self._event, | |||||
| data='\n'.join(self._data), | |||||
| id=self._id, | |||||
| retry=self._retry | |||||
| ) | |||||
| self._event = None | |||||
| self._data = [] | |||||
| self._id = None | |||||
| self._retry = None | |||||
| yield sse_event | |||||
| self.decode_line(line) | |||||
| def decode_line(self, line: str): | |||||
| if line.startswith(":") or not line: | |||||
| return | |||||
| field, _p, value = line.partition(":") | |||||
| if value.startswith(' '): | |||||
| value = value[1:] | |||||
| if field == "data": | |||||
| self._data.append(value) | |||||
| elif field == "event": | |||||
| self._event = value | |||||
| elif field == "retry": | |||||
| try: | |||||
| self._retry = int(value) | |||||
| except (TypeError, ValueError): | |||||
| pass | |||||
| return |
| from __future__ import annotations | |||||
| from typing import Mapping, Iterable, TypeVar | |||||
| from ._base_type import NotGiven | |||||
| def remove_notgiven_indict(obj): | |||||
| if obj is None or (not isinstance(obj, Mapping)): | |||||
| return obj | |||||
| return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} | |||||
| _T = TypeVar("_T") | |||||
| def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: | |||||
| return [item for sublist in t for item in sublist] |
| from typing import List, Optional | |||||
| from pydantic import BaseModel | |||||
| from .chat_completion import CompletionChoice, CompletionUsage | |||||
| __all__ = ["AsyncTaskStatus"] | |||||
| class AsyncTaskStatus(BaseModel): | |||||
| id: Optional[str] = None | |||||
| request_id: Optional[str] = None | |||||
| model: Optional[str] = None | |||||
| task_status: Optional[str] = None | |||||
| class AsyncCompletion(BaseModel): | |||||
| id: Optional[str] = None | |||||
| request_id: Optional[str] = None | |||||
| model: Optional[str] = None | |||||
| task_status: str | |||||
| choices: List[CompletionChoice] | |||||
| usage: CompletionUsage |
| from typing import List, Optional | |||||
| from pydantic import BaseModel | |||||
| __all__ = ["Completion", "CompletionUsage"] | |||||
| class Function(BaseModel): | |||||
| arguments: str | |||||
| name: str | |||||
| class CompletionMessageToolCall(BaseModel): | |||||
| id: str | |||||
| function: Function | |||||
| type: str | |||||
| class CompletionMessage(BaseModel): | |||||
| content: Optional[str] = None | |||||
| role: str | |||||
| tool_calls: Optional[List[CompletionMessageToolCall]] = None | |||||
| class CompletionUsage(BaseModel): | |||||
| prompt_tokens: int | |||||
| completion_tokens: int | |||||
| total_tokens: int | |||||
| class CompletionChoice(BaseModel): | |||||
| index: int | |||||
| finish_reason: str | |||||
| message: CompletionMessage | |||||
| class Completion(BaseModel): | |||||
| model: Optional[str] = None | |||||
| created: Optional[int] = None | |||||
| choices: List[CompletionChoice] | |||||
| request_id: Optional[str] = None | |||||
| id: Optional[str] = None | |||||
| usage: CompletionUsage | |||||
| from typing import List, Optional | |||||
| from pydantic import BaseModel | |||||
| __all__ = [ | |||||
| "ChatCompletionChunk", | |||||
| "Choice", | |||||
| "ChoiceDelta", | |||||
| "ChoiceDeltaFunctionCall", | |||||
| "ChoiceDeltaToolCall", | |||||
| "ChoiceDeltaToolCallFunction", | |||||
| ] | |||||
| class ChoiceDeltaFunctionCall(BaseModel): | |||||
| arguments: Optional[str] = None | |||||
| name: Optional[str] = None | |||||
| class ChoiceDeltaToolCallFunction(BaseModel): | |||||
| arguments: Optional[str] = None | |||||
| name: Optional[str] = None | |||||
| class ChoiceDeltaToolCall(BaseModel): | |||||
| index: int | |||||
| id: Optional[str] = None | |||||
| function: Optional[ChoiceDeltaToolCallFunction] = None | |||||
| type: Optional[str] = None | |||||
| class ChoiceDelta(BaseModel): | |||||
| content: Optional[str] = None | |||||
| role: Optional[str] = None | |||||
| tool_calls: Optional[List[ChoiceDeltaToolCall]] = None | |||||
| class Choice(BaseModel): | |||||
| delta: ChoiceDelta | |||||
| finish_reason: Optional[str] = None | |||||
| index: int | |||||
| class CompletionUsage(BaseModel): | |||||
| prompt_tokens: int | |||||
| completion_tokens: int | |||||
| total_tokens: int | |||||
| class ChatCompletionChunk(BaseModel): | |||||
| id: Optional[str] = None | |||||
| choices: List[Choice] | |||||
| created: Optional[int] = None | |||||
| model: Optional[str] = None | |||||
| usage: Optional[CompletionUsage] = None |
| from typing import Optional | |||||
| from typing_extensions import TypedDict | |||||
| class Reference(TypedDict, total=False): | |||||
| enable: Optional[bool] | |||||
| search_query: Optional[str] |
| from __future__ import annotations | |||||
| from typing import Optional, List | |||||
| from pydantic import BaseModel | |||||
| from .chat.chat_completion import CompletionUsage | |||||
| __all__ = ["Embedding", "EmbeddingsResponded"] | |||||
| class Embedding(BaseModel): | |||||
| object: str | |||||
| index: Optional[int] = None | |||||
| embedding: List[float] | |||||
| class EmbeddingsResponded(BaseModel): | |||||
| object: str | |||||
| data: List[Embedding] | |||||
| model: str | |||||
| usage: CompletionUsage |
| from typing import Optional, List | |||||
| from pydantic import BaseModel | |||||
| __all__ = ["FileObject"] | |||||
| class FileObject(BaseModel): | |||||
| id: Optional[str] = None | |||||
| bytes: Optional[int] = None | |||||
| created_at: Optional[int] = None | |||||
| filename: Optional[str] = None | |||||
| object: Optional[str] = None | |||||
| purpose: Optional[str] = None | |||||
| status: Optional[str] = None | |||||
| status_details: Optional[str] = None | |||||
| class ListOfFileObject(BaseModel): | |||||
| object: Optional[str] = None | |||||
| data: List[FileObject] | |||||
| has_more: Optional[bool] = None |
| from __future__ import annotations | |||||
| from .fine_tuning_job import FineTuningJob as FineTuningJob | |||||
| from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob | |||||
| from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent |
| from typing import List, Union, Optional | |||||
| from typing_extensions import Literal | |||||
| from pydantic import BaseModel | |||||
| __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] | |||||
| class Error(BaseModel): | |||||
| code: str | |||||
| message: str | |||||
| param: Optional[str] = None | |||||
| class Hyperparameters(BaseModel): | |||||
| n_epochs: Union[str, int, None] = None | |||||
| class FineTuningJob(BaseModel): | |||||
| id: Optional[str] = None | |||||
| request_id: Optional[str] = None | |||||
| created_at: Optional[int] = None | |||||
| error: Optional[Error] = None | |||||
| fine_tuned_model: Optional[str] = None | |||||
| finished_at: Optional[int] = None | |||||
| hyperparameters: Optional[Hyperparameters] = None | |||||
| model: Optional[str] = None | |||||
| object: Optional[str] = None | |||||
| result_files: List[str] | |||||
| status: str | |||||
| trained_tokens: Optional[int] = None | |||||
| training_file: str | |||||
| validation_file: Optional[str] = None | |||||
| class ListOfFineTuningJob(BaseModel): | |||||
| object: Optional[str] = None | |||||
| data: List[FineTuningJob] | |||||
| has_more: Optional[bool] = None |
| from typing import List, Union, Optional | |||||
| from typing_extensions import Literal | |||||
| from pydantic import BaseModel | |||||
| __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] | |||||
| class Metric(BaseModel): | |||||
| epoch: Optional[Union[str, int, float]] = None | |||||
| current_steps: Optional[int] = None | |||||
| total_steps: Optional[int] = None | |||||
| elapsed_time: Optional[str] = None | |||||
| remaining_time: Optional[str] = None | |||||
| trained_tokens: Optional[int] = None | |||||
| loss: Optional[Union[str, int, float]] = None | |||||
| eval_loss: Optional[Union[str, int, float]] = None | |||||
| acc: Optional[Union[str, int, float]] = None | |||||
| eval_acc: Optional[Union[str, int, float]] = None | |||||
| learning_rate: Optional[Union[str, int, float]] = None | |||||
| class JobEvent(BaseModel): | |||||
| object: Optional[str] = None | |||||
| id: Optional[str] = None | |||||
| type: Optional[str] = None | |||||
| created_at: Optional[int] = None | |||||
| level: Optional[str] = None | |||||
| message: Optional[str] = None | |||||
| data: Optional[Metric] = None | |||||
| class FineTuningJobEvent(BaseModel): | |||||
| object: Optional[str] = None | |||||
| data: List[JobEvent] | |||||
| has_more: Optional[bool] = None |
| from __future__ import annotations | |||||
| from typing import Union | |||||
| from typing_extensions import Literal, TypedDict | |||||
| __all__ = ["Hyperparameters"] | |||||
| class Hyperparameters(TypedDict, total=False): | |||||
| batch_size: Union[Literal["auto"], int] | |||||
| learning_rate_multiplier: Union[Literal["auto"], float] | |||||
| n_epochs: Union[Literal["auto"], int] |
| from __future__ import annotations | |||||
| from typing import Optional, List | |||||
| from pydantic import BaseModel | |||||
| __all__ = ["GeneratedImage", "ImagesResponded"] | |||||
| class GeneratedImage(BaseModel): | |||||
| b64_json: Optional[str] = None | |||||
| url: Optional[str] = None | |||||
| revised_prompt: Optional[str] = None | |||||
| class ImagesResponded(BaseModel): | |||||
| created: int | |||||
| data: List[GeneratedImage] |
| import pytest | import pytest | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | ||||
| from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage | |||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, | |||||
| UserPromptMessage, PromptMessageTool) | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel | from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel | ||||
| ) | ) | ||||
| assert num_tokens == 14 | assert num_tokens == 14 | ||||
| def test_get_tools_num_tokens(): | |||||
| model = ZhipuAILargeLanguageModel() | |||||
| num_tokens = model.get_num_tokens( | |||||
| model='tools', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('ZHIPUAI_API_KEY') | |||||
| }, | |||||
| tools=[ | |||||
| PromptMessageTool( | |||||
| name='get_current_weather', | |||||
| description='Get the current weather in a given location', | |||||
| parameters={ | |||||
| "type": "object", | |||||
| "properties": { | |||||
| "location": { | |||||
| "type": "string", | |||||
| "description": "The city and state e.g. San Francisco, CA" | |||||
| }, | |||||
| "unit": { | |||||
| "type": "string", | |||||
| "enum": [ | |||||
| "c", | |||||
| "f" | |||||
| ] | |||||
| } | |||||
| }, | |||||
| "required": [ | |||||
| "location" | |||||
| ] | |||||
| } | |||||
| ) | |||||
| ], | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| assert num_tokens == 108 |
| assert isinstance(result, TextEmbeddingResult) | assert isinstance(result, TextEmbeddingResult) | ||||
| assert len(result.embeddings) == 2 | assert len(result.embeddings) == 2 | ||||
| assert result.usage.total_tokens == 2 | |||||
| assert result.usage.total_tokens > 0 | |||||
| def test_get_num_tokens(): | def test_get_num_tokens(): |
| <Thought | <Thought | ||||
| thought={item} | thought={item} | ||||
| allToolIcons={allToolIcons || {}} | allToolIcons={allToolIcons || {}} | ||||
| isFinished={!!item.observation} | |||||
| isFinished={!!item.observation || !isResponsing} | |||||
| /> | /> | ||||
| )} | )} | ||||
| import { useProviderContext } from '@/context/provider-context' | import { useProviderContext } from '@/context/provider-context' | ||||
| import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app' | import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app' | ||||
| import { PromptMode } from '@/models/debug' | import { PromptMode } from '@/models/debug' | ||||
| import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' | |||||
| import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config' | |||||
| import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset' | import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset' | ||||
| import I18n from '@/context/i18n' | import I18n from '@/context/i18n' | ||||
| import { useModalContext } from '@/context/modal-context' | import { useModalContext } from '@/context/modal-context' | ||||
| doSetModelConfig(newModelConfig) | doSetModelConfig(newModelConfig) | ||||
| } | } | ||||
| const isOpenAI = modelConfig.provider === 'openai' | const isOpenAI = modelConfig.provider === 'openai' | ||||
| const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat | |||||
| const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id) | |||||
| const [collectionList, setCollectionList] = useState<Collection[]>([]) | const [collectionList, setCollectionList] = useState<Collection[]>([]) | ||||
| useEffect(() => { | useEffect(() => { | ||||
| tools: [], | tools: [], | ||||
| } | } | ||||
| export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4'] | |||||
| export const DEFAULT_AGENT_PROMPT = { | export const DEFAULT_AGENT_PROMPT = { | ||||
| chat: `Respond to the human as helpfully and accurately as possible. | chat: `Respond to the human as helpfully and accurately as possible. | ||||