瀏覽代碼

feat: support function call for ollama block chat api (#10784)

tags/0.12.0
GeorgeCaoJ 11 月之前
父節點
當前提交
fbfc811a44
沒有連結到貢獻者的電子郵件帳戶。

+ 63
- 5
api/core/model_runtime/model_providers/ollama/llm/llm.py 查看文件

PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
credentials=credentials, credentials=credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=model_parameters, model_parameters=model_parameters,
tools=tools,
stop=stop, stop=stop,
stream=stream, stream=stream,
user=user, user=user,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat") endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else: else:
endpoint_url = urljoin(endpoint_url, "api/generate") endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0] first_prompt_message = prompt_messages[0]
if stream: if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)


return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)


def _handle_generate_response( def _handle_generate_response(
self, self,
completion_type: LLMMode, completion_type: LLMMode,
response: requests.Response, response: requests.Response,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult: ) -> LLMResult:
""" """
Handle llm completion response Handle llm completion response
:return: llm result :return: llm result
""" """
response_json = response.json() response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
message = response_json.get("message", {}) message = response_json.get("message", {})
response_content = message.get("content", "") response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else: else:
response_content = response_json["response"] response_content = response_json["response"]


assistant_message = AssistantPromptMessage(content=response_content)
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)


if "prompt_eval_count" in response_json and "eval_count" in response_json: if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage # transform usage


chunk_index += 1 chunk_index += 1


def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API

:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}

def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict for Ollama API Convert PromptMessage to dict for Ollama API

:param message: prompt message
:return: message dict
""" """
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")




return num_tokens return num_tokens


def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)

function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)

return tool_call

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
""" """
Get customizable model schema. Get customizable model schema.


:return: model schema :return: model schema
""" """
extras = {}
extras = {
"features": [],
}


if "vision_support" in credentials and credentials["vision_support"] == "true": if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION]
extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)


entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,

+ 19
- 0
api/core/model_runtime/model_providers/ollama/ollama.yaml 查看文件

label: label:
en_US: 'No' en_US: 'No'
zh_Hans: 否 zh_Hans: 否
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans: 是
- value: 'false'
label:
en_US: 'No'
zh_Hans: 否

Loading…
取消
儲存