| @@ -1,6 +1,6 @@ | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from typing import TYPE_CHECKING, Optional, Union | |||
| from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu | |||
| from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature | |||
| from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature | |||
| from core.external_data_tool.external_data_fetch import ExternalDataFetch | |||
| from core.file.file_obj import FileVar | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| @@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp | |||
| from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform | |||
| from models.model import App, AppMode, Message, MessageAnnotation | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| class AppRunner: | |||
| def get_pre_calculate_rest_tokens(self, app_record: App, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| query: Optional[str] = None) -> int: | |||
| """ | |||
| Get pre calculate rest tokens | |||
| @@ -126,7 +128,7 @@ class AppRunner: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| query: Optional[str] = None, | |||
| context: Optional[str] = None, | |||
| memory: Optional[TokenBufferMemory] = None) \ | |||
| @@ -366,7 +368,7 @@ class AppRunner: | |||
| message_id=message_id, | |||
| trace_manager=app_generate_entity.trace_manager | |||
| ) | |||
| def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| prompt_messages: list[PromptMessage]) -> bool: | |||
| @@ -418,7 +420,7 @@ class AppRunner: | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| def query_app_annotations_to_reply(self, app_record: App, | |||
| message: Message, | |||
| query: str, | |||
| @@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| node_id: str | |||
| inputs: dict | |||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | |||
| single_iteration_run: Optional[SingleIterationRunEntity] = None | |||
| @@ -99,7 +99,7 @@ class MessageFileParser: | |||
| # return all file objs | |||
| return new_files | |||
| def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: | |||
| def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): | |||
| """ | |||
| transform message files | |||
| @@ -144,7 +144,7 @@ class MessageFileParser: | |||
| return type_file_objs | |||
| def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: | |||
| def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): | |||
| """ | |||
| transform file to file obj | |||
| @@ -1,11 +1,10 @@ | |||
| import enum | |||
| import json | |||
| import os | |||
| from typing import Optional | |||
| from typing import TYPE_CHECKING, Optional | |||
| from core.app.app_config.entities import PromptTemplateEntity | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.file.file_obj import FileVar | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| @@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from models.model import AppMode | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| class ModelMode(enum.Enum): | |||
| COMPLETION = 'completion' | |||
| @@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform): | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict, | |||
| query: str, | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| context: Optional[str], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) -> \ | |||
| @@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform): | |||
| inputs: dict, | |||
| query: str, | |||
| context: Optional[str], | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) \ | |||
| -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| @@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform): | |||
| inputs: dict, | |||
| query: str, | |||
| context: Optional[str], | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) \ | |||
| -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| @@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform): | |||
| return [self.get_last_user_message(prompt, files)], stops | |||
| def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: | |||
| def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: | |||
| if files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=prompt)] | |||
| for file in files: | |||
| @@ -2,13 +2,12 @@ from abc import ABC, abstractmethod | |||
| from collections.abc import Mapping | |||
| from copy import deepcopy | |||
| from enum import Enum | |||
| from typing import Any, Optional, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from pydantic import BaseModel, ConfigDict, field_validator | |||
| from pydantic_core.core_schema import ValidationInfo | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file.file_obj import FileVar | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolDescription, | |||
| ToolIdentity, | |||
| @@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import ( | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| class Tool(BaseModel, ABC): | |||
| identity: Optional[ToolIdentity] = None | |||
| @@ -76,7 +78,7 @@ class Tool(BaseModel, ABC): | |||
| description=self.description.model_copy() if self.description else None, | |||
| runtime=Tool.Runtime(**runtime), | |||
| ) | |||
| @abstractmethod | |||
| def tool_provider_type(self) -> ToolProviderType: | |||
| """ | |||
| @@ -84,7 +86,7 @@ class Tool(BaseModel, ABC): | |||
| :return: the tool provider type | |||
| """ | |||
| def load_variables(self, variables: ToolRuntimeVariablePool): | |||
| """ | |||
| load variables from database | |||
| @@ -99,7 +101,7 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return | |||
| self.variables.set_file(self.identity.name, variable_name, image_key) | |||
| def set_text_variable(self, variable_name: str, text: str) -> None: | |||
| @@ -108,9 +110,9 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return | |||
| self.variables.set_text(self.identity.name, variable_name, text) | |||
| def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: | |||
| """ | |||
| get a variable | |||
| @@ -120,14 +122,14 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return None | |||
| if isinstance(name, Enum): | |||
| name = name.value | |||
| for variable in self.variables.pool: | |||
| if variable.name == name: | |||
| return variable | |||
| return None | |||
| def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: | |||
| @@ -138,9 +140,9 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return None | |||
| return self.get_variable(self.VARIABLE_KEY.IMAGE) | |||
| def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: | |||
| """ | |||
| get a variable file | |||
| @@ -151,7 +153,7 @@ class Tool(BaseModel, ABC): | |||
| variable = self.get_variable(name) | |||
| if not variable: | |||
| return None | |||
| if not isinstance(variable, ToolRuntimeImageVariable): | |||
| return None | |||
| @@ -160,9 +162,9 @@ class Tool(BaseModel, ABC): | |||
| file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) | |||
| if not file_binary: | |||
| return None | |||
| return file_binary[0] | |||
| def list_variables(self) -> list[ToolRuntimeVariable]: | |||
| """ | |||
| list all variables | |||
| @@ -171,9 +173,9 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return [] | |||
| return self.variables.pool | |||
| def list_default_image_variables(self) -> list[ToolRuntimeVariable]: | |||
| """ | |||
| list all image variables | |||
| @@ -182,9 +184,9 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| if not self.variables: | |||
| return [] | |||
| result = [] | |||
| for variable in self.variables.pool: | |||
| if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): | |||
| result.append(variable) | |||
| @@ -225,7 +227,7 @@ class Tool(BaseModel, ABC): | |||
| @abstractmethod | |||
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||
| pass | |||
| def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: | |||
| """ | |||
| validate the credentials | |||
| @@ -244,7 +246,7 @@ class Tool(BaseModel, ABC): | |||
| :return: the runtime parameters | |||
| """ | |||
| return self.parameters or [] | |||
| def get_all_runtime_parameters(self) -> list[ToolParameter]: | |||
| """ | |||
| get all runtime parameters | |||
| @@ -278,7 +280,7 @@ class Tool(BaseModel, ABC): | |||
| parameters.append(parameter) | |||
| return parameters | |||
| def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: | |||
| """ | |||
| create an image message | |||
| @@ -286,18 +288,18 @@ class Tool(BaseModel, ABC): | |||
| :param image: the url of the image | |||
| :return: the image message | |||
| """ | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, | |||
| message=image, | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, | |||
| message=image, | |||
| save_as=save_as) | |||
| def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: | |||
| def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, | |||
| message='', | |||
| meta={ | |||
| 'file_var': file_var | |||
| }, | |||
| save_as='') | |||
| def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: | |||
| """ | |||
| create a link message | |||
| @@ -305,10 +307,10 @@ class Tool(BaseModel, ABC): | |||
| :param link: the url of the link | |||
| :return: the link message | |||
| """ | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, | |||
| message=link, | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, | |||
| message=link, | |||
| save_as=save_as) | |||
| def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: | |||
| """ | |||
| create a text message | |||
| @@ -321,7 +323,7 @@ class Tool(BaseModel, ABC): | |||
| message=text, | |||
| save_as=save_as | |||
| ) | |||
| def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: | |||
| """ | |||
| create a blob message | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| from mimetypes import guess_extension | |||
| from core.file.file_obj import FileTransferMethod, FileType, FileVar | |||
| from core.file.file_obj import FileTransferMethod, FileType | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| @@ -27,12 +27,12 @@ class ToolFileMessageTransformer: | |||
| # try to download image | |||
| try: | |||
| file = ToolFileManager.create_file_by_url( | |||
| user_id=user_id, | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| conversation_id=conversation_id, | |||
| file_url=message.message | |||
| ) | |||
| url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' | |||
| result.append(ToolInvokeMessage( | |||
| @@ -55,14 +55,14 @@ class ToolFileMessageTransformer: | |||
| # if message is str, encode it to bytes | |||
| if isinstance(message.message, str): | |||
| message.message = message.message.encode('utf-8') | |||
| file = ToolFileManager.create_file_by_raw( | |||
| user_id=user_id, tenant_id=tenant_id, | |||
| conversation_id=conversation_id, | |||
| file_binary=message.message, | |||
| mimetype=mimetype | |||
| ) | |||
| url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) | |||
| # check if file is image | |||
| @@ -81,7 +81,7 @@ class ToolFileMessageTransformer: | |||
| meta=message.meta.copy() if message.meta is not None else {}, | |||
| )) | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: | |||
| file_var: FileVar = message.meta.get('file_var') | |||
| file_var = message.meta.get('file_var') | |||
| if file_var: | |||
| if file_var.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| url = cls.get_tool_file_url(file_var.related_id, file_var.extension) | |||
| @@ -103,7 +103,7 @@ class ToolFileMessageTransformer: | |||
| result.append(message) | |||
| return result | |||
| @classmethod | |||
| def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: | |||
| return f'/files/tools/{tool_file_id}{extension or ".bin"}' | |||
| return f'/files/tools/{tool_file_id}{extension or ".bin"}' | |||
| @@ -1,14 +1,13 @@ | |||
| import json | |||
| from collections.abc import Generator | |||
| from copy import deepcopy | |||
| from typing import Optional, cast | |||
| from typing import TYPE_CHECKING, Optional, cast | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.file.file_obj import FileVar | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| @@ -39,6 +38,10 @@ from models.model import Conversation | |||
| from models.provider import Provider, ProviderType | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| class LLMNode(BaseNode): | |||
| _node_data_cls = LLMNodeData | |||
| @@ -71,7 +74,7 @@ class LLMNode(BaseNode): | |||
| node_inputs = {} | |||
| # fetch files | |||
| files: list[FileVar] = self._fetch_files(node_data, variable_pool) | |||
| files = self._fetch_files(node_data, variable_pool) | |||
| if files: | |||
| node_inputs['#files#'] = [file.to_dict() for file in files] | |||
| @@ -322,7 +325,7 @@ class LLMNode(BaseNode): | |||
| return inputs | |||
| def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: | |||
| def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: | |||
| """ | |||
| Fetch files | |||
| :param node_data: node data | |||
| @@ -521,7 +524,7 @@ class LLMNode(BaseNode): | |||
| query: Optional[str], | |||
| query_prompt_template: Optional[str], | |||
| inputs: dict[str, str], | |||
| files: list[FileVar], | |||
| files: list["FileVar"], | |||
| context: Optional[str], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity) \ | |||