|
|
|
@@ -1,14 +1,13 @@ |
|
|
|
import copy |
|
|
|
import logging |
|
|
|
from collections.abc import Generator |
|
|
|
from collections.abc import Generator, Sequence |
|
|
|
from typing import Optional, Union, cast |
|
|
|
|
|
|
|
import tiktoken |
|
|
|
from openai import AzureOpenAI, Stream |
|
|
|
from openai.types import Completion |
|
|
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall |
|
|
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall |
|
|
|
from openai.types.chat.chat_completion_message import FunctionCall |
|
|
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall |
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
@@ -16,6 +15,7 @@ from core.model_runtime.entities.message_entities import ( |
|
|
|
ImagePromptMessageContent, |
|
|
|
PromptMessage, |
|
|
|
PromptMessageContentType, |
|
|
|
PromptMessageFunction, |
|
|
|
PromptMessageTool, |
|
|
|
SystemPromptMessage, |
|
|
|
TextPromptMessageContent, |
|
|
|
@@ -26,7 +26,8 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelPrope |
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError |
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
|
|
|
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI |
|
|
|
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel |
|
|
|
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS |
|
|
|
from core.model_runtime.utils import helper |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@@ -39,9 +40,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
stream: bool = True, user: Optional[str] = None) \ |
|
|
|
-> Union[LLMResult, Generator]: |
|
|
|
|
|
|
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) |
|
|
|
base_model_name = credentials.get('base_model_name') |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError('Base Model Name is required') |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
|
|
|
|
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: |
|
|
|
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: |
|
|
|
# chat model |
|
|
|
return self._chat_generate( |
|
|
|
model=model, |
|
|
|
@@ -65,18 +69,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
user=user |
|
|
|
) |
|
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int: |
|
|
|
|
|
|
|
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( |
|
|
|
ModelPropertyKey.MODE) |
|
|
|
def get_num_tokens( |
|
|
|
self, |
|
|
|
model: str, |
|
|
|
credentials: dict, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None |
|
|
|
) -> int: |
|
|
|
base_model_name = credentials.get('base_model_name') |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError('Base Model Name is required') |
|
|
|
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
if not model_entity: |
|
|
|
raise ValueError(f'Base Model Name {base_model_name} is invalid') |
|
|
|
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) |
|
|
|
|
|
|
|
if model_mode == LLMMode.CHAT.value: |
|
|
|
# chat model |
|
|
|
return self._num_tokens_from_messages(credentials, prompt_messages, tools) |
|
|
|
else: |
|
|
|
# text completion model, do not support tool calling |
|
|
|
return self._num_tokens_from_string(credentials, prompt_messages[0].content) |
|
|
|
content = prompt_messages[0].content |
|
|
|
assert isinstance(content, str) |
|
|
|
return self._num_tokens_from_string(credentials,content) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
if 'openai_api_base' not in credentials: |
|
|
|
@@ -88,7 +103,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
if 'base_model_name' not in credentials: |
|
|
|
raise CredentialsValidateFailedError('Base Model Name is required') |
|
|
|
|
|
|
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) |
|
|
|
base_model_name = credentials.get('base_model_name') |
|
|
|
if not base_model_name: |
|
|
|
raise CredentialsValidateFailedError('Base Model Name is required') |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
|
|
|
|
if not ai_model_entity: |
|
|
|
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') |
|
|
|
@@ -118,7 +136,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: |
|
|
|
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) |
|
|
|
base_model_name = credentials.get('base_model_name') |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError('Base Model Name is required') |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
return ai_model_entity.entity if ai_model_entity else None |
|
|
|
|
|
|
|
def _generate(self, model: str, credentials: dict, |
|
|
|
@@ -149,8 +170,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: Completion, |
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult: |
|
|
|
def _handle_generate_response( |
|
|
|
self, model: str, credentials: dict, response: Completion, |
|
|
|
prompt_messages: list[PromptMessage] |
|
|
|
): |
|
|
|
assistant_text = response.choices[0].text |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
@@ -165,7 +188,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
completion_tokens = response.usage.completion_tokens |
|
|
|
else: |
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) |
|
|
|
content = prompt_messages[0].content |
|
|
|
assert isinstance(content, str) |
|
|
|
prompt_tokens = self._num_tokens_from_string(credentials, content) |
|
|
|
completion_tokens = self._num_tokens_from_string(credentials, assistant_text) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
@@ -182,8 +207,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], |
|
|
|
prompt_messages: list[PromptMessage]) -> Generator: |
|
|
|
def _handle_generate_stream_response( |
|
|
|
self, model: str, credentials: dict, response: Stream[Completion], |
|
|
|
prompt_messages: list[PromptMessage] |
|
|
|
) -> Generator: |
|
|
|
full_text = '' |
|
|
|
for chunk in response: |
|
|
|
if len(chunk.choices) == 0: |
|
|
|
@@ -210,7 +237,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
completion_tokens = chunk.usage.completion_tokens |
|
|
|
else: |
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) |
|
|
|
content = prompt_messages[0].content |
|
|
|
assert isinstance(content, str) |
|
|
|
prompt_tokens = self._num_tokens_from_string(credentials, content) |
|
|
|
completion_tokens = self._num_tokens_from_string(credentials, full_text) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
@@ -257,12 +286,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
extra_model_kwargs = {} |
|
|
|
|
|
|
|
if tools: |
|
|
|
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] |
|
|
|
extra_model_kwargs['functions'] = [{ |
|
|
|
"name": tool.name, |
|
|
|
"description": tool.description, |
|
|
|
"parameters": tool.parameters |
|
|
|
} for tool in tools] |
|
|
|
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] |
|
|
|
# extra_model_kwargs['functions'] = [{ |
|
|
|
# "name": tool.name, |
|
|
|
# "description": tool.description, |
|
|
|
# "parameters": tool.parameters |
|
|
|
# } for tool in tools] |
|
|
|
|
|
|
|
if stop: |
|
|
|
extra_model_kwargs['stop'] = stop |
|
|
|
@@ -271,8 +300,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
extra_model_kwargs['user'] = user |
|
|
|
|
|
|
|
# chat model |
|
|
|
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] |
|
|
|
response = client.chat.completions.create( |
|
|
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], |
|
|
|
messages=messages, |
|
|
|
model=model, |
|
|
|
stream=stream, |
|
|
|
**model_parameters, |
|
|
|
@@ -284,18 +314,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) |
|
|
|
|
|
|
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: |
|
|
|
|
|
|
|
def _handle_chat_generate_response( |
|
|
|
self, model: str, credentials: dict, response: ChatCompletion, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None |
|
|
|
): |
|
|
|
assistant_message = response.choices[0].message |
|
|
|
# assistant_message_tool_calls = assistant_message.tool_calls |
|
|
|
assistant_message_function_call = assistant_message.function_call |
|
|
|
assistant_message_tool_calls = assistant_message.tool_calls |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
|
|
|
function_call = self._extract_response_function_call(assistant_message_function_call) |
|
|
|
tool_calls = [function_call] if function_call else [] |
|
|
|
tool_calls = [] |
|
|
|
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
@@ -317,7 +346,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
|
|
|
|
|
|
|
# transform response |
|
|
|
response = LLMResult( |
|
|
|
result = LLMResult( |
|
|
|
model=response.model or model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
message=assistant_prompt_message, |
|
|
|
@@ -325,58 +354,34 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
system_fingerprint=response.system_fingerprint, |
|
|
|
) |
|
|
|
|
|
|
|
return response |
|
|
|
return result |
|
|
|
|
|
|
|
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, |
|
|
|
response: Stream[ChatCompletionChunk], |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> Generator: |
|
|
|
def _handle_chat_generate_stream_response( |
|
|
|
self, |
|
|
|
model: str, |
|
|
|
credentials: dict, |
|
|
|
response: Stream[ChatCompletionChunk], |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None |
|
|
|
): |
|
|
|
index = 0 |
|
|
|
full_assistant_content = '' |
|
|
|
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None |
|
|
|
real_model = model |
|
|
|
system_fingerprint = None |
|
|
|
completion = '' |
|
|
|
tool_calls = [] |
|
|
|
for chunk in response: |
|
|
|
if len(chunk.choices) == 0: |
|
|
|
continue |
|
|
|
|
|
|
|
delta = chunk.choices[0] |
|
|
|
|
|
|
|
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter |
|
|
|
if delta.delta is None or ( |
|
|
|
delta.finish_reason is None |
|
|
|
and (delta.delta.content is None or delta.delta.content == '') |
|
|
|
and delta.delta.function_call is None |
|
|
|
): |
|
|
|
continue |
|
|
|
|
|
|
|
# assistant_message_tool_calls = delta.delta.tool_calls |
|
|
|
assistant_message_function_call = delta.delta.function_call |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
if delta_assistant_message_function_call_storage is not None: |
|
|
|
# handle process of stream function call |
|
|
|
if assistant_message_function_call: |
|
|
|
# message has not ended ever |
|
|
|
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments |
|
|
|
continue |
|
|
|
else: |
|
|
|
# message has ended |
|
|
|
assistant_message_function_call = delta_assistant_message_function_call_storage |
|
|
|
delta_assistant_message_function_call_storage = None |
|
|
|
else: |
|
|
|
if assistant_message_function_call: |
|
|
|
# start of stream function call |
|
|
|
delta_assistant_message_function_call_storage = assistant_message_function_call |
|
|
|
if delta_assistant_message_function_call_storage.arguments is None: |
|
|
|
delta_assistant_message_function_call_storage.arguments = '' |
|
|
|
continue |
|
|
|
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
|
|
|
function_call = self._extract_response_function_call(assistant_message_function_call) |
|
|
|
tool_calls = [function_call] if function_call else [] |
|
|
|
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter |
|
|
|
if delta.finish_reason is None and not delta.delta.content: |
|
|
|
continue |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
@@ -426,54 +431,56 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ |
|
|
|
-> list[AssistantPromptMessage.ToolCall]: |
|
|
|
|
|
|
|
tool_calls = [] |
|
|
|
if response_tool_calls: |
|
|
|
for response_tool_call in response_tool_calls: |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_tool_call.function.name, |
|
|
|
arguments=response_tool_call.function.arguments |
|
|
|
) |
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_tool_call.id, |
|
|
|
type=response_tool_call.type, |
|
|
|
function=function |
|
|
|
) |
|
|
|
tool_calls.append(tool_call) |
|
|
|
|
|
|
|
return tool_calls |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ |
|
|
|
-> AssistantPromptMessage.ToolCall: |
|
|
|
|
|
|
|
tool_call = None |
|
|
|
if response_function_call: |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_function_call.name, |
|
|
|
arguments=response_function_call.arguments |
|
|
|
) |
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_function_call.name, |
|
|
|
type="function", |
|
|
|
function=function |
|
|
|
) |
|
|
|
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: |
|
|
|
if tool_calls_response: |
|
|
|
for response_tool_call in tool_calls_response: |
|
|
|
if isinstance(response_tool_call, ChatCompletionMessageToolCall): |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_tool_call.function.name, |
|
|
|
arguments=response_tool_call.function.arguments |
|
|
|
) |
|
|
|
|
|
|
|
return tool_call |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_tool_call.id, |
|
|
|
type=response_tool_call.type, |
|
|
|
function=function |
|
|
|
) |
|
|
|
tool_calls.append(tool_call) |
|
|
|
elif isinstance(response_tool_call, ChoiceDeltaToolCall): |
|
|
|
index = response_tool_call.index |
|
|
|
if index < len(tool_calls): |
|
|
|
tool_calls[index].id = response_tool_call.id or tool_calls[index].id |
|
|
|
tool_calls[index].type = response_tool_call.type or tool_calls[index].type |
|
|
|
if response_tool_call.function: |
|
|
|
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name |
|
|
|
tool_calls[index].function.arguments += response_tool_call.function.arguments or '' |
|
|
|
else: |
|
|
|
assert response_tool_call.id is not None |
|
|
|
assert response_tool_call.type is not None |
|
|
|
assert response_tool_call.function is not None |
|
|
|
assert response_tool_call.function.name is not None |
|
|
|
assert response_tool_call.function.arguments is not None |
|
|
|
|
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_tool_call.function.name, |
|
|
|
arguments=response_tool_call.function.arguments |
|
|
|
) |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_tool_call.id, |
|
|
|
type=response_tool_call.type, |
|
|
|
function=function |
|
|
|
) |
|
|
|
tool_calls.append(tool_call) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: |
|
|
|
|
|
|
|
def _convert_prompt_message_to_dict(message: PromptMessage): |
|
|
|
if isinstance(message, UserPromptMessage): |
|
|
|
message = cast(UserPromptMessage, message) |
|
|
|
if isinstance(message.content, str): |
|
|
|
message_dict = {"role": "user", "content": message.content} |
|
|
|
else: |
|
|
|
sub_messages = [] |
|
|
|
assert message.content is not None |
|
|
|
for message_content in message.content: |
|
|
|
if message_content.type == PromptMessageContentType.TEXT: |
|
|
|
message_content = cast(TextPromptMessageContent, message_content) |
|
|
|
@@ -492,33 +499,22 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
} |
|
|
|
} |
|
|
|
sub_messages.append(sub_message_dict) |
|
|
|
|
|
|
|
message_dict = {"role": "user", "content": sub_messages} |
|
|
|
elif isinstance(message, AssistantPromptMessage): |
|
|
|
message = cast(AssistantPromptMessage, message) |
|
|
|
message_dict = {"role": "assistant", "content": message.content} |
|
|
|
if message.tool_calls: |
|
|
|
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in |
|
|
|
# message.tool_calls] |
|
|
|
function_call = message.tool_calls[0] |
|
|
|
message_dict["function_call"] = { |
|
|
|
"name": function_call.function.name, |
|
|
|
"arguments": function_call.function.arguments, |
|
|
|
} |
|
|
|
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] |
|
|
|
elif isinstance(message, SystemPromptMessage): |
|
|
|
message = cast(SystemPromptMessage, message) |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
message = cast(ToolPromptMessage, message) |
|
|
|
# message_dict = { |
|
|
|
# "role": "tool", |
|
|
|
# "content": message.content, |
|
|
|
# "tool_call_id": message.tool_call_id |
|
|
|
# } |
|
|
|
message_dict = { |
|
|
|
"role": "function", |
|
|
|
"role": "tool", |
|
|
|
"name": message.name, |
|
|
|
"content": message.content, |
|
|
|
"name": message.tool_call_id |
|
|
|
"tool_call_id": message.tool_call_id |
|
|
|
} |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
@@ -542,8 +538,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
return num_tokens |
|
|
|
|
|
|
|
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int: |
|
|
|
def _num_tokens_from_messages( |
|
|
|
self, credentials: dict, messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None |
|
|
|
) -> int: |
|
|
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. |
|
|
|
|
|
|
|
Official documentation: https://github.com/openai/openai-cookbook/blob/ |
|
|
|
@@ -591,6 +589,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
if key == "tool_calls": |
|
|
|
for tool_call in value: |
|
|
|
assert isinstance(tool_call, dict) |
|
|
|
for t_key, t_value in tool_call.items(): |
|
|
|
num_tokens += len(encoding.encode(t_key)) |
|
|
|
if t_key == "function": |
|
|
|
@@ -631,12 +630,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
num_tokens += len(encoding.encode('parameters')) |
|
|
|
if 'title' in parameters: |
|
|
|
num_tokens += len(encoding.encode('title')) |
|
|
|
num_tokens += len(encoding.encode(parameters.get("title"))) |
|
|
|
num_tokens += len(encoding.encode(parameters['title'])) |
|
|
|
num_tokens += len(encoding.encode('type')) |
|
|
|
num_tokens += len(encoding.encode(parameters.get("type"))) |
|
|
|
num_tokens += len(encoding.encode(parameters['type'])) |
|
|
|
if 'properties' in parameters: |
|
|
|
num_tokens += len(encoding.encode('properties')) |
|
|
|
for key, value in parameters.get('properties').items(): |
|
|
|
for key, value in parameters['properties'].items(): |
|
|
|
num_tokens += len(encoding.encode(key)) |
|
|
|
for field_key, field_value in value.items(): |
|
|
|
num_tokens += len(encoding.encode(field_key)) |
|
|
|
@@ -656,7 +655,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
return num_tokens |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: |
|
|
|
def _get_ai_model_entity(base_model_name: str, model: str): |
|
|
|
for ai_model_entity in LLM_BASE_MODELS: |
|
|
|
if ai_model_entity.base_model_name == base_model_name: |
|
|
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity) |
|
|
|
@@ -664,5 +663,3 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
ai_model_entity_copy.entity.label.en_US = model |
|
|
|
ai_model_entity_copy.entity.label.zh_Hans = model |
|
|
|
return ai_model_entity_copy |
|
|
|
|
|
|
|
return None |