| import json | import json | ||||
| import logging | import logging | ||||
| import re | |||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional, Union, cast | from typing import Any, Optional, Union, cast | ||||
| prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) | prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) | ||||
| # o1 compatibility | # o1 compatibility | ||||
| block_as_stream = False | |||||
| if model.startswith("o1"): | if model.startswith("o1"): | ||||
| if "max_tokens" in model_parameters: | if "max_tokens" in model_parameters: | ||||
| model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] | model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] | ||||
| del model_parameters["max_tokens"] | del model_parameters["max_tokens"] | ||||
| if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model): | |||||
| if stream: | |||||
| block_as_stream = True | |||||
| stream = False | |||||
| if "stream_options" in extra_model_kwargs: | |||||
| del extra_model_kwargs["stream_options"] | |||||
| if "stop" in extra_model_kwargs: | if "stop" in extra_model_kwargs: | ||||
| del extra_model_kwargs["stop"] | del extra_model_kwargs["stop"] | ||||
| if stream: | if stream: | ||||
| return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) | return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) | ||||
| return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) | |||||
| block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) | |||||
| if block_as_stream: | |||||
| return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) | |||||
| return block_result | |||||
| def _handle_chat_block_as_stream_response( | |||||
| self, | |||||
| block_result: LLMResult, | |||||
| prompt_messages: list[PromptMessage], | |||||
| stop: Optional[list[str]] = None, | |||||
| ) -> Generator[LLMResultChunk, None, None]: | |||||
| """ | |||||
| Handle llm chat response | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param response: response | |||||
| :param prompt_messages: prompt messages | |||||
| :param tools: tools for tool calling | |||||
| :return: llm response chunk generator | |||||
| """ | |||||
| text = block_result.message.content | |||||
| text = cast(str, text) | |||||
| if stop: | |||||
| text = self.enforce_stop_tokens(text, stop) | |||||
| yield LLMResultChunk( | |||||
| model=block_result.model, | |||||
| prompt_messages=prompt_messages, | |||||
| system_fingerprint=block_result.system_fingerprint, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=0, | |||||
| message=block_result.message, | |||||
| finish_reason="stop", | |||||
| usage=block_result.usage, | |||||
| ), | |||||
| ) | |||||
| def _handle_chat_generate_response( | def _handle_chat_generate_response( | ||||
| self, | self, |