|
|
|
@@ -1,7 +1,9 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from collections.abc import Generator |
|
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
import google.ai.generativelanguage as glm |
|
|
|
import google.api_core.exceptions as exceptions |
|
|
|
import google.generativeai as genai |
|
|
|
import google.generativeai.client as client |
|
|
|
@@ -13,9 +15,9 @@ from core.model_runtime.entities.message_entities import ( |
|
|
|
AssistantPromptMessage, |
|
|
|
PromptMessage, |
|
|
|
PromptMessageContentType, |
|
|
|
PromptMessageRole, |
|
|
|
PromptMessageTool, |
|
|
|
SystemPromptMessage, |
|
|
|
ToolPromptMessage, |
|
|
|
UserPromptMessage, |
|
|
|
) |
|
|
|
from core.model_runtime.errors.invoke import ( |
|
|
|
@@ -62,7 +64,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
:return: full response or stream response chunk generator result |
|
|
|
""" |
|
|
|
# invoke model |
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) |
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) |
|
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int: |
|
|
|
@@ -94,6 +96,32 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
) |
|
|
|
|
|
|
|
return text.rstrip() |
|
|
|
|
|
|
|
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: |
|
|
|
""" |
|
|
|
Convert tool messages to glm tools |
|
|
|
|
|
|
|
:param tools: tool messages |
|
|
|
:return: glm tools |
|
|
|
""" |
|
|
|
return glm.Tool( |
|
|
|
function_declarations=[ |
|
|
|
glm.FunctionDeclaration( |
|
|
|
name=tool.name, |
|
|
|
parameters=glm.Schema( |
|
|
|
type=glm.Type.OBJECT, |
|
|
|
properties={ |
|
|
|
key: { |
|
|
|
'type_': value.get('type', 'string').upper(), |
|
|
|
'description': value.get('description', ''), |
|
|
|
'enum': value.get('enum', []) |
|
|
|
} for key, value in tool.parameters.get('properties', {}).items() |
|
|
|
}, |
|
|
|
required=tool.parameters.get('required', []) |
|
|
|
), |
|
|
|
) for tool in tools |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
@@ -105,7 +133,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
""" |
|
|
|
|
|
|
|
try: |
|
|
|
ping_message = PromptMessage(content="ping", role="system") |
|
|
|
ping_message = SystemPromptMessage(content="ping") |
|
|
|
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) |
|
|
|
|
|
|
|
except Exception as ex: |
|
|
|
@@ -114,8 +142,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
def _generate(self, model: str, credentials: dict, |
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, |
|
|
|
user: Optional[str] = None) -> Union[LLMResult, Generator]: |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, |
|
|
|
stream: bool = True, user: Optional[str] = None |
|
|
|
) -> Union[LLMResult, Generator]: |
|
|
|
""" |
|
|
|
Invoke large language model |
|
|
|
|
|
|
|
@@ -153,7 +182,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
else: |
|
|
|
history.append(content) |
|
|
|
|
|
|
|
|
|
|
|
# Create a new ClientManager with tenant's API key |
|
|
|
new_client_manager = client._ClientManager() |
|
|
|
new_client_manager.configure(api_key=credentials["google_api_key"]) |
|
|
|
@@ -167,14 +195,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
|
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
response = google_model.generate_content( |
|
|
|
contents=history, |
|
|
|
generation_config=genai.types.GenerationConfig( |
|
|
|
**config_kwargs |
|
|
|
), |
|
|
|
stream=stream, |
|
|
|
safety_settings=safety_settings |
|
|
|
safety_settings=safety_settings, |
|
|
|
tools=self._convert_tools_to_glm_tool(tools) if tools else None, |
|
|
|
) |
|
|
|
|
|
|
|
if stream: |
|
|
|
@@ -228,43 +257,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
""" |
|
|
|
index = -1 |
|
|
|
for chunk in response: |
|
|
|
content = chunk.text |
|
|
|
index += 1 |
|
|
|
|
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content=content if content else '', |
|
|
|
) |
|
|
|
|
|
|
|
if not response._done: |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=index, |
|
|
|
message=assistant_prompt_message |
|
|
|
) |
|
|
|
for part in chunk.parts: |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content='' |
|
|
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) |
|
|
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
|
|
|
|
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=index, |
|
|
|
message=assistant_prompt_message, |
|
|
|
finish_reason=chunk.candidates[0].finish_reason, |
|
|
|
usage=usage |
|
|
|
|
|
|
|
if part.text: |
|
|
|
assistant_prompt_message.content += part.text |
|
|
|
|
|
|
|
if part.function_call: |
|
|
|
assistant_prompt_message.tool_calls = [ |
|
|
|
AssistantPromptMessage.ToolCall( |
|
|
|
id=part.function_call.name, |
|
|
|
type='function', |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=part.function_call.name, |
|
|
|
arguments=json.dumps({ |
|
|
|
key: value |
|
|
|
for key, value in part.function_call.args.items() |
|
|
|
}) |
|
|
|
) |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
index += 1 |
|
|
|
|
|
|
|
if not response._done: |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=index, |
|
|
|
message=assistant_prompt_message |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) |
|
|
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
|
|
|
|
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=index, |
|
|
|
message=assistant_prompt_message, |
|
|
|
finish_reason=chunk.candidates[0].finish_reason, |
|
|
|
usage=usage |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str: |
|
|
|
""" |
|
|
|
@@ -288,6 +335,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
message_text = f"{ai_prompt} {content}" |
|
|
|
elif isinstance(message, SystemPromptMessage): |
|
|
|
message_text = f"{human_prompt} {content}" |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
message_text = f"{human_prompt} {content}" |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
|
|
|
|
@@ -300,26 +349,53 @@ class GoogleLargeLanguageModel(LargeLanguageModel): |
|
|
|
:param message: one PromptMessage |
|
|
|
:return: glm Content representation of message |
|
|
|
""" |
|
|
|
|
|
|
|
parts = [] |
|
|
|
if (isinstance(message.content, str)): |
|
|
|
parts.append(to_part(message.content)) |
|
|
|
if isinstance(message, UserPromptMessage): |
|
|
|
glm_content = { |
|
|
|
"role": "user", |
|
|
|
"parts": [] |
|
|
|
} |
|
|
|
if (isinstance(message.content, str)): |
|
|
|
glm_content['parts'].append(to_part(message.content)) |
|
|
|
else: |
|
|
|
for c in message.content: |
|
|
|
if c.type == PromptMessageContentType.TEXT: |
|
|
|
glm_content['parts'].append(to_part(c.data)) |
|
|
|
else: |
|
|
|
metadata, data = c.data.split(',', 1) |
|
|
|
mime_type = metadata.split(';', 1)[0].split(':')[1] |
|
|
|
blob = {"inline_data":{"mime_type":mime_type,"data":data}} |
|
|
|
glm_content['parts'].append(blob) |
|
|
|
return glm_content |
|
|
|
elif isinstance(message, AssistantPromptMessage): |
|
|
|
glm_content = { |
|
|
|
"role": "model", |
|
|
|
"parts": [] |
|
|
|
} |
|
|
|
if message.content: |
|
|
|
glm_content['parts'].append(to_part(message.content)) |
|
|
|
if message.tool_calls: |
|
|
|
glm_content["parts"].append(to_part(glm.FunctionCall( |
|
|
|
name=message.tool_calls[0].function.name, |
|
|
|
args=json.loads(message.tool_calls[0].function.arguments), |
|
|
|
))) |
|
|
|
return glm_content |
|
|
|
elif isinstance(message, SystemPromptMessage): |
|
|
|
return { |
|
|
|
"role": "user", |
|
|
|
"parts": [to_part(message.content)] |
|
|
|
} |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
return { |
|
|
|
"role": "function", |
|
|
|
"parts": [glm.Part(function_response=glm.FunctionResponse( |
|
|
|
name=message.name, |
|
|
|
response={ |
|
|
|
"response": message.content |
|
|
|
} |
|
|
|
))] |
|
|
|
} |
|
|
|
else: |
|
|
|
for c in message.content: |
|
|
|
if c.type == PromptMessageContentType.TEXT: |
|
|
|
parts.append(to_part(c.data)) |
|
|
|
else: |
|
|
|
metadata, data = c.data.split(',', 1) |
|
|
|
mime_type = metadata.split(';', 1)[0].split(':')[1] |
|
|
|
blob = {"inline_data":{"mime_type":mime_type,"data":data}} |
|
|
|
parts.append(blob) |
|
|
|
|
|
|
|
glm_content = { |
|
|
|
"role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model", |
|
|
|
"parts": parts |
|
|
|
} |
|
|
|
|
|
|
|
return glm_content |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
|
|
|
|
@property |
|
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: |