Преглед на файлове

fix(api/model_runtime/azure/llm): Switch to tool_call. (#5541)

tags/0.6.12
-LAN- преди 1 година
родител
ревизия
ba67206bb9
No account linked to committer's email address
променени са 2 файла, в които са добавени 138 реда и са изтрити 143 реда
  1. 134
    137
      api/core/model_runtime/model_providers/azure_openai/llm/llm.py
  2. 4
    6
      api/tests/integration_tests/model_runtime/__mock/openai_chat.py

+ 134
- 137
api/core/model_runtime/model_providers/azure_openai/llm/llm.py Целия файл

import copy import copy
import logging import logging
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast from typing import Optional, Union, cast


import tiktoken import tiktoken
from openai import AzureOpenAI, Stream from openai import AzureOpenAI, Stream
from openai.types import Completion from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall


from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


stream: bool = True, user: Optional[str] = None) \ stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:


ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)


if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model # chat model
return self._chat_generate( return self._chat_generate(
model=model, model=model,
user=user user=user
) )


def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:

model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f'Base Model Name {base_model_name} is invalid')
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)


if model_mode == LLMMode.CHAT.value: if model_mode == LLMMode.CHAT.value:
# chat model # chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools) return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else: else:
# text completion model, do not support tool calling # text completion model, do not support tool calling
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials,content)


def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
if 'openai_api_base' not in credentials: if 'openai_api_base' not in credentials:
if 'base_model_name' not in credentials: if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required') raise CredentialsValidateFailedError('Base Model Name is required')


ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise CredentialsValidateFailedError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)


if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))


def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
base_model_name = credentials.get('base_model_name')
if not base_model_name:
raise ValueError('Base Model Name is required')
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None return ai_model_entity.entity if ai_model_entity else None


def _generate(self, model: str, credentials: dict, def _generate(self, model: str, credentials: dict,


return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)


def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text assistant_text = response.choices[0].text


# transform assistant message to prompt message # transform assistant message to prompt message
completion_tokens = response.usage.completion_tokens completion_tokens = response.usage.completion_tokens
else: else:
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text) completion_tokens = self._num_tokens_from_string(credentials, assistant_text)


# transform usage # transform usage


return result return result


def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]
) -> Generator:
full_text = '' full_text = ''
for chunk in response: for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
completion_tokens = chunk.usage.completion_tokens completion_tokens = chunk.usage.completion_tokens
else: else:
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text) completion_tokens = self._num_tokens_from_string(credentials, full_text)


# transform usage # transform usage
extra_model_kwargs = {} extra_model_kwargs = {}


if tools: if tools:
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
extra_model_kwargs['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
# "parameters": tool.parameters
# } for tool in tools]


if stop: if stop:
extra_model_kwargs['stop'] = stop extra_model_kwargs['stop'] = stop
extra_model_kwargs['user'] = user extra_model_kwargs['user'] = user


# chat model # chat model
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
response = client.chat.completions.create( response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
messages=messages,
model=model, model=model,
stream=stream, stream=stream,
**model_parameters, **model_parameters,


return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)


def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:

def _handle_chat_generate_response(
self, model: str, credentials: dict, response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
):
assistant_message = response.choices[0].message assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls
assistant_message_function_call = assistant_message.function_call
assistant_message_tool_calls = assistant_message.tool_calls


# extract tool calls from response # extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)


# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)


# transform response # transform response
response = LLMResult(
result = LLMResult(
model=response.model or model, model=response.model or model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=assistant_prompt_message, message=assistant_prompt_message,
system_fingerprint=response.system_fingerprint, system_fingerprint=response.system_fingerprint,
) )


return response
return result


def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
):
index = 0 index = 0
full_assistant_content = '' full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model real_model = model
system_fingerprint = None system_fingerprint = None
completion = '' completion = ''
tool_calls = []
for chunk in response: for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue


delta = chunk.choices[0] delta = chunk.choices[0]


# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.delta is None or (
delta.finish_reason is None
and (delta.delta.content is None or delta.delta.content == '')
and delta.delta.function_call is None
):
continue
# assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call

# extract tool calls from response # extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)


# extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call)
tool_calls = [function_call] if function_call else []
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.finish_reason is None and not delta.delta.content:
continue


# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
) )


@staticmethod @staticmethod
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
-> list[AssistantPromptMessage.ToolCall]:

tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)

tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)

return tool_calls

@staticmethod
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
-> AssistantPromptMessage.ToolCall:

tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=response_function_call.arguments
)

tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name,
type="function",
function=function
)
def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)


return tool_call
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name
tool_calls[index].function.arguments += response_tool_call.function.arguments or ''
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None

function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name,
arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id,
type=response_tool_call.type,
function=function
)
tool_calls.append(tool_call)


@staticmethod @staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:

def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
if isinstance(message.content, str): if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
else: else:
sub_messages = [] sub_messages = []
assert message.content is not None
for message_content in message.content: for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT: if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content) message_content = cast(TextPromptMessageContent, message_content)
} }
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)

message_dict = {"role": "user", "content": sub_messages} message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
# message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in
# message.tool_calls]
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message) message = cast(ToolPromptMessage, message)
# message_dict = {
# "role": "tool",
# "content": message.content,
# "tool_call_id": message.tool_call_id
# }
message_dict = { message_dict = {
"role": "function",
"role": "tool",
"name": message.name,
"content": message.content, "content": message.content,
"name": message.tool_call_id
"tool_call_id": message.tool_call_id
} }
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")


return num_tokens return num_tokens


def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.


Official documentation: https://github.com/openai/openai-cookbook/blob/ Official documentation: https://github.com/openai/openai-cookbook/blob/


if key == "tool_calls": if key == "tool_calls":
for tool_call in value: for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items(): for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key)) num_tokens += len(encoding.encode(t_key))
if t_key == "function": if t_key == "function":
num_tokens += len(encoding.encode('parameters')) num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters: if 'title' in parameters:
num_tokens += len(encoding.encode('title')) num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode(parameters['title']))
num_tokens += len(encoding.encode('type')) num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
num_tokens += len(encoding.encode(parameters['type']))
if 'properties' in parameters: if 'properties' in parameters:
num_tokens += len(encoding.encode('properties')) num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
for key, value in parameters['properties'].items():
num_tokens += len(encoding.encode(key)) num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items(): for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(field_key))
return num_tokens return num_tokens


@staticmethod @staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS: for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name: if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity) ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy return ai_model_entity_copy

return None

+ 4
- 6
api/tests/integration_tests/model_runtime/__mock/openai_chat.py Целия файл

return FunctionCall(name=function_name, arguments=dumps(parameters)) return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod @staticmethod
def generate_tool_calls(
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> Optional[list[ChatCompletionMessageToolCall]]:
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = [] list_tool_calls = []
if not tools or len(tools) == 0: if not tools or len(tools) == 0:
return None return None
tool: ChatCompletionToolParam = tools[0]
tool = tools[0]


if tools['type'] != 'function':
if 'type' in tools and tools['type'] != 'function':
return None return None
function = tool['function'] function = tool['function']


function_call = MockChatClass.generate_function_call(functions=[function]) function_call = MockChatClass.generate_function_call(functions=[function])

Loading…
Отказ
Запис