|
|
|
@@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import ( |
|
|
|
PromptMessageTool, |
|
|
|
SystemPromptMessage, |
|
|
|
TextPromptMessageContent, |
|
|
|
ToolPromptMessage, |
|
|
|
UserPromptMessage, |
|
|
|
) |
|
|
|
from core.model_runtime.entities.model_entities import ( |
|
|
|
@@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
credentials=credentials, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
model_parameters=model_parameters, |
|
|
|
tools=tools, |
|
|
|
stop=stop, |
|
|
|
stream=stream, |
|
|
|
user=user, |
|
|
|
@@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
credentials: dict, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
model_parameters: dict, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, |
|
|
|
stop: Optional[list[str]] = None, |
|
|
|
stream: bool = True, |
|
|
|
user: Optional[str] = None, |
|
|
|
@@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
if completion_type is LLMMode.CHAT: |
|
|
|
endpoint_url = urljoin(endpoint_url, "api/chat") |
|
|
|
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] |
|
|
|
if tools: |
|
|
|
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools] |
|
|
|
else: |
|
|
|
endpoint_url = urljoin(endpoint_url, "api/generate") |
|
|
|
first_prompt_message = prompt_messages[0] |
|
|
|
@@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
if stream: |
|
|
|
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) |
|
|
|
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools) |
|
|
|
|
|
|
|
def _handle_generate_response( |
|
|
|
self, |
|
|
|
@@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
completion_type: LLMMode, |
|
|
|
response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]], |
|
|
|
) -> LLMResult: |
|
|
|
""" |
|
|
|
Handle llm completion response |
|
|
|
@@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
:return: llm result |
|
|
|
""" |
|
|
|
response_json = response.json() |
|
|
|
|
|
|
|
tool_calls = [] |
|
|
|
if completion_type is LLMMode.CHAT: |
|
|
|
message = response_json.get("message", {}) |
|
|
|
response_content = message.get("content", "") |
|
|
|
response_tool_calls = message.get("tool_calls", []) |
|
|
|
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls] |
|
|
|
else: |
|
|
|
response_content = response_json["response"] |
|
|
|
|
|
|
|
assistant_message = AssistantPromptMessage(content=response_content) |
|
|
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls) |
|
|
|
|
|
|
|
if "prompt_eval_count" in response_json and "eval_count" in response_json: |
|
|
|
# transform usage |
|
|
|
@@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
chunk_index += 1 |
|
|
|
|
|
|
|
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict: |
|
|
|
""" |
|
|
|
Convert PromptMessageTool to dict for Ollama API |
|
|
|
|
|
|
|
:param tool: tool |
|
|
|
:return: tool dict |
|
|
|
""" |
|
|
|
return { |
|
|
|
"type": "function", |
|
|
|
"function": { |
|
|
|
"name": tool.name, |
|
|
|
"description": tool.description, |
|
|
|
"parameters": tool.parameters, |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: |
|
|
|
""" |
|
|
|
Convert PromptMessage to dict for Ollama API |
|
|
|
|
|
|
|
:param message: prompt message |
|
|
|
:return: message dict |
|
|
|
""" |
|
|
|
if isinstance(message, UserPromptMessage): |
|
|
|
message = cast(UserPromptMessage, message) |
|
|
|
@@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
elif isinstance(message, SystemPromptMessage): |
|
|
|
message = cast(SystemPromptMessage, message) |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
message = cast(ToolPromptMessage, message) |
|
|
|
message_dict = {"role": "tool", "content": message.content} |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
|
|
|
|
@@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
return num_tokens |
|
|
|
|
|
|
|
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall: |
|
|
|
""" |
|
|
|
Extract response tool call |
|
|
|
""" |
|
|
|
tool_call = None |
|
|
|
if response_tool_call and "function" in response_tool_call: |
|
|
|
# Convert arguments to JSON string if it's a dict |
|
|
|
arguments = response_tool_call.get("function").get("arguments") |
|
|
|
if isinstance(arguments, dict): |
|
|
|
arguments = json.dumps(arguments) |
|
|
|
|
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_tool_call.get("function").get("name"), |
|
|
|
arguments=arguments, |
|
|
|
) |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_tool_call.get("function").get("name"), |
|
|
|
type="function", |
|
|
|
function=function, |
|
|
|
) |
|
|
|
|
|
|
|
return tool_call |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: |
|
|
|
""" |
|
|
|
Get customizable model schema. |
|
|
|
@@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
:return: model schema |
|
|
|
""" |
|
|
|
extras = {} |
|
|
|
extras = { |
|
|
|
"features": [], |
|
|
|
} |
|
|
|
|
|
|
|
if "vision_support" in credentials and credentials["vision_support"] == "true": |
|
|
|
extras["features"] = [ModelFeature.VISION] |
|
|
|
extras["features"].append(ModelFeature.VISION) |
|
|
|
if "function_call_support" in credentials and credentials["function_call_support"] == "true": |
|
|
|
extras["features"].append(ModelFeature.TOOL_CALL) |
|
|
|
extras["features"].append(ModelFeature.MULTI_TOOL_CALL) |
|
|
|
|
|
|
|
entity = AIModelEntity( |
|
|
|
model=model, |