| @@ -1,6 +1,10 @@ | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from zhipuai import ZhipuAI | |||
| from zhipuai.types.chat.chat_completion import Completion | |||
| from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| @@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk | |||
| from core.model_runtime.utils import helper | |||
| GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. | |||
| @@ -1,13 +1,14 @@ | |||
| import time | |||
| from typing import Optional | |||
| from zhipuai import ZhipuAI | |||
| from core.embedding.embedding_constant import EmbeddingInputType | |||
| from core.model_runtime.entities.model_entities import PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI | |||
| from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | |||
| class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): | |||
| @@ -1,15 +0,0 @@ | |||
| from .__version__ import __version__ | |||
| from ._client import ZhipuAI | |||
| from .core import ( | |||
| APIAuthenticationError, | |||
| APIConnectionError, | |||
| APIInternalError, | |||
| APIReachLimitError, | |||
| APIRequestFailedError, | |||
| APIResponseError, | |||
| APIResponseValidationError, | |||
| APIServerFlowExceedError, | |||
| APIStatusError, | |||
| APITimeoutError, | |||
| ZhipuAIError, | |||
| ) | |||
| @@ -1 +0,0 @@ | |||
| __version__ = "v2.1.0" | |||
| @@ -1,82 +0,0 @@ | |||
| from __future__ import annotations | |||
| import os | |||
| from collections.abc import Mapping | |||
| from typing import Union | |||
| import httpx | |||
| from httpx import Timeout | |||
| from typing_extensions import override | |||
| from . import api_resource | |||
| from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token | |||
| class ZhipuAI(HttpClient): | |||
| chat: api_resource.chat.Chat | |||
| api_key: str | |||
| _disable_token_cache: bool = True | |||
| def __init__( | |||
| self, | |||
| *, | |||
| api_key: str | None = None, | |||
| base_url: str | httpx.URL | None = None, | |||
| timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, | |||
| max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, | |||
| http_client: httpx.Client | None = None, | |||
| custom_headers: Mapping[str, str] | None = None, | |||
| disable_token_cache: bool = True, | |||
| _strict_response_validation: bool = False, | |||
| ) -> None: | |||
| if api_key is None: | |||
| api_key = os.environ.get("ZHIPUAI_API_KEY") | |||
| if api_key is None: | |||
| raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") | |||
| self.api_key = api_key | |||
| self._disable_token_cache = disable_token_cache | |||
| if base_url is None: | |||
| base_url = os.environ.get("ZHIPUAI_BASE_URL") | |||
| if base_url is None: | |||
| base_url = "https://open.bigmodel.cn/api/paas/v4" | |||
| from .__version__ import __version__ | |||
| super().__init__( | |||
| version=__version__, | |||
| base_url=base_url, | |||
| max_retries=max_retries, | |||
| timeout=timeout, | |||
| custom_httpx_client=http_client, | |||
| custom_headers=custom_headers, | |||
| _strict_response_validation=_strict_response_validation, | |||
| ) | |||
| self.chat = api_resource.chat.Chat(self) | |||
| self.images = api_resource.images.Images(self) | |||
| self.embeddings = api_resource.embeddings.Embeddings(self) | |||
| self.files = api_resource.files.Files(self) | |||
| self.fine_tuning = api_resource.fine_tuning.FineTuning(self) | |||
| self.batches = api_resource.Batches(self) | |||
| self.knowledge = api_resource.Knowledge(self) | |||
| self.tools = api_resource.Tools(self) | |||
| self.videos = api_resource.Videos(self) | |||
| self.assistant = api_resource.Assistant(self) | |||
| @property | |||
| @override | |||
| def auth_headers(self) -> dict[str, str]: | |||
| api_key = self.api_key | |||
| if self._disable_token_cache: | |||
| return {"Authorization": f"Bearer {api_key}"} | |||
| else: | |||
| return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"} | |||
| def __del__(self) -> None: | |||
| if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): | |||
| # if the '__init__' method raised an error, self would not have client attr | |||
| return | |||
| if self._has_custom_http_client: | |||
| return | |||
| self.close() | |||
| @@ -1,34 +0,0 @@ | |||
| from .assistant import ( | |||
| Assistant, | |||
| ) | |||
| from .batches import Batches | |||
| from .chat import ( | |||
| AsyncCompletions, | |||
| Chat, | |||
| Completions, | |||
| ) | |||
| from .embeddings import Embeddings | |||
| from .files import Files, FilesWithRawResponse | |||
| from .fine_tuning import FineTuning | |||
| from .images import Images | |||
| from .knowledge import Knowledge | |||
| from .tools import Tools | |||
| from .videos import ( | |||
| Videos, | |||
| ) | |||
| __all__ = [ | |||
| "Videos", | |||
| "AsyncCompletions", | |||
| "Chat", | |||
| "Completions", | |||
| "Images", | |||
| "Embeddings", | |||
| "Files", | |||
| "FilesWithRawResponse", | |||
| "FineTuning", | |||
| "Batches", | |||
| "Knowledge", | |||
| "Tools", | |||
| "Assistant", | |||
| ] | |||
| @@ -1,3 +0,0 @@ | |||
| from .assistant import Assistant | |||
| __all__ = ["Assistant"] | |||
| @@ -1,122 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Optional | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| StreamResponse, | |||
| deepcopy_minimal, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.assistant import AssistantCompletion | |||
| from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp | |||
| from ...types.assistant.assistant_support_resp import AssistantSupportResp | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| from ...types.assistant import assistant_conversation_params, assistant_create_params | |||
| __all__ = ["Assistant"] | |||
| class Assistant(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def conversation( | |||
| self, | |||
| assistant_id: str, | |||
| model: str, | |||
| messages: list[assistant_create_params.ConversationMessage], | |||
| *, | |||
| stream: bool = True, | |||
| conversation_id: Optional[str] = None, | |||
| attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None, | |||
| metadata: dict | None = None, | |||
| request_id: Optional[str] = None, | |||
| user_id: Optional[str] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> StreamResponse[AssistantCompletion]: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "assistant_id": assistant_id, | |||
| "model": model, | |||
| "messages": messages, | |||
| "stream": stream, | |||
| "conversation_id": conversation_id, | |||
| "attachments": attachments, | |||
| "metadata": metadata, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/assistant", | |||
| body=maybe_transform(body, assistant_create_params.AssistantParameters), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=AssistantCompletion, | |||
| stream=stream or True, | |||
| stream_cls=StreamResponse[AssistantCompletion], | |||
| ) | |||
| def query_support( | |||
| self, | |||
| *, | |||
| assistant_id_list: Optional[list[str]] = None, | |||
| request_id: Optional[str] = None, | |||
| user_id: Optional[str] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> AssistantSupportResp: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "assistant_id_list": assistant_id_list, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/assistant/list", | |||
| body=body, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=AssistantSupportResp, | |||
| ) | |||
| def query_conversation_usage( | |||
| self, | |||
| assistant_id: str, | |||
| page: int = 1, | |||
| page_size: int = 10, | |||
| *, | |||
| request_id: Optional[str] = None, | |||
| user_id: Optional[str] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ConversationUsageListResp: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "assistant_id": assistant_id, | |||
| "page": page, | |||
| "page_size": page_size, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/assistant/conversation/list", | |||
| body=maybe_transform(body, assistant_conversation_params.ConversationParameters), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=ConversationUsageListResp, | |||
| ) | |||
| @@ -1,146 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Literal, Optional | |||
| import httpx | |||
| from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform | |||
| from ..core.pagination import SyncCursorPage | |||
| from ..types import batch_create_params, batch_list_params | |||
| from ..types.batch import Batch | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class Batches(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| completion_window: str | None = None, | |||
| endpoint: Literal["/v1/chat/completions", "/v1/embeddings"], | |||
| input_file_id: str, | |||
| metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, | |||
| auto_delete_input_file: bool = True, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Batch: | |||
| return self._post( | |||
| "/batches", | |||
| body=maybe_transform( | |||
| { | |||
| "completion_window": completion_window, | |||
| "endpoint": endpoint, | |||
| "input_file_id": input_file_id, | |||
| "metadata": metadata, | |||
| "auto_delete_input_file": auto_delete_input_file, | |||
| }, | |||
| batch_create_params.BatchCreateParams, | |||
| ), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=Batch, | |||
| ) | |||
| def retrieve( | |||
| self, | |||
| batch_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Batch: | |||
| """ | |||
| Retrieves a batch. | |||
| Args: | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not batch_id: | |||
| raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") | |||
| return self._get( | |||
| f"/batches/{batch_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=Batch, | |||
| ) | |||
| def list( | |||
| self, | |||
| *, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> SyncCursorPage[Batch]: | |||
| """List your organization's batches. | |||
| Args: | |||
| after: A cursor for use in pagination. | |||
| `after` is an object ID that defines your place | |||
| in the list. For instance, if you make a list request and receive 100 objects, | |||
| ending with obj_foo, your subsequent call can include after=obj_foo in order to | |||
| fetch the next page of the list. | |||
| limit: A limit on the number of objects to be returned. Limit can range between 1 and | |||
| 100, and the default is 20. | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| return self._get_api_list( | |||
| "/batches", | |||
| page=SyncCursorPage[Batch], | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query=maybe_transform( | |||
| { | |||
| "after": after, | |||
| "limit": limit, | |||
| }, | |||
| batch_list_params.BatchListParams, | |||
| ), | |||
| ), | |||
| model=Batch, | |||
| ) | |||
| def cancel( | |||
| self, | |||
| batch_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Batch: | |||
| """ | |||
| Cancels an in-progress batch. | |||
| Args: | |||
| batch_id: The ID of the batch to cancel. | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not batch_id: | |||
| raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") | |||
| return self._post( | |||
| f"/batches/{batch_id}/cancel", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=Batch, | |||
| ) | |||
| @@ -1,5 +0,0 @@ | |||
| from .async_completions import AsyncCompletions | |||
| from .chat import Chat | |||
| from .completions import Completions | |||
| __all__ = ["AsyncCompletions", "Chat", "Completions"] | |||
| @@ -1,115 +0,0 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| from typing import TYPE_CHECKING, Literal, Optional, Union | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| drop_prefix_image_data, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus | |||
| from ...types.chat.code_geex import code_geex_params | |||
| from ...types.sensitive_word_check import SensitiveWordCheckRequest | |||
| logger = logging.getLogger(__name__) | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class AsyncCompletions(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| user_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||
| seed: int | NotGiven = NOT_GIVEN, | |||
| messages: Union[str, list[str], list[int], list[list[int]], None], | |||
| stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, | |||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||
| meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, | |||
| extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> AsyncTaskStatus: | |||
| _cast_type = AsyncTaskStatus | |||
| logger.debug(f"temperature:{temperature}, top_p:{top_p}") | |||
| if temperature is not None and temperature != NOT_GIVEN: | |||
| if temperature <= 0: | |||
| do_sample = False | |||
| temperature = 0.01 | |||
| # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501 | |||
| if temperature >= 1: | |||
| temperature = 0.99 | |||
| # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") | |||
| if top_p is not None and top_p != NOT_GIVEN: | |||
| if top_p >= 1: | |||
| top_p = 0.99 | |||
| # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") | |||
| if top_p <= 0: | |||
| top_p = 0.01 | |||
| # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") | |||
| logger.debug(f"temperature:{temperature}, top_p:{top_p}") | |||
| if isinstance(messages, list): | |||
| for item in messages: | |||
| if item.get("content"): | |||
| item["content"] = drop_prefix_image_data(item["content"]) | |||
| body = { | |||
| "model": model, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| "temperature": temperature, | |||
| "top_p": top_p, | |||
| "do_sample": do_sample, | |||
| "max_tokens": max_tokens, | |||
| "seed": seed, | |||
| "messages": messages, | |||
| "stop": stop, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "tools": tools, | |||
| "tool_choice": tool_choice, | |||
| "meta": meta, | |||
| "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), | |||
| } | |||
| return self._post( | |||
| "/async/chat/completions", | |||
| body=body, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=_cast_type, | |||
| stream=False, | |||
| ) | |||
| def retrieve_completion_result( | |||
| self, | |||
| id: str, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Union[AsyncCompletion, AsyncTaskStatus]: | |||
| _cast_type = Union[AsyncCompletion, AsyncTaskStatus] | |||
| return self._get( | |||
| path=f"/async-result/{id}", | |||
| cast_type=_cast_type, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| ) | |||
| @@ -1,18 +0,0 @@ | |||
| from typing import TYPE_CHECKING | |||
| from ...core import BaseAPI, cached_property | |||
| from .async_completions import AsyncCompletions | |||
| from .completions import Completions | |||
| if TYPE_CHECKING: | |||
| pass | |||
| class Chat(BaseAPI): | |||
| @cached_property | |||
| def completions(self) -> Completions: | |||
| return Completions(self._client) | |||
| @cached_property | |||
| def asyncCompletions(self) -> AsyncCompletions: # noqa: N802 | |||
| return AsyncCompletions(self._client) | |||
| @@ -1,108 +0,0 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| from typing import TYPE_CHECKING, Literal, Optional, Union | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| StreamResponse, | |||
| deepcopy_minimal, | |||
| drop_prefix_image_data, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.chat.chat_completion import Completion | |||
| from ...types.chat.chat_completion_chunk import ChatCompletionChunk | |||
| from ...types.chat.code_geex import code_geex_params | |||
| from ...types.sensitive_word_check import SensitiveWordCheckRequest | |||
| logger = logging.getLogger(__name__) | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| class Completions(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| user_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| temperature: Optional[float] | NotGiven = NOT_GIVEN, | |||
| top_p: Optional[float] | NotGiven = NOT_GIVEN, | |||
| max_tokens: int | NotGiven = NOT_GIVEN, | |||
| seed: int | NotGiven = NOT_GIVEN, | |||
| messages: Union[str, list[str], list[int], object, None], | |||
| stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, | |||
| tools: Optional[object] | NotGiven = NOT_GIVEN, | |||
| tool_choice: str | NotGiven = NOT_GIVEN, | |||
| meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, | |||
| extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> Completion | StreamResponse[ChatCompletionChunk]: | |||
| logger.debug(f"temperature:{temperature}, top_p:{top_p}") | |||
| if temperature is not None and temperature != NOT_GIVEN: | |||
| if temperature <= 0: | |||
| do_sample = False | |||
| temperature = 0.01 | |||
| # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501 | |||
| if temperature >= 1: | |||
| temperature = 0.99 | |||
| # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") | |||
| if top_p is not None and top_p != NOT_GIVEN: | |||
| if top_p >= 1: | |||
| top_p = 0.99 | |||
| # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") | |||
| if top_p <= 0: | |||
| top_p = 0.01 | |||
| # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") | |||
| logger.debug(f"temperature:{temperature}, top_p:{top_p}") | |||
| if isinstance(messages, list): | |||
| for item in messages: | |||
| if item.get("content"): | |||
| item["content"] = drop_prefix_image_data(item["content"]) | |||
| body = deepcopy_minimal( | |||
| { | |||
| "model": model, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| "temperature": temperature, | |||
| "top_p": top_p, | |||
| "do_sample": do_sample, | |||
| "max_tokens": max_tokens, | |||
| "seed": seed, | |||
| "messages": messages, | |||
| "stop": stop, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "stream": stream, | |||
| "tools": tools, | |||
| "tool_choice": tool_choice, | |||
| "meta": meta, | |||
| "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/chat/completions", | |||
| body=body, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=Completion, | |||
| stream=stream or False, | |||
| stream_cls=StreamResponse[ChatCompletionChunk], | |||
| ) | |||
| @@ -1,50 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Optional, Union | |||
| import httpx | |||
| from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options | |||
| from ..types.embeddings import EmbeddingsResponded | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class Embeddings(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| input: Union[str, list[str], list[int], list[list[int]]], | |||
| model: Union[str], | |||
| dimensions: Union[int] | NotGiven = NOT_GIVEN, | |||
| encoding_format: str | NotGiven = NOT_GIVEN, | |||
| user: str | NotGiven = NOT_GIVEN, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> EmbeddingsResponded: | |||
| _cast_type = EmbeddingsResponded | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._post( | |||
| "/embeddings", | |||
| body={ | |||
| "input": input, | |||
| "model": model, | |||
| "dimensions": dimensions, | |||
| "encoding_format": encoding_format, | |||
| "user": user, | |||
| "request_id": request_id, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| }, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=_cast_type, | |||
| stream=False, | |||
| ) | |||
| @@ -1,194 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections.abc import Mapping | |||
| from typing import TYPE_CHECKING, Literal, Optional, cast | |||
| import httpx | |||
| from ..core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| FileTypes, | |||
| Headers, | |||
| NotGiven, | |||
| _legacy_binary_response, | |||
| _legacy_response, | |||
| deepcopy_minimal, | |||
| extract_files, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| __all__ = ["Files", "FilesWithRawResponse"] | |||
| class Files(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| file: Optional[FileTypes] = None, | |||
| upload_detail: Optional[list[UploadDetail]] = None, | |||
| purpose: Literal["fine-tune", "retrieval", "batch"], | |||
| knowledge_id: Optional[str] = None, | |||
| sentence_size: Optional[int] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FileObject: | |||
| if not file and not upload_detail: | |||
| raise ValueError("At least one of `file` and `upload_detail` must be provided.") | |||
| body = deepcopy_minimal( | |||
| { | |||
| "file": file, | |||
| "upload_detail": upload_detail, | |||
| "purpose": purpose, | |||
| "knowledge_id": knowledge_id, | |||
| "sentence_size": sentence_size, | |||
| } | |||
| ) | |||
| files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) | |||
| if files: | |||
| # It should be noted that the actual Content-Type header that will be | |||
| # sent to the server will contain a `boundary` parameter, e.g. | |||
| # multipart/form-data; boundary=---abc-- | |||
| extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} | |||
| return self._post( | |||
| "/files", | |||
| body=maybe_transform(body, file_create_params.FileCreateParams), | |||
| files=files, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FileObject, | |||
| ) | |||
| # def retrieve( | |||
| # self, | |||
| # file_id: str, | |||
| # *, | |||
| # extra_headers: Headers | None = None, | |||
| # extra_body: Body | None = None, | |||
| # timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| # ) -> FileObject: | |||
| # """ | |||
| # Returns information about a specific file. | |||
| # | |||
| # Args: | |||
| # file_id: The ID of the file to retrieve information about | |||
| # extra_headers: Send extra headers | |||
| # | |||
| # extra_body: Add additional JSON properties to the request | |||
| # | |||
| # timeout: Override the client-level default timeout for this request, in seconds | |||
| # """ | |||
| # if not file_id: | |||
| # raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") | |||
| # return self._get( | |||
| # f"/files/{file_id}", | |||
| # options=make_request_options( | |||
| # extra_headers=extra_headers, extra_body=extra_body, timeout=timeout | |||
| # ), | |||
| # cast_type=FileObject, | |||
| # ) | |||
| def list( | |||
| self, | |||
| *, | |||
| purpose: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| order: str | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ListOfFileObject: | |||
| return self._get( | |||
| "/files", | |||
| cast_type=ListOfFileObject, | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query={ | |||
| "purpose": purpose, | |||
| "limit": limit, | |||
| "after": after, | |||
| "order": order, | |||
| }, | |||
| ), | |||
| ) | |||
| def delete( | |||
| self, | |||
| file_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FileDeleted: | |||
| """ | |||
| Delete a file. | |||
| Args: | |||
| file_id: The ID of the file to delete | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not file_id: | |||
| raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") | |||
| return self._delete( | |||
| f"/files/{file_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FileDeleted, | |||
| ) | |||
| def content( | |||
| self, | |||
| file_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> _legacy_response.HttpxBinaryResponseContent: | |||
| """ | |||
| Returns the contents of the specified file. | |||
| Args: | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not file_id: | |||
| raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") | |||
| extra_headers = {"Accept": "application/binary", **(extra_headers or {})} | |||
| return self._get( | |||
| f"/files/{file_id}/content", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=_legacy_binary_response.HttpxBinaryResponseContent, | |||
| ) | |||
| class FilesWithRawResponse: | |||
| def __init__(self, files: Files) -> None: | |||
| self._files = files | |||
| self.create = _legacy_response.to_raw_response_wrapper( | |||
| files.create, | |||
| ) | |||
| self.list = _legacy_response.to_raw_response_wrapper( | |||
| files.list, | |||
| ) | |||
| self.content = _legacy_response.to_raw_response_wrapper( | |||
| files.content, | |||
| ) | |||
| @@ -1,5 +0,0 @@ | |||
| from .fine_tuning import FineTuning | |||
| from .jobs import Jobs | |||
| from .models import FineTunedModels | |||
| __all__ = ["Jobs", "FineTunedModels", "FineTuning"] | |||
| @@ -1,18 +0,0 @@ | |||
| from typing import TYPE_CHECKING | |||
| from ...core import BaseAPI, cached_property | |||
| from .jobs import Jobs | |||
| from .models import FineTunedModels | |||
| if TYPE_CHECKING: | |||
| pass | |||
| class FineTuning(BaseAPI): | |||
| @cached_property | |||
| def jobs(self) -> Jobs: | |||
| return Jobs(self._client) | |||
| @cached_property | |||
| def models(self) -> FineTunedModels: | |||
| return FineTunedModels(self._client) | |||
| @@ -1,3 +0,0 @@ | |||
| from .jobs import Jobs | |||
| __all__ = ["Jobs"] | |||
| @@ -1,152 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Optional | |||
| import httpx | |||
| from ....core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| make_request_options, | |||
| ) | |||
| from ....types.fine_tuning import ( | |||
| FineTuningJob, | |||
| FineTuningJobEvent, | |||
| ListOfFineTuningJob, | |||
| job_create_params, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from ...._client import ZhipuAI | |||
| __all__ = ["Jobs"] | |||
| class Jobs(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| model: str, | |||
| training_file: str, | |||
| hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, | |||
| suffix: Optional[str] | NotGiven = NOT_GIVEN, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| validation_file: Optional[str] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| return self._post( | |||
| "/fine_tuning/jobs", | |||
| body={ | |||
| "model": model, | |||
| "training_file": training_file, | |||
| "hyperparameters": hyperparameters, | |||
| "suffix": suffix, | |||
| "validation_file": validation_file, | |||
| "request_id": request_id, | |||
| }, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| def retrieve( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| return self._get( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| def list( | |||
| self, | |||
| *, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ListOfFineTuningJob: | |||
| return self._get( | |||
| "/fine_tuning/jobs", | |||
| cast_type=ListOfFineTuningJob, | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query={ | |||
| "after": after, | |||
| "limit": limit, | |||
| }, | |||
| ), | |||
| ) | |||
| def cancel( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501 | |||
| # The extra values given here take precedence over values defined on the client or passed to this method. | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| if not fine_tuning_job_id: | |||
| raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") | |||
| return self._post( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| def list_events( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| after: str | NotGiven = NOT_GIVEN, | |||
| limit: int | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJobEvent: | |||
| return self._get( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}/events", | |||
| cast_type=FineTuningJobEvent, | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query={ | |||
| "after": after, | |||
| "limit": limit, | |||
| }, | |||
| ), | |||
| ) | |||
| def delete( | |||
| self, | |||
| fine_tuning_job_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTuningJob: | |||
| if not fine_tuning_job_id: | |||
| raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") | |||
| return self._delete( | |||
| f"/fine_tuning/jobs/{fine_tuning_job_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FineTuningJob, | |||
| ) | |||
| @@ -1,3 +0,0 @@ | |||
| from .fine_tuned_models import FineTunedModels | |||
| __all__ = ["FineTunedModels"] | |||
| @@ -1,41 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| import httpx | |||
| from ....core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| make_request_options, | |||
| ) | |||
| from ....types.fine_tuning.models import FineTunedModelsStatus | |||
| if TYPE_CHECKING: | |||
| from ...._client import ZhipuAI | |||
| __all__ = ["FineTunedModels"] | |||
| class FineTunedModels(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def delete( | |||
| self, | |||
| fine_tuned_model: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> FineTunedModelsStatus: | |||
| if not fine_tuned_model: | |||
| raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}") | |||
| return self._delete( | |||
| f"fine_tuning/fine_tuned_models/{fine_tuned_model}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=FineTunedModelsStatus, | |||
| ) | |||
| @@ -1,59 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Optional | |||
| import httpx | |||
| from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options | |||
| from ..types.image import ImagesResponded | |||
| from ..types.sensitive_word_check import SensitiveWordCheckRequest | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class Images(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def generations( | |||
| self, | |||
| *, | |||
| prompt: str, | |||
| model: str | NotGiven = NOT_GIVEN, | |||
| n: Optional[int] | NotGiven = NOT_GIVEN, | |||
| quality: Optional[str] | NotGiven = NOT_GIVEN, | |||
| response_format: Optional[str] | NotGiven = NOT_GIVEN, | |||
| size: Optional[str] | NotGiven = NOT_GIVEN, | |||
| style: Optional[str] | NotGiven = NOT_GIVEN, | |||
| sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, | |||
| user: str | NotGiven = NOT_GIVEN, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| user_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| disable_strict_validation: Optional[bool] | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> ImagesResponded: | |||
| _cast_type = ImagesResponded | |||
| if disable_strict_validation: | |||
| _cast_type = object | |||
| return self._post( | |||
| "/images/generations", | |||
| body={ | |||
| "prompt": prompt, | |||
| "model": model, | |||
| "n": n, | |||
| "quality": quality, | |||
| "response_format": response_format, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "size": size, | |||
| "style": style, | |||
| "user": user, | |||
| "user_id": user_id, | |||
| "request_id": request_id, | |||
| }, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=_cast_type, | |||
| stream=False, | |||
| ) | |||
| @@ -1,3 +0,0 @@ | |||
| from .knowledge import Knowledge | |||
| __all__ = ["Knowledge"] | |||
| @@ -1,3 +0,0 @@ | |||
| from .document import Document | |||
| __all__ = ["Document"] | |||
| @@ -1,217 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections.abc import Mapping | |||
| from typing import TYPE_CHECKING, Literal, Optional, cast | |||
| import httpx | |||
| from ....core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| FileTypes, | |||
| Headers, | |||
| NotGiven, | |||
| deepcopy_minimal, | |||
| extract_files, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ....types.files import UploadDetail, file_create_params | |||
| from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params | |||
| from ....types.knowledge.document.document_list_resp import DocumentPage | |||
| if TYPE_CHECKING: | |||
| from ...._client import ZhipuAI | |||
| __all__ = ["Document"] | |||
| class Document(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def create( | |||
| self, | |||
| *, | |||
| file: Optional[FileTypes] = None, | |||
| custom_separator: Optional[list[str]] = None, | |||
| upload_detail: Optional[list[UploadDetail]] = None, | |||
| purpose: Literal["retrieval"], | |||
| knowledge_id: Optional[str] = None, | |||
| sentence_size: Optional[int] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> DocumentObject: | |||
| if not file and not upload_detail: | |||
| raise ValueError("At least one of `file` and `upload_detail` must be provided.") | |||
| body = deepcopy_minimal( | |||
| { | |||
| "file": file, | |||
| "upload_detail": upload_detail, | |||
| "purpose": purpose, | |||
| "custom_separator": custom_separator, | |||
| "knowledge_id": knowledge_id, | |||
| "sentence_size": sentence_size, | |||
| } | |||
| ) | |||
| files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) | |||
| if files: | |||
| # It should be noted that the actual Content-Type header that will be | |||
| # sent to the server will contain a `boundary` parameter, e.g. | |||
| # multipart/form-data; boundary=---abc-- | |||
| extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} | |||
| return self._post( | |||
| "/files", | |||
| body=maybe_transform(body, file_create_params.FileCreateParams), | |||
| files=files, | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=DocumentObject, | |||
| ) | |||
| def edit( | |||
| self, | |||
| document_id: str, | |||
| knowledge_type: str, | |||
| *, | |||
| custom_separator: Optional[list[str]] = None, | |||
| sentence_size: Optional[int] = None, | |||
| callback_url: Optional[str] = None, | |||
| callback_header: Optional[dict[str, str]] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> httpx.Response: | |||
| """ | |||
| Args: | |||
| document_id: 知识id | |||
| knowledge_type: 知识类型: | |||
| 1:文章知识: 支持pdf,url,docx | |||
| 2.问答知识-文档: 支持pdf,url,docx | |||
| 3.问答知识-表格: 支持xlsx | |||
| 4.商品库-表格: 支持xlsx | |||
| 5.自定义: 支持pdf,url,docx | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| :param knowledge_type: | |||
| :param document_id: | |||
| :param timeout: | |||
| :param extra_body: | |||
| :param callback_header: | |||
| :param sentence_size: | |||
| :param extra_headers: | |||
| :param callback_url: | |||
| :param custom_separator: | |||
| """ | |||
| if not document_id: | |||
| raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") | |||
| body = deepcopy_minimal( | |||
| { | |||
| "id": document_id, | |||
| "knowledge_type": knowledge_type, | |||
| "custom_separator": custom_separator, | |||
| "sentence_size": sentence_size, | |||
| "callback_url": callback_url, | |||
| "callback_header": callback_header, | |||
| } | |||
| ) | |||
| return self._put( | |||
| f"/document/{document_id}", | |||
| body=maybe_transform(body, document_edit_params.DocumentEditParams), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=httpx.Response, | |||
| ) | |||
| def list( | |||
| self, | |||
| knowledge_id: str, | |||
| *, | |||
| purpose: str | NotGiven = NOT_GIVEN, | |||
| page: str | NotGiven = NOT_GIVEN, | |||
| limit: str | NotGiven = NOT_GIVEN, | |||
| order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> DocumentPage: | |||
| return self._get( | |||
| "/files", | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query=maybe_transform( | |||
| { | |||
| "knowledge_id": knowledge_id, | |||
| "purpose": purpose, | |||
| "page": page, | |||
| "limit": limit, | |||
| "order": order, | |||
| }, | |||
| document_list_params.DocumentListParams, | |||
| ), | |||
| ), | |||
| cast_type=DocumentPage, | |||
| ) | |||
| def delete( | |||
| self, | |||
| document_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> httpx.Response: | |||
| """ | |||
| Delete a file. | |||
| Args: | |||
| document_id: 知识id | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not document_id: | |||
| raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") | |||
| return self._delete( | |||
| f"/document/{document_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=httpx.Response, | |||
| ) | |||
| def retrieve( | |||
| self, | |||
| document_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> DocumentData: | |||
| """ | |||
| Args: | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not document_id: | |||
| raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") | |||
| return self._get( | |||
| f"/document/{document_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=DocumentData, | |||
| ) | |||
| @@ -1,173 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Literal, Optional | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| cached_property, | |||
| deepcopy_minimal, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params | |||
| from ...types.knowledge.knowledge_list_resp import KnowledgePage | |||
| from .document import Document | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| __all__ = ["Knowledge"] | |||
| class Knowledge(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| @cached_property | |||
| def document(self) -> Document: | |||
| return Document(self._client) | |||
| def create( | |||
| self, | |||
| embedding_id: int, | |||
| name: str, | |||
| *, | |||
| customer_identifier: Optional[str] = None, | |||
| description: Optional[str] = None, | |||
| background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, | |||
| icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, | |||
| bucket_id: Optional[str] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> KnowledgeInfo: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "embedding_id": embedding_id, | |||
| "name": name, | |||
| "customer_identifier": customer_identifier, | |||
| "description": description, | |||
| "background": background, | |||
| "icon": icon, | |||
| "bucket_id": bucket_id, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/knowledge", | |||
| body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=KnowledgeInfo, | |||
| ) | |||
| def modify( | |||
| self, | |||
| knowledge_id: str, | |||
| embedding_id: int, | |||
| *, | |||
| name: str, | |||
| description: Optional[str] = None, | |||
| background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, | |||
| icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> httpx.Response: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "id": knowledge_id, | |||
| "embedding_id": embedding_id, | |||
| "name": name, | |||
| "description": description, | |||
| "background": background, | |||
| "icon": icon, | |||
| } | |||
| ) | |||
| return self._put( | |||
| f"/knowledge/{knowledge_id}", | |||
| body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=httpx.Response, | |||
| ) | |||
| def query( | |||
| self, | |||
| *, | |||
| page: int | NotGiven = 1, | |||
| size: int | NotGiven = 10, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> KnowledgePage: | |||
| return self._get( | |||
| "/knowledge", | |||
| options=make_request_options( | |||
| extra_headers=extra_headers, | |||
| extra_body=extra_body, | |||
| timeout=timeout, | |||
| query=maybe_transform( | |||
| { | |||
| "page": page, | |||
| "size": size, | |||
| }, | |||
| knowledge_list_params.KnowledgeListParams, | |||
| ), | |||
| ), | |||
| cast_type=KnowledgePage, | |||
| ) | |||
| def delete( | |||
| self, | |||
| knowledge_id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> httpx.Response: | |||
| """ | |||
| Delete a file. | |||
| Args: | |||
| knowledge_id: 知识库ID | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| if not knowledge_id: | |||
| raise ValueError("Expected a non-empty value for `knowledge_id`") | |||
| return self._delete( | |||
| f"/knowledge/{knowledge_id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=httpx.Response, | |||
| ) | |||
| def used( | |||
| self, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> KnowledgeUsed: | |||
| """ | |||
| Returns the contents of the specified file. | |||
| Args: | |||
| extra_headers: Send extra headers | |||
| extra_body: Add additional JSON properties to the request | |||
| timeout: Override the client-level default timeout for this request, in seconds | |||
| """ | |||
| return self._get( | |||
| "/knowledge/capacity", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=KnowledgeUsed, | |||
| ) | |||
| @@ -1,3 +0,0 @@ | |||
| from .tools import Tools | |||
| __all__ = ["Tools"] | |||
| @@ -1,65 +0,0 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| from typing import TYPE_CHECKING, Literal, Optional, Union | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| StreamResponse, | |||
| deepcopy_minimal, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params | |||
| logger = logging.getLogger(__name__) | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| __all__ = ["Tools"] | |||
| class Tools(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def web_search( | |||
| self, | |||
| *, | |||
| model: str, | |||
| request_id: Optional[str] | NotGiven = NOT_GIVEN, | |||
| stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, | |||
| messages: Union[str, list[str], list[int], object, None], | |||
| scope: Optional[str] | NotGiven = NOT_GIVEN, | |||
| location: Optional[str] | NotGiven = NOT_GIVEN, | |||
| recent_days: Optional[int] | NotGiven = NOT_GIVEN, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> WebSearch | StreamResponse[WebSearchChunk]: | |||
| body = deepcopy_minimal( | |||
| { | |||
| "model": model, | |||
| "request_id": request_id, | |||
| "messages": messages, | |||
| "stream": stream, | |||
| "scope": scope, | |||
| "location": location, | |||
| "recent_days": recent_days, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/tools", | |||
| body=maybe_transform(body, tools_web_search_params.WebSearchParams), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=WebSearch, | |||
| stream=stream or False, | |||
| stream_cls=StreamResponse[WebSearchChunk], | |||
| ) | |||
| @@ -1,7 +0,0 @@ | |||
| from .videos import ( | |||
| Videos, | |||
| ) | |||
| __all__ = [ | |||
| "Videos", | |||
| ] | |||
| @@ -1,77 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING, Optional | |||
| import httpx | |||
| from ...core import ( | |||
| NOT_GIVEN, | |||
| BaseAPI, | |||
| Body, | |||
| Headers, | |||
| NotGiven, | |||
| deepcopy_minimal, | |||
| make_request_options, | |||
| maybe_transform, | |||
| ) | |||
| from ...types.sensitive_word_check import SensitiveWordCheckRequest | |||
| from ...types.video import VideoObject, video_create_params | |||
| if TYPE_CHECKING: | |||
| from ..._client import ZhipuAI | |||
| __all__ = ["Videos"] | |||
| class Videos(BaseAPI): | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| super().__init__(client) | |||
| def generations( | |||
| self, | |||
| model: str, | |||
| *, | |||
| prompt: Optional[str] = None, | |||
| image_url: Optional[str] = None, | |||
| sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, | |||
| request_id: Optional[str] = None, | |||
| user_id: Optional[str] = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> VideoObject: | |||
| if not model and not model: | |||
| raise ValueError("At least one of `model` and `prompt` must be provided.") | |||
| body = deepcopy_minimal( | |||
| { | |||
| "model": model, | |||
| "prompt": prompt, | |||
| "image_url": image_url, | |||
| "sensitive_word_check": sensitive_word_check, | |||
| "request_id": request_id, | |||
| "user_id": user_id, | |||
| } | |||
| ) | |||
| return self._post( | |||
| "/videos/generations", | |||
| body=maybe_transform(body, video_create_params.VideoCreateParams), | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=VideoObject, | |||
| ) | |||
| def retrieve_videos_result( | |||
| self, | |||
| id: str, | |||
| *, | |||
| extra_headers: Headers | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| ) -> VideoObject: | |||
| if not id: | |||
| raise ValueError("At least one of `id` must be provided.") | |||
| return self._get( | |||
| f"/async-result/{id}", | |||
| options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), | |||
| cast_type=VideoObject, | |||
| ) | |||
| @@ -1,108 +0,0 @@ | |||
| from ._base_api import BaseAPI | |||
| from ._base_compat import ( | |||
| PYDANTIC_V2, | |||
| ConfigDict, | |||
| GenericModel, | |||
| cached_property, | |||
| field_get_default, | |||
| get_args, | |||
| get_model_config, | |||
| get_model_fields, | |||
| get_origin, | |||
| is_literal_type, | |||
| is_union, | |||
| parse_obj, | |||
| ) | |||
| from ._base_models import BaseModel, construct_type | |||
| from ._base_type import ( | |||
| NOT_GIVEN, | |||
| Body, | |||
| FileTypes, | |||
| Headers, | |||
| IncEx, | |||
| ModelT, | |||
| NotGiven, | |||
| Query, | |||
| ) | |||
| from ._constants import ( | |||
| ZHIPUAI_DEFAULT_LIMITS, | |||
| ZHIPUAI_DEFAULT_MAX_RETRIES, | |||
| ZHIPUAI_DEFAULT_TIMEOUT, | |||
| ) | |||
| from ._errors import ( | |||
| APIAuthenticationError, | |||
| APIConnectionError, | |||
| APIInternalError, | |||
| APIReachLimitError, | |||
| APIRequestFailedError, | |||
| APIResponseError, | |||
| APIResponseValidationError, | |||
| APIServerFlowExceedError, | |||
| APIStatusError, | |||
| APITimeoutError, | |||
| ZhipuAIError, | |||
| ) | |||
| from ._files import is_file_content | |||
| from ._http_client import HttpClient, make_request_options | |||
| from ._sse_client import StreamResponse | |||
| from ._utils import ( | |||
| deepcopy_minimal, | |||
| drop_prefix_image_data, | |||
| extract_files, | |||
| is_given, | |||
| is_list, | |||
| is_mapping, | |||
| maybe_transform, | |||
| parse_date, | |||
| parse_datetime, | |||
| ) | |||
| __all__ = [ | |||
| "BaseModel", | |||
| "construct_type", | |||
| "BaseAPI", | |||
| "NOT_GIVEN", | |||
| "Headers", | |||
| "NotGiven", | |||
| "Body", | |||
| "IncEx", | |||
| "ModelT", | |||
| "Query", | |||
| "FileTypes", | |||
| "PYDANTIC_V2", | |||
| "ConfigDict", | |||
| "GenericModel", | |||
| "get_args", | |||
| "is_union", | |||
| "parse_obj", | |||
| "get_origin", | |||
| "is_literal_type", | |||
| "get_model_config", | |||
| "get_model_fields", | |||
| "field_get_default", | |||
| "is_file_content", | |||
| "ZhipuAIError", | |||
| "APIStatusError", | |||
| "APIRequestFailedError", | |||
| "APIAuthenticationError", | |||
| "APIReachLimitError", | |||
| "APIInternalError", | |||
| "APIServerFlowExceedError", | |||
| "APIResponseError", | |||
| "APIResponseValidationError", | |||
| "APITimeoutError", | |||
| "make_request_options", | |||
| "HttpClient", | |||
| "ZHIPUAI_DEFAULT_TIMEOUT", | |||
| "ZHIPUAI_DEFAULT_MAX_RETRIES", | |||
| "ZHIPUAI_DEFAULT_LIMITS", | |||
| "is_list", | |||
| "is_mapping", | |||
| "parse_date", | |||
| "parse_datetime", | |||
| "is_given", | |||
| "maybe_transform", | |||
| "deepcopy_minimal", | |||
| "extract_files", | |||
| "StreamResponse", | |||
| ] | |||
| @@ -1,19 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| if TYPE_CHECKING: | |||
| from .._client import ZhipuAI | |||
| class BaseAPI: | |||
| _client: ZhipuAI | |||
| def __init__(self, client: ZhipuAI) -> None: | |||
| self._client = client | |||
| self._delete = client.delete | |||
| self._get = client.get | |||
| self._post = client.post | |||
| self._put = client.put | |||
| self._patch = client.patch | |||
| self._get_api_list = client.get_api_list | |||
| @@ -1,209 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections.abc import Callable | |||
| from datetime import date, datetime | |||
| from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload | |||
| import pydantic | |||
| from pydantic.fields import FieldInfo | |||
| from typing_extensions import Self | |||
| from ._base_type import StrBytesIntFloat | |||
| _T = TypeVar("_T") | |||
| _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) | |||
| # --------------- Pydantic v2 compatibility --------------- | |||
| # Pyright incorrectly reports some of our functions as overriding a method when they don't | |||
| # pyright: reportIncompatibleMethodOverride=false | |||
| PYDANTIC_V2 = pydantic.VERSION.startswith("2.") | |||
| # v1 re-exports | |||
| if TYPE_CHECKING: | |||
| def parse_date(value: date | StrBytesIntFloat) -> date: ... | |||
| def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ... | |||
| def get_args(t: type[Any]) -> tuple[Any, ...]: ... | |||
| def is_union(tp: type[Any] | None) -> bool: ... | |||
| def get_origin(t: type[Any]) -> type[Any] | None: ... | |||
| def is_literal_type(type_: type[Any]) -> bool: ... | |||
| def is_typeddict(type_: type[Any]) -> bool: ... | |||
| else: | |||
| if PYDANTIC_V2: | |||
| from pydantic.v1.typing import ( # noqa: I001 | |||
| get_args as get_args, # noqa: PLC0414 | |||
| is_union as is_union, # noqa: PLC0414 | |||
| get_origin as get_origin, # noqa: PLC0414 | |||
| is_typeddict as is_typeddict, # noqa: PLC0414 | |||
| is_literal_type as is_literal_type, # noqa: PLC0414 | |||
| ) | |||
| from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 | |||
| else: | |||
| from pydantic.typing import ( # noqa: I001 | |||
| get_args as get_args, # noqa: PLC0414 | |||
| is_union as is_union, # noqa: PLC0414 | |||
| get_origin as get_origin, # noqa: PLC0414 | |||
| is_typeddict as is_typeddict, # noqa: PLC0414 | |||
| is_literal_type as is_literal_type, # noqa: PLC0414 | |||
| ) | |||
| from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 | |||
| # refactored config | |||
| if TYPE_CHECKING: | |||
| from pydantic import ConfigDict | |||
| else: | |||
| if PYDANTIC_V2: | |||
| from pydantic import ConfigDict | |||
| else: | |||
| # TODO: provide an error message here? | |||
| ConfigDict = None | |||
| # renamed methods / properties | |||
| def parse_obj(model: type[_ModelT], value: object) -> _ModelT: | |||
| if PYDANTIC_V2: | |||
| return model.model_validate(value) | |||
| else: | |||
| # pyright: ignore[reportDeprecated, reportUnnecessaryCast] | |||
| return cast(_ModelT, model.parse_obj(value)) | |||
| def field_is_required(field: FieldInfo) -> bool: | |||
| if PYDANTIC_V2: | |||
| return field.is_required() | |||
| return field.required # type: ignore | |||
| def field_get_default(field: FieldInfo) -> Any: | |||
| value = field.get_default() | |||
| if PYDANTIC_V2: | |||
| from pydantic_core import PydanticUndefined | |||
| if value == PydanticUndefined: | |||
| return None | |||
| return value | |||
| return value | |||
| def field_outer_type(field: FieldInfo) -> Any: | |||
| if PYDANTIC_V2: | |||
| return field.annotation | |||
| return field.outer_type_ # type: ignore | |||
| def get_model_config(model: type[pydantic.BaseModel]) -> Any: | |||
| if PYDANTIC_V2: | |||
| return model.model_config | |||
| return model.__config__ # type: ignore | |||
| def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: | |||
| if PYDANTIC_V2: | |||
| return model.model_fields | |||
| return model.__fields__ # type: ignore | |||
| def model_copy(model: _ModelT) -> _ModelT: | |||
| if PYDANTIC_V2: | |||
| return model.model_copy() | |||
| return model.copy() # type: ignore | |||
| def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: | |||
| if PYDANTIC_V2: | |||
| return model.model_dump_json(indent=indent) | |||
| return model.json(indent=indent) # type: ignore | |||
| def model_dump( | |||
| model: pydantic.BaseModel, | |||
| *, | |||
| exclude_unset: bool = False, | |||
| exclude_defaults: bool = False, | |||
| ) -> dict[str, Any]: | |||
| if PYDANTIC_V2: | |||
| return model.model_dump( | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| ) | |||
| return cast( | |||
| "dict[str, Any]", | |||
| model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| ), | |||
| ) | |||
| def model_parse(model: type[_ModelT], data: Any) -> _ModelT: | |||
| if PYDANTIC_V2: | |||
| return model.model_validate(data) | |||
| return model.parse_obj(data) # pyright: ignore[reportDeprecated] | |||
| # generic models | |||
| if TYPE_CHECKING: | |||
| class GenericModel(pydantic.BaseModel): ... | |||
| else: | |||
| if PYDANTIC_V2: | |||
| # there no longer needs to be a distinction in v2 but | |||
| # we still have to create our own subclass to avoid | |||
| # inconsistent MRO ordering errors | |||
| class GenericModel(pydantic.BaseModel): ... | |||
| else: | |||
| import pydantic.generics | |||
| class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... | |||
| # cached properties | |||
| if TYPE_CHECKING: | |||
| cached_property = property | |||
| # we define a separate type (copied from typeshed) | |||
| # that represents that `cached_property` is `set`able | |||
| # at runtime, which differs from `@property`. | |||
| # | |||
| # this is a separate type as editors likely special case | |||
| # `@property` and we don't want to cause issues just to have | |||
| # more helpful internal types. | |||
| class typed_cached_property(Generic[_T]): # noqa: N801 | |||
| func: Callable[[Any], _T] | |||
| attrname: str | None | |||
| def __init__(self, func: Callable[[Any], _T]) -> None: ... | |||
| @overload | |||
| def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... | |||
| @overload | |||
| def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... | |||
| def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: | |||
| raise NotImplementedError() | |||
| def __set_name__(self, owner: type[Any], name: str) -> None: ... | |||
| # __set__ is not defined at runtime, but @cached_property is designed to be settable | |||
| def __set__(self, instance: object, value: _T) -> None: ... | |||
| else: | |||
| try: | |||
| from functools import cached_property | |||
| except ImportError: | |||
| from cached_property import cached_property | |||
| typed_cached_property = cached_property | |||
| @@ -1,670 +0,0 @@ | |||
| from __future__ import annotations | |||
| import inspect | |||
| import os | |||
| from collections.abc import Callable | |||
| from datetime import date, datetime | |||
| from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast | |||
| import pydantic | |||
| import pydantic.generics | |||
| from pydantic.fields import FieldInfo | |||
| from typing_extensions import ( | |||
| ParamSpec, | |||
| Protocol, | |||
| override, | |||
| runtime_checkable, | |||
| ) | |||
| from ._base_compat import ( | |||
| PYDANTIC_V2, | |||
| ConfigDict, | |||
| field_get_default, | |||
| get_args, | |||
| get_model_config, | |||
| get_model_fields, | |||
| get_origin, | |||
| is_literal_type, | |||
| is_union, | |||
| parse_obj, | |||
| ) | |||
| from ._base_compat import ( | |||
| GenericModel as BaseGenericModel, | |||
| ) | |||
| from ._base_type import ( | |||
| IncEx, | |||
| ModelT, | |||
| ) | |||
| from ._utils import ( | |||
| PropertyInfo, | |||
| coerce_boolean, | |||
| extract_type_arg, | |||
| is_annotated_type, | |||
| is_list, | |||
| is_mapping, | |||
| parse_date, | |||
| parse_datetime, | |||
| strip_annotated_type, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from pydantic_core.core_schema import ModelField | |||
| __all__ = ["BaseModel", "GenericModel"] | |||
| _BaseModelT = TypeVar("_BaseModelT", bound="BaseModel") | |||
| _T = TypeVar("_T") | |||
| P = ParamSpec("P") | |||
| @runtime_checkable | |||
| class _ConfigProtocol(Protocol): | |||
| allow_population_by_field_name: bool | |||
| class BaseModel(pydantic.BaseModel): | |||
| if PYDANTIC_V2: | |||
| model_config: ClassVar[ConfigDict] = ConfigDict( | |||
| extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) | |||
| ) | |||
| else: | |||
| @property | |||
| @override | |||
| def model_fields_set(self) -> set[str]: | |||
| # a forwards-compat shim for pydantic v2 | |||
| return self.__fields_set__ # type: ignore | |||
| class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] | |||
| extra: Any = pydantic.Extra.allow # type: ignore | |||
| def to_dict( | |||
| self, | |||
| *, | |||
| mode: Literal["json", "python"] = "python", | |||
| use_api_names: bool = True, | |||
| exclude_unset: bool = True, | |||
| exclude_defaults: bool = False, | |||
| exclude_none: bool = False, | |||
| warnings: bool = True, | |||
| ) -> dict[str, object]: | |||
| """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. | |||
| By default, fields that were not set by the API will not be included, | |||
| and keys will match the API response, *not* the property names from the model. | |||
| For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, | |||
| the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). | |||
| Args: | |||
| mode: | |||
| If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. | |||
| If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` | |||
| use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. | |||
| exclude_unset: Whether to exclude fields that have not been explicitly set. | |||
| exclude_defaults: Whether to exclude fields that are set to their default value from the output. | |||
| exclude_none: Whether to exclude fields that have a value of `None` from the output. | |||
| warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. | |||
| """ # noqa: E501 | |||
| return self.model_dump( | |||
| mode=mode, | |||
| by_alias=use_api_names, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| warnings=warnings, | |||
| ) | |||
| def to_json( | |||
| self, | |||
| *, | |||
| indent: int | None = 2, | |||
| use_api_names: bool = True, | |||
| exclude_unset: bool = True, | |||
| exclude_defaults: bool = False, | |||
| exclude_none: bool = False, | |||
| warnings: bool = True, | |||
| ) -> str: | |||
| """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). | |||
| By default, fields that were not set by the API will not be included, | |||
| and keys will match the API response, *not* the property names from the model. | |||
| For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, | |||
| the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). | |||
| Args: | |||
| indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` | |||
| use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. | |||
| exclude_unset: Whether to exclude fields that have not been explicitly set. | |||
| exclude_defaults: Whether to exclude fields that have the default value. | |||
| exclude_none: Whether to exclude fields that have a value of `None`. | |||
| warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. | |||
| """ # noqa: E501 | |||
| return self.model_dump_json( | |||
| indent=indent, | |||
| by_alias=use_api_names, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| warnings=warnings, | |||
| ) | |||
| @override | |||
| def __str__(self) -> str: | |||
| # mypy complains about an invalid self arg | |||
| return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] | |||
| # Override the 'construct' method in a way that supports recursive parsing without validation. | |||
| # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. | |||
| @classmethod | |||
| @override | |||
| def construct( | |||
| cls: type[ModelT], | |||
| _fields_set: set[str] | None = None, | |||
| **values: object, | |||
| ) -> ModelT: | |||
| m = cls.__new__(cls) | |||
| fields_values: dict[str, object] = {} | |||
| config = get_model_config(cls) | |||
| populate_by_name = ( | |||
| config.allow_population_by_field_name | |||
| if isinstance(config, _ConfigProtocol) | |||
| else config.get("populate_by_name") | |||
| ) | |||
| if _fields_set is None: | |||
| _fields_set = set() | |||
| model_fields = get_model_fields(cls) | |||
| for name, field in model_fields.items(): | |||
| key = field.alias | |||
| if key is None or (key not in values and populate_by_name): | |||
| key = name | |||
| if key in values: | |||
| fields_values[name] = _construct_field(value=values[key], field=field, key=key) | |||
| _fields_set.add(name) | |||
| else: | |||
| fields_values[name] = field_get_default(field) | |||
| _extra = {} | |||
| for key, value in values.items(): | |||
| if key not in model_fields: | |||
| if PYDANTIC_V2: | |||
| _extra[key] = value | |||
| else: | |||
| _fields_set.add(key) | |||
| fields_values[key] = value | |||
| object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801 | |||
| if PYDANTIC_V2: | |||
| # these properties are copied from Pydantic's `model_construct()` method | |||
| object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801 | |||
| object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801 | |||
| object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801 | |||
| else: | |||
| # init_private_attributes() does not exist in v2 | |||
| m._init_private_attributes() # type: ignore | |||
| # copied from Pydantic v1's `construct()` method | |||
| object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801 | |||
| return m | |||
| if not TYPE_CHECKING: | |||
| # type checkers incorrectly complain about this assignment | |||
| # because the type signatures are technically different | |||
| # although not in practice | |||
| model_construct = construct | |||
| if not PYDANTIC_V2: | |||
| # we define aliases for some of the new pydantic v2 methods so | |||
| # that we can just document these methods without having to specify | |||
| # a specific pydantic version as some users may not know which | |||
| # pydantic version they are currently using | |||
| @override | |||
| def model_dump( | |||
| self, | |||
| *, | |||
| mode: Literal["json", "python"] | str = "python", | |||
| include: IncEx = None, | |||
| exclude: IncEx = None, | |||
| by_alias: bool = False, | |||
| exclude_unset: bool = False, | |||
| exclude_defaults: bool = False, | |||
| exclude_none: bool = False, | |||
| round_trip: bool = False, | |||
| warnings: bool | Literal["none", "warn", "error"] = True, | |||
| context: dict[str, Any] | None = None, | |||
| serialize_as_any: bool = False, | |||
| ) -> dict[str, Any]: | |||
| """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump | |||
| Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. | |||
| Args: | |||
| mode: The mode in which `to_python` should run. | |||
| If mode is 'json', the dictionary will only contain JSON serializable types. | |||
| If mode is 'python', the dictionary may contain any Python objects. | |||
| include: A list of fields to include in the output. | |||
| exclude: A list of fields to exclude from the output. | |||
| by_alias: Whether to use the field's alias in the dictionary key if defined. | |||
| exclude_unset: Whether to exclude fields that are unset or None from the output. | |||
| exclude_defaults: Whether to exclude fields that are set to their default value from the output. | |||
| exclude_none: Whether to exclude fields that have a value of `None` from the output. | |||
| round_trip: Whether to enable serialization and deserialization round-trip support. | |||
| warnings: Whether to log warnings when invalid fields are encountered. | |||
| Returns: | |||
| A dictionary representation of the model. | |||
| """ | |||
| if mode != "python": | |||
| raise ValueError("mode is only supported in Pydantic v2") | |||
| if round_trip != False: | |||
| raise ValueError("round_trip is only supported in Pydantic v2") | |||
| if warnings != True: | |||
| raise ValueError("warnings is only supported in Pydantic v2") | |||
| if context is not None: | |||
| raise ValueError("context is only supported in Pydantic v2") | |||
| if serialize_as_any != False: | |||
| raise ValueError("serialize_as_any is only supported in Pydantic v2") | |||
| return super().dict( # pyright: ignore[reportDeprecated] | |||
| include=include, | |||
| exclude=exclude, | |||
| by_alias=by_alias, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| ) | |||
| @override | |||
| def model_dump_json( | |||
| self, | |||
| *, | |||
| indent: int | None = None, | |||
| include: IncEx = None, | |||
| exclude: IncEx = None, | |||
| by_alias: bool = False, | |||
| exclude_unset: bool = False, | |||
| exclude_defaults: bool = False, | |||
| exclude_none: bool = False, | |||
| round_trip: bool = False, | |||
| warnings: bool | Literal["none", "warn", "error"] = True, | |||
| context: dict[str, Any] | None = None, | |||
| serialize_as_any: bool = False, | |||
| ) -> str: | |||
| """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json | |||
| Generates a JSON representation of the model using Pydantic's `to_json` method. | |||
| Args: | |||
| indent: Indentation to use in the JSON output. If None is passed, the output will be compact. | |||
| include: Field(s) to include in the JSON output. Can take either a string or set of strings. | |||
| exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. | |||
| by_alias: Whether to serialize using field aliases. | |||
| exclude_unset: Whether to exclude fields that have not been explicitly set. | |||
| exclude_defaults: Whether to exclude fields that have the default value. | |||
| exclude_none: Whether to exclude fields that have a value of `None`. | |||
| round_trip: Whether to use serialization/deserialization between JSON and class instance. | |||
| warnings: Whether to show any warnings that occurred during serialization. | |||
| Returns: | |||
| A JSON string representation of the model. | |||
| """ | |||
| if round_trip != False: | |||
| raise ValueError("round_trip is only supported in Pydantic v2") | |||
| if warnings != True: | |||
| raise ValueError("warnings is only supported in Pydantic v2") | |||
| if context is not None: | |||
| raise ValueError("context is only supported in Pydantic v2") | |||
| if serialize_as_any != False: | |||
| raise ValueError("serialize_as_any is only supported in Pydantic v2") | |||
| return super().json( # type: ignore[reportDeprecated] | |||
| indent=indent, | |||
| include=include, | |||
| exclude=exclude, | |||
| by_alias=by_alias, | |||
| exclude_unset=exclude_unset, | |||
| exclude_defaults=exclude_defaults, | |||
| exclude_none=exclude_none, | |||
| ) | |||
| def _construct_field(value: object, field: FieldInfo, key: str) -> object: | |||
| if value is None: | |||
| return field_get_default(field) | |||
| if PYDANTIC_V2: | |||
| type_ = field.annotation | |||
| else: | |||
| type_ = cast(type, field.outer_type_) # type: ignore | |||
| if type_ is None: | |||
| raise RuntimeError(f"Unexpected field type is None for {key}") | |||
| return construct_type(value=value, type_=type_) | |||
| def is_basemodel(type_: type) -> bool: | |||
| """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" | |||
| if is_union(type_): | |||
| return any(is_basemodel(variant) for variant in get_args(type_)) | |||
| return is_basemodel_type(type_) | |||
| def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: | |||
| origin = get_origin(type_) or type_ | |||
| return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) | |||
| def build( | |||
| base_model_cls: Callable[P, _BaseModelT], | |||
| *args: P.args, | |||
| **kwargs: P.kwargs, | |||
| ) -> _BaseModelT: | |||
| """Construct a BaseModel class without validation. | |||
| This is useful for cases where you need to instantiate a `BaseModel` | |||
| from an API response as this provides type-safe params which isn't supported | |||
| by helpers like `construct_type()`. | |||
| ```py | |||
| build(MyModel, my_field_a="foo", my_field_b=123) | |||
| ``` | |||
| """ | |||
| if args: | |||
| raise TypeError( | |||
| "Received positional arguments which are not supported; Keyword arguments must be used instead", | |||
| ) | |||
| return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) | |||
| def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: | |||
| """Loose coercion to the expected type with construction of nested values. | |||
| Note: the returned value from this function is not guaranteed to match the | |||
| given type. | |||
| """ | |||
| return cast(_T, construct_type(value=value, type_=type_)) | |||
| def construct_type(*, value: object, type_: type) -> object: | |||
| """Loose coercion to the expected type with construction of nested values. | |||
| If the given value does not match the expected type then it is returned as-is. | |||
| """ | |||
| # we allow `object` as the input type because otherwise, passing things like | |||
| # `Literal['value']` will be reported as a type error by type checkers | |||
| type_ = cast("type[object]", type_) | |||
| # unwrap `Annotated[T, ...]` -> `T` | |||
| if is_annotated_type(type_): | |||
| meta: tuple[Any, ...] = get_args(type_)[1:] | |||
| type_ = extract_type_arg(type_, 0) | |||
| else: | |||
| meta = () | |||
| # we need to use the origin class for any types that are subscripted generics | |||
| # e.g. Dict[str, object] | |||
| origin = get_origin(type_) or type_ | |||
| args = get_args(type_) | |||
| if is_union(origin): | |||
| try: | |||
| return validate_type(type_=cast("type[object]", type_), value=value) | |||
| except Exception: | |||
| pass | |||
| # if the type is a discriminated union then we want to construct the right variant | |||
| # in the union, even if the data doesn't match exactly, otherwise we'd break code | |||
| # that relies on the constructed class types, e.g. | |||
| # | |||
| # class FooType: | |||
| # kind: Literal['foo'] | |||
| # value: str | |||
| # | |||
| # class BarType: | |||
| # kind: Literal['bar'] | |||
| # value: int | |||
| # | |||
| # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then | |||
| # we'd end up constructing `FooType` when it should be `BarType`. | |||
| discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) | |||
| if discriminator and is_mapping(value): | |||
| variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) | |||
| if variant_value and isinstance(variant_value, str): | |||
| variant_type = discriminator.mapping.get(variant_value) | |||
| if variant_type: | |||
| return construct_type(type_=variant_type, value=value) | |||
| # if the data is not valid, use the first variant that doesn't fail while deserializing | |||
| for variant in args: | |||
| try: | |||
| return construct_type(value=value, type_=variant) | |||
| except Exception: | |||
| continue | |||
| raise RuntimeError(f"Could not convert data into a valid instance of {type_}") | |||
| if origin == dict: | |||
| if not is_mapping(value): | |||
| return value | |||
| _, items_type = get_args(type_) # Dict[_, items_type] | |||
| return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} | |||
| if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): | |||
| if is_list(value): | |||
| return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] | |||
| if is_mapping(value): | |||
| if issubclass(type_, BaseModel): | |||
| return type_.construct(**value) # type: ignore[arg-type] | |||
| return cast(Any, type_).construct(**value) | |||
| if origin == list: | |||
| if not is_list(value): | |||
| return value | |||
| inner_type = args[0] # List[inner_type] | |||
| return [construct_type(value=entry, type_=inner_type) for entry in value] | |||
| if origin == float: | |||
| if isinstance(value, int): | |||
| coerced = float(value) | |||
| if coerced != value: | |||
| return value | |||
| return coerced | |||
| return value | |||
| if type_ == datetime: | |||
| try: | |||
| return parse_datetime(value) # type: ignore | |||
| except Exception: | |||
| return value | |||
| if type_ == date: | |||
| try: | |||
| return parse_date(value) # type: ignore | |||
| except Exception: | |||
| return value | |||
| return value | |||
| @runtime_checkable | |||
| class CachedDiscriminatorType(Protocol): | |||
| __discriminator__: DiscriminatorDetails | |||
| class DiscriminatorDetails: | |||
| field_name: str | |||
| """The name of the discriminator field in the variant class, e.g. | |||
| ```py | |||
| class Foo(BaseModel): | |||
| type: Literal['foo'] | |||
| ``` | |||
| Will result in field_name='type' | |||
| """ | |||
| field_alias_from: str | None | |||
| """The name of the discriminator field in the API response, e.g. | |||
| ```py | |||
| class Foo(BaseModel): | |||
| type: Literal['foo'] = Field(alias='type_from_api') | |||
| ``` | |||
| Will result in field_alias_from='type_from_api' | |||
| """ | |||
| mapping: dict[str, type] | |||
| """Mapping of discriminator value to variant type, e.g. | |||
| {'foo': FooVariant, 'bar': BarVariant} | |||
| """ | |||
| def __init__( | |||
| self, | |||
| *, | |||
| mapping: dict[str, type], | |||
| discriminator_field: str, | |||
| discriminator_alias: str | None, | |||
| ) -> None: | |||
| self.mapping = mapping | |||
| self.field_name = discriminator_field | |||
| self.field_alias_from = discriminator_alias | |||
| def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: | |||
| if isinstance(union, CachedDiscriminatorType): | |||
| return union.__discriminator__ | |||
| discriminator_field_name: str | None = None | |||
| for annotation in meta_annotations: | |||
| if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: | |||
| discriminator_field_name = annotation.discriminator | |||
| break | |||
| if not discriminator_field_name: | |||
| return None | |||
| mapping: dict[str, type] = {} | |||
| discriminator_alias: str | None = None | |||
| for variant in get_args(union): | |||
| variant = strip_annotated_type(variant) | |||
| if is_basemodel_type(variant): | |||
| if PYDANTIC_V2: | |||
| field = _extract_field_schema_pv2(variant, discriminator_field_name) | |||
| if not field: | |||
| continue | |||
| # Note: if one variant defines an alias then they all should | |||
| discriminator_alias = field.get("serialization_alias") | |||
| field_schema = field["schema"] | |||
| if field_schema["type"] == "literal": | |||
| for entry in cast("LiteralSchema", field_schema)["expected"]: | |||
| if isinstance(entry, str): | |||
| mapping[entry] = variant | |||
| else: | |||
| field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] | |||
| if not field_info: | |||
| continue | |||
| # Note: if one variant defines an alias then they all should | |||
| discriminator_alias = field_info.alias | |||
| if field_info.annotation and is_literal_type(field_info.annotation): | |||
| for entry in get_args(field_info.annotation): | |||
| if isinstance(entry, str): | |||
| mapping[entry] = variant | |||
| if not mapping: | |||
| return None | |||
| details = DiscriminatorDetails( | |||
| mapping=mapping, | |||
| discriminator_field=discriminator_field_name, | |||
| discriminator_alias=discriminator_alias, | |||
| ) | |||
| cast(CachedDiscriminatorType, union).__discriminator__ = details | |||
| return details | |||
| def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: | |||
| schema = model.__pydantic_core_schema__ | |||
| if schema["type"] != "model": | |||
| return None | |||
| fields_schema = schema["schema"] | |||
| if fields_schema["type"] != "model-fields": | |||
| return None | |||
| fields_schema = cast("ModelFieldsSchema", fields_schema) | |||
| field = fields_schema["fields"].get(field_name) | |||
| if not field: | |||
| return None | |||
| return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] | |||
| def validate_type(*, type_: type[_T], value: object) -> _T: | |||
| """Strict validation that the given value matches the expected type""" | |||
| if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): | |||
| return cast(_T, parse_obj(type_, value)) | |||
| return cast(_T, _validate_non_model_type(type_=type_, value=value)) | |||
| # Subclassing here confuses type checkers, so we treat this class as non-inheriting. | |||
| if TYPE_CHECKING: | |||
| GenericModel = BaseModel | |||
| else: | |||
| class GenericModel(BaseGenericModel, BaseModel): | |||
| pass | |||
| if PYDANTIC_V2: | |||
| from pydantic import TypeAdapter | |||
| def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: | |||
| return TypeAdapter(type_).validate_python(value) | |||
| elif not TYPE_CHECKING: | |||
| class TypeAdapter(Generic[_T]): | |||
| """Used as a placeholder to easily convert runtime types to a Pydantic format | |||
| to provide validation. | |||
| For example: | |||
| ```py | |||
| validated = RootModel[int](__root__="5").__root__ | |||
| # validated: 5 | |||
| ``` | |||
| """ | |||
| def __init__(self, type_: type[_T]): | |||
| self.type_ = type_ | |||
| def validate_python(self, value: Any) -> _T: | |||
| if not isinstance(value, self.type_): | |||
| raise ValueError(f"Invalid type: {value} is not of type {self.type_}") | |||
| return value | |||
| def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: | |||
| return TypeAdapter(type_).validate_python(value) | |||
| @@ -1,170 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections.abc import Callable, Mapping, Sequence | |||
| from os import PathLike | |||
| from typing import ( | |||
| IO, | |||
| TYPE_CHECKING, | |||
| Any, | |||
| Literal, | |||
| Optional, | |||
| TypeAlias, | |||
| TypeVar, | |||
| Union, | |||
| ) | |||
| import pydantic | |||
| from httpx import Response | |||
| from typing_extensions import Protocol, TypedDict, override, runtime_checkable | |||
| Query = Mapping[str, object] | |||
| Body = object | |||
| AnyMapping = Mapping[str, object] | |||
| PrimitiveData = Union[str, int, float, bool, None] | |||
| Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] | |||
| ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) | |||
| _T = TypeVar("_T") | |||
| if TYPE_CHECKING: | |||
| NoneType: type[None] | |||
| else: | |||
| NoneType = type(None) | |||
| # Sentinel class used until PEP 0661 is accepted | |||
| class NotGiven: | |||
| """ | |||
| A sentinel singleton class used to distinguish omitted keyword arguments | |||
| from those passed in with the value None (which may have different behavior). | |||
| For example: | |||
| ```py | |||
| def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... | |||
| get(timeout=1) # 1s timeout | |||
| get(timeout=None) # No timeout | |||
| get() # Default timeout behavior, which may not be statically known at the method definition. | |||
| ``` | |||
| """ | |||
| def __bool__(self) -> Literal[False]: | |||
| return False | |||
| @override | |||
| def __repr__(self) -> str: | |||
| return "NOT_GIVEN" | |||
| NotGivenOr = Union[_T, NotGiven] | |||
| NOT_GIVEN = NotGiven() | |||
| class Omit: | |||
| """In certain situations you need to be able to represent a case where a default value has | |||
| to be explicitly removed and `None` is not an appropriate substitute, for example: | |||
| ```py | |||
| # as the default `Content-Type` header is `application/json` that will be sent | |||
| client.post('/upload/files', files={'file': b'my raw file content'}) | |||
| # you can't explicitly override the header as it has to be dynamically generated | |||
| # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' | |||
| client.post(..., headers={'Content-Type': 'multipart/form-data'}) | |||
| # instead you can remove the default `application/json` header by passing Omit | |||
| client.post(..., headers={'Content-Type': Omit()}) | |||
| ``` | |||
| """ | |||
| def __bool__(self) -> Literal[False]: | |||
| return False | |||
| @runtime_checkable | |||
| class ModelBuilderProtocol(Protocol): | |||
| @classmethod | |||
| def build( | |||
| cls: type[_T], | |||
| *, | |||
| response: Response, | |||
| data: object, | |||
| ) -> _T: ... | |||
| Headers = Mapping[str, Union[str, Omit]] | |||
| class HeadersLikeProtocol(Protocol): | |||
| def get(self, __key: str) -> str | None: ... | |||
| HeadersLike = Union[Headers, HeadersLikeProtocol] | |||
| ResponseT = TypeVar( | |||
| "ResponseT", | |||
| bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501 | |||
| ) | |||
| StrBytesIntFloat = Union[str, bytes, int, float] | |||
| # Note: copied from Pydantic | |||
| # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 | |||
| IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" | |||
| PostParser = Callable[[Any], Any] | |||
| @runtime_checkable | |||
| class InheritsGeneric(Protocol): | |||
| """Represents a type that has inherited from `Generic` | |||
| The `__orig_bases__` property can be used to determine the resolved | |||
| type variable for a given base class. | |||
| """ | |||
| __orig_bases__: tuple[_GenericAlias] | |||
| class _GenericAlias(Protocol): | |||
| __origin__: type[object] | |||
| class HttpxSendArgs(TypedDict, total=False): | |||
| auth: httpx.Auth | |||
| # for user input files | |||
| if TYPE_CHECKING: | |||
| Base64FileInput = Union[IO[bytes], PathLike[str]] | |||
| FileContent = Union[IO[bytes], bytes, PathLike[str]] | |||
| else: | |||
| Base64FileInput = Union[IO[bytes], PathLike] | |||
| FileContent = Union[IO[bytes], bytes, PathLike] | |||
| FileTypes = Union[ | |||
| # file (or bytes) | |||
| FileContent, | |||
| # (filename, file (or bytes)) | |||
| tuple[Optional[str], FileContent], | |||
| # (filename, file (or bytes), content_type) | |||
| tuple[Optional[str], FileContent, Optional[str]], | |||
| # (filename, file (or bytes), content_type, headers) | |||
| tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], | |||
| ] | |||
| RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] | |||
| # duplicate of the above but without our custom file support | |||
| HttpxFileContent = Union[bytes, IO[bytes]] | |||
| HttpxFileTypes = Union[ | |||
| # file (or bytes) | |||
| HttpxFileContent, | |||
| # (filename, file (or bytes)) | |||
| tuple[Optional[str], HttpxFileContent], | |||
| # (filename, file (or bytes), content_type) | |||
| tuple[Optional[str], HttpxFileContent, Optional[str]], | |||
| # (filename, file (or bytes), content_type, headers) | |||
| tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], | |||
| ] | |||
| HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] | |||
| @@ -1,12 +0,0 @@ | |||
| import httpx | |||
| RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" | |||
| # 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0` | |||
| ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) | |||
| # 通过 `retry` 参数控制重试次数,默认为3次 | |||
| ZHIPUAI_DEFAULT_MAX_RETRIES = 3 | |||
| # 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10` | |||
| ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) | |||
| INITIAL_RETRY_DELAY = 0.5 | |||
| MAX_RETRY_DELAY = 8.0 | |||
| @@ -1,86 +0,0 @@ | |||
| from __future__ import annotations | |||
| import httpx | |||
| __all__ = [ | |||
| "ZhipuAIError", | |||
| "APIStatusError", | |||
| "APIRequestFailedError", | |||
| "APIAuthenticationError", | |||
| "APIReachLimitError", | |||
| "APIInternalError", | |||
| "APIServerFlowExceedError", | |||
| "APIResponseError", | |||
| "APIResponseValidationError", | |||
| "APITimeoutError", | |||
| "APIConnectionError", | |||
| ] | |||
| class ZhipuAIError(Exception): | |||
| def __init__( | |||
| self, | |||
| message: str, | |||
| ) -> None: | |||
| super().__init__(message) | |||
| class APIStatusError(ZhipuAIError): | |||
| response: httpx.Response | |||
| status_code: int | |||
| def __init__(self, message: str, *, response: httpx.Response) -> None: | |||
| super().__init__(message) | |||
| self.response = response | |||
| self.status_code = response.status_code | |||
| class APIRequestFailedError(APIStatusError): ... | |||
| class APIAuthenticationError(APIStatusError): ... | |||
| class APIReachLimitError(APIStatusError): ... | |||
| class APIInternalError(APIStatusError): ... | |||
| class APIServerFlowExceedError(APIStatusError): ... | |||
| class APIResponseError(ZhipuAIError): | |||
| message: str | |||
| request: httpx.Request | |||
| json_data: object | |||
| def __init__(self, message: str, request: httpx.Request, json_data: object): | |||
| self.message = message | |||
| self.request = request | |||
| self.json_data = json_data | |||
| super().__init__(message) | |||
| class APIResponseValidationError(APIResponseError): | |||
| status_code: int | |||
| response: httpx.Response | |||
| def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: | |||
| super().__init__( | |||
| message=message or "Data returned by API invalid for expected schema.", | |||
| request=response.request, | |||
| json_data=json_data, | |||
| ) | |||
| self.response = response | |||
| self.status_code = response.status_code | |||
| class APIConnectionError(APIResponseError): | |||
| def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: | |||
| super().__init__(message, request, json_data=None) | |||
| class APITimeoutError(APIConnectionError): | |||
| def __init__(self, request: httpx.Request) -> None: | |||
| super().__init__(message="Request timed out.", request=request) | |||
| @@ -1,75 +0,0 @@ | |||
| from __future__ import annotations | |||
| import io | |||
| import os | |||
| import pathlib | |||
| from typing import TypeGuard, overload | |||
| from ._base_type import ( | |||
| Base64FileInput, | |||
| FileContent, | |||
| FileTypes, | |||
| HttpxFileContent, | |||
| HttpxFileTypes, | |||
| HttpxRequestFiles, | |||
| RequestFiles, | |||
| ) | |||
| from ._utils import is_mapping_t, is_sequence_t, is_tuple_t | |||
| def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: | |||
| return isinstance(obj, io.IOBase | os.PathLike) | |||
| def is_file_content(obj: object) -> TypeGuard[FileContent]: | |||
| return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) | |||
| def assert_is_file_content(obj: object, *, key: str | None = None) -> None: | |||
| if not is_file_content(obj): | |||
| prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" | |||
| raise RuntimeError( | |||
| f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" | |||
| ) from None | |||
| @overload | |||
| def to_httpx_files(files: None) -> None: ... | |||
| @overload | |||
| def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... | |||
| def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: | |||
| if files is None: | |||
| return None | |||
| if is_mapping_t(files): | |||
| files = {key: _transform_file(file) for key, file in files.items()} | |||
| elif is_sequence_t(files): | |||
| files = [(key, _transform_file(file)) for key, file in files] | |||
| else: | |||
| raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") | |||
| return files | |||
| def _transform_file(file: FileTypes) -> HttpxFileTypes: | |||
| if is_file_content(file): | |||
| if isinstance(file, os.PathLike): | |||
| path = pathlib.Path(file) | |||
| return (path.name, path.read_bytes()) | |||
| return file | |||
| if is_tuple_t(file): | |||
| return (file[0], _read_file_content(file[1]), *file[2:]) | |||
| raise TypeError("Expected file types input to be a FileContent type or to be a tuple") | |||
| def _read_file_content(file: FileContent) -> HttpxFileContent: | |||
| if isinstance(file, os.PathLike): | |||
| return pathlib.Path(file).read_bytes() | |||
| return file | |||
| @@ -1,910 +0,0 @@ | |||
| from __future__ import annotations | |||
| import inspect | |||
| import logging | |||
| import time | |||
| import warnings | |||
| from collections.abc import Iterator, Mapping | |||
| from itertools import starmap | |||
| from random import random | |||
| from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload | |||
| import httpx | |||
| import pydantic | |||
| from httpx import URL, Timeout | |||
| from . import _errors, get_origin | |||
| from ._base_compat import model_copy | |||
| from ._base_models import GenericModel, construct_type, validate_type | |||
| from ._base_type import ( | |||
| NOT_GIVEN, | |||
| AnyMapping, | |||
| Body, | |||
| Data, | |||
| Headers, | |||
| HttpxSendArgs, | |||
| ModelBuilderProtocol, | |||
| NotGiven, | |||
| Omit, | |||
| PostParser, | |||
| Query, | |||
| RequestFiles, | |||
| ResponseT, | |||
| ) | |||
| from ._constants import ( | |||
| INITIAL_RETRY_DELAY, | |||
| MAX_RETRY_DELAY, | |||
| RAW_RESPONSE_HEADER, | |||
| ZHIPUAI_DEFAULT_LIMITS, | |||
| ZHIPUAI_DEFAULT_MAX_RETRIES, | |||
| ZHIPUAI_DEFAULT_TIMEOUT, | |||
| ) | |||
| from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError | |||
| from ._files import to_httpx_files | |||
| from ._legacy_response import LegacyAPIResponse | |||
| from ._request_opt import FinalRequestOptions, UserRequestInput | |||
| from ._response import APIResponse, BaseAPIResponse, extract_response_type | |||
| from ._sse_client import StreamResponse | |||
| from ._utils import flatten, is_given, is_mapping | |||
| log: logging.Logger = logging.getLogger(__name__) | |||
| # TODO: make base page type vars covariant | |||
| SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") | |||
| # AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") | |||
| _T = TypeVar("_T") | |||
| _T_co = TypeVar("_T_co", covariant=True) | |||
| if TYPE_CHECKING: | |||
| from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT | |||
| else: | |||
| try: | |||
| from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT | |||
| except ImportError: | |||
| # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366 | |||
| HTTPX_DEFAULT_TIMEOUT = Timeout(5.0) | |||
| headers = { | |||
| "Accept": "application/json", | |||
| "Content-Type": "application/json; charset=UTF-8", | |||
| } | |||
| class PageInfo: | |||
| """Stores the necessary information to build the request to retrieve the next page. | |||
| Either `url` or `params` must be set. | |||
| """ | |||
| url: URL | NotGiven | |||
| params: Query | NotGiven | |||
| @overload | |||
| def __init__( | |||
| self, | |||
| *, | |||
| url: URL, | |||
| ) -> None: ... | |||
| @overload | |||
| def __init__( | |||
| self, | |||
| *, | |||
| params: Query, | |||
| ) -> None: ... | |||
| def __init__( | |||
| self, | |||
| *, | |||
| url: URL | NotGiven = NOT_GIVEN, | |||
| params: Query | NotGiven = NOT_GIVEN, | |||
| ) -> None: | |||
| self.url = url | |||
| self.params = params | |||
| class BasePage(GenericModel, Generic[_T]): | |||
| """ | |||
| Defines the core interface for pagination. | |||
| Type Args: | |||
| ModelT: The pydantic model that represents an item in the response. | |||
| Methods: | |||
| has_next_page(): Check if there is another page available | |||
| next_page_info(): Get the necessary information to make a request for the next page | |||
| """ | |||
| _options: FinalRequestOptions = pydantic.PrivateAttr() | |||
| _model: type[_T] = pydantic.PrivateAttr() | |||
| def has_next_page(self) -> bool: | |||
| items = self._get_page_items() | |||
| if not items: | |||
| return False | |||
| return self.next_page_info() is not None | |||
| def next_page_info(self) -> Optional[PageInfo]: ... | |||
| def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] | |||
| ... | |||
| def _params_from_url(self, url: URL) -> httpx.QueryParams: | |||
| # TODO: do we have to preprocess params here? | |||
| return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params) | |||
| def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: | |||
| options = model_copy(self._options) | |||
| options._strip_raw_response_header() | |||
| if not isinstance(info.params, NotGiven): | |||
| options.params = {**options.params, **info.params} | |||
| return options | |||
| if not isinstance(info.url, NotGiven): | |||
| params = self._params_from_url(info.url) | |||
| url = info.url.copy_with(params=params) | |||
| options.params = dict(url.params) | |||
| options.url = str(url) | |||
| return options | |||
| raise ValueError("Unexpected PageInfo state") | |||
| class BaseSyncPage(BasePage[_T], Generic[_T]): | |||
| _client: HttpClient = pydantic.PrivateAttr() | |||
| def _set_private_attributes( | |||
| self, | |||
| client: HttpClient, | |||
| model: type[_T], | |||
| options: FinalRequestOptions, | |||
| ) -> None: | |||
| self._model = model | |||
| self._client = client | |||
| self._options = options | |||
| # Pydantic uses a custom `__iter__` method to support casting BaseModels | |||
| # to dictionaries. e.g. dict(model). | |||
| # As we want to support `for item in page`, this is inherently incompatible | |||
| # with the default pydantic behavior. It is not possible to support both | |||
| # use cases at once. Fortunately, this is not a big deal as all other pydantic | |||
| # methods should continue to work as expected as there is an alternative method | |||
| # to cast a model to a dictionary, model.dict(), which is used internally | |||
| # by pydantic. | |||
| def __iter__(self) -> Iterator[_T]: # type: ignore | |||
| for page in self.iter_pages(): | |||
| yield from page._get_page_items() | |||
| def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]: | |||
| page = self | |||
| while True: | |||
| yield page | |||
| if page.has_next_page(): | |||
| page = page.get_next_page() | |||
| else: | |||
| return | |||
| def get_next_page(self: SyncPageT) -> SyncPageT: | |||
| info = self.next_page_info() | |||
| if not info: | |||
| raise RuntimeError( | |||
| "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`." | |||
| ) | |||
| options = self._info_to_options(info) | |||
| return self._client._request_api_list(self._model, page=self.__class__, options=options) | |||
| class HttpClient: | |||
| _client: httpx.Client | |||
| _version: str | |||
| _base_url: URL | |||
| max_retries: int | |||
| timeout: Union[float, Timeout, None] | |||
| _limits: httpx.Limits | |||
| _has_custom_http_client: bool | |||
| _default_stream_cls: type[StreamResponse[Any]] | None = None | |||
| _strict_response_validation: bool | |||
| def __init__( | |||
| self, | |||
| *, | |||
| version: str, | |||
| base_url: URL, | |||
| _strict_response_validation: bool, | |||
| max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, | |||
| timeout: Union[float, Timeout, None], | |||
| limits: httpx.Limits | None = None, | |||
| custom_httpx_client: httpx.Client | None = None, | |||
| custom_headers: Mapping[str, str] | None = None, | |||
| ) -> None: | |||
| if limits is not None: | |||
| warnings.warn( | |||
| "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501 | |||
| category=DeprecationWarning, | |||
| stacklevel=3, | |||
| ) | |||
| if custom_httpx_client is not None: | |||
| raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") | |||
| else: | |||
| limits = ZHIPUAI_DEFAULT_LIMITS | |||
| if not is_given(timeout): | |||
| if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: | |||
| timeout = custom_httpx_client.timeout | |||
| else: | |||
| timeout = ZHIPUAI_DEFAULT_TIMEOUT | |||
| self.max_retries = max_retries | |||
| self.timeout = timeout | |||
| self._limits = limits | |||
| self._has_custom_http_client = bool(custom_httpx_client) | |||
| self._client = custom_httpx_client or httpx.Client( | |||
| base_url=base_url, | |||
| timeout=self.timeout, | |||
| limits=limits, | |||
| ) | |||
| self._version = version | |||
| url = URL(url=base_url) | |||
| if not url.raw_path.endswith(b"/"): | |||
| url = url.copy_with(raw_path=url.raw_path + b"/") | |||
| self._base_url = url | |||
| self._custom_headers = custom_headers or {} | |||
| self._strict_response_validation = _strict_response_validation | |||
| def _prepare_url(self, url: str) -> URL: | |||
| sub_url = URL(url) | |||
| if sub_url.is_relative_url: | |||
| request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") | |||
| return self._base_url.copy_with(raw_path=request_raw_url) | |||
| return sub_url | |||
| @property | |||
| def _default_headers(self): | |||
| return { | |||
| "Accept": "application/json", | |||
| "Content-Type": "application/json; charset=UTF-8", | |||
| "ZhipuAI-SDK-Ver": self._version, | |||
| "source_type": "zhipu-sdk-python", | |||
| "x-request-sdk": "zhipu-sdk-python", | |||
| **self.auth_headers, | |||
| **self._custom_headers, | |||
| } | |||
| @property | |||
| def custom_auth(self) -> httpx.Auth | None: | |||
| return None | |||
| @property | |||
| def auth_headers(self): | |||
| return {} | |||
| def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers: | |||
| custom_headers = options.headers or {} | |||
| headers_dict = _merge_mappings(self._default_headers, custom_headers) | |||
| httpx_headers = httpx.Headers(headers_dict) | |||
| return httpx_headers | |||
| def _remaining_retries( | |||
| self, | |||
| remaining_retries: Optional[int], | |||
| options: FinalRequestOptions, | |||
| ) -> int: | |||
| return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) | |||
| def _calculate_retry_timeout( | |||
| self, | |||
| remaining_retries: int, | |||
| options: FinalRequestOptions, | |||
| response_headers: Optional[httpx.Headers] = None, | |||
| ) -> float: | |||
| max_retries = options.get_max_retries(self.max_retries) | |||
| # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. | |||
| # retry_after = self._parse_retry_after_header(response_headers) | |||
| # if retry_after is not None and 0 < retry_after <= 60: | |||
| # return retry_after | |||
| nb_retries = max_retries - remaining_retries | |||
| # Apply exponential backoff, but not more than the max. | |||
| sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) | |||
| # Apply some jitter, plus-or-minus half a second. | |||
| jitter = 1 - 0.25 * random() | |||
| timeout = sleep_seconds * jitter | |||
| return max(timeout, 0) | |||
| def _build_request(self, options: FinalRequestOptions) -> httpx.Request: | |||
| kwargs: dict[str, Any] = {} | |||
| headers = self._prepare_headers(options) | |||
| url = self._prepare_url(options.url) | |||
| json_data = options.json_data | |||
| if options.extra_json is not None: | |||
| if json_data is None: | |||
| json_data = cast(Body, options.extra_json) | |||
| elif is_mapping(json_data): | |||
| json_data = _merge_mappings(json_data, options.extra_json) | |||
| else: | |||
| raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") | |||
| content_type = headers.get("Content-Type") | |||
| # multipart/form-data; boundary=---abc-- | |||
| if headers.get("Content-Type") == "multipart/form-data": | |||
| if "boundary" not in content_type: | |||
| # only remove the header if the boundary hasn't been explicitly set | |||
| # as the caller doesn't want httpx to come up with their own boundary | |||
| headers.pop("Content-Type") | |||
| if json_data: | |||
| kwargs["data"] = self._make_multipartform(json_data) | |||
| return self._client.build_request( | |||
| headers=headers, | |||
| timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, | |||
| method=options.method, | |||
| url=url, | |||
| json=json_data, | |||
| files=options.files, | |||
| params=options.params, | |||
| **kwargs, | |||
| ) | |||
| def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: | |||
| items = [] | |||
| if isinstance(value, Mapping): | |||
| for k, v in value.items(): | |||
| items.extend(self._object_to_formdata(f"{key}[{k}]", v)) | |||
| return items | |||
| if isinstance(value, list | tuple): | |||
| for v in value: | |||
| items.extend(self._object_to_formdata(key + "[]", v)) | |||
| return items | |||
| def _primitive_value_to_str(val) -> str: | |||
| # copied from httpx | |||
| if val is True: | |||
| return "true" | |||
| elif val is False: | |||
| return "false" | |||
| elif val is None: | |||
| return "" | |||
| return str(val) | |||
| str_data = _primitive_value_to_str(value) | |||
| if not str_data: | |||
| return [] | |||
| return [(key, str_data)] | |||
| def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: | |||
| items = flatten(list(starmap(self._object_to_formdata, data.items()))) | |||
| serialized: dict[str, object] = {} | |||
| for key, value in items: | |||
| if key in serialized: | |||
| raise ValueError(f"存在重复的键: {key};") | |||
| serialized[key] = value | |||
| return serialized | |||
| def _process_response_data( | |||
| self, | |||
| *, | |||
| data: object, | |||
| cast_type: type[ResponseT], | |||
| response: httpx.Response, | |||
| ) -> ResponseT: | |||
| if data is None: | |||
| return cast(ResponseT, None) | |||
| if cast_type is object: | |||
| return cast(ResponseT, data) | |||
| try: | |||
| if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol): | |||
| return cast(ResponseT, cast_type.build(response=response, data=data)) | |||
| if self._strict_response_validation: | |||
| return cast(ResponseT, validate_type(type_=cast_type, value=data)) | |||
| return cast(ResponseT, construct_type(type_=cast_type, value=data)) | |||
| except pydantic.ValidationError as err: | |||
| raise APIResponseValidationError(response=response, json_data=data) from err | |||
| def _should_stream_response_body(self, request: httpx.Request) -> bool: | |||
| return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] | |||
| def _should_retry(self, response: httpx.Response) -> bool: | |||
| # Note: this is not a standard header | |||
| should_retry_header = response.headers.get("x-should-retry") | |||
| # If the server explicitly says whether or not to retry, obey. | |||
| if should_retry_header == "true": | |||
| log.debug("Retrying as header `x-should-retry` is set to `true`") | |||
| return True | |||
| if should_retry_header == "false": | |||
| log.debug("Not retrying as header `x-should-retry` is set to `false`") | |||
| return False | |||
| # Retry on request timeouts. | |||
| if response.status_code == 408: | |||
| log.debug("Retrying due to status code %i", response.status_code) | |||
| return True | |||
| # Retry on lock timeouts. | |||
| if response.status_code == 409: | |||
| log.debug("Retrying due to status code %i", response.status_code) | |||
| return True | |||
| # Retry on rate limits. | |||
| if response.status_code == 429: | |||
| log.debug("Retrying due to status code %i", response.status_code) | |||
| return True | |||
| # Retry internal errors. | |||
| if response.status_code >= 500: | |||
| log.debug("Retrying due to status code %i", response.status_code) | |||
| return True | |||
| log.debug("Not retrying") | |||
| return False | |||
| def is_closed(self) -> bool: | |||
| return self._client.is_closed | |||
| def close(self): | |||
| self._client.close() | |||
| def __enter__(self): | |||
| return self | |||
| def __exit__(self, exc_type, exc_val, exc_tb): | |||
| self.close() | |||
| def request( | |||
| self, | |||
| cast_type: type[ResponseT], | |||
| options: FinalRequestOptions, | |||
| remaining_retries: Optional[int] = None, | |||
| *, | |||
| stream: bool = False, | |||
| stream_cls: type[StreamResponse] | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| return self._request( | |||
| cast_type=cast_type, | |||
| options=options, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| remaining_retries=remaining_retries, | |||
| ) | |||
| def _request( | |||
| self, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: FinalRequestOptions, | |||
| remaining_retries: int | None, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse] | None, | |||
| ) -> ResponseT | StreamResponse: | |||
| retries = self._remaining_retries(remaining_retries, options) | |||
| request = self._build_request(options) | |||
| kwargs: HttpxSendArgs = {} | |||
| if self.custom_auth is not None: | |||
| kwargs["auth"] = self.custom_auth | |||
| try: | |||
| response = self._client.send( | |||
| request, | |||
| stream=stream or self._should_stream_response_body(request=request), | |||
| **kwargs, | |||
| ) | |||
| except httpx.TimeoutException as err: | |||
| log.debug("Encountered httpx.TimeoutException", exc_info=True) | |||
| if retries > 0: | |||
| return self._retry_request( | |||
| options, | |||
| cast_type, | |||
| retries, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| response_headers=None, | |||
| ) | |||
| log.debug("Raising timeout error") | |||
| raise APITimeoutError(request=request) from err | |||
| except Exception as err: | |||
| log.debug("Encountered Exception", exc_info=True) | |||
| if retries > 0: | |||
| return self._retry_request( | |||
| options, | |||
| cast_type, | |||
| retries, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| response_headers=None, | |||
| ) | |||
| log.debug("Raising connection error") | |||
| raise APIConnectionError(request=request) from err | |||
| log.debug( | |||
| 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase | |||
| ) | |||
| try: | |||
| response.raise_for_status() | |||
| except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code | |||
| log.debug("Encountered httpx.HTTPStatusError", exc_info=True) | |||
| if retries > 0 and self._should_retry(err.response): | |||
| err.response.close() | |||
| return self._retry_request( | |||
| options, | |||
| cast_type, | |||
| retries, | |||
| err.response.headers, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| ) | |||
| # If the response is streamed then we need to explicitly read the response | |||
| # to completion before attempting to access the response text. | |||
| if not err.response.is_closed: | |||
| err.response.read() | |||
| log.debug("Re-raising status error") | |||
| raise self._make_status_error(err.response) from None | |||
| # return self._parse_response( | |||
| # cast_type=cast_type, | |||
| # options=options, | |||
| # response=response, | |||
| # stream=stream, | |||
| # stream_cls=stream_cls, | |||
| # ) | |||
| return self._process_response( | |||
| cast_type=cast_type, | |||
| options=options, | |||
| response=response, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| ) | |||
| def _retry_request( | |||
| self, | |||
| options: FinalRequestOptions, | |||
| cast_type: type[ResponseT], | |||
| remaining_retries: int, | |||
| response_headers: httpx.Headers | None, | |||
| *, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse] | None, | |||
| ) -> ResponseT | StreamResponse: | |||
| remaining = remaining_retries - 1 | |||
| if remaining == 1: | |||
| log.debug("1 retry left") | |||
| else: | |||
| log.debug("%i retries left", remaining) | |||
| timeout = self._calculate_retry_timeout(remaining, options, response_headers) | |||
| log.info("Retrying request to %s in %f seconds", options.url, timeout) | |||
| # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a | |||
| # different thread if necessary. | |||
| time.sleep(timeout) | |||
| return self._request( | |||
| options=options, | |||
| cast_type=cast_type, | |||
| remaining_retries=remaining, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| ) | |||
| def _process_response( | |||
| self, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: FinalRequestOptions, | |||
| response: httpx.Response, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse] | None, | |||
| ) -> ResponseT: | |||
| # _legacy_response with raw_response_header to parser method | |||
| if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": | |||
| return cast( | |||
| ResponseT, | |||
| LegacyAPIResponse( | |||
| raw=response, | |||
| client=self, | |||
| cast_type=cast_type, | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| options=options, | |||
| ), | |||
| ) | |||
| origin = get_origin(cast_type) or cast_type | |||
| if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): | |||
| if not issubclass(origin, APIResponse): | |||
| raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") | |||
| response_cls = cast("type[BaseAPIResponse[Any]]", cast_type) | |||
| return cast( | |||
| ResponseT, | |||
| response_cls( | |||
| raw=response, | |||
| client=self, | |||
| cast_type=extract_response_type(response_cls), | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| options=options, | |||
| ), | |||
| ) | |||
| if cast_type == httpx.Response: | |||
| return cast(ResponseT, response) | |||
| api_response = APIResponse( | |||
| raw=response, | |||
| client=self, | |||
| cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast] | |||
| stream=stream, | |||
| stream_cls=stream_cls, | |||
| options=options, | |||
| ) | |||
| if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): | |||
| return cast(ResponseT, api_response) | |||
| return api_response.parse() | |||
| def _request_api_list( | |||
| self, | |||
| model: type[object], | |||
| page: type[SyncPageT], | |||
| options: FinalRequestOptions, | |||
| ) -> SyncPageT: | |||
| def _parser(resp: SyncPageT) -> SyncPageT: | |||
| resp._set_private_attributes( | |||
| client=self, | |||
| model=model, | |||
| options=options, | |||
| ) | |||
| return resp | |||
| options.post_parser = _parser | |||
| return self.request(page, options, stream=False) | |||
| @overload | |||
| def get( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| stream: Literal[False] = False, | |||
| ) -> ResponseT: ... | |||
| @overload | |||
| def get( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| stream: Literal[True], | |||
| stream_cls: type[StreamResponse], | |||
| ) -> StreamResponse: ... | |||
| @overload | |||
| def get( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse] | None = None, | |||
| ) -> ResponseT | StreamResponse: ... | |||
| def get( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| options: UserRequestInput = {}, | |||
| stream: bool = False, | |||
| stream_cls: type[StreamResponse] | None = None, | |||
| ) -> ResponseT: | |||
| opts = FinalRequestOptions.construct(method="get", url=path, **options) | |||
| return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) | |||
| @overload | |||
| def post( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| stream: Literal[False] = False, | |||
| ) -> ResponseT: ... | |||
| @overload | |||
| def post( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| stream: Literal[True], | |||
| stream_cls: type[StreamResponse], | |||
| ) -> StreamResponse: ... | |||
| @overload | |||
| def post( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse] | None = None, | |||
| ) -> ResponseT | StreamResponse: ... | |||
| def post( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| stream: bool = False, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = FinalRequestOptions.construct( | |||
| method="post", url=path, json_data=body, files=to_httpx_files(files), **options | |||
| ) | |||
| return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) | |||
| def patch( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| ) -> ResponseT: | |||
| opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) | |||
| return self.request( | |||
| cast_type=cast_type, | |||
| options=opts, | |||
| ) | |||
| def put( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| files: RequestFiles | None = None, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = FinalRequestOptions.construct( | |||
| method="put", url=path, json_data=body, files=to_httpx_files(files), **options | |||
| ) | |||
| return self.request( | |||
| cast_type=cast_type, | |||
| options=opts, | |||
| ) | |||
| def delete( | |||
| self, | |||
| path: str, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| ) -> ResponseT | StreamResponse: | |||
| opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options) | |||
| return self.request( | |||
| cast_type=cast_type, | |||
| options=opts, | |||
| ) | |||
| def get_api_list( | |||
| self, | |||
| path: str, | |||
| *, | |||
| model: type[object], | |||
| page: type[SyncPageT], | |||
| body: Body | None = None, | |||
| options: UserRequestInput = {}, | |||
| method: str = "get", | |||
| ) -> SyncPageT: | |||
| opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options) | |||
| return self._request_api_list(model, page, opts) | |||
| def _make_status_error(self, response) -> APIStatusError: | |||
| response_text = response.text.strip() | |||
| status_code = response.status_code | |||
| error_msg = f"Error code: {status_code}, with error text {response_text}" | |||
| if status_code == 400: | |||
| return _errors.APIRequestFailedError(message=error_msg, response=response) | |||
| elif status_code == 401: | |||
| return _errors.APIAuthenticationError(message=error_msg, response=response) | |||
| elif status_code == 429: | |||
| return _errors.APIReachLimitError(message=error_msg, response=response) | |||
| elif status_code == 500: | |||
| return _errors.APIInternalError(message=error_msg, response=response) | |||
| elif status_code == 503: | |||
| return _errors.APIServerFlowExceedError(message=error_msg, response=response) | |||
| return APIStatusError(message=error_msg, response=response) | |||
| def make_request_options( | |||
| *, | |||
| query: Query | None = None, | |||
| extra_headers: Headers | None = None, | |||
| extra_query: Query | None = None, | |||
| extra_body: Body | None = None, | |||
| timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, | |||
| post_parser: PostParser | NotGiven = NOT_GIVEN, | |||
| ) -> UserRequestInput: | |||
| """Create a dict of type RequestOptions without keys of NotGiven values.""" | |||
| options: UserRequestInput = {} | |||
| if extra_headers is not None: | |||
| options["headers"] = extra_headers | |||
| if extra_body is not None: | |||
| options["extra_json"] = cast(AnyMapping, extra_body) | |||
| if query is not None: | |||
| options["params"] = query | |||
| if extra_query is not None: | |||
| options["params"] = {**options.get("params", {}), **extra_query} | |||
| if not isinstance(timeout, NotGiven): | |||
| options["timeout"] = timeout | |||
| if is_given(post_parser): | |||
| # internal | |||
| options["post_parser"] = post_parser # type: ignore | |||
| return options | |||
| def _merge_mappings( | |||
| obj1: Mapping[_T_co, Union[_T, Omit]], | |||
| obj2: Mapping[_T_co, Union[_T, Omit]], | |||
| ) -> dict[_T_co, _T]: | |||
| """Merge two mappings of the same type, removing any values that are instances of `Omit`. | |||
| In cases with duplicate keys the second mapping takes precedence. | |||
| """ | |||
| merged = {**obj1, **obj2} | |||
| return {key: value for key, value in merged.items() if not isinstance(value, Omit)} | |||
| @@ -1,31 +0,0 @@ | |||
| import time | |||
| import cachetools.func | |||
| import jwt | |||
| # 缓存时间 3分钟 | |||
| CACHE_TTL_SECONDS = 3 * 60 | |||
| # token 有效期比缓存时间 多30秒 | |||
| API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30 | |||
| @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) | |||
| def generate_token(apikey: str): | |||
| try: | |||
| api_key, secret = apikey.split(".") | |||
| except Exception as e: | |||
| raise Exception("invalid api_key", e) | |||
| payload = { | |||
| "api_key": api_key, | |||
| "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, | |||
| "timestamp": int(round(time.time() * 1000)), | |||
| } | |||
| ret = jwt.encode( | |||
| payload, | |||
| secret, | |||
| algorithm="HS256", | |||
| headers={"alg": "HS256", "sign_type": "SIGN"}, | |||
| ) | |||
| return ret | |||
| @@ -1,207 +0,0 @@ | |||
| from __future__ import annotations | |||
| import os | |||
| from collections.abc import AsyncIterator, Iterator | |||
| from typing import Any | |||
| import httpx | |||
| class HttpxResponseContent: | |||
| @property | |||
| def content(self) -> bytes: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| @property | |||
| def text(self) -> str: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| @property | |||
| def encoding(self) -> str | None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| @property | |||
| def charset_encoding(self) -> str | None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def json(self, **kwargs: Any) -> Any: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def read(self) -> bytes: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def iter_lines(self) -> Iterator[str]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def write_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| ) -> None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def stream_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| *, | |||
| chunk_size: int | None = None, | |||
| ) -> None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| def close(self) -> None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aread(self) -> bytes: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aiter_lines(self) -> AsyncIterator[str]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def astream_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| *, | |||
| chunk_size: int | None = None, | |||
| ) -> None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| async def aclose(self) -> None: | |||
| raise NotImplementedError("This method is not implemented for this class.") | |||
| class HttpxBinaryResponseContent(HttpxResponseContent): | |||
| response: httpx.Response | |||
| def __init__(self, response: httpx.Response) -> None: | |||
| self.response = response | |||
| @property | |||
| def content(self) -> bytes: | |||
| return self.response.content | |||
| @property | |||
| def encoding(self) -> str | None: | |||
| return self.response.encoding | |||
| @property | |||
| def charset_encoding(self) -> str | None: | |||
| return self.response.charset_encoding | |||
| def read(self) -> bytes: | |||
| return self.response.read() | |||
| def text(self) -> str: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| def json(self, **kwargs: Any) -> Any: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| def iter_lines(self) -> Iterator[str]: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| async def aiter_lines(self) -> AsyncIterator[str]: | |||
| raise NotImplementedError("Not implemented for binary response content") | |||
| def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: | |||
| return self.response.iter_bytes(chunk_size) | |||
| def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: | |||
| return self.response.iter_raw(chunk_size) | |||
| def write_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| ) -> None: | |||
| """Write the output to the given file. | |||
| Accepts a filename or any path-like object, e.g. pathlib.Path | |||
| Note: if you want to stream the data to the file instead of writing | |||
| all at once then you should use `.with_streaming_response` when making | |||
| the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')` | |||
| """ | |||
| with open(file, mode="wb") as f: | |||
| for data in self.response.iter_bytes(): | |||
| f.write(data) | |||
| def stream_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| *, | |||
| chunk_size: int | None = None, | |||
| ) -> None: | |||
| with open(file, mode="wb") as f: | |||
| for data in self.response.iter_bytes(chunk_size): | |||
| f.write(data) | |||
| def close(self) -> None: | |||
| return self.response.close() | |||
| async def aread(self) -> bytes: | |||
| return await self.response.aread() | |||
| async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: | |||
| return self.response.aiter_bytes(chunk_size) | |||
| async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: | |||
| return self.response.aiter_raw(chunk_size) | |||
| async def astream_to_file( | |||
| self, | |||
| file: str | os.PathLike[str], | |||
| *, | |||
| chunk_size: int | None = None, | |||
| ) -> None: | |||
| path = anyio.Path(file) | |||
| async with await path.open(mode="wb") as f: | |||
| async for data in self.response.aiter_bytes(chunk_size): | |||
| await f.write(data) | |||
| async def aclose(self) -> None: | |||
| return await self.response.aclose() | |||
| class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent): | |||
| response: httpx.Response | |||
| @property | |||
| def text(self) -> str: | |||
| return self.response.text | |||
| def json(self, **kwargs: Any) -> Any: | |||
| return self.response.json(**kwargs) | |||
| def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: | |||
| return self.response.iter_text(chunk_size) | |||
| def iter_lines(self) -> Iterator[str]: | |||
| return self.response.iter_lines() | |||
| async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: | |||
| return self.response.aiter_text(chunk_size) | |||
| async def aiter_lines(self) -> AsyncIterator[str]: | |||
| return self.response.aiter_lines() | |||
| @@ -1,341 +0,0 @@ | |||
| from __future__ import annotations | |||
| import datetime | |||
| import functools | |||
| import inspect | |||
| import logging | |||
| from collections.abc import Callable | |||
| from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload | |||
| import httpx | |||
| import pydantic | |||
| from typing_extensions import ParamSpec, override | |||
| from ._base_models import BaseModel, is_basemodel | |||
| from ._base_type import NoneType | |||
| from ._constants import RAW_RESPONSE_HEADER | |||
| from ._errors import APIResponseValidationError | |||
| from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent | |||
| from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type | |||
| from ._utils import extract_type_arg, is_annotated_type, is_given | |||
| if TYPE_CHECKING: | |||
| from ._http_client import HttpClient | |||
| from ._request_opt import FinalRequestOptions | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| _T = TypeVar("_T") | |||
| log: logging.Logger = logging.getLogger(__name__) | |||
| class LegacyAPIResponse(Generic[R]): | |||
| """This is a legacy class as it will be replaced by `APIResponse` | |||
| and `AsyncAPIResponse` in the `_response.py` file in the next major | |||
| release. | |||
| For the sync client this will mostly be the same with the exception | |||
| of `content` & `text` will be methods instead of properties. In the | |||
| async client, all methods will be async. | |||
| A migration script will be provided & the migration in general should | |||
| be smooth. | |||
| """ | |||
| _cast_type: type[R] | |||
| _client: HttpClient | |||
| _parsed_by_type: dict[type[Any], Any] | |||
| _stream: bool | |||
| _stream_cls: type[StreamResponse[Any]] | None | |||
| _options: FinalRequestOptions | |||
| http_response: httpx.Response | |||
| def __init__( | |||
| self, | |||
| *, | |||
| raw: httpx.Response, | |||
| cast_type: type[R], | |||
| client: HttpClient, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse[Any]] | None, | |||
| options: FinalRequestOptions, | |||
| ) -> None: | |||
| self._cast_type = cast_type | |||
| self._client = client | |||
| self._parsed_by_type = {} | |||
| self._stream = stream | |||
| self._stream_cls = stream_cls | |||
| self._options = options | |||
| self.http_response = raw | |||
| @property | |||
| def request_id(self) -> str | None: | |||
| return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] | |||
| @overload | |||
| def parse(self, *, to: type[_T]) -> _T: ... | |||
| @overload | |||
| def parse(self) -> R: ... | |||
| def parse(self, *, to: type[_T] | None = None) -> R | _T: | |||
| """Returns the rich python representation of this response's data. | |||
| NOTE: For the async client: this will become a coroutine in the next major version. | |||
| For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. | |||
| You can customize the type that the response is parsed into through | |||
| the `to` argument, e.g. | |||
| ```py | |||
| from zhipuai import BaseModel | |||
| class MyModel(BaseModel): | |||
| foo: str | |||
| obj = response.parse(to=MyModel) | |||
| print(obj.foo) | |||
| ``` | |||
| We support parsing: | |||
| - `BaseModel` | |||
| - `dict` | |||
| - `list` | |||
| - `Union` | |||
| - `str` | |||
| - `int` | |||
| - `float` | |||
| - `httpx.Response` | |||
| """ | |||
| cache_key = to if to is not None else self._cast_type | |||
| cached = self._parsed_by_type.get(cache_key) | |||
| if cached is not None: | |||
| return cached # type: ignore[no-any-return] | |||
| parsed = self._parse(to=to) | |||
| if is_given(self._options.post_parser): | |||
| parsed = self._options.post_parser(parsed) | |||
| self._parsed_by_type[cache_key] = parsed | |||
| return parsed | |||
| @property | |||
| def headers(self) -> httpx.Headers: | |||
| return self.http_response.headers | |||
| @property | |||
| def http_request(self) -> httpx.Request: | |||
| return self.http_response.request | |||
| @property | |||
| def status_code(self) -> int: | |||
| return self.http_response.status_code | |||
| @property | |||
| def url(self) -> httpx.URL: | |||
| return self.http_response.url | |||
| @property | |||
| def method(self) -> str: | |||
| return self.http_request.method | |||
| @property | |||
| def content(self) -> bytes: | |||
| """Return the binary response content. | |||
| NOTE: this will be removed in favour of `.read()` in the | |||
| next major version. | |||
| """ | |||
| return self.http_response.content | |||
| @property | |||
| def text(self) -> str: | |||
| """Return the decoded response content. | |||
| NOTE: this will be turned into a method in the next major version. | |||
| """ | |||
| return self.http_response.text | |||
| @property | |||
| def http_version(self) -> str: | |||
| return self.http_response.http_version | |||
| @property | |||
| def is_closed(self) -> bool: | |||
| return self.http_response.is_closed | |||
| @property | |||
| def elapsed(self) -> datetime.timedelta: | |||
| """The time taken for the complete request/response cycle to complete.""" | |||
| return self.http_response.elapsed | |||
| def _parse(self, *, to: type[_T] | None = None) -> R | _T: | |||
| # unwrap `Annotated[T, ...]` -> `T` | |||
| if to and is_annotated_type(to): | |||
| to = extract_type_arg(to, 0) | |||
| if self._stream: | |||
| if to: | |||
| if not is_stream_class_type(to): | |||
| raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") | |||
| return cast( | |||
| _T, | |||
| to( | |||
| cast_type=extract_stream_chunk_type( | |||
| to, | |||
| failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 | |||
| ), | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| if self._stream_cls: | |||
| return cast( | |||
| R, | |||
| self._stream_cls( | |||
| cast_type=extract_stream_chunk_type(self._stream_cls), | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls) | |||
| if stream_cls is None: | |||
| raise MissingStreamClassError() | |||
| return cast( | |||
| R, | |||
| stream_cls( | |||
| cast_type=self._cast_type, | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| cast_type = to if to is not None else self._cast_type | |||
| # unwrap `Annotated[T, ...]` -> `T` | |||
| if is_annotated_type(cast_type): | |||
| cast_type = extract_type_arg(cast_type, 0) | |||
| if cast_type is NoneType: | |||
| return cast(R, None) | |||
| response = self.http_response | |||
| if cast_type == str: | |||
| return cast(R, response.text) | |||
| if cast_type == int: | |||
| return cast(R, int(response.text)) | |||
| if cast_type == float: | |||
| return cast(R, float(response.text)) | |||
| origin = get_origin(cast_type) or cast_type | |||
| if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent): | |||
| # in the response, e.g. mime file | |||
| *_, filename = response.headers.get("content-disposition", "").split("filename=") | |||
| # 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent | |||
| if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"): | |||
| return cast(R, HttpxTextBinaryResponseContent(response)) | |||
| else: | |||
| return cast(R, cast_type(response)) # type: ignore | |||
| if origin == LegacyAPIResponse: | |||
| raise RuntimeError("Unexpected state - cast_type is `APIResponse`") | |||
| if inspect.isclass(origin) and issubclass(origin, httpx.Response): | |||
| # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response | |||
| # and pass that class to our request functions. We cannot change the variance to be either | |||
| # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct | |||
| # the response class ourselves but that is something that should be supported directly in httpx | |||
| # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. | |||
| if cast_type != httpx.Response: | |||
| raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") | |||
| return cast(R, response) | |||
| if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): | |||
| raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") | |||
| if ( | |||
| cast_type is not object | |||
| and origin is not list | |||
| and origin is not dict | |||
| and origin is not Union | |||
| and not issubclass(origin, BaseModel) | |||
| ): | |||
| raise RuntimeError( | |||
| f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 | |||
| ) | |||
| # split is required to handle cases where additional information is included | |||
| # in the response, e.g. application/json; charset=utf-8 | |||
| content_type, *_ = response.headers.get("content-type", "*").split(";") | |||
| if content_type != "application/json": | |||
| if is_basemodel(cast_type): | |||
| try: | |||
| data = response.json() | |||
| except Exception as exc: | |||
| log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) | |||
| else: | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=response, | |||
| ) | |||
| if self._client._strict_response_validation: | |||
| raise APIResponseValidationError( | |||
| response=response, | |||
| message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 | |||
| json_data=response.text, | |||
| ) | |||
| # If the API responds with content that isn't JSON then we just return | |||
| # the (decoded) text without performing any parsing so that you can still | |||
| # handle the response however you need to. | |||
| return response.text # type: ignore | |||
| data = response.json() | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=response, | |||
| ) | |||
| @override | |||
| def __repr__(self) -> str: | |||
| return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" | |||
| class MissingStreamClassError(TypeError): | |||
| def __init__(self) -> None: | |||
| super().__init__( | |||
| "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 | |||
| ) | |||
| def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]: | |||
| """Higher order function that takes one of our bound API methods and wraps it | |||
| to support returning the raw `APIResponse` object directly. | |||
| """ | |||
| @functools.wraps(func) | |||
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: | |||
| extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |||
| extra_headers[RAW_RESPONSE_HEADER] = "true" | |||
| kwargs["extra_headers"] = extra_headers | |||
| return cast(LegacyAPIResponse[R], func(*args, **kwargs)) | |||
| return wrapped | |||
| @@ -1,97 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections.abc import Callable | |||
| from typing import TYPE_CHECKING, Any, ClassVar, Union, cast | |||
| import pydantic.generics | |||
| from httpx import Timeout | |||
| from typing_extensions import Required, TypedDict, Unpack, final | |||
| from ._base_compat import PYDANTIC_V2, ConfigDict | |||
| from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query | |||
| from ._constants import RAW_RESPONSE_HEADER | |||
| from ._utils import is_given, strip_not_given | |||
| class UserRequestInput(TypedDict, total=False): | |||
| headers: Headers | |||
| max_retries: int | |||
| timeout: float | Timeout | None | |||
| params: Query | |||
| extra_json: AnyMapping | |||
| class FinalRequestOptionsInput(TypedDict, total=False): | |||
| method: Required[str] | |||
| url: Required[str] | |||
| params: Query | |||
| headers: Headers | |||
| max_retries: int | |||
| timeout: float | Timeout | None | |||
| files: HttpxRequestFiles | None | |||
| json_data: Body | |||
| extra_json: AnyMapping | |||
| @final | |||
| class FinalRequestOptions(pydantic.BaseModel): | |||
| method: str | |||
| url: str | |||
| params: Query = {} | |||
| headers: Union[Headers, NotGiven] = NotGiven() | |||
| max_retries: Union[int, NotGiven] = NotGiven() | |||
| timeout: Union[float, Timeout, None, NotGiven] = NotGiven() | |||
| files: Union[HttpxRequestFiles, None] = None | |||
| idempotency_key: Union[str, None] = None | |||
| post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() | |||
| # It should be noted that we cannot use `json` here as that would override | |||
| # a BaseModel method in an incompatible fashion. | |||
| json_data: Union[Body, None] = None | |||
| extra_json: Union[AnyMapping, None] = None | |||
| if PYDANTIC_V2: | |||
| model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) | |||
| else: | |||
| class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] | |||
| arbitrary_types_allowed: bool = True | |||
| def get_max_retries(self, max_retries: int) -> int: | |||
| if isinstance(self.max_retries, NotGiven): | |||
| return max_retries | |||
| return self.max_retries | |||
| def _strip_raw_response_header(self) -> None: | |||
| if not is_given(self.headers): | |||
| return | |||
| if self.headers.get(RAW_RESPONSE_HEADER): | |||
| self.headers = {**self.headers} | |||
| self.headers.pop(RAW_RESPONSE_HEADER) | |||
| # override the `construct` method so that we can run custom transformations. | |||
| # this is necessary as we don't want to do any actual runtime type checking | |||
| # (which means we can't use validators) but we do want to ensure that `NotGiven` | |||
| # values are not present | |||
| # | |||
| # type ignore required because we're adding explicit types to `**values` | |||
| @classmethod | |||
| def construct( # type: ignore | |||
| cls, | |||
| _fields_set: set[str] | None = None, | |||
| **values: Unpack[UserRequestInput], | |||
| ) -> FinalRequestOptions: | |||
| kwargs: dict[str, Any] = { | |||
| # we unconditionally call `strip_not_given` on any value | |||
| # as it will just ignore any non-mapping types | |||
| key: strip_not_given(value) | |||
| for key, value in values.items() | |||
| } | |||
| if PYDANTIC_V2: | |||
| return super().model_construct(_fields_set, **kwargs) | |||
| return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] | |||
| if not TYPE_CHECKING: | |||
| # type checkers incorrectly complain about this assignment | |||
| model_construct = construct | |||
| @@ -1,398 +0,0 @@ | |||
| from __future__ import annotations | |||
| import datetime | |||
| import inspect | |||
| import logging | |||
| from collections.abc import Iterator | |||
| from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload | |||
| import httpx | |||
| import pydantic | |||
| from typing_extensions import ParamSpec, override | |||
| from ._base_models import BaseModel, is_basemodel | |||
| from ._base_type import NoneType | |||
| from ._errors import APIResponseValidationError, ZhipuAIError | |||
| from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type | |||
| from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given | |||
| if TYPE_CHECKING: | |||
| from ._http_client import HttpClient | |||
| from ._request_opt import FinalRequestOptions | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| _T = TypeVar("_T") | |||
| _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") | |||
| log: logging.Logger = logging.getLogger(__name__) | |||
| class BaseAPIResponse(Generic[R]): | |||
| _cast_type: type[R] | |||
| _client: HttpClient | |||
| _parsed_by_type: dict[type[Any], Any] | |||
| _is_sse_stream: bool | |||
| _stream_cls: type[StreamResponse[Any]] | |||
| _options: FinalRequestOptions | |||
| http_response: httpx.Response | |||
| def __init__( | |||
| self, | |||
| *, | |||
| raw: httpx.Response, | |||
| cast_type: type[R], | |||
| client: HttpClient, | |||
| stream: bool, | |||
| stream_cls: type[StreamResponse[Any]] | None = None, | |||
| options: FinalRequestOptions, | |||
| ) -> None: | |||
| self._cast_type = cast_type | |||
| self._client = client | |||
| self._parsed_by_type = {} | |||
| self._is_sse_stream = stream | |||
| self._stream_cls = stream_cls | |||
| self._options = options | |||
| self.http_response = raw | |||
| def _parse(self, *, to: type[_T] | None = None) -> R | _T: | |||
| # unwrap `Annotated[T, ...]` -> `T` | |||
| if to and is_annotated_type(to): | |||
| to = extract_type_arg(to, 0) | |||
| if self._is_sse_stream: | |||
| if to: | |||
| if not is_stream_class_type(to): | |||
| raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") | |||
| return cast( | |||
| _T, | |||
| to( | |||
| cast_type=extract_stream_chunk_type( | |||
| to, | |||
| failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 | |||
| ), | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| if self._stream_cls: | |||
| return cast( | |||
| R, | |||
| self._stream_cls( | |||
| cast_type=extract_stream_chunk_type(self._stream_cls), | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls) | |||
| if stream_cls is None: | |||
| raise MissingStreamClassError() | |||
| return cast( | |||
| R, | |||
| stream_cls( | |||
| cast_type=self._cast_type, | |||
| response=self.http_response, | |||
| client=cast(Any, self._client), | |||
| ), | |||
| ) | |||
| cast_type = to if to is not None else self._cast_type | |||
| # unwrap `Annotated[T, ...]` -> `T` | |||
| if is_annotated_type(cast_type): | |||
| cast_type = extract_type_arg(cast_type, 0) | |||
| if cast_type is NoneType: | |||
| return cast(R, None) | |||
| response = self.http_response | |||
| if cast_type == str: | |||
| return cast(R, response.text) | |||
| if cast_type == bytes: | |||
| return cast(R, response.content) | |||
| if cast_type == int: | |||
| return cast(R, int(response.text)) | |||
| if cast_type == float: | |||
| return cast(R, float(response.text)) | |||
| origin = get_origin(cast_type) or cast_type | |||
| # handle the legacy binary response case | |||
| if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent": | |||
| return cast(R, cast_type(response)) # type: ignore | |||
| if origin == APIResponse: | |||
| raise RuntimeError("Unexpected state - cast_type is `APIResponse`") | |||
| if inspect.isclass(origin) and issubclass(origin, httpx.Response): | |||
| # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response | |||
| # and pass that class to our request functions. We cannot change the variance to be either | |||
| # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct | |||
| # the response class ourselves but that is something that should be supported directly in httpx | |||
| # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. | |||
| if cast_type != httpx.Response: | |||
| raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") | |||
| return cast(R, response) | |||
| if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): | |||
| raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") | |||
| if ( | |||
| cast_type is not object | |||
| and origin is not list | |||
| and origin is not dict | |||
| and origin is not Union | |||
| and not issubclass(origin, BaseModel) | |||
| ): | |||
| raise RuntimeError( | |||
| f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 | |||
| ) | |||
| # split is required to handle cases where additional information is included | |||
| # in the response, e.g. application/json; charset=utf-8 | |||
| content_type, *_ = response.headers.get("content-type", "*").split(";") | |||
| if content_type != "application/json": | |||
| if is_basemodel(cast_type): | |||
| try: | |||
| data = response.json() | |||
| except Exception as exc: | |||
| log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) | |||
| else: | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=response, | |||
| ) | |||
| if self._client._strict_response_validation: | |||
| raise APIResponseValidationError( | |||
| response=response, | |||
| message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 | |||
| json_data=response.text, | |||
| ) | |||
| # If the API responds with content that isn't JSON then we just return | |||
| # the (decoded) text without performing any parsing so that you can still | |||
| # handle the response however you need to. | |||
| return response.text # type: ignore | |||
| data = response.json() | |||
| return self._client._process_response_data( | |||
| data=data, | |||
| cast_type=cast_type, # type: ignore | |||
| response=response, | |||
| ) | |||
| @property | |||
| def headers(self) -> httpx.Headers: | |||
| return self.http_response.headers | |||
| @property | |||
| def http_request(self) -> httpx.Request: | |||
| """Returns the httpx Request instance associated with the current response.""" | |||
| return self.http_response.request | |||
| @property | |||
| def status_code(self) -> int: | |||
| return self.http_response.status_code | |||
| @property | |||
| def url(self) -> httpx.URL: | |||
| """Returns the URL for which the request was made.""" | |||
| return self.http_response.url | |||
| @property | |||
| def method(self) -> str: | |||
| return self.http_request.method | |||
| @property | |||
| def http_version(self) -> str: | |||
| return self.http_response.http_version | |||
| @property | |||
| def elapsed(self) -> datetime.timedelta: | |||
| """The time taken for the complete request/response cycle to complete.""" | |||
| return self.http_response.elapsed | |||
| @property | |||
| def is_closed(self) -> bool: | |||
| """Whether or not the response body has been closed. | |||
| If this is False then there is response data that has not been read yet. | |||
| You must either fully consume the response body or call `.close()` | |||
| before discarding the response to prevent resource leaks. | |||
| """ | |||
| return self.http_response.is_closed | |||
| @override | |||
| def __repr__(self) -> str: | |||
| return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501 | |||
| class APIResponse(BaseAPIResponse[R]): | |||
| @property | |||
| def request_id(self) -> str | None: | |||
| return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] | |||
| @overload | |||
| def parse(self, *, to: type[_T]) -> _T: ... | |||
| @overload | |||
| def parse(self) -> R: ... | |||
| def parse(self, *, to: type[_T] | None = None) -> R | _T: | |||
| """Returns the rich python representation of this response's data. | |||
| For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. | |||
| You can customize the type that the response is parsed into through | |||
| the `to` argument, e.g. | |||
| ```py | |||
| from openai import BaseModel | |||
| class MyModel(BaseModel): | |||
| foo: str | |||
| obj = response.parse(to=MyModel) | |||
| print(obj.foo) | |||
| ``` | |||
| We support parsing: | |||
| - `BaseModel` | |||
| - `dict` | |||
| - `list` | |||
| - `Union` | |||
| - `str` | |||
| - `int` | |||
| - `float` | |||
| - `httpx.Response` | |||
| """ | |||
| cache_key = to if to is not None else self._cast_type | |||
| cached = self._parsed_by_type.get(cache_key) | |||
| if cached is not None: | |||
| return cached # type: ignore[no-any-return] | |||
| if not self._is_sse_stream: | |||
| self.read() | |||
| parsed = self._parse(to=to) | |||
| if is_given(self._options.post_parser): | |||
| parsed = self._options.post_parser(parsed) | |||
| self._parsed_by_type[cache_key] = parsed | |||
| return parsed | |||
| def read(self) -> bytes: | |||
| """Read and return the binary response content.""" | |||
| try: | |||
| return self.http_response.read() | |||
| except httpx.StreamConsumed as exc: | |||
| # The default error raised by httpx isn't very | |||
| # helpful in our case so we re-raise it with | |||
| # a different error message. | |||
| raise StreamAlreadyConsumed() from exc | |||
| def text(self) -> str: | |||
| """Read and decode the response content into a string.""" | |||
| self.read() | |||
| return self.http_response.text | |||
| def json(self) -> object: | |||
| """Read and decode the JSON response content.""" | |||
| self.read() | |||
| return self.http_response.json() | |||
| def close(self) -> None: | |||
| """Close the response and release the connection. | |||
| Automatically called if the response body is read to completion. | |||
| """ | |||
| self.http_response.close() | |||
| def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: | |||
| """ | |||
| A byte-iterator over the decoded response content. | |||
| This automatically handles gzip, deflate and brotli encoded responses. | |||
| """ | |||
| yield from self.http_response.iter_bytes(chunk_size) | |||
| def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: | |||
| """A str-iterator over the decoded response content | |||
| that handles both gzip, deflate, etc but also detects the content's | |||
| string encoding. | |||
| """ | |||
| yield from self.http_response.iter_text(chunk_size) | |||
| def iter_lines(self) -> Iterator[str]: | |||
| """Like `iter_text()` but will only yield chunks for each line""" | |||
| yield from self.http_response.iter_lines() | |||
| class MissingStreamClassError(TypeError): | |||
| def __init__(self) -> None: | |||
| super().__init__( | |||
| "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 | |||
| ) | |||
| class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818 | |||
| """ | |||
| Attempted to read or stream content, but the content has already | |||
| been streamed. | |||
| This can happen if you use a method like `.iter_lines()` and then attempt | |||
| to read th entire response body afterwards, e.g. | |||
| ```py | |||
| response = await client.post(...) | |||
| async for line in response.iter_lines(): | |||
| ... # do something with `line` | |||
| content = await response.read() | |||
| # ^ error | |||
| ``` | |||
| If you want this behavior you'll need to either manually accumulate the response | |||
| content or call `await response.read()` before iterating over the stream. | |||
| """ | |||
| def __init__(self) -> None: | |||
| message = ( | |||
| "Attempted to read or stream some content, but the content has " | |||
| "already been streamed. " | |||
| "This could be due to attempting to stream the response " | |||
| "content more than once." | |||
| "\n\n" | |||
| "You can fix this by manually accumulating the response content while streaming " | |||
| "or by calling `.read()` before starting to stream." | |||
| ) | |||
| super().__init__(message) | |||
| def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: | |||
| """Given a type like `APIResponse[T]`, returns the generic type variable `T`. | |||
| This also handles the case where a concrete subclass is given, e.g. | |||
| ```py | |||
| class MyResponse(APIResponse[bytes]): | |||
| ... | |||
| extract_response_type(MyResponse) -> bytes | |||
| ``` | |||
| """ | |||
| return extract_type_var_from_base( | |||
| typ, | |||
| generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)), | |||
| index=0, | |||
| ) | |||
| @@ -1,206 +0,0 @@ | |||
| from __future__ import annotations | |||
| import inspect | |||
| import json | |||
| from collections.abc import Iterator, Mapping | |||
| from typing import TYPE_CHECKING, Generic, TypeGuard, cast | |||
| import httpx | |||
| from . import get_origin | |||
| from ._base_type import ResponseT | |||
| from ._errors import APIResponseError | |||
| from ._utils import extract_type_var_from_base, is_mapping | |||
| _FIELD_SEPARATOR = ":" | |||
| if TYPE_CHECKING: | |||
| from ._http_client import HttpClient | |||
| class StreamResponse(Generic[ResponseT]): | |||
| response: httpx.Response | |||
| _cast_type: type[ResponseT] | |||
| def __init__( | |||
| self, | |||
| *, | |||
| cast_type: type[ResponseT], | |||
| response: httpx.Response, | |||
| client: HttpClient, | |||
| ) -> None: | |||
| self.response = response | |||
| self._cast_type = cast_type | |||
| self._data_process_func = client._process_response_data | |||
| self._stream_chunks = self.__stream__() | |||
| def __next__(self) -> ResponseT: | |||
| return self._stream_chunks.__next__() | |||
| def __iter__(self) -> Iterator[ResponseT]: | |||
| yield from self._stream_chunks | |||
| def __stream__(self) -> Iterator[ResponseT]: | |||
| sse_line_parser = SSELineParser() | |||
| iterator = sse_line_parser.iter_lines(self.response.iter_lines()) | |||
| for sse in iterator: | |||
| if sse.data.startswith("[DONE]"): | |||
| break | |||
| if sse.event is None: | |||
| data = sse.json_data() | |||
| if isinstance(data, Mapping) and data.get("error"): | |||
| raise APIResponseError( | |||
| message="An error occurred during streaming", | |||
| request=self.response.request, | |||
| json_data=data["error"], | |||
| ) | |||
| if sse.event is None: | |||
| data = sse.json_data() | |||
| if is_mapping(data) and data.get("error"): | |||
| message = None | |||
| error = data.get("error") | |||
| if is_mapping(error): | |||
| message = error.get("message") | |||
| if not message or not isinstance(message, str): | |||
| message = "An error occurred during streaming" | |||
| raise APIResponseError( | |||
| message=message, | |||
| request=self.response.request, | |||
| json_data=data["error"], | |||
| ) | |||
| yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) | |||
| else: | |||
| data = sse.json_data() | |||
| if sse.event == "error" and is_mapping(data) and data.get("error"): | |||
| message = None | |||
| error = data.get("error") | |||
| if is_mapping(error): | |||
| message = error.get("message") | |||
| if not message or not isinstance(message, str): | |||
| message = "An error occurred during streaming" | |||
| raise APIResponseError( | |||
| message=message, | |||
| request=self.response.request, | |||
| json_data=data["error"], | |||
| ) | |||
| yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) | |||
| for sse in iterator: | |||
| pass | |||
| class Event: | |||
| def __init__( | |||
| self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None | |||
| ): | |||
| self._event = event | |||
| self._data = data | |||
| self._id = id | |||
| self._retry = retry | |||
| def __repr__(self): | |||
| data_len = len(self._data) if self._data else 0 | |||
| return ( | |||
| f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" | |||
| ) | |||
| @property | |||
| def event(self): | |||
| return self._event | |||
| @property | |||
| def data(self): | |||
| return self._data | |||
| def json_data(self): | |||
| return json.loads(self._data) | |||
| @property | |||
| def id(self): | |||
| return self._id | |||
| @property | |||
| def retry(self): | |||
| return self._retry | |||
| class SSELineParser: | |||
| _data: list[str] | |||
| _event: str | None | |||
| _retry: int | None | |||
| _id: str | None | |||
| def __init__(self): | |||
| self._event = None | |||
| self._data = [] | |||
| self._id = None | |||
| self._retry = None | |||
| def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: | |||
| for line in lines: | |||
| line = line.rstrip("\n") | |||
| if not line: | |||
| if self._event is None and not self._data and self._id is None and self._retry is None: | |||
| continue | |||
| sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) | |||
| self._event = None | |||
| self._data = [] | |||
| self._id = None | |||
| self._retry = None | |||
| yield sse_event | |||
| self.decode_line(line) | |||
| def decode_line(self, line: str): | |||
| if line.startswith(":") or not line: | |||
| return | |||
| field, _p, value = line.partition(":") | |||
| value = value.removeprefix(" ") | |||
| if field == "data": | |||
| self._data.append(value) | |||
| elif field == "event": | |||
| self._event = value | |||
| elif field == "retry": | |||
| try: | |||
| self._retry = int(value) | |||
| except (TypeError, ValueError): | |||
| pass | |||
| return | |||
| def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]: | |||
| """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" | |||
| origin = get_origin(typ) or typ | |||
| return inspect.isclass(origin) and issubclass(origin, StreamResponse) | |||
| def extract_stream_chunk_type( | |||
| stream_cls: type, | |||
| *, | |||
| failure_message: str | None = None, | |||
| ) -> type: | |||
| """Given a type like `StreamResponse[T]`, returns the generic type variable `T`. | |||
| This also handles the case where a concrete subclass is given, e.g. | |||
| ```py | |||
| class MyStream(StreamResponse[bytes]): | |||
| ... | |||
| extract_stream_chunk_type(MyStream) -> bytes | |||
| ``` | |||
| """ | |||
| return extract_type_var_from_base( | |||
| stream_cls, | |||
| index=0, | |||
| generic_bases=cast("tuple[type, ...]", (StreamResponse,)), | |||
| failure_message=failure_message, | |||
| ) | |||
| @@ -1,52 +0,0 @@ | |||
| from ._utils import ( # noqa: I001 | |||
| remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414 | |||
| flatten as flatten, # noqa: PLC0414 | |||
| is_dict as is_dict, # noqa: PLC0414 | |||
| is_list as is_list, # noqa: PLC0414 | |||
| is_given as is_given, # noqa: PLC0414 | |||
| is_tuple as is_tuple, # noqa: PLC0414 | |||
| is_mapping as is_mapping, # noqa: PLC0414 | |||
| is_tuple_t as is_tuple_t, # noqa: PLC0414 | |||
| parse_date as parse_date, # noqa: PLC0414 | |||
| is_iterable as is_iterable, # noqa: PLC0414 | |||
| is_sequence as is_sequence, # noqa: PLC0414 | |||
| coerce_float as coerce_float, # noqa: PLC0414 | |||
| is_mapping_t as is_mapping_t, # noqa: PLC0414 | |||
| removeprefix as removeprefix, # noqa: PLC0414 | |||
| removesuffix as removesuffix, # noqa: PLC0414 | |||
| extract_files as extract_files, # noqa: PLC0414 | |||
| is_sequence_t as is_sequence_t, # noqa: PLC0414 | |||
| required_args as required_args, # noqa: PLC0414 | |||
| coerce_boolean as coerce_boolean, # noqa: PLC0414 | |||
| coerce_integer as coerce_integer, # noqa: PLC0414 | |||
| file_from_path as file_from_path, # noqa: PLC0414 | |||
| parse_datetime as parse_datetime, # noqa: PLC0414 | |||
| strip_not_given as strip_not_given, # noqa: PLC0414 | |||
| deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414 | |||
| get_async_library as get_async_library, # noqa: PLC0414 | |||
| maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414 | |||
| get_required_header as get_required_header, # noqa: PLC0414 | |||
| maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414 | |||
| maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414 | |||
| drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414 | |||
| ) | |||
| from ._typing import ( | |||
| is_list_type as is_list_type, # noqa: PLC0414 | |||
| is_union_type as is_union_type, # noqa: PLC0414 | |||
| extract_type_arg as extract_type_arg, # noqa: PLC0414 | |||
| is_iterable_type as is_iterable_type, # noqa: PLC0414 | |||
| is_required_type as is_required_type, # noqa: PLC0414 | |||
| is_annotated_type as is_annotated_type, # noqa: PLC0414 | |||
| strip_annotated_type as strip_annotated_type, # noqa: PLC0414 | |||
| extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414 | |||
| ) | |||
| from ._transform import ( | |||
| PropertyInfo as PropertyInfo, # noqa: PLC0414 | |||
| transform as transform, # noqa: PLC0414 | |||
| async_transform as async_transform, # noqa: PLC0414 | |||
| maybe_transform as maybe_transform, # noqa: PLC0414 | |||
| async_maybe_transform as async_maybe_transform, # noqa: PLC0414 | |||
| ) | |||
| @@ -1,383 +0,0 @@ | |||
| from __future__ import annotations | |||
| import base64 | |||
| import io | |||
| import pathlib | |||
| from collections.abc import Mapping | |||
| from datetime import date, datetime | |||
| from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints | |||
| import anyio | |||
| import pydantic | |||
| from typing_extensions import override | |||
| from .._base_compat import is_typeddict, model_dump | |||
| from .._files import is_base64_file_input | |||
| from ._typing import ( | |||
| extract_type_arg, | |||
| is_annotated_type, | |||
| is_iterable_type, | |||
| is_list_type, | |||
| is_required_type, | |||
| is_union_type, | |||
| strip_annotated_type, | |||
| ) | |||
| from ._utils import ( | |||
| is_iterable, | |||
| is_list, | |||
| is_mapping, | |||
| ) | |||
| _T = TypeVar("_T") | |||
| # TODO: support for drilling globals() and locals() | |||
| # TODO: ensure works correctly with forward references in all cases | |||
| PropertyFormat = Literal["iso8601", "base64", "custom"] | |||
| class PropertyInfo: | |||
| """Metadata class to be used in Annotated types to provide information about a given type. | |||
| For example: | |||
| class MyParams(TypedDict): | |||
| account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')] | |||
| This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API. | |||
| """ # noqa: E501 | |||
| alias: str | None | |||
| format: PropertyFormat | None | |||
| format_template: str | None | |||
| discriminator: str | None | |||
| def __init__( | |||
| self, | |||
| *, | |||
| alias: str | None = None, | |||
| format: PropertyFormat | None = None, | |||
| format_template: str | None = None, | |||
| discriminator: str | None = None, | |||
| ) -> None: | |||
| self.alias = alias | |||
| self.format = format | |||
| self.format_template = format_template | |||
| self.discriminator = discriminator | |||
| @override | |||
| def __repr__(self) -> str: | |||
| return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501 | |||
| def maybe_transform( | |||
| data: object, | |||
| expected_type: object, | |||
| ) -> Any | None: | |||
| """Wrapper over `transform()` that allows `None` to be passed. | |||
| See `transform()` for more details. | |||
| """ | |||
| if data is None: | |||
| return None | |||
| return transform(data, expected_type) | |||
| # Wrapper over _transform_recursive providing fake types | |||
| def transform( | |||
| data: _T, | |||
| expected_type: object, | |||
| ) -> _T: | |||
| """Transform dictionaries based off of type information from the given type, for example: | |||
| ```py | |||
| class Params(TypedDict, total=False): | |||
| card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] | |||
| transformed = transform({"card_id": "<my card ID>"}, Params) | |||
| # {'cardID': '<my card ID>'} | |||
| ``` | |||
| Any keys / data that does not have type information given will be included as is. | |||
| It should be noted that the transformations that this function does are not represented in the type system. | |||
| """ | |||
| transformed = _transform_recursive(data, annotation=cast(type, expected_type)) | |||
| return cast(_T, transformed) | |||
| def _get_annotated_type(type_: type) -> type | None: | |||
| """If the given type is an `Annotated` type then it is returned, if not `None` is returned. | |||
| This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]` | |||
| """ | |||
| if is_required_type(type_): | |||
| # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]` | |||
| type_ = get_args(type_)[0] | |||
| if is_annotated_type(type_): | |||
| return type_ | |||
| return None | |||
| def _maybe_transform_key(key: str, type_: type) -> str: | |||
| """Transform the given `data` based on the annotations provided in `type_`. | |||
| Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. | |||
| """ | |||
| annotated_type = _get_annotated_type(type_) | |||
| if annotated_type is None: | |||
| # no `Annotated` definition for this type, no transformation needed | |||
| return key | |||
| # ignore the first argument as it is the actual type | |||
| annotations = get_args(annotated_type)[1:] | |||
| for annotation in annotations: | |||
| if isinstance(annotation, PropertyInfo) and annotation.alias is not None: | |||
| return annotation.alias | |||
| return key | |||
| def _transform_recursive( | |||
| data: object, | |||
| *, | |||
| annotation: type, | |||
| inner_type: type | None = None, | |||
| ) -> object: | |||
| """Transform the given data against the expected type. | |||
| Args: | |||
| annotation: The direct type annotation given to the particular piece of data. | |||
| This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc | |||
| inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type | |||
| is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in | |||
| the list can be transformed using the metadata from the container type. | |||
| Defaults to the same value as the `annotation` argument. | |||
| """ | |||
| if inner_type is None: | |||
| inner_type = annotation | |||
| stripped_type = strip_annotated_type(inner_type) | |||
| if is_typeddict(stripped_type) and is_mapping(data): | |||
| return _transform_typeddict(data, stripped_type) | |||
| if ( | |||
| # List[T] | |||
| (is_list_type(stripped_type) and is_list(data)) | |||
| # Iterable[T] | |||
| or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) | |||
| ): | |||
| inner_type = extract_type_arg(stripped_type, 0) | |||
| return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] | |||
| if is_union_type(stripped_type): | |||
| # For union types we run the transformation against all subtypes to ensure that everything is transformed. | |||
| # | |||
| # TODO: there may be edge cases where the same normalized field name will transform to two different names | |||
| # in different subtypes. | |||
| for subtype in get_args(stripped_type): | |||
| data = _transform_recursive(data, annotation=annotation, inner_type=subtype) | |||
| return data | |||
| if isinstance(data, pydantic.BaseModel): | |||
| return model_dump(data, exclude_unset=True) | |||
| annotated_type = _get_annotated_type(annotation) | |||
| if annotated_type is None: | |||
| return data | |||
| # ignore the first argument as it is the actual type | |||
| annotations = get_args(annotated_type)[1:] | |||
| for annotation in annotations: | |||
| if isinstance(annotation, PropertyInfo) and annotation.format is not None: | |||
| return _format_data(data, annotation.format, annotation.format_template) | |||
| return data | |||
| def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: | |||
| if isinstance(data, date | datetime): | |||
| if format_ == "iso8601": | |||
| return data.isoformat() | |||
| if format_ == "custom" and format_template is not None: | |||
| return data.strftime(format_template) | |||
| if format_ == "base64" and is_base64_file_input(data): | |||
| binary: str | bytes | None = None | |||
| if isinstance(data, pathlib.Path): | |||
| binary = data.read_bytes() | |||
| elif isinstance(data, io.IOBase): | |||
| binary = data.read() | |||
| if isinstance(binary, str): # type: ignore[unreachable] | |||
| binary = binary.encode() | |||
| if not isinstance(binary, bytes): | |||
| raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") | |||
| return base64.b64encode(binary).decode("ascii") | |||
| return data | |||
| def _transform_typeddict( | |||
| data: Mapping[str, object], | |||
| expected_type: type, | |||
| ) -> Mapping[str, object]: | |||
| result: dict[str, object] = {} | |||
| annotations = get_type_hints(expected_type, include_extras=True) | |||
| for key, value in data.items(): | |||
| type_ = annotations.get(key) | |||
| if type_ is None: | |||
| # we do not have a type annotation for this field, leave it as is | |||
| result[key] = value | |||
| else: | |||
| result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) | |||
| return result | |||
| async def async_maybe_transform( | |||
| data: object, | |||
| expected_type: object, | |||
| ) -> Any | None: | |||
| """Wrapper over `async_transform()` that allows `None` to be passed. | |||
| See `async_transform()` for more details. | |||
| """ | |||
| if data is None: | |||
| return None | |||
| return await async_transform(data, expected_type) | |||
| async def async_transform( | |||
| data: _T, | |||
| expected_type: object, | |||
| ) -> _T: | |||
| """Transform dictionaries based off of type information from the given type, for example: | |||
| ```py | |||
| class Params(TypedDict, total=False): | |||
| card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] | |||
| transformed = transform({"card_id": "<my card ID>"}, Params) | |||
| # {'cardID': '<my card ID>'} | |||
| ``` | |||
| Any keys / data that does not have type information given will be included as is. | |||
| It should be noted that the transformations that this function does are not represented in the type system. | |||
| """ | |||
| transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) | |||
| return cast(_T, transformed) | |||
| async def _async_transform_recursive( | |||
| data: object, | |||
| *, | |||
| annotation: type, | |||
| inner_type: type | None = None, | |||
| ) -> object: | |||
| """Transform the given data against the expected type. | |||
| Args: | |||
| annotation: The direct type annotation given to the particular piece of data. | |||
| This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc | |||
| inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type | |||
| is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in | |||
| the list can be transformed using the metadata from the container type. | |||
| Defaults to the same value as the `annotation` argument. | |||
| """ | |||
| if inner_type is None: | |||
| inner_type = annotation | |||
| stripped_type = strip_annotated_type(inner_type) | |||
| if is_typeddict(stripped_type) and is_mapping(data): | |||
| return await _async_transform_typeddict(data, stripped_type) | |||
| if ( | |||
| # List[T] | |||
| (is_list_type(stripped_type) and is_list(data)) | |||
| # Iterable[T] | |||
| or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) | |||
| ): | |||
| inner_type = extract_type_arg(stripped_type, 0) | |||
| return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] | |||
| if is_union_type(stripped_type): | |||
| # For union types we run the transformation against all subtypes to ensure that everything is transformed. | |||
| # | |||
| # TODO: there may be edge cases where the same normalized field name will transform to two different names | |||
| # in different subtypes. | |||
| for subtype in get_args(stripped_type): | |||
| data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) | |||
| return data | |||
| if isinstance(data, pydantic.BaseModel): | |||
| return model_dump(data, exclude_unset=True) | |||
| annotated_type = _get_annotated_type(annotation) | |||
| if annotated_type is None: | |||
| return data | |||
| # ignore the first argument as it is the actual type | |||
| annotations = get_args(annotated_type)[1:] | |||
| for annotation in annotations: | |||
| if isinstance(annotation, PropertyInfo) and annotation.format is not None: | |||
| return await _async_format_data(data, annotation.format, annotation.format_template) | |||
| return data | |||
| async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: | |||
| if isinstance(data, date | datetime): | |||
| if format_ == "iso8601": | |||
| return data.isoformat() | |||
| if format_ == "custom" and format_template is not None: | |||
| return data.strftime(format_template) | |||
| if format_ == "base64" and is_base64_file_input(data): | |||
| binary: str | bytes | None = None | |||
| if isinstance(data, pathlib.Path): | |||
| binary = await anyio.Path(data).read_bytes() | |||
| elif isinstance(data, io.IOBase): | |||
| binary = data.read() | |||
| if isinstance(binary, str): # type: ignore[unreachable] | |||
| binary = binary.encode() | |||
| if not isinstance(binary, bytes): | |||
| raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") | |||
| return base64.b64encode(binary).decode("ascii") | |||
| return data | |||
| async def _async_transform_typeddict( | |||
| data: Mapping[str, object], | |||
| expected_type: type, | |||
| ) -> Mapping[str, object]: | |||
| result: dict[str, object] = {} | |||
| annotations = get_type_hints(expected_type, include_extras=True) | |||
| for key, value in data.items(): | |||
| type_ = annotations.get(key) | |||
| if type_ is None: | |||
| # we do not have a type annotation for this field, leave it as is | |||
| result[key] = value | |||
| else: | |||
| result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) | |||
| return result | |||
| @@ -1,122 +0,0 @@ | |||
| from __future__ import annotations | |||
| from collections import abc as _c_abc | |||
| from collections.abc import Iterable | |||
| from typing import Annotated, Any, TypeVar, cast, get_args, get_origin | |||
| from typing_extensions import Required | |||
| from .._base_compat import is_union as _is_union | |||
| from .._base_type import InheritsGeneric | |||
| def is_annotated_type(typ: type) -> bool: | |||
| return get_origin(typ) == Annotated | |||
| def is_list_type(typ: type) -> bool: | |||
| return (get_origin(typ) or typ) == list | |||
| def is_iterable_type(typ: type) -> bool: | |||
| """If the given type is `typing.Iterable[T]`""" | |||
| origin = get_origin(typ) or typ | |||
| return origin in {Iterable, _c_abc.Iterable} | |||
| def is_union_type(typ: type) -> bool: | |||
| return _is_union(get_origin(typ)) | |||
| def is_required_type(typ: type) -> bool: | |||
| return get_origin(typ) == Required | |||
| def is_typevar(typ: type) -> bool: | |||
| # type ignore is required because type checkers | |||
| # think this expression will always return False | |||
| return type(typ) == TypeVar # type: ignore | |||
| # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] | |||
| def strip_annotated_type(typ: type) -> type: | |||
| if is_required_type(typ) or is_annotated_type(typ): | |||
| return strip_annotated_type(cast(type, get_args(typ)[0])) | |||
| return typ | |||
| def extract_type_arg(typ: type, index: int) -> type: | |||
| args = get_args(typ) | |||
| try: | |||
| return cast(type, args[index]) | |||
| except IndexError as err: | |||
| raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err | |||
| def extract_type_var_from_base( | |||
| typ: type, | |||
| *, | |||
| generic_bases: tuple[type, ...], | |||
| index: int, | |||
| failure_message: str | None = None, | |||
| ) -> type: | |||
| """Given a type like `Foo[T]`, returns the generic type variable `T`. | |||
| This also handles the case where a concrete subclass is given, e.g. | |||
| ```py | |||
| class MyResponse(Foo[bytes]): | |||
| ... | |||
| extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes | |||
| ``` | |||
| And where a generic subclass is given: | |||
| ```py | |||
| _T = TypeVar('_T') | |||
| class MyResponse(Foo[_T]): | |||
| ... | |||
| extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes | |||
| ``` | |||
| """ | |||
| cls = cast(object, get_origin(typ) or typ) | |||
| if cls in generic_bases: | |||
| # we're given the class directly | |||
| return extract_type_arg(typ, index) | |||
| # if a subclass is given | |||
| # --- | |||
| # this is needed as __orig_bases__ is not present in the typeshed stubs | |||
| # because it is intended to be for internal use only, however there does | |||
| # not seem to be a way to resolve generic TypeVars for inherited subclasses | |||
| # without using it. | |||
| if isinstance(cls, InheritsGeneric): | |||
| target_base_class: Any | None = None | |||
| for base in cls.__orig_bases__: | |||
| if base.__origin__ in generic_bases: | |||
| target_base_class = base | |||
| break | |||
| if target_base_class is None: | |||
| raise RuntimeError( | |||
| "Could not find the generic base class;\n" | |||
| "This should never happen;\n" | |||
| f"Does {cls} inherit from one of {generic_bases} ?" | |||
| ) | |||
| extracted = extract_type_arg(target_base_class, index) | |||
| if is_typevar(extracted): | |||
| # If the extracted type argument is itself a type variable | |||
| # then that means the subclass itself is generic, so we have | |||
| # to resolve the type argument from the class itself, not | |||
| # the base class. | |||
| # | |||
| # Note: if there is more than 1 type argument, the subclass could | |||
| # change the ordering of the type arguments, this is not currently | |||
| # supported. | |||
| return extract_type_arg(typ, index) | |||
| return extracted | |||
| raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") | |||
| @@ -1,409 +0,0 @@ | |||
| from __future__ import annotations | |||
| import functools | |||
| import inspect | |||
| import os | |||
| import re | |||
| from collections.abc import Callable, Iterable, Mapping, Sequence | |||
| from pathlib import Path | |||
| from typing import ( | |||
| Any, | |||
| TypeGuard, | |||
| TypeVar, | |||
| Union, | |||
| cast, | |||
| overload, | |||
| ) | |||
| import sniffio | |||
| from .._base_compat import parse_date as parse_date # noqa: PLC0414 | |||
| from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414 | |||
| from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr | |||
| def remove_notgiven_indict(obj): | |||
| if obj is None or (not isinstance(obj, Mapping)): | |||
| return obj | |||
| return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} | |||
| _T = TypeVar("_T") | |||
| _TupleT = TypeVar("_TupleT", bound=tuple[object, ...]) | |||
| _MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) | |||
| _SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) | |||
| CallableT = TypeVar("CallableT", bound=Callable[..., Any]) | |||
| def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: | |||
| return [item for sublist in t for item in sublist] | |||
| def extract_files( | |||
| # TODO: this needs to take Dict but variance issues..... | |||
| # create protocol type ? | |||
| query: Mapping[str, object], | |||
| *, | |||
| paths: Sequence[Sequence[str]], | |||
| ) -> list[tuple[str, FileTypes]]: | |||
| """Recursively extract files from the given dictionary based on specified paths. | |||
| A path may look like this ['foo', 'files', '<array>', 'data']. | |||
| Note: this mutates the given dictionary. | |||
| """ | |||
| files: list[tuple[str, FileTypes]] = [] | |||
| for path in paths: | |||
| files.extend(_extract_items(query, path, index=0, flattened_key=None)) | |||
| return files | |||
| def _extract_items( | |||
| obj: object, | |||
| path: Sequence[str], | |||
| *, | |||
| index: int, | |||
| flattened_key: str | None, | |||
| ) -> list[tuple[str, FileTypes]]: | |||
| try: | |||
| key = path[index] | |||
| except IndexError: | |||
| if isinstance(obj, NotGiven): | |||
| # no value was provided - we can safely ignore | |||
| return [] | |||
| # cyclical import | |||
| from .._files import assert_is_file_content | |||
| # We have exhausted the path, return the entry we found. | |||
| assert_is_file_content(obj, key=flattened_key) | |||
| assert flattened_key is not None | |||
| return [(flattened_key, cast(FileTypes, obj))] | |||
| index += 1 | |||
| if is_dict(obj): | |||
| try: | |||
| # We are at the last entry in the path so we must remove the field | |||
| if (len(path)) == index: | |||
| item = obj.pop(key) | |||
| else: | |||
| item = obj[key] | |||
| except KeyError: | |||
| # Key was not present in the dictionary, this is not indicative of an error | |||
| # as the given path may not point to a required field. We also do not want | |||
| # to enforce required fields as the API may differ from the spec in some cases. | |||
| return [] | |||
| if flattened_key is None: | |||
| flattened_key = key | |||
| else: | |||
| flattened_key += f"[{key}]" | |||
| return _extract_items( | |||
| item, | |||
| path, | |||
| index=index, | |||
| flattened_key=flattened_key, | |||
| ) | |||
| elif is_list(obj): | |||
| if key != "<array>": | |||
| return [] | |||
| return flatten( | |||
| [ | |||
| _extract_items( | |||
| item, | |||
| path, | |||
| index=index, | |||
| flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", | |||
| ) | |||
| for item in obj | |||
| ] | |||
| ) | |||
| # Something unexpected was passed, just ignore it. | |||
| return [] | |||
| def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: | |||
| return not isinstance(obj, NotGiven) | |||
| # Type safe methods for narrowing types with TypeVars. | |||
| # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], | |||
| # however this cause Pyright to rightfully report errors. As we know we don't | |||
| # care about the contained types we can safely use `object` in it's place. | |||
| # | |||
| # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. | |||
| # `is_*` is for when you're dealing with an unknown input | |||
| # `is_*_t` is for when you're narrowing a known union type to a specific subset | |||
| def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: | |||
| return isinstance(obj, tuple) | |||
| def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: | |||
| return isinstance(obj, tuple) | |||
| def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: | |||
| return isinstance(obj, Sequence) | |||
| def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: | |||
| return isinstance(obj, Sequence) | |||
| def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: | |||
| return isinstance(obj, Mapping) | |||
| def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: | |||
| return isinstance(obj, Mapping) | |||
| def is_dict(obj: object) -> TypeGuard[dict[object, object]]: | |||
| return isinstance(obj, dict) | |||
| def is_list(obj: object) -> TypeGuard[list[object]]: | |||
| return isinstance(obj, list) | |||
| def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: | |||
| return isinstance(obj, Iterable) | |||
| def deepcopy_minimal(item: _T) -> _T: | |||
| """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: | |||
| - mappings, e.g. `dict` | |||
| - list | |||
| This is done for performance reasons. | |||
| """ | |||
| if is_mapping(item): | |||
| return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) | |||
| if is_list(item): | |||
| return cast(_T, [deepcopy_minimal(entry) for entry in item]) | |||
| return item | |||
| # copied from https://github.com/Rapptz/RoboDanny | |||
| def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: | |||
| size = len(seq) | |||
| if size == 0: | |||
| return "" | |||
| if size == 1: | |||
| return seq[0] | |||
| if size == 2: | |||
| return f"{seq[0]} {final} {seq[1]}" | |||
| return delim.join(seq[:-1]) + f" {final} {seq[-1]}" | |||
| def quote(string: str) -> str: | |||
| """Add single quotation marks around the given string. Does *not* do any escaping.""" | |||
| return f"'{string}'" | |||
| def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: | |||
| """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. | |||
| Useful for enforcing runtime validation of overloaded functions. | |||
| Example usage: | |||
| ```py | |||
| @overload | |||
| def foo(*, a: str) -> str: | |||
| ... | |||
| @overload | |||
| def foo(*, b: bool) -> str: | |||
| ... | |||
| # This enforces the same constraints that a static type checker would | |||
| # i.e. that either a or b must be passed to the function | |||
| @required_args(["a"], ["b"]) | |||
| def foo(*, a: str | None = None, b: bool | None = None) -> str: | |||
| ... | |||
| ``` | |||
| """ | |||
| def inner(func: CallableT) -> CallableT: | |||
| params = inspect.signature(func).parameters | |||
| positional = [ | |||
| name | |||
| for name, param in params.items() | |||
| if param.kind | |||
| in { | |||
| param.POSITIONAL_ONLY, | |||
| param.POSITIONAL_OR_KEYWORD, | |||
| } | |||
| ] | |||
| @functools.wraps(func) | |||
| def wrapper(*args: object, **kwargs: object) -> object: | |||
| given_params: set[str] = set() | |||
| for i in range(len(args)): | |||
| try: | |||
| given_params.add(positional[i]) | |||
| except IndexError: | |||
| raise TypeError( | |||
| f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" | |||
| ) from None | |||
| given_params.update(kwargs.keys()) | |||
| for variant in variants: | |||
| matches = all(param in given_params for param in variant) | |||
| if matches: | |||
| break | |||
| else: # no break | |||
| if len(variants) > 1: | |||
| variations = human_join( | |||
| ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] | |||
| ) | |||
| msg = f"Missing required arguments; Expected either {variations} arguments to be given" | |||
| else: | |||
| # TODO: this error message is not deterministic | |||
| missing = list(set(variants[0]) - given_params) | |||
| if len(missing) > 1: | |||
| msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" | |||
| else: | |||
| msg = f"Missing required argument: {quote(missing[0])}" | |||
| raise TypeError(msg) | |||
| return func(*args, **kwargs) | |||
| return wrapper # type: ignore | |||
| return inner | |||
| _K = TypeVar("_K") | |||
| _V = TypeVar("_V") | |||
| @overload | |||
| def strip_not_given(obj: None) -> None: ... | |||
| @overload | |||
| def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... | |||
| @overload | |||
| def strip_not_given(obj: object) -> object: ... | |||
| def strip_not_given(obj: object | None) -> object: | |||
| """Remove all top-level keys where their values are instances of `NotGiven`""" | |||
| if obj is None: | |||
| return None | |||
| if not is_mapping(obj): | |||
| return obj | |||
| return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} | |||
| def coerce_integer(val: str) -> int: | |||
| return int(val, base=10) | |||
| def coerce_float(val: str) -> float: | |||
| return float(val) | |||
| def coerce_boolean(val: str) -> bool: | |||
| return val in {"true", "1", "on"} | |||
| def maybe_coerce_integer(val: str | None) -> int | None: | |||
| if val is None: | |||
| return None | |||
| return coerce_integer(val) | |||
| def maybe_coerce_float(val: str | None) -> float | None: | |||
| if val is None: | |||
| return None | |||
| return coerce_float(val) | |||
| def maybe_coerce_boolean(val: str | None) -> bool | None: | |||
| if val is None: | |||
| return None | |||
| return coerce_boolean(val) | |||
| def removeprefix(string: str, prefix: str) -> str: | |||
| """Remove a prefix from a string. | |||
| Backport of `str.removeprefix` for Python < 3.9 | |||
| """ | |||
| if string.startswith(prefix): | |||
| return string[len(prefix) :] | |||
| return string | |||
| def removesuffix(string: str, suffix: str) -> str: | |||
| """Remove a suffix from a string. | |||
| Backport of `str.removesuffix` for Python < 3.9 | |||
| """ | |||
| if string.endswith(suffix): | |||
| return string[: -len(suffix)] | |||
| return string | |||
| def file_from_path(path: str) -> FileTypes: | |||
| contents = Path(path).read_bytes() | |||
| file_name = os.path.basename(path) | |||
| return (file_name, contents) | |||
| def get_required_header(headers: HeadersLike, header: str) -> str: | |||
| lower_header = header.lower() | |||
| if isinstance(headers, Mapping): | |||
| headers = cast(Headers, headers) | |||
| for k, v in headers.items(): | |||
| if k.lower() == lower_header and isinstance(v, str): | |||
| return v | |||
| """ to deal with the case where the header looks like Stainless-Event-Id """ | |||
| intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) | |||
| for normalized_header in [header, lower_header, header.upper(), intercaps_header]: | |||
| value = headers.get(normalized_header) | |||
| if value: | |||
| return value | |||
| raise ValueError(f"Could not find {header} header") | |||
| def get_async_library() -> str: | |||
| try: | |||
| return sniffio.current_async_library() | |||
| except Exception: | |||
| return "false" | |||
| def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]: | |||
| """ | |||
| 删除 ;base64, 前缀 | |||
| :param image_data: | |||
| :return: | |||
| """ | |||
| if isinstance(content, list): | |||
| for data in content: | |||
| if data.get("type") == "image_url": | |||
| image_data = data.get("image_url").get("url") | |||
| if image_data.startswith("data:image/"): | |||
| image_data = image_data.split("base64,")[-1] | |||
| data["image_url"]["url"] = image_data | |||
| return content | |||
| @@ -1,78 +0,0 @@ | |||
| import logging | |||
| import os | |||
| import time | |||
| logger = logging.getLogger(__name__) | |||
| class LoggerNameFilter(logging.Filter): | |||
| def filter(self, record): | |||
| # return record.name.startswith("loom_core") or record.name in "ERROR" or ( | |||
| # record.name.startswith("uvicorn.error") | |||
| # and record.getMessage().startswith("Uvicorn running on") | |||
| # ) | |||
| return True | |||
| def get_log_file(log_path: str, sub_dir: str): | |||
| """ | |||
| sub_dir should contain a timestamp. | |||
| """ | |||
| log_dir = os.path.join(log_path, sub_dir) | |||
| # Here should be creating a new directory each time, so `exist_ok=False` | |||
| os.makedirs(log_dir, exist_ok=False) | |||
| return os.path.join(log_dir, "zhipuai.log") | |||
| def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict: | |||
| # for windows, the path should be a raw string. | |||
| log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path | |||
| log_level = log_level.upper() | |||
| config_dict = { | |||
| "version": 1, | |||
| "disable_existing_loggers": False, | |||
| "formatters": { | |||
| "formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")}, | |||
| }, | |||
| "filters": { | |||
| "logger_name_filter": { | |||
| "()": __name__ + ".LoggerNameFilter", | |||
| }, | |||
| }, | |||
| "handlers": { | |||
| "stream_handler": { | |||
| "class": "logging.StreamHandler", | |||
| "formatter": "formatter", | |||
| "level": log_level, | |||
| # "stream": "ext://sys.stdout", | |||
| # "filters": ["logger_name_filter"], | |||
| }, | |||
| "file_handler": { | |||
| "class": "logging.handlers.RotatingFileHandler", | |||
| "formatter": "formatter", | |||
| "level": log_level, | |||
| "filename": log_file_path, | |||
| "mode": "a", | |||
| "maxBytes": log_max_bytes, | |||
| "backupCount": log_backup_count, | |||
| "encoding": "utf8", | |||
| }, | |||
| }, | |||
| "loggers": { | |||
| "loom_core": { | |||
| "handlers": ["stream_handler", "file_handler"], | |||
| "level": log_level, | |||
| "propagate": False, | |||
| } | |||
| }, | |||
| "root": { | |||
| "level": log_level, | |||
| "handlers": ["stream_handler", "file_handler"], | |||
| }, | |||
| } | |||
| return config_dict | |||
| def get_timestamp_ms(): | |||
| t = time.time() | |||
| return int(round(t * 1000)) | |||
| @@ -1,62 +0,0 @@ | |||
| # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. | |||
| from typing import Any, Generic, Optional, TypeVar, cast | |||
| from typing_extensions import Protocol, override, runtime_checkable | |||
| from ._http_client import BasePage, BaseSyncPage, PageInfo | |||
| __all__ = ["SyncPage", "SyncCursorPage"] | |||
| _T = TypeVar("_T") | |||
| @runtime_checkable | |||
| class CursorPageItem(Protocol): | |||
| id: Optional[str] | |||
| class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): | |||
| """Note: no pagination actually occurs yet, this is for forwards-compatibility.""" | |||
| data: list[_T] | |||
| object: str | |||
| @override | |||
| def _get_page_items(self) -> list[_T]: | |||
| data = self.data | |||
| if not data: | |||
| return [] | |||
| return data | |||
| @override | |||
| def next_page_info(self) -> None: | |||
| """ | |||
| This page represents a response that isn't actually paginated at the API level | |||
| so there will never be a next page. | |||
| """ | |||
| return None | |||
| class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): | |||
| data: list[_T] | |||
| @override | |||
| def _get_page_items(self) -> list[_T]: | |||
| data = self.data | |||
| if not data: | |||
| return [] | |||
| return data | |||
| @override | |||
| def next_page_info(self) -> Optional[PageInfo]: | |||
| data = self.data | |||
| if not data: | |||
| return None | |||
| item = cast(Any, data[-1]) | |||
| if not isinstance(item, CursorPageItem) or item.id is None: | |||
| # TODO emit warning log | |||
| return None | |||
| return PageInfo(params={"after": item.id}) | |||
| @@ -1,5 +0,0 @@ | |||
| from .assistant_completion import AssistantCompletion | |||
| __all__ = [ | |||
| "AssistantCompletion", | |||
| ] | |||
| @@ -1,40 +0,0 @@ | |||
| from typing import Any, Optional | |||
| from ...core import BaseModel | |||
| from .message import MessageContent | |||
| __all__ = ["AssistantCompletion", "CompletionUsage"] | |||
| class ErrorInfo(BaseModel): | |||
| code: str # 错误码 | |||
| message: str # 错误信息 | |||
| class AssistantChoice(BaseModel): | |||
| index: int # 结果下标 | |||
| delta: MessageContent # 当前会话输出消息体 | |||
| finish_reason: str | |||
| """ | |||
| # 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。 | |||
| # network_error 代表模型推理服务异常。 | |||
| """ # noqa: E501 | |||
| metadata: dict # 元信息,拓展字段 | |||
| class CompletionUsage(BaseModel): | |||
| prompt_tokens: int # 输入的 tokens 数量 | |||
| completion_tokens: int # 输出的 tokens 数量 | |||
| total_tokens: int # 总 tokens 数量 | |||
| class AssistantCompletion(BaseModel): | |||
| id: str # 请求 ID | |||
| conversation_id: str # 会话 ID | |||
| assistant_id: str # 智能体 ID | |||
| created: int # 请求创建时间,Unix 时间戳 | |||
| status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常 | |||
| last_error: Optional[ErrorInfo] # 异常信息 | |||
| choices: list[AssistantChoice] # 增量返回的信息 | |||
| metadata: Optional[dict[str, Any]] # 元信息,拓展字段 | |||
| usage: Optional[CompletionUsage] # tokens 数量统计 | |||
| @@ -1,7 +0,0 @@ | |||
| from typing import TypedDict | |||
| class ConversationParameters(TypedDict, total=False): | |||
| assistant_id: str # 智能体 ID | |||
| page: int # 当前分页 | |||
| page_size: int # 分页数量 | |||
| @@ -1,29 +0,0 @@ | |||
| from ...core import BaseModel | |||
| __all__ = ["ConversationUsageListResp"] | |||
| class Usage(BaseModel): | |||
| prompt_tokens: int # 用户输入的 tokens 数量 | |||
| completion_tokens: int # 模型输入的 tokens 数量 | |||
| total_tokens: int # 总 tokens 数量 | |||
| class ConversationUsage(BaseModel): | |||
| id: str # 会话 id | |||
| assistant_id: str # 智能体Assistant id | |||
| create_time: int # 创建时间 | |||
| update_time: int # 更新时间 | |||
| usage: Usage # 会话中 tokens 数量统计 | |||
| class ConversationUsageList(BaseModel): | |||
| assistant_id: str # 智能体id | |||
| has_more: bool # 是否还有更多页 | |||
| conversation_list: list[ConversationUsage] # 返回的 | |||
| class ConversationUsageListResp(BaseModel): | |||
| code: int | |||
| msg: str | |||
| data: ConversationUsageList | |||
| @@ -1,32 +0,0 @@ | |||
| from typing import Optional, TypedDict, Union | |||
| class AssistantAttachments: | |||
| file_id: str | |||
| class MessageTextContent: | |||
| type: str # 目前支持 type = text | |||
| text: str | |||
| MessageContent = Union[MessageTextContent] | |||
| class ConversationMessage(TypedDict): | |||
| """会话消息体""" | |||
| role: str # 用户的输入角色,例如 'user' | |||
| content: list[MessageContent] # 会话消息体的内容 | |||
| class AssistantParameters(TypedDict, total=False): | |||
| """智能体参数类""" | |||
| assistant_id: str # 智能体 ID | |||
| conversation_id: Optional[str] # 会话 ID,不传则创建新会话 | |||
| model: str # 模型名称,默认为 'GLM-4-Assistant' | |||
| stream: bool # 是否支持流式 SSE,需要传入 True | |||
| messages: list[ConversationMessage] # 会话消息体 | |||
| attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填 | |||
| metadata: Optional[dict] # 元信息,拓展字段,非必填 | |||
| @@ -1,21 +0,0 @@ | |||
| from ...core import BaseModel | |||
| __all__ = ["AssistantSupportResp"] | |||
| class AssistantSupport(BaseModel): | |||
| assistant_id: str # 智能体的 Assistant id,用于智能体会话 | |||
| created_at: int # 创建时间 | |||
| updated_at: int # 更新时间 | |||
| name: str # 智能体名称 | |||
| avatar: str # 智能体头像 | |||
| description: str # 智能体描述 | |||
| status: str # 智能体状态,目前只有 publish | |||
| tools: list[str] # 智能体支持的工具名 | |||
| starter_prompts: list[str] # 智能体启动推荐的 prompt | |||
| class AssistantSupportResp(BaseModel): | |||
| code: int | |||
| msg: str | |||
| data: list[AssistantSupport] # 智能体列表 | |||
| @@ -1,3 +0,0 @@ | |||
| from .message_content import MessageContent | |||
| __all__ = ["MessageContent"] | |||
| @@ -1,13 +0,0 @@ | |||
| from typing import Annotated, TypeAlias, Union | |||
| from ....core._utils import PropertyInfo | |||
| from .text_content_block import TextContentBlock | |||
| from .tools_delta_block import ToolsDeltaBlock | |||
| __all__ = ["MessageContent"] | |||
| MessageContent: TypeAlias = Annotated[ | |||
| Union[ToolsDeltaBlock, TextContentBlock], | |||
| PropertyInfo(discriminator="type"), | |||
| ] | |||
| @@ -1,14 +0,0 @@ | |||
| from typing import Literal | |||
| from ....core import BaseModel | |||
| __all__ = ["TextContentBlock"] | |||
| class TextContentBlock(BaseModel): | |||
| content: str | |||
| role: str = "assistant" | |||
| type: Literal["content"] = "content" | |||
| """Always `content`.""" | |||
| @@ -1,27 +0,0 @@ | |||
| from typing import Literal | |||
| __all__ = ["CodeInterpreterToolBlock"] | |||
| from .....core import BaseModel | |||
| class CodeInterpreterToolOutput(BaseModel): | |||
| """代码工具输出结果""" | |||
| type: str # 代码执行日志,目前只有 logs | |||
| logs: str # 代码执行的日志结果 | |||
| error_msg: str # 错误信息 | |||
| class CodeInterpreter(BaseModel): | |||
| """代码解释器""" | |||
| input: str # 生成的代码片段,输入给代码沙盒 | |||
| outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果 | |||
| class CodeInterpreterToolBlock(BaseModel): | |||
| """代码工具块""" | |||
| code_interpreter: CodeInterpreter # 代码解释器对象 | |||
| type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter` | |||
| @@ -1,21 +0,0 @@ | |||
| from typing import Literal | |||
| from .....core import BaseModel | |||
| __all__ = ["DrawingToolBlock"] | |||
| class DrawingToolOutput(BaseModel): | |||
| image: str | |||
| class DrawingTool(BaseModel): | |||
| input: str | |||
| outputs: list[DrawingToolOutput] | |||
| class DrawingToolBlock(BaseModel): | |||
| drawing_tool: DrawingTool | |||
| type: Literal["drawing_tool"] | |||
| """Always `drawing_tool`.""" | |||
| @@ -1,22 +0,0 @@ | |||
| from typing import Literal, Union | |||
| __all__ = ["FunctionToolBlock"] | |||
| from .....core import BaseModel | |||
| class FunctionToolOutput(BaseModel): | |||
| content: str | |||
| class FunctionTool(BaseModel): | |||
| name: str | |||
| arguments: Union[str, dict] | |||
| outputs: list[FunctionToolOutput] | |||
| class FunctionToolBlock(BaseModel): | |||
| function: FunctionTool | |||
| type: Literal["function"] | |||
| """Always `drawing_tool`.""" | |||
| @@ -1,41 +0,0 @@ | |||
| from typing import Literal | |||
| from .....core import BaseModel | |||
| class RetrievalToolOutput(BaseModel): | |||
| """ | |||
| This class represents the output of a retrieval tool. | |||
| Attributes: | |||
| - text (str): The text snippet retrieved from the knowledge base. | |||
| - document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration. | |||
| """ # noqa: E501 | |||
| text: str | |||
| document: str | |||
| class RetrievalTool(BaseModel): | |||
| """ | |||
| This class represents the outputs of a retrieval tool. | |||
| Attributes: | |||
| - outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base. | |||
| """ # noqa: E501 | |||
| outputs: list[RetrievalToolOutput] | |||
| class RetrievalToolBlock(BaseModel): | |||
| """ | |||
| This class represents a block for invoking the retrieval tool. | |||
| Attributes: | |||
| - retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs. | |||
| - type (Literal["retrieval"]): The type of tool being used, always set to "retrieval". | |||
| """ | |||
| retrieval: RetrievalTool | |||
| type: Literal["retrieval"] | |||
| """Always `retrieval`.""" | |||
| @@ -1,16 +0,0 @@ | |||
| from typing import Annotated, TypeAlias, Union | |||
| from .....core._utils import PropertyInfo | |||
| from .code_interpreter_delta_block import CodeInterpreterToolBlock | |||
| from .drawing_tool_delta_block import DrawingToolBlock | |||
| from .function_delta_block import FunctionToolBlock | |||
| from .retrieval_delta_black import RetrievalToolBlock | |||
| from .web_browser_delta_block import WebBrowserToolBlock | |||
| __all__ = ["ToolsType"] | |||
| ToolsType: TypeAlias = Annotated[ | |||
| Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock], | |||
| PropertyInfo(discriminator="type"), | |||
| ] | |||
| @@ -1,48 +0,0 @@ | |||
| from typing import Literal | |||
| from .....core import BaseModel | |||
| __all__ = ["WebBrowserToolBlock"] | |||
| class WebBrowserOutput(BaseModel): | |||
| """ | |||
| This class represents the output of a web browser search result. | |||
| Attributes: | |||
| - title (str): The title of the search result. | |||
| - link (str): The URL link to the search result's webpage. | |||
| - content (str): The textual content extracted from the search result. | |||
| - error_msg (str): Any error message encountered during the search or retrieval process. | |||
| """ | |||
| title: str | |||
| link: str | |||
| content: str | |||
| error_msg: str | |||
| class WebBrowser(BaseModel): | |||
| """ | |||
| This class represents the input and outputs of a web browser search. | |||
| Attributes: | |||
| - input (str): The input query for the web browser search. | |||
| - outputs (List[WebBrowserOutput]): A list of search results returned by the web browser. | |||
| """ | |||
| input: str | |||
| outputs: list[WebBrowserOutput] | |||
| class WebBrowserToolBlock(BaseModel): | |||
| """ | |||
| This class represents a block for invoking the web browser tool. | |||
| Attributes: | |||
| - web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs. | |||
| - type (Literal["web_browser"]): The type of tool being used, always set to "web_browser". | |||
| """ | |||
| web_browser: WebBrowser | |||
| type: Literal["web_browser"] | |||
| @@ -1,16 +0,0 @@ | |||
| from typing import Literal | |||
| from ....core import BaseModel | |||
| from .tools.tools_type import ToolsType | |||
| __all__ = ["ToolsDeltaBlock"] | |||
| class ToolsDeltaBlock(BaseModel): | |||
| tool_calls: list[ToolsType] | |||
| """The index of the content part in the message.""" | |||
| role: str = "tool" | |||
| type: Literal["tool_calls"] = "tool_calls" | |||
| """Always `tool_calls`.""" | |||
| @@ -1,82 +0,0 @@ | |||
| # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. | |||
| import builtins | |||
| from typing import Literal, Optional | |||
| from ..core import BaseModel | |||
| from .batch_error import BatchError | |||
| from .batch_request_counts import BatchRequestCounts | |||
| __all__ = ["Batch", "Errors"] | |||
| class Errors(BaseModel): | |||
| data: Optional[list[BatchError]] = None | |||
| object: Optional[str] = None | |||
| """这个类型,一直是`list`。""" | |||
| class Batch(BaseModel): | |||
| id: str | |||
| completion_window: str | |||
| """用于执行请求的地址信息。""" | |||
| created_at: int | |||
| """这是 Unix timestamp (in seconds) 表示的创建时间。""" | |||
| endpoint: str | |||
| """这是ZhipuAI endpoint的地址。""" | |||
| input_file_id: str | |||
| """标记为batch的输入文件的ID。""" | |||
| object: Literal["batch"] | |||
| """这个类型,一直是`batch`.""" | |||
| status: Literal[ | |||
| "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled" | |||
| ] | |||
| """batch 的状态。""" | |||
| cancelled_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的取消时间。""" | |||
| cancelling_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示发起取消的请求时间 """ | |||
| completed_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的完成时间。""" | |||
| error_file_id: Optional[str] = None | |||
| """这个文件id包含了执行请求失败的请求的输出。""" | |||
| errors: Optional[Errors] = None | |||
| expired_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的将在过期时间。""" | |||
| expires_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 触发过期""" | |||
| failed_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的失败时间。""" | |||
| finalizing_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的最终时间。""" | |||
| in_progress_at: Optional[int] = None | |||
| """Unix timestamp (in seconds) 表示的开始处理时间。""" | |||
| metadata: Optional[builtins.object] = None | |||
| """ | |||
| key:value形式的元数据,以便将信息存储 | |||
| 结构化格式。键的长度是64个字符,值最长512个字符 | |||
| """ | |||
| output_file_id: Optional[str] = None | |||
| """完成请求的输出文件的ID。""" | |||
| request_counts: Optional[BatchRequestCounts] = None | |||
| """批次中不同状态的请求计数""" | |||
| @@ -1,37 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Literal, Optional | |||
| from typing_extensions import Required, TypedDict | |||
| __all__ = ["BatchCreateParams"] | |||
| class BatchCreateParams(TypedDict, total=False): | |||
| completion_window: Required[str] | |||
| """The time frame within which the batch should be processed. | |||
| Currently only `24h` is supported. | |||
| """ | |||
| endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]] | |||
| """The endpoint to be used for all requests in the batch. | |||
| Currently `/v1/chat/completions` and `/v1/embeddings` are supported. | |||
| """ | |||
| input_file_id: Required[str] | |||
| """The ID of an uploaded file that contains requests for the new batch. | |||
| See [upload file](https://platform.openai.com/docs/api-reference/files/create) | |||
| for how to upload a file. | |||
| Your input file must be formatted as a | |||
| [JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput), | |||
| and must be uploaded with the purpose `batch`. | |||
| """ | |||
| metadata: Optional[dict[str, str]] | |||
| """Optional custom metadata for the batch.""" | |||
| auto_delete_input_file: Optional[bool] | |||
| @@ -1,21 +0,0 @@ | |||
| # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. | |||
| from typing import Optional | |||
| from ..core import BaseModel | |||
| __all__ = ["BatchError"] | |||
| class BatchError(BaseModel): | |||
| code: Optional[str] = None | |||
| """定义的业务错误码""" | |||
| line: Optional[int] = None | |||
| """文件中的行号""" | |||
| message: Optional[str] = None | |||
| """关于对话文件中的错误的描述""" | |||
| param: Optional[str] = None | |||
| """参数名称,如果有的话""" | |||
| @@ -1,20 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing_extensions import TypedDict | |||
| __all__ = ["BatchListParams"] | |||
| class BatchListParams(TypedDict, total=False): | |||
| after: str | |||
| """分页的游标,用于获取下一页的数据。 | |||
| `after` 是一个指向当前页面的游标,用于获取下一页的数据。如果没有提供 `after`,则返回第一页的数据。 | |||
| list. | |||
| """ | |||
| limit: int | |||
| """这个参数用于限制返回的结果数量。 | |||
| Limit 用于限制返回的结果数量。默认值为 10 | |||
| """ | |||
| @@ -1,14 +0,0 @@ | |||
| from ..core import BaseModel | |||
| __all__ = ["BatchRequestCounts"] | |||
| class BatchRequestCounts(BaseModel): | |||
| completed: int | |||
| """这个数字表示已经完成的请求。""" | |||
| failed: int | |||
| """这个数字表示失败的请求。""" | |||
| total: int | |||
| """这个数字表示总的请求。""" | |||
| @@ -1,22 +0,0 @@ | |||
| from typing import Optional | |||
| from ...core import BaseModel | |||
| from .chat_completion import CompletionChoice, CompletionUsage | |||
| __all__ = ["AsyncTaskStatus", "AsyncCompletion"] | |||
| class AsyncTaskStatus(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| model: Optional[str] = None | |||
| task_status: Optional[str] = None | |||
| class AsyncCompletion(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| model: Optional[str] = None | |||
| task_status: str | |||
| choices: list[CompletionChoice] | |||
| usage: CompletionUsage | |||
| @@ -1,43 +0,0 @@ | |||
| from typing import Optional | |||
| from ...core import BaseModel | |||
| __all__ = ["Completion", "CompletionUsage"] | |||
| class Function(BaseModel): | |||
| arguments: str | |||
| name: str | |||
| class CompletionMessageToolCall(BaseModel): | |||
| id: str | |||
| function: Function | |||
| type: str | |||
| class CompletionMessage(BaseModel): | |||
| content: Optional[str] = None | |||
| role: str | |||
| tool_calls: Optional[list[CompletionMessageToolCall]] = None | |||
| class CompletionUsage(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| total_tokens: int | |||
| class CompletionChoice(BaseModel): | |||
| index: int | |||
| finish_reason: str | |||
| message: CompletionMessage | |||
| class Completion(BaseModel): | |||
| model: Optional[str] = None | |||
| created: Optional[int] = None | |||
| choices: list[CompletionChoice] | |||
| request_id: Optional[str] = None | |||
| id: Optional[str] = None | |||
| usage: CompletionUsage | |||
| @@ -1,57 +0,0 @@ | |||
| from typing import Any, Optional | |||
| from ...core import BaseModel | |||
| __all__ = [ | |||
| "CompletionUsage", | |||
| "ChatCompletionChunk", | |||
| "Choice", | |||
| "ChoiceDelta", | |||
| "ChoiceDeltaFunctionCall", | |||
| "ChoiceDeltaToolCall", | |||
| "ChoiceDeltaToolCallFunction", | |||
| ] | |||
| class ChoiceDeltaFunctionCall(BaseModel): | |||
| arguments: Optional[str] = None | |||
| name: Optional[str] = None | |||
| class ChoiceDeltaToolCallFunction(BaseModel): | |||
| arguments: Optional[str] = None | |||
| name: Optional[str] = None | |||
| class ChoiceDeltaToolCall(BaseModel): | |||
| index: int | |||
| id: Optional[str] = None | |||
| function: Optional[ChoiceDeltaToolCallFunction] = None | |||
| type: Optional[str] = None | |||
| class ChoiceDelta(BaseModel): | |||
| content: Optional[str] = None | |||
| role: Optional[str] = None | |||
| tool_calls: Optional[list[ChoiceDeltaToolCall]] = None | |||
| class Choice(BaseModel): | |||
| delta: ChoiceDelta | |||
| finish_reason: Optional[str] = None | |||
| index: int | |||
| class CompletionUsage(BaseModel): | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| total_tokens: int | |||
| class ChatCompletionChunk(BaseModel): | |||
| id: Optional[str] = None | |||
| choices: list[Choice] | |||
| created: Optional[int] = None | |||
| model: Optional[str] = None | |||
| usage: Optional[CompletionUsage] = None | |||
| extra_json: dict[str, Any] | |||
| @@ -1,8 +0,0 @@ | |||
| from typing import Optional | |||
| from typing_extensions import TypedDict | |||
| class Reference(TypedDict, total=False): | |||
| enable: Optional[bool] | |||
| search_query: Optional[str] | |||
| @@ -1,146 +0,0 @@ | |||
| from typing import Literal, Optional | |||
| from typing_extensions import Required, TypedDict | |||
| __all__ = [ | |||
| "CodeGeexTarget", | |||
| "CodeGeexContext", | |||
| "CodeGeexExtra", | |||
| ] | |||
| class CodeGeexTarget(TypedDict, total=False): | |||
| """补全的内容参数""" | |||
| path: Optional[str] | |||
| """文件路径""" | |||
| language: Required[ | |||
| Literal[ | |||
| "c", | |||
| "c++", | |||
| "cpp", | |||
| "c#", | |||
| "csharp", | |||
| "c-sharp", | |||
| "css", | |||
| "cuda", | |||
| "dart", | |||
| "lua", | |||
| "objectivec", | |||
| "objective-c", | |||
| "objective-c++", | |||
| "python", | |||
| "perl", | |||
| "prolog", | |||
| "swift", | |||
| "lisp", | |||
| "java", | |||
| "scala", | |||
| "tex", | |||
| "jsx", | |||
| "tsx", | |||
| "vue", | |||
| "markdown", | |||
| "html", | |||
| "php", | |||
| "js", | |||
| "javascript", | |||
| "typescript", | |||
| "go", | |||
| "shell", | |||
| "rust", | |||
| "sql", | |||
| "kotlin", | |||
| "vb", | |||
| "ruby", | |||
| "pascal", | |||
| "r", | |||
| "fortran", | |||
| "lean", | |||
| "matlab", | |||
| "delphi", | |||
| "scheme", | |||
| "basic", | |||
| "assembly", | |||
| "groovy", | |||
| "abap", | |||
| "gdscript", | |||
| "haskell", | |||
| "julia", | |||
| "elixir", | |||
| "excel", | |||
| "clojure", | |||
| "actionscript", | |||
| "solidity", | |||
| "powershell", | |||
| "erlang", | |||
| "cobol", | |||
| "alloy", | |||
| "awk", | |||
| "thrift", | |||
| "sparql", | |||
| "augeas", | |||
| "cmake", | |||
| "f-sharp", | |||
| "stan", | |||
| "isabelle", | |||
| "dockerfile", | |||
| "rmarkdown", | |||
| "literate-agda", | |||
| "tcl", | |||
| "glsl", | |||
| "antlr", | |||
| "verilog", | |||
| "racket", | |||
| "standard-ml", | |||
| "elm", | |||
| "yaml", | |||
| "smalltalk", | |||
| "ocaml", | |||
| "idris", | |||
| "visual-basic", | |||
| "protocol-buffer", | |||
| "bluespec", | |||
| "applescript", | |||
| "makefile", | |||
| "tcsh", | |||
| "maple", | |||
| "systemverilog", | |||
| "literate-coffeescript", | |||
| "vhdl", | |||
| "restructuredtext", | |||
| "sas", | |||
| "literate-haskell", | |||
| "java-server-pages", | |||
| "coffeescript", | |||
| "emacs-lisp", | |||
| "mathematica", | |||
| "xslt", | |||
| "zig", | |||
| "common-lisp", | |||
| "stata", | |||
| "agda", | |||
| "ada", | |||
| ] | |||
| ] | |||
| """代码语言类型,如python""" | |||
| code_prefix: Required[str] | |||
| """补全位置的前文""" | |||
| code_suffix: Required[str] | |||
| """补全位置的后文""" | |||
| class CodeGeexContext(TypedDict, total=False): | |||
| """附加代码""" | |||
| path: Required[str] | |||
| """附加代码文件的路径""" | |||
| code: Required[str] | |||
| """附加的代码内容""" | |||
| class CodeGeexExtra(TypedDict, total=False): | |||
| target: Required[CodeGeexTarget] | |||
| """补全的内容参数""" | |||
| contexts: Optional[list[CodeGeexContext]] | |||
| """附加代码""" | |||
| @@ -1,21 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional | |||
| from ..core import BaseModel | |||
| from .chat.chat_completion import CompletionUsage | |||
| __all__ = ["Embedding", "EmbeddingsResponded"] | |||
| class Embedding(BaseModel): | |||
| object: str | |||
| index: Optional[int] = None | |||
| embedding: list[float] | |||
| class EmbeddingsResponded(BaseModel): | |||
| object: str | |||
| data: list[Embedding] | |||
| model: str | |||
| usage: CompletionUsage | |||
| @@ -1,5 +0,0 @@ | |||
| from .file_deleted import FileDeleted | |||
| from .file_object import FileObject, ListOfFileObject | |||
| from .upload_detail import UploadDetail | |||
| __all__ = ["FileObject", "ListOfFileObject", "UploadDetail", "FileDeleted"] | |||
| @@ -1,38 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Literal, Optional | |||
| from typing_extensions import Required, TypedDict | |||
| __all__ = ["FileCreateParams"] | |||
| from ...core import FileTypes | |||
| from . import UploadDetail | |||
| class FileCreateParams(TypedDict, total=False): | |||
| file: FileTypes | |||
| """file和 upload_detail二选一必填""" | |||
| upload_detail: list[UploadDetail] | |||
| """file和 upload_detail二选一必填""" | |||
| purpose: Required[Literal["fine-tune", "retrieval", "batch"]] | |||
| """ | |||
| 上传文件的用途,支持 "fine-tune和 "retrieval" | |||
| retrieval支持上传Doc、Docx、PDF、Xlsx、URL类型文件,且单个文件的大小不超过 5MB。 | |||
| fine-tune支持上传.jsonl文件且当前单个文件的大小最大可为 100 MB ,文件中语料格式需满足微调指南中所描述的格式。 | |||
| """ | |||
| custom_separator: Optional[list[str]] | |||
| """ | |||
| 当 purpose 为 retrieval 且文件类型为 pdf, url, docx 时上传,切片规则默认为 `\n`。 | |||
| """ | |||
| knowledge_id: str | |||
| """ | |||
| 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 | |||
| """ | |||
| sentence_size: int | |||
| """ | |||
| 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 | |||
| """ | |||
| @@ -1,13 +0,0 @@ | |||
| from typing import Literal | |||
| from ...core import BaseModel | |||
| __all__ = ["FileDeleted"] | |||
| class FileDeleted(BaseModel): | |||
| id: str | |||
| deleted: bool | |||
| object: Literal["file"] | |||
| @@ -1,22 +0,0 @@ | |||
| from typing import Optional | |||
| from ...core import BaseModel | |||
| __all__ = ["FileObject", "ListOfFileObject"] | |||
| class FileObject(BaseModel): | |||
| id: Optional[str] = None | |||
| bytes: Optional[int] = None | |||
| created_at: Optional[int] = None | |||
| filename: Optional[str] = None | |||
| object: Optional[str] = None | |||
| purpose: Optional[str] = None | |||
| status: Optional[str] = None | |||
| status_details: Optional[str] = None | |||
| class ListOfFileObject(BaseModel): | |||
| object: Optional[str] = None | |||
| data: list[FileObject] | |||
| has_more: Optional[bool] = None | |||
| @@ -1,13 +0,0 @@ | |||
| from typing import Optional | |||
| from ...core import BaseModel | |||
| class UploadDetail(BaseModel): | |||
| url: str | |||
| knowledge_type: int | |||
| file_name: Optional[str] = None | |||
| sentence_size: Optional[int] = None | |||
| custom_separator: Optional[list[str]] = None | |||
| callback_url: Optional[str] = None | |||
| callback_header: Optional[dict[str, str]] = None | |||
| @@ -1,4 +0,0 @@ | |||
| from __future__ import annotations | |||
| from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob | |||
| from .fine_tuning_job_event import FineTuningJobEvent | |||
| @@ -1,51 +0,0 @@ | |||
| from typing import Optional, Union | |||
| from ...core import BaseModel | |||
| __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] | |||
| class Error(BaseModel): | |||
| code: str | |||
| message: str | |||
| param: Optional[str] = None | |||
| class Hyperparameters(BaseModel): | |||
| n_epochs: Union[str, int, None] = None | |||
| class FineTuningJob(BaseModel): | |||
| id: Optional[str] = None | |||
| request_id: Optional[str] = None | |||
| created_at: Optional[int] = None | |||
| error: Optional[Error] = None | |||
| fine_tuned_model: Optional[str] = None | |||
| finished_at: Optional[int] = None | |||
| hyperparameters: Optional[Hyperparameters] = None | |||
| model: Optional[str] = None | |||
| object: Optional[str] = None | |||
| result_files: list[str] | |||
| status: str | |||
| trained_tokens: Optional[int] = None | |||
| training_file: str | |||
| validation_file: Optional[str] = None | |||
| class ListOfFineTuningJob(BaseModel): | |||
| object: Optional[str] = None | |||
| data: list[FineTuningJob] | |||
| has_more: Optional[bool] = None | |||
| @@ -1,35 +0,0 @@ | |||
| from typing import Optional, Union | |||
| from ...core import BaseModel | |||
| __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] | |||
| class Metric(BaseModel): | |||
| epoch: Optional[Union[str, int, float]] = None | |||
| current_steps: Optional[int] = None | |||
| total_steps: Optional[int] = None | |||
| elapsed_time: Optional[str] = None | |||
| remaining_time: Optional[str] = None | |||
| trained_tokens: Optional[int] = None | |||
| loss: Optional[Union[str, int, float]] = None | |||
| eval_loss: Optional[Union[str, int, float]] = None | |||
| acc: Optional[Union[str, int, float]] = None | |||
| eval_acc: Optional[Union[str, int, float]] = None | |||
| learning_rate: Optional[Union[str, int, float]] = None | |||
| class JobEvent(BaseModel): | |||
| object: Optional[str] = None | |||
| id: Optional[str] = None | |||
| type: Optional[str] = None | |||
| created_at: Optional[int] = None | |||
| level: Optional[str] = None | |||
| message: Optional[str] = None | |||
| data: Optional[Metric] = None | |||
| class FineTuningJobEvent(BaseModel): | |||
| object: Optional[str] = None | |||
| data: list[JobEvent] | |||
| has_more: Optional[bool] = None | |||
| @@ -1,15 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Literal, Union | |||
| from typing_extensions import TypedDict | |||
| __all__ = ["Hyperparameters"] | |||
| class Hyperparameters(TypedDict, total=False): | |||
| batch_size: Union[Literal["auto"], int] | |||
| learning_rate_multiplier: Union[Literal["auto"], float] | |||
| n_epochs: Union[Literal["auto"], int] | |||
| @@ -1 +0,0 @@ | |||
| from .fine_tuned_models import FineTunedModelsStatus | |||
| @@ -1,13 +0,0 @@ | |||
| from typing import ClassVar | |||
| from ....core import PYDANTIC_V2, BaseModel, ConfigDict | |||
| __all__ = ["FineTunedModelsStatus"] | |||
| class FineTunedModelsStatus(BaseModel): | |||
| if PYDANTIC_V2: | |||
| model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=()) | |||
| request_id: str # 请求id | |||
| model_name: str # 模型名称 | |||
| delete_status: str # 删除状态 deleting(删除中), deleted (已删除) | |||
| @@ -1,18 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional | |||
| from ..core import BaseModel | |||
| __all__ = ["GeneratedImage", "ImagesResponded"] | |||
| class GeneratedImage(BaseModel): | |||
| b64_json: Optional[str] = None | |||
| url: Optional[str] = None | |||
| revised_prompt: Optional[str] = None | |||
| class ImagesResponded(BaseModel): | |||
| created: int | |||
| data: list[GeneratedImage] | |||
| @@ -1,8 +0,0 @@ | |||
| from .knowledge import KnowledgeInfo | |||
| from .knowledge_used import KnowledgeStatistics, KnowledgeUsed | |||
| __all__ = [ | |||
| "KnowledgeInfo", | |||
| "KnowledgeStatistics", | |||
| "KnowledgeUsed", | |||
| ] | |||
| @@ -1,8 +0,0 @@ | |||
| from .document import DocumentData, DocumentFailedInfo, DocumentObject, DocumentSuccessInfo | |||
| __all__ = [ | |||
| "DocumentData", | |||
| "DocumentObject", | |||
| "DocumentSuccessInfo", | |||
| "DocumentFailedInfo", | |||
| ] | |||
| @@ -1,51 +0,0 @@ | |||
| from typing import Optional | |||
| from ....core import BaseModel | |||
| __all__ = ["DocumentData", "DocumentObject", "DocumentSuccessInfo", "DocumentFailedInfo"] | |||
| class DocumentSuccessInfo(BaseModel): | |||
| documentId: Optional[str] = None | |||
| """文件id""" | |||
| filename: Optional[str] = None | |||
| """文件名称""" | |||
| class DocumentFailedInfo(BaseModel): | |||
| failReason: Optional[str] = None | |||
| """上传失败的原因,包括:文件格式不支持、文件大小超出限制、知识库容量已满、容量上限为 50 万字。""" | |||
| filename: Optional[str] = None | |||
| """文件名称""" | |||
| documentId: Optional[str] = None | |||
| """知识库id""" | |||
| class DocumentObject(BaseModel): | |||
| """文档信息""" | |||
| successInfos: Optional[list[DocumentSuccessInfo]] = None | |||
| """上传成功的文件信息""" | |||
| failedInfos: Optional[list[DocumentFailedInfo]] = None | |||
| """上传失败的文件信息""" | |||
| class DocumentDataFailInfo(BaseModel): | |||
| """失败原因""" | |||
| embedding_code: Optional[int] = ( | |||
| None # 失败码 10001:知识不可用,知识库空间已达上限 10002:知识不可用,知识库空间已达上限(字数超出限制) | |||
| ) | |||
| embedding_msg: Optional[str] = None # 失败原因 | |||
| class DocumentData(BaseModel): | |||
| id: str = None # 知识唯一id | |||
| custom_separator: list[str] = None # 切片规则 | |||
| sentence_size: str = None # 切片大小 | |||
| length: int = None # 文件大小(字节) | |||
| word_num: int = None # 文件字数 | |||
| name: str = None # 文件名 | |||
| url: str = None # 文件下载链接 | |||
| embedding_stat: int = None # 0:向量化中 1:向量化完成 2:向量化失败 | |||
| failInfo: Optional[DocumentDataFailInfo] = None # 失败原因 向量化失败embedding_stat=2的时候 会有此值 | |||
| @@ -1,29 +0,0 @@ | |||
| from typing import Optional, TypedDict | |||
| __all__ = ["DocumentEditParams"] | |||
| class DocumentEditParams(TypedDict): | |||
| """ | |||
| 知识参数类型定义 | |||
| Attributes: | |||
| id (str): 知识ID | |||
| knowledge_type (int): 知识类型: | |||
| 1:文章知识: 支持pdf,url,docx | |||
| 2.问答知识-文档: 支持pdf,url,docx | |||
| 3.问答知识-表格: 支持xlsx | |||
| 4.商品库-表格: 支持xlsx | |||
| 5.自定义: 支持pdf,url,docx | |||
| custom_separator (Optional[List[str]]): 当前知识类型为自定义(knowledge_type=5)时的切片规则,默认\n | |||
| sentence_size (Optional[int]): 当前知识类型为自定义(knowledge_type=5)时的切片字数,取值范围: 20-2000,默认300 | |||
| callback_url (Optional[str]): 回调地址 | |||
| callback_header (Optional[dict]): 回调时携带的header | |||
| """ | |||
| id: str | |||
| knowledge_type: int | |||
| custom_separator: Optional[list[str]] | |||
| sentence_size: Optional[int] | |||
| callback_url: Optional[str] | |||
| callback_header: Optional[dict[str, str]] | |||
| @@ -1,26 +0,0 @@ | |||
| from __future__ import annotations | |||
| from typing import Optional | |||
| from typing_extensions import TypedDict | |||
| class DocumentListParams(TypedDict, total=False): | |||
| """ | |||
| 文件查询参数类型定义 | |||
| Attributes: | |||
| purpose (Optional[str]): 文件用途 | |||
| knowledge_id (Optional[str]): 当文件用途为 retrieval 时,需要提供查询的知识库ID | |||
| page (Optional[int]): 页,默认1 | |||
| limit (Optional[int]): 查询文件列表数,默认10 | |||
| after (Optional[str]): 查询指定fileID之后的文件列表(当文件用途为 fine-tune 时需要) | |||
| order (Optional[str]): 排序规则,可选值['desc', 'asc'],默认desc(当文件用途为 fine-tune 时需要) | |||
| """ | |||
| purpose: Optional[str] | |||
| knowledge_id: Optional[str] | |||
| page: Optional[int] | |||
| limit: Optional[int] | |||
| after: Optional[str] | |||
| order: Optional[str] | |||
| @@ -1,11 +0,0 @@ | |||
| from __future__ import annotations | |||
| from ....core import BaseModel | |||
| from . import DocumentData | |||
| __all__ = ["DocumentPage"] | |||
| class DocumentPage(BaseModel): | |||
| list: list[DocumentData] | |||
| object: str | |||
| @@ -1,21 +0,0 @@ | |||
| from typing import Optional | |||
| from ...core import BaseModel | |||
| __all__ = ["KnowledgeInfo"] | |||
| class KnowledgeInfo(BaseModel): | |||
| id: Optional[str] = None | |||
| """知识库唯一 id""" | |||
| embedding_id: Optional[str] = ( | |||
| None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4) | |||
| ) | |||
| name: Optional[str] = None # 知识库名称 100字限制 | |||
| customer_identifier: Optional[str] = None # 用户标识 长度32位以内 | |||
| description: Optional[str] = None # 知识库描述 500字限制 | |||
| background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky' | |||
| icon: Optional[str] = ( | |||
| None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子 # noqa: E501 | |||
| ) | |||
| bucket_id: Optional[str] = None # 桶id 限制32位 | |||