|
|
|
@@ -113,7 +113,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
try: |
|
|
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials)) |
|
|
|
|
|
|
|
if "o1" in model: |
|
|
|
if model.startswith("o1"): |
|
|
|
client.chat.completions.create( |
|
|
|
messages=[{"role": "user", "content": "ping"}], |
|
|
|
model=model, |
|
|
|
@@ -311,7 +311,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) |
|
|
|
|
|
|
|
block_as_stream = False |
|
|
|
if "o1" in model: |
|
|
|
if model.startswith("o1"): |
|
|
|
if "max_tokens" in model_parameters: |
|
|
|
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] |
|
|
|
del model_parameters["max_tokens"] |
|
|
|
if stream: |
|
|
|
block_as_stream = True |
|
|
|
stream = False |
|
|
|
@@ -404,7 +407,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
if "o1" in model: |
|
|
|
if model.startswith("o1"): |
|
|
|
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) |
|
|
|
if system_message_count > 0: |
|
|
|
new_prompt_messages = [] |