| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import re | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union, cast | |||
| @@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) | |||
| # o1 compatibility | |||
| block_as_stream = False | |||
| if model.startswith("o1"): | |||
| if "max_tokens" in model_parameters: | |||
| model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] | |||
| del model_parameters["max_tokens"] | |||
| if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model): | |||
| if stream: | |||
| block_as_stream = True | |||
| stream = False | |||
| if "stream_options" in extra_model_kwargs: | |||
| del extra_model_kwargs["stream_options"] | |||
| if "stop" in extra_model_kwargs: | |||
| del extra_model_kwargs["stop"] | |||
| @@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| if stream: | |||
| return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) | |||
| return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) | |||
| block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) | |||
| if block_as_stream: | |||
| return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) | |||
| return block_result | |||
| def _handle_chat_block_as_stream_response( | |||
| self, | |||
| block_result: LLMResult, | |||
| prompt_messages: list[PromptMessage], | |||
| stop: Optional[list[str]] = None, | |||
| ) -> Generator[LLMResultChunk, None, None]: | |||
| """ | |||
| Handle llm chat response | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param response: response | |||
| :param prompt_messages: prompt messages | |||
| :param tools: tools for tool calling | |||
| :return: llm response chunk generator | |||
| """ | |||
| text = block_result.message.content | |||
| text = cast(str, text) | |||
| if stop: | |||
| text = self.enforce_stop_tokens(text, stop) | |||
| yield LLMResultChunk( | |||
| model=block_result.model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=block_result.system_fingerprint, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=block_result.message, | |||
| finish_reason="stop", | |||
| usage=block_result.usage, | |||
| ), | |||
| ) | |||
| def _handle_chat_generate_response( | |||
| self, | |||