| @@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30 | |||
| UPLOAD_FILE_SIZE_LIMIT=15 | |||
| UPLOAD_FILE_BATCH_LIMIT=5 | |||
| UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 | |||
| UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 | |||
| UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 | |||
| # Model Configuration | |||
| MULTIMODAL_SEND_IMAGE_FORMAT=base64 | |||
| @@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 | |||
| WORKFLOW_MAX_EXECUTION_STEPS=500 | |||
| WORKFLOW_MAX_EXECUTION_TIME=1200 | |||
| WORKFLOW_CALL_MAX_DEPTH=5 | |||
| MAX_VARIABLE_SIZE=204800 | |||
| # App configuration | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| @@ -1,8 +1,15 @@ | |||
| { | |||
| "version": "0.2.0", | |||
| "compounds": [ | |||
| { | |||
| "name": "Launch Flask and Celery", | |||
| "configurations": ["Python: Flask", "Python: Celery"] | |||
| } | |||
| ], | |||
| "configurations": [ | |||
| { | |||
| "name": "Python: Flask", | |||
| "consoleName": "Flask", | |||
| "type": "debugpy", | |||
| "request": "launch", | |||
| "python": "${workspaceFolder}/.venv/bin/python", | |||
| @@ -17,12 +24,12 @@ | |||
| }, | |||
| "args": [ | |||
| "run", | |||
| "--host=0.0.0.0", | |||
| "--port=5001" | |||
| ] | |||
| }, | |||
| { | |||
| "name": "Python: Celery", | |||
| "consoleName": "Celery", | |||
| "type": "debugpy", | |||
| "request": "launch", | |||
| "python": "${workspaceFolder}/.venv/bin/python", | |||
| @@ -45,10 +52,10 @@ | |||
| "-c", | |||
| "1", | |||
| "--loglevel", | |||
| "info", | |||
| "DEBUG", | |||
| "-Q", | |||
| "dataset,generation,mail,ops_trace,app_deletion" | |||
| ] | |||
| }, | |||
| } | |||
| ] | |||
| } | |||
| } | |||
| @@ -118,7 +118,7 @@ def create_app() -> Flask: | |||
| logging.basicConfig( | |||
| level=app.config.get("LOG_LEVEL"), | |||
| format=app.config.get("LOG_FORMAT"), | |||
| format=app.config["LOG_FORMAT"], | |||
| datefmt=app.config.get("LOG_DATEFORMAT"), | |||
| handlers=log_handlers, | |||
| force=True, | |||
| @@ -135,6 +135,7 @@ def create_app() -> Flask: | |||
| return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() | |||
| for handler in logging.root.handlers: | |||
| assert handler.formatter | |||
| handler.formatter.converter = time_converter | |||
| initialize_extensions(app) | |||
| register_blueprints(app) | |||
| @@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client | |||
| from libs.helper import email as email_validate | |||
| from libs.password import hash_password, password_pattern, valid_password | |||
| from libs.rsa import generate_key_pair | |||
| from models.account import Tenant | |||
| from models import Tenant | |||
| from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation | |||
| @@ -457,14 +457,14 @@ def convert_to_agent_apps(): | |||
| # fetch first 1000 apps | |||
| sql_query = """SELECT a.id AS id FROM apps a | |||
| INNER JOIN app_model_configs am ON a.app_model_config_id=am.id | |||
| WHERE a.mode = 'chat' | |||
| AND am.agent_mode is not null | |||
| WHERE a.mode = 'chat' | |||
| AND am.agent_mode is not null | |||
| AND ( | |||
| am.agent_mode like '%"strategy": "function_call"%' | |||
| am.agent_mode like '%"strategy": "function_call"%' | |||
| OR am.agent_mode like '%"strategy": "react"%' | |||
| ) | |||
| ) | |||
| AND ( | |||
| am.agent_mode like '{"enabled": true%' | |||
| am.agent_mode like '{"enabled": true%' | |||
| OR am.agent_mode like '{"max_iteration": %' | |||
| ) ORDER BY a.created_at DESC LIMIT 1000 | |||
| """ | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Annotated, Optional | |||
| from typing import Annotated, Literal, Optional | |||
| from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field | |||
| from pydantic_settings import BaseSettings | |||
| @@ -11,11 +11,11 @@ class SecurityConfig(BaseSettings): | |||
| Security-related configurations for the application | |||
| """ | |||
| SECRET_KEY: Optional[str] = Field( | |||
| SECRET_KEY: str = Field( | |||
| description="Secret key for secure session cookie signing." | |||
| "Make sure you are changing this key for your deployment with a strong key." | |||
| "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", | |||
| default=None, | |||
| default="", | |||
| ) | |||
| RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( | |||
| @@ -177,6 +177,16 @@ class FileUploadConfig(BaseSettings): | |||
| default=10, | |||
| ) | |||
| UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field( | |||
| description="video file size limit in Megabytes for uploading files", | |||
| default=100, | |||
| ) | |||
| UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field( | |||
| description="audio file size limit in Megabytes for uploading files", | |||
| default=50, | |||
| ) | |||
| BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( | |||
| description="Maximum number of files allowed in a batch upload operation", | |||
| default=20, | |||
| @@ -355,8 +365,8 @@ class WorkflowConfig(BaseSettings): | |||
| ) | |||
| MAX_VARIABLE_SIZE: PositiveInt = Field( | |||
| description="Maximum size in bytes for a single variable in workflows. Default to 5KB.", | |||
| default=5 * 1024, | |||
| description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", | |||
| default=200 * 1024, | |||
| ) | |||
| @@ -479,6 +489,7 @@ class RagEtlConfig(BaseSettings): | |||
| Configuration for RAG ETL processes | |||
| """ | |||
| # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config | |||
| ETL_TYPE: str = Field( | |||
| description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", | |||
| default="dify", | |||
| @@ -540,7 +551,7 @@ class IndexingConfig(BaseSettings): | |||
| class ImageFormatConfig(BaseSettings): | |||
| MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( | |||
| MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( | |||
| description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", | |||
| default="base64", | |||
| ) | |||
| @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): | |||
| CURRENT_VERSION: str = Field( | |||
| description="Dify version", | |||
| default="0.9.1", | |||
| default="0.10.0-beta2", | |||
| ) | |||
| COMMIT_SHA: str = Field( | |||
| @@ -1,2 +1,21 @@ | |||
| from configs import dify_config | |||
| HIDDEN_VALUE = "[__HIDDEN__]" | |||
| UUID_NIL = "00000000-0000-0000-0000-000000000000" | |||
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] | |||
| IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) | |||
| VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] | |||
| VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) | |||
| AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] | |||
| AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) | |||
| DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] | |||
| DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) | |||
| if dify_config.ETL_TYPE == "Unstructured": | |||
| DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"] | |||
| DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub")) | |||
| DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) | |||
| @@ -1,7 +1,9 @@ | |||
| from contextvars import ContextVar | |||
| from typing import TYPE_CHECKING | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| if TYPE_CHECKING: | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| tenant_id: ContextVar[str] = ContextVar("tenant_id") | |||
| workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") | |||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | |||
| @@ -22,7 +22,8 @@ from fields.conversation_fields import ( | |||
| ) | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation | |||
| from models import Conversation, EndUser, Message, MessageAnnotation | |||
| from models.model import AppMode | |||
| class CompletionConversationApi(Resource): | |||
| @@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required | |||
| from extensions.ext_database import db | |||
| from fields.app_fields import app_site_fields | |||
| from libs.login import login_required | |||
| from models.model import Site | |||
| from models import Site | |||
| def parse_app_site_args(): | |||
| @@ -13,14 +13,14 @@ from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.segments import factory | |||
| from core.errors.error import AppInvokeQuotaExceededError | |||
| from factories import variable_factory | |||
| from fields.workflow_fields import workflow_fields | |||
| from fields.workflow_run_fields import workflow_run_node_execution_fields | |||
| from libs import helper | |||
| from libs.helper import TimestampField, uuid_value | |||
| from libs.login import current_user, login_required | |||
| from models.model import App, AppMode | |||
| from models import App | |||
| from models.model import AppMode | |||
| from services.app_dsl_service import AppDslService | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| @@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource): | |||
| try: | |||
| environment_variables_list = args.get("environment_variables") or [] | |||
| environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] | |||
| environment_variables = [ | |||
| variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list | |||
| ] | |||
| conversation_variables_list = args.get("conversation_variables") or [] | |||
| conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] | |||
| conversation_variables = [ | |||
| variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list | |||
| ] | |||
| workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app_model, | |||
| graph=args["graph"], | |||
| @@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource): | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True | |||
| ) | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=True, | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return helper.compact_generate_response(response) | |||
| class WorkflowTaskStopApi(Resource): | |||
| @@ -7,7 +7,8 @@ from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from libs.login import login_required | |||
| from models.model import App, AppMode | |||
| from models import App | |||
| from models.model import AppMode | |||
| from services.workflow_app_service import WorkflowAppService | |||
| @@ -13,7 +13,8 @@ from fields.workflow_run_fields import ( | |||
| ) | |||
| from libs.helper import uuid_value | |||
| from libs.login import login_required | |||
| from models.model import App, AppMode | |||
| from models import App | |||
| from models.model import AppMode | |||
| from services.workflow_run_service import WorkflowRunService | |||
| @@ -10,11 +10,11 @@ from controllers.console import api | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from enums import WorkflowRunTriggeredFrom | |||
| from extensions.ext_database import db | |||
| from libs.helper import DatetimeString | |||
| from libs.login import login_required | |||
| from models.model import AppMode | |||
| from models.workflow import WorkflowRunTriggeredFrom | |||
| class WorkflowDailyRunsStatistic(Resource): | |||
| @@ -5,7 +5,8 @@ from typing import Optional, Union | |||
| from controllers.console.app.error import AppNotFoundError | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models.model import App, AppMode | |||
| from models import App | |||
| from models.model import AppMode | |||
| def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): | |||
| @@ -15,7 +15,7 @@ from controllers.console.setup import setup_required | |||
| from extensions.ext_database import db | |||
| from libs.helper import email as email_validate | |||
| from libs.password import hash_password, valid_password | |||
| from models.account import Account | |||
| from models import Account | |||
| from services.account_service import AccountService | |||
| from services.errors.account import RateLimitExceededError | |||
| @@ -9,7 +9,7 @@ from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from libs.helper import email, extract_remote_ip | |||
| from libs.password import valid_password | |||
| from models.account import Account | |||
| from models import Account | |||
| from services.account_service import AccountService, TenantService | |||
| @@ -11,7 +11,8 @@ from constants.languages import languages | |||
| from extensions.ext_database import db | |||
| from libs.helper import extract_remote_ip | |||
| from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo | |||
| from models.account import Account, AccountStatus | |||
| from models import Account | |||
| from models.account import AccountStatus | |||
| from services.account_service import AccountService, RegisterService, TenantService | |||
| from .. import api | |||
| @@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from extensions.ext_database import db | |||
| from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | |||
| from libs.login import login_required | |||
| from models.dataset import Document | |||
| from models.source import DataSourceOauthBinding | |||
| from models import DataSourceOauthBinding, Document | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from tasks.document_indexing_sync_task import document_indexing_sync_task | |||
| @@ -24,8 +24,8 @@ from fields.app_fields import related_app_list | |||
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |||
| from fields.document_fields import document_status_fields | |||
| from libs.login import login_required | |||
| from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment | |||
| from models.model import ApiToken, UploadFile | |||
| from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile | |||
| from models.dataset import DatasetPermissionEnum | |||
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |||
| @@ -46,8 +46,7 @@ from fields.document_fields import ( | |||
| document_with_segments_fields, | |||
| ) | |||
| from libs.login import login_required | |||
| from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment | |||
| from models.model import UploadFile | |||
| from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from tasks.add_document_to_index_task import add_document_to_index_task | |||
| from tasks.remove_document_from_index_task import remove_document_from_index_task | |||
| @@ -24,7 +24,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from fields.segment_fields import segment_fields | |||
| from libs.login import login_required | |||
| from models.dataset import DocumentSegment | |||
| from models import DocumentSegment | |||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | |||
| from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task | |||
| from tasks.disable_segment_from_index_task import disable_segment_from_index_task | |||
| @@ -1,9 +1,12 @@ | |||
| import urllib.parse | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, marshal_with | |||
| import services | |||
| from configs import dify_config | |||
| from constants import DOCUMENT_EXTENSIONS | |||
| from controllers.console import api | |||
| from controllers.console.datasets.error import ( | |||
| FileTooLargeError, | |||
| @@ -13,9 +16,10 @@ from controllers.console.datasets.error import ( | |||
| ) | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from fields.file_fields import file_fields, upload_config_fields | |||
| from core.helper import ssrf_proxy | |||
| from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields | |||
| from libs.login import login_required | |||
| from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService | |||
| from services.file_service import FileService | |||
| PREVIEW_WORDS_LIMIT = 3000 | |||
| @@ -51,7 +55,7 @@ class FileApi(Resource): | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| try: | |||
| upload_file = FileService.upload_file(file, current_user) | |||
| upload_file = FileService.upload_file(file=file, user=current_user) | |||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| @@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| etl_type = dify_config.ETL_TYPE | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS | |||
| return {"allowed_extensions": allowed_extensions} | |||
| return {"allowed_extensions": DOCUMENT_EXTENSIONS} | |||
| class RemoteFileInfoApi(Resource): | |||
| @marshal_with(remote_file_info_fields) | |||
| def get(self, url): | |||
| decoded_url = urllib.parse.unquote(url) | |||
| try: | |||
| response = ssrf_proxy.head(decoded_url) | |||
| return { | |||
| "file_type": response.headers.get("Content-Type", "application/octet-stream"), | |||
| "file_length": int(response.headers.get("Content-Length", 0)), | |||
| } | |||
| except Exception as e: | |||
| return {"error": str(e)}, 400 | |||
| api.add_resource(FileApi, "/files/upload") | |||
| api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") | |||
| api.add_resource(FileSupportTypeApi, "/files/support-type") | |||
| api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") | |||
| @@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi | |||
| from extensions.ext_database import db | |||
| from fields.installed_app_fields import installed_app_list_fields | |||
| from libs.login import login_required | |||
| from models.model import App, InstalledApp, RecommendedApp | |||
| from models import App, InstalledApp, RecommendedApp | |||
| from services.account_service import TenantService | |||
| @@ -18,7 +18,7 @@ message_fields = { | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String, | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields)), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "created_at": TimestampField, | |||
| } | |||
| @@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console.wraps import account_initialization_required | |||
| from extensions.ext_database import db | |||
| from libs.login import login_required | |||
| from models.model import InstalledApp | |||
| from models import InstalledApp | |||
| def installed_app_required(view=None): | |||
| @@ -20,7 +20,7 @@ from extensions.ext_database import db | |||
| from fields.member_fields import account_fields | |||
| from libs.helper import TimestampField, timezone | |||
| from libs.login import login_required | |||
| from models.account import AccountIntegrate, InvitationCode | |||
| from models import AccountIntegrate, InvitationCode | |||
| from services.account_service import AccountService | |||
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError | |||
| @@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource): | |||
| args = reqparser.parse_args() | |||
| return WorkflowToolManageService.create_workflow_tool( | |||
| user_id, | |||
| tenant_id, | |||
| args["workflow_app_id"], | |||
| args["name"], | |||
| args["label"], | |||
| args["icon"], | |||
| args["description"], | |||
| args["parameters"], | |||
| args["privacy_policy"], | |||
| args.get("labels", []), | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| workflow_app_id=args["workflow_app_id"], | |||
| name=args["name"], | |||
| label=args["label"], | |||
| icon=args["icon"], | |||
| description=args["description"], | |||
| parameters=args["parameters"], | |||
| privacy_policy=args["privacy_policy"], | |||
| ) | |||
| @@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource): | |||
| raise UnsupportedFileTypeError() | |||
| try: | |||
| upload_file = FileService.upload_file(file, current_user, True) | |||
| upload_file = FileService.upload_file(file=file, user=current_user) | |||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| @@ -21,7 +21,36 @@ class ImagePreviewApi(Resource): | |||
| return {"content": "Invalid request."}, 400 | |||
| try: | |||
| generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) | |||
| generator, mimetype = FileService.get_image_preview( | |||
| file_id=file_id, | |||
| timestamp=timestamp, | |||
| nonce=nonce, | |||
| sign=sign, | |||
| ) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| return Response(generator, mimetype=mimetype) | |||
| class FilePreviewApi(Resource): | |||
| def get(self, file_id): | |||
| file_id = str(file_id) | |||
| timestamp = request.args.get("timestamp") | |||
| nonce = request.args.get("nonce") | |||
| sign = request.args.get("sign") | |||
| if not timestamp or not nonce or not sign: | |||
| return {"content": "Invalid request."}, 400 | |||
| try: | |||
| generator, mimetype = FileService.get_signed_file_preview( | |||
| file_id=file_id, | |||
| timestamp=timestamp, | |||
| nonce=nonce, | |||
| sign=sign, | |||
| ) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| @@ -49,4 +78,5 @@ class WorkspaceWebappLogoApi(Resource): | |||
| api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview") | |||
| api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview") | |||
| api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo") | |||
| @@ -48,7 +48,7 @@ class MessageListApi(Resource): | |||
| "tool_input": fields.String, | |||
| "created_at": TimestampField, | |||
| "observation": fields.String, | |||
| "message_files": fields.List(fields.String, attribute="files"), | |||
| "message_files": fields.List(fields.String), | |||
| } | |||
| message_fields = { | |||
| @@ -58,7 +58,7 @@ class MessageListApi(Resource): | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String(attribute="re_sign_file_url_answer"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields)), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), | |||
| "created_at": TimestampField, | |||
| @@ -1,3 +1,5 @@ | |||
| import urllib.parse | |||
| from flask import request | |||
| from flask_restful import marshal_with | |||
| @@ -5,7 +7,8 @@ import services | |||
| from controllers.web import api | |||
| from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError | |||
| from controllers.web.wraps import WebApiResource | |||
| from fields.file_fields import file_fields | |||
| from core.helper import ssrf_proxy | |||
| from fields.file_fields import file_fields, remote_file_info_fields | |||
| from services.file_service import FileService | |||
| @@ -31,4 +34,19 @@ class FileApi(WebApiResource): | |||
| return upload_file, 201 | |||
| class RemoteFileInfoApi(WebApiResource): | |||
| @marshal_with(remote_file_info_fields) | |||
| def get(self, url): | |||
| decoded_url = urllib.parse.unquote(url) | |||
| try: | |||
| response = ssrf_proxy.head(decoded_url) | |||
| return { | |||
| "file_type": response.headers.get("Content-Type", "application/octet-stream"), | |||
| "file_length": int(response.headers.get("Content-Length", 0)), | |||
| } | |||
| except Exception as e: | |||
| return {"error": str(e)}, 400 | |||
| api.add_resource(FileApi, "/files/upload") | |||
| api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") | |||
| @@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from fields.conversation_fields import message_file_fields | |||
| from fields.message_fields import agent_thought_fields | |||
| from fields.raws import FilesContainedField | |||
| from libs import helper | |||
| from libs.helper import TimestampField, uuid_value | |||
| from models.model import AppMode | |||
| @@ -58,10 +59,10 @@ class MessageListApi(WebApiResource): | |||
| "id": fields.String, | |||
| "conversation_id": fields.String, | |||
| "parent_message_id": fields.String, | |||
| "inputs": fields.Raw, | |||
| "inputs": FilesContainedField, | |||
| "query": fields.String, | |||
| "answer": fields.String(attribute="re_sign_file_url_answer"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields)), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), | |||
| "created_at": TimestampField, | |||
| @@ -17,7 +17,7 @@ message_fields = { | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String, | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields)), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "created_at": TimestampField, | |||
| } | |||
| @@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import ( | |||
| ) | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.file import file_manager | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| LLMUsage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| @@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import ( | |||
| from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message, MessageAgentThought | |||
| from factories import file_factory | |||
| from models.model import Conversation, Message, MessageAgentThought, MessageFile | |||
| from models.tools import ToolConversationVariables | |||
| logger = logging.getLogger(__name__) | |||
| @@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner): | |||
| db_variables: Optional[ToolConversationVariables] = None, | |||
| model_instance: ModelInstance = None, | |||
| ) -> None: | |||
| """ | |||
| Agent runner | |||
| :param tenant_id: tenant id | |||
| :param application_generate_entity: application generate entity | |||
| :param conversation: conversation | |||
| :param app_config: app generate entity | |||
| :param model_config: model config | |||
| :param config: dataset config | |||
| :param queue_manager: queue manager | |||
| :param message: message | |||
| :param user_id: user id | |||
| :param memory: memory | |||
| :param prompt_messages: prompt messages | |||
| :param variables_pool: variables pool | |||
| :param db_variables: db variables | |||
| :param model_instance: model instance | |||
| """ | |||
| self.tenant_id = tenant_id | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| @@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner): | |||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | |||
| continue | |||
| parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) | |||
| parameter_type = parameter.type.as_normal_type() | |||
| enum = [] | |||
| if parameter.type == ToolParameter.ToolParameterType.SELECT: | |||
| enum = [option.value for option in parameter.options] | |||
| @@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner): | |||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | |||
| continue | |||
| parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) | |||
| parameter_type = parameter.type.as_normal_type() | |||
| enum = [] | |||
| if parameter.type == ToolParameter.ToolParameterType.SELECT: | |||
| enum = [option.value for option in parameter.options] | |||
| @@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner): | |||
| return result | |||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | |||
| message_file_parser = MessageFileParser( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_config.app_id, | |||
| ) | |||
| files = message.message_files | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| if files: | |||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.transform_message_files(files, file_extra_config) | |||
| file_objs = file_factory.build_from_message_files( | |||
| message_files=files, tenant_id=self.tenant_id, config=file_extra_config | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| if not file_objs: | |||
| return UserPromptMessage(content=message.query) | |||
| else: | |||
| prompt_message_contents = [TextPromptMessageContent(data=message.query)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||
| for file_obj in file_objs: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) | |||
| return UserPromptMessage(content=prompt_message_contents) | |||
| else: | |||
| @@ -1,9 +1,11 @@ | |||
| import json | |||
| from core.agent.cot_agent_runner import CotAgentRunner | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.file import file_manager | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| @@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| Organize user query | |||
| """ | |||
| if self.files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| for file_obj in self.files: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| @@ -7,10 +7,15 @@ from typing import Any, Optional, Union | |||
| from core.agent.base_agent_runner import BaseAgentRunner | |||
| from core.app.apps.base_app_queue_manager import PublishFrom | |||
| from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.file import file_manager | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| LLMResult, | |||
| LLMResultChunk, | |||
| LLMResultChunkDelta, | |||
| LLMUsage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentType, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| @@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| Organize user query | |||
| """ | |||
| if self.files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| for file_obj in self.files: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| @@ -53,12 +53,11 @@ class BasicVariablesConfigManager: | |||
| VariableEntity( | |||
| type=variable_type, | |||
| variable=variable.get("variable"), | |||
| description=variable.get("description"), | |||
| description=variable.get("description", ""), | |||
| label=variable.get("label"), | |||
| required=variable.get("required", False), | |||
| max_length=variable.get("max_length"), | |||
| options=variable.get("options"), | |||
| default=variable.get("default"), | |||
| options=variable.get("options", []), | |||
| ) | |||
| ) | |||
| @@ -1,11 +1,12 @@ | |||
| from collections.abc import Sequence | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, Field | |||
| from core.file.file_obj import FileExtraConfig | |||
| from core.file import FileExtraConfig, FileTransferMethod, FileType | |||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||
| from models import AppMode | |||
| from models.model import AppMode | |||
| class ModelConfigEntity(BaseModel): | |||
| @@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel): | |||
| ADVANCED = "advanced" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "PromptType": | |||
| def value_of(cls, value: str): | |||
| """ | |||
| Get value of given mode. | |||
| @@ -93,6 +94,8 @@ class VariableEntityType(str, Enum): | |||
| PARAGRAPH = "paragraph" | |||
| NUMBER = "number" | |||
| EXTERNAL_DATA_TOOL = "external_data_tool" | |||
| FILE = "file" | |||
| FILE_LIST = "file-list" | |||
| class VariableEntity(BaseModel): | |||
| @@ -102,13 +105,14 @@ class VariableEntity(BaseModel): | |||
| variable: str | |||
| label: str | |||
| description: Optional[str] = None | |||
| description: str = "" | |||
| type: VariableEntityType | |||
| required: bool = False | |||
| max_length: Optional[int] = None | |||
| options: Optional[list[str]] = None | |||
| default: Optional[str] = None | |||
| hint: Optional[str] = None | |||
| options: Sequence[str] = Field(default_factory=list) | |||
| allowed_file_types: Sequence[FileType] = Field(default_factory=list) | |||
| allowed_file_extensions: Sequence[str] = Field(default_factory=list) | |||
| allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) | |||
| class ExternalDataVariableEntity(BaseModel): | |||
| @@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel): | |||
| MULTIPLE = "multiple" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "RetrieveStrategy": | |||
| def value_of(cls, value: str): | |||
| """ | |||
| Get value of given mode. | |||
| @@ -1,12 +1,13 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from typing import Any | |||
| from core.file.file_obj import FileExtraConfig | |||
| from core.file.models import FileExtraConfig | |||
| from models import FileUploadConfig | |||
| class FileUploadConfigManager: | |||
| @classmethod | |||
| def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: | |||
| def convert(cls, config: Mapping[str, Any], is_vision: bool = True): | |||
| """ | |||
| Convert model config to model config | |||
| @@ -15,19 +16,18 @@ class FileUploadConfigManager: | |||
| """ | |||
| file_upload_dict = config.get("file_upload") | |||
| if file_upload_dict: | |||
| if file_upload_dict.get("image"): | |||
| if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: | |||
| image_config = { | |||
| "number_limits": file_upload_dict["image"]["number_limits"], | |||
| "transfer_methods": file_upload_dict["image"]["transfer_methods"], | |||
| if file_upload_dict.get("enabled"): | |||
| data = { | |||
| "image_config": { | |||
| "number_limits": file_upload_dict["number_limits"], | |||
| "transfer_methods": file_upload_dict["allowed_file_upload_methods"], | |||
| } | |||
| } | |||
| if is_vision: | |||
| image_config["detail"] = file_upload_dict["image"]["detail"] | |||
| if is_vision: | |||
| data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") | |||
| return FileExtraConfig(image_config=image_config) | |||
| return None | |||
| return FileExtraConfig.model_validate(data) | |||
| @classmethod | |||
| def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: | |||
| @@ -39,29 +39,7 @@ class FileUploadConfigManager: | |||
| """ | |||
| if not config.get("file_upload"): | |||
| config["file_upload"] = {} | |||
| if not isinstance(config["file_upload"], dict): | |||
| raise ValueError("file_upload must be of dict type") | |||
| # check image config | |||
| if not config["file_upload"].get("image"): | |||
| config["file_upload"]["image"] = {"enabled": False} | |||
| if config["file_upload"]["image"]["enabled"]: | |||
| number_limits = config["file_upload"]["image"]["number_limits"] | |||
| if number_limits < 1 or number_limits > 6: | |||
| raise ValueError("number_limits must be in [1, 6]") | |||
| if is_vision: | |||
| detail = config["file_upload"]["image"]["detail"] | |||
| if detail not in {"high", "low"}: | |||
| raise ValueError("detail must be in ['high', 'low']") | |||
| transfer_methods = config["file_upload"]["image"]["transfer_methods"] | |||
| if not isinstance(transfer_methods, list): | |||
| raise ValueError("transfer_methods must be of list type") | |||
| for method in transfer_methods: | |||
| if method not in {"remote_url", "local_file"}: | |||
| raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") | |||
| else: | |||
| FileUploadConfig.model_validate(config["file_upload"]) | |||
| return config, ["file_upload"] | |||
| @@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager: | |||
| # variables | |||
| for variable in user_input_form: | |||
| variables.append(VariableEntity(**variable)) | |||
| variables.append(VariableEntity.model_validate(variable)) | |||
| return variables | |||
| @@ -20,10 +20,11 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import Workflow | |||
| @@ -95,10 +96,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # parse files | |||
| files = args["files"] if args.get("files") else [] | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| @@ -106,8 +113,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) | |||
| # get tracing instance | |||
| user_id = user.id if isinstance(user, Account) else user.session_id | |||
| trace_manager = TraceQueueManager(app_model.id, user_id) | |||
| trace_manager = TraceQueueManager( | |||
| app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id | |||
| ) | |||
| if invoke_from == InvokeFrom.DEBUGGER: | |||
| # always enable retriever resource in debugger mode | |||
| @@ -119,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| conversation_id=conversation.id if conversation else None, | |||
| inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), | |||
| inputs=conversation.inputs | |||
| if conversation | |||
| else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), | |||
| query=query, | |||
| files=file_objs, | |||
| parent_message_id=args.get("parent_message_id"), | |||
| @@ -1,30 +1,26 @@ | |||
| import logging | |||
| import os | |||
| from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAnnotationReplyEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| ) | |||
| from core.moderation.base import ModerationError | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from enums import UserFrom | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import ConversationVariable, WorkflowType | |||
| @@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| """ | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| """ | |||
| super().__init__(queue_manager) | |||
| self.application_generate_entity = application_generate_entity | |||
| @@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| self.message = message | |||
| def run(self) -> None: | |||
| """ | |||
| Run application | |||
| :return: | |||
| """ | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| @@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| user_id = self.application_generate_entity.user_id | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if bool(os.environ.get("DEBUG", "False").lower() == "true"): | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| if self.application_generate_entity.single_iteration_run: | |||
| @@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| query: str, | |||
| message_id: str, | |||
| ) -> bool: | |||
| """ | |||
| Handle input moderation | |||
| :param app_record: app record | |||
| :param app_generate_entity: application generate entity | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :param message_id: message id | |||
| :return: | |||
| """ | |||
| try: | |||
| # process sensitive_word_avoidance | |||
| _, inputs, query = self.moderation_for_inputs( | |||
| @@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| def handle_annotation_reply( | |||
| self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity | |||
| ) -> bool: | |||
| """ | |||
| Handle annotation reply | |||
| :param app_record: app record | |||
| :param message: message | |||
| :param query: query | |||
| :param app_generate_entity: application generate entity | |||
| """ | |||
| # annotation reply | |||
| annotation_reply = self.query_app_annotations_to_reply( | |||
| app_record=app_record, | |||
| message=message, | |||
| @@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: | |||
| """ | |||
| Direct output | |||
| :param text: text | |||
| :return: | |||
| """ | |||
| self._publish_event(QueueTextChunkEvent(text=text)) | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Optional, Union | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| @@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAdvancedChatMessageEndEvent, | |||
| @@ -50,10 +51,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from enums.workflow_nodes import NodeType | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models import Conversation, EndUser, Message, MessageFile | |||
| from models.account import Account | |||
| from models.model import Conversation, EndUser, Message | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| @@ -120,6 +122,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._wip_workflow_node_executions = {} | |||
| self._conversation_name_generate_thread = None | |||
| self._recorded_files: list[Mapping[str, Any]] = [] | |||
| def process(self): | |||
| """ | |||
| @@ -298,6 +301,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(event) | |||
| # Record files if it's an answer node or end node | |||
| if event.node_type in [NodeType.ANSWER, NodeType.END]: | |||
| self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -364,7 +371,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=json.dumps(event.outputs) if event.outputs else None, | |||
| outputs=event.outputs, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| @@ -490,10 +497,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| """ | |||
| Save message. | |||
| :return: | |||
| """ | |||
| self._refetch_message() | |||
| self._message.answer = self._task_state.answer | |||
| @@ -501,6 +504,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message_files = [ | |||
| MessageFile( | |||
| message_id=self._message.id, | |||
| type=file["type"], | |||
| transfer_method=file["transfer_method"], | |||
| url=file["remote_url"], | |||
| belongs_to="assistant", | |||
| upload_file_id=file["related_id"], | |||
| created_by_role="account" | |||
| if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else "end_user", | |||
| created_by=self._message.from_account_id or self._message.from_end_user_id or "", | |||
| ) | |||
| for file in self._recorded_files | |||
| ] | |||
| db.session.add_all(message_files) | |||
| if graph_runtime_state and graph_runtime_state.llm_usage: | |||
| usage = graph_runtime_state.llm_usage | |||
| @@ -540,7 +559,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| del extras["metadata"]["annotation_reply"] | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, id=self._message.id, **extras | |||
| task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras | |||
| ) | |||
| def _handle_output_moderation_chunk(self, text: str) -> bool: | |||
| @@ -17,12 +17,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import App, EndUser | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser | |||
| logger = logging.getLogger(__name__) | |||
| @@ -49,7 +49,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| ) -> dict: ... | |||
| def generate( | |||
| self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True | |||
| self, | |||
| app_model: App, | |||
| user: Union[Account, EndUser], | |||
| args: Any, | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| ) -> Union[dict, Generator[dict, None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -97,12 +102,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| # always enable retriever resource in debugger mode | |||
| override_model_config_dict["retriever_resource"] = {"enabled": True} | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| # parse files | |||
| files = args["files"] if args.get("files") else [] | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| files = args.get("files") or [] | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| @@ -115,8 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| # get tracing instance | |||
| user_id = user.id if isinstance(user, Account) else user.session_id | |||
| trace_manager = TraceQueueManager(app_model.id, user_id) | |||
| trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) | |||
| # init application generate entity | |||
| application_generate_entity = AgentChatAppGenerateEntity( | |||
| @@ -124,7 +135,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| app_config=app_config, | |||
| model_conf=ModelConfigConverter.convert(app_config), | |||
| conversation_id=conversation.id if conversation else None, | |||
| inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), | |||
| inputs=conversation.inputs | |||
| if conversation | |||
| else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), | |||
| query=query, | |||
| files=file_objs, | |||
| parent_message_id=args.get("parent_message_id"), | |||
| @@ -1,35 +1,92 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from typing import TYPE_CHECKING, Any, Optional | |||
| from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType | |||
| from core.app.app_config.entities import VariableEntityType | |||
| from core.file import File, FileExtraConfig | |||
| from factories import file_factory | |||
| if TYPE_CHECKING: | |||
| from core.app.app_config.entities import AppConfig, VariableEntity | |||
| from enums import CreatedByRole | |||
| class BaseAppGenerator: | |||
| def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: | |||
| def _prepare_user_inputs( | |||
| self, | |||
| *, | |||
| user_inputs: Optional[Mapping[str, Any]], | |||
| app_config: "AppConfig", | |||
| user_id: str, | |||
| role: "CreatedByRole", | |||
| ) -> Mapping[str, Any]: | |||
| user_inputs = user_inputs or {} | |||
| # Filter input variables from form configuration, handle required fields, default values, and option values | |||
| variables = app_config.variables | |||
| filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||
| filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} | |||
| return filtered_inputs | |||
| user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||
| user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} | |||
| # Convert files in inputs to File | |||
| entity_dictionary = {item.variable: item for item in app_config.variables} | |||
| # Convert single file to File | |||
| files_inputs = { | |||
| k: file_factory.build_from_mapping( | |||
| mapping=v, | |||
| tenant_id=app_config.tenant_id, | |||
| user_id=user_id, | |||
| role=role, | |||
| config=FileExtraConfig( | |||
| allowed_file_types=entity_dictionary[k].allowed_file_types, | |||
| allowed_extensions=entity_dictionary[k].allowed_file_extensions, | |||
| allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, | |||
| ), | |||
| ) | |||
| for k, v in user_inputs.items() | |||
| if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE | |||
| } | |||
| # Convert list of files to File | |||
| file_list_inputs = { | |||
| k: file_factory.build_from_mappings( | |||
| mappings=v, | |||
| tenant_id=app_config.tenant_id, | |||
| user_id=user_id, | |||
| role=role, | |||
| config=FileExtraConfig( | |||
| allowed_file_types=entity_dictionary[k].allowed_file_types, | |||
| allowed_extensions=entity_dictionary[k].allowed_file_extensions, | |||
| allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, | |||
| ), | |||
| ) | |||
| for k, v in user_inputs.items() | |||
| if isinstance(v, list) | |||
| # Ensure skip List<File> | |||
| and all(isinstance(item, dict) for item in v) | |||
| and entity_dictionary[k].type == VariableEntityType.FILE_LIST | |||
| } | |||
| # Merge all inputs | |||
| user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} | |||
| def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): | |||
| user_input_value = inputs.get(var.variable) | |||
| if var.required and not user_input_value: | |||
| raise ValueError(f"{var.variable} is required in input form") | |||
| if not var.required and not user_input_value: | |||
| # TODO: should we return None here if the default value is None? | |||
| return var.default or "" | |||
| if ( | |||
| var.type | |||
| in { | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| } | |||
| and user_input_value | |||
| and not isinstance(user_input_value, str) | |||
| # Check if all files are converted to File | |||
| if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): | |||
| raise ValueError("Invalid input type") | |||
| if any( | |||
| filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) | |||
| ): | |||
| raise ValueError("Invalid input type") | |||
| return user_inputs | |||
| def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): | |||
| user_input_value = inputs.get(var.variable) | |||
| if not user_input_value: | |||
| if var.required: | |||
| raise ValueError(f"{var.variable} is required in input form") | |||
| else: | |||
| return None | |||
| if var.type in { | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| } and not isinstance(user_input_value, str): | |||
| raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") | |||
| if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): | |||
| # may raise ValueError if user_input_value is not a valid number | |||
| @@ -41,12 +98,24 @@ class BaseAppGenerator: | |||
| except ValueError: | |||
| raise ValueError(f"{var.variable} in input form must be a valid number") | |||
| if var.type == VariableEntityType.SELECT: | |||
| options = var.options or [] | |||
| options = var.options | |||
| if user_input_value not in options: | |||
| raise ValueError(f"{var.variable} in input form must be one of the following: {options}") | |||
| elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: | |||
| if var.max_length and user_input_value and len(user_input_value) > var.max_length: | |||
| if var.max_length and len(user_input_value) > var.max_length: | |||
| raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") | |||
| elif var.type == VariableEntityType.FILE: | |||
| if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): | |||
| raise ValueError(f"{var.variable} in input form must be a file") | |||
| elif var.type == VariableEntityType.FILE_LIST: | |||
| if not ( | |||
| isinstance(user_input_value, list) | |||
| and ( | |||
| all(isinstance(item, dict) for item in user_input_value) | |||
| or all(isinstance(item, File) for item in user_input_value) | |||
| ) | |||
| ): | |||
| raise ValueError(f"{var.variable} in input form must be a list of files") | |||
| return user_input_value | |||
| @@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform | |||
| from models.model import App, AppMode, Message, MessageAnnotation | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| from core.file.models import File | |||
| class AppRunner: | |||
| @@ -37,7 +37,7 @@ class AppRunner: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list["FileVar"], | |||
| files: list["File"], | |||
| query: Optional[str] = None, | |||
| ) -> int: | |||
| """ | |||
| @@ -137,7 +137,7 @@ class AppRunner: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list["FileVar"], | |||
| files: list["File"], | |||
| query: Optional[str] = None, | |||
| context: Optional[str] = None, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| @@ -17,10 +17,11 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.account import Account | |||
| from models.model import App, EndUser | |||
| @@ -99,12 +100,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| # always enable retriever resource in debugger mode | |||
| override_model_config_dict["retriever_resource"] = {"enabled": True} | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| # parse files | |||
| files = args["files"] if args.get("files") else [] | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| @@ -117,7 +125,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| # get tracing instance | |||
| trace_manager = TraceQueueManager(app_model.id) | |||
| trace_manager = TraceQueueManager(app_id=app_model.id) | |||
| # init application generate entity | |||
| application_generate_entity = ChatAppGenerateEntity( | |||
| @@ -125,15 +133,17 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| app_config=app_config, | |||
| model_conf=ModelConfigConverter.convert(app_config), | |||
| conversation_id=conversation.id if conversation else None, | |||
| inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), | |||
| inputs=conversation.inputs | |||
| if conversation | |||
| else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), | |||
| query=query, | |||
| files=file_objs, | |||
| parent_message_id=args.get("parent_message_id"), | |||
| user_id=user.id, | |||
| stream=stream, | |||
| invoke_from=invoke_from, | |||
| extras=extras, | |||
| trace_manager=trace_manager, | |||
| stream=stream, | |||
| ) | |||
| # init generate records | |||
| @@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import App, EndUser, Message | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Message | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| from services.errors.message import MessageNotExistsError | |||
| @@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| tenant_id=app_model.tenant_id, config=args.get("model_config") | |||
| ) | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| # parse files | |||
| files = args["files"] if args.get("files") else [] | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| @@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| ) | |||
| # get tracing instance | |||
| user_id = user.id if isinstance(user, Account) else user.session_id | |||
| trace_manager = TraceQueueManager(app_model.id) | |||
| # init application generate entity | |||
| @@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| model_conf=ModelConfigConverter.convert(app_config), | |||
| inputs=self._get_cleaned_inputs(inputs, app_config), | |||
| inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), | |||
| query=query, | |||
| files=file_objs, | |||
| user_id=user.id, | |||
| @@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| override_model_config_dict["model"] = model_dict | |||
| # parse files | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) | |||
| file_objs = file_factory.build_from_mappings( | |||
| mappings=message.message_files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| @@ -26,7 +26,7 @@ from core.app.entities.task_entities import ( | |||
| from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models import Account | |||
| from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile | |||
| from services.errors.app_model_config import AppModelConfigBrokenError | |||
| from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError | |||
| @@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| for file in application_generate_entity.files: | |||
| message_file = MessageFile( | |||
| message_id=message.id, | |||
| type=file.type.value, | |||
| transfer_method=file.transfer_method.value, | |||
| type=file.type, | |||
| transfer_method=file.transfer_method, | |||
| belongs_to="user", | |||
| url=file.url, | |||
| url=file.remote_url, | |||
| upload_file_id=file.related_id, | |||
| created_by_role=("account" if account_id else "end_user"), | |||
| created_by=account_id or end_user_id, | |||
| created_by=account_id or end_user_id or "", | |||
| ) | |||
| db.session.add(message_file) | |||
| db.session.commit() | |||
| @@ -3,7 +3,7 @@ import logging | |||
| import os | |||
| import threading | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| @@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import App, EndUser | |||
| from models.workflow import Workflow | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| app_model: App, | |||
| workflow: Workflow, | |||
| user: Union[Account, EndUser], | |||
| args: dict, | |||
| args: Mapping[str, Any], | |||
| invoke_from: InvokeFrom, | |||
| stream: bool = True, | |||
| call_depth: int = 0, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ): | |||
| """ | |||
| Generate App response. | |||
| files: Sequence[Mapping[str, Any]] = args.get("files") or [] | |||
| :param app_model: App | |||
| :param workflow: Workflow | |||
| :param user: account or end user | |||
| :param args: request args | |||
| :param invoke_from: invoke from source | |||
| :param stream: is stream | |||
| :param call_depth: call depth | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| inputs = args["inputs"] | |||
| role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER | |||
| # parse files | |||
| files = args["files"] if args.get("files") else [] | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) | |||
| file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) | |||
| else: | |||
| file_objs = [] | |||
| system_files = file_factory.build_from_mappings( | |||
| mappings=files, | |||
| tenant_id=app_model.tenant_id, | |||
| user_id=user.id, | |||
| role=role, | |||
| config=file_extra_config, | |||
| ) | |||
| # convert to app config | |||
| app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) | |||
| app_config = WorkflowAppConfigManager.get_app_config( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| ) | |||
| # get tracing instance | |||
| user_id = user.id if isinstance(user, Account) else user.session_id | |||
| trace_manager = TraceQueueManager(app_model.id, user_id) | |||
| trace_manager = TraceQueueManager( | |||
| app_id=app_model.id, | |||
| user_id=user.id if isinstance(user, Account) else user.session_id, | |||
| ) | |||
| inputs: Mapping[str, Any] = args["inputs"] | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| # init application generate entity | |||
| application_generate_entity = WorkflowAppGenerateEntity( | |||
| task_id=str(uuid.uuid4()), | |||
| app_config=app_config, | |||
| inputs=self._get_cleaned_inputs(inputs, app_config), | |||
| files=file_objs, | |||
| inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), | |||
| files=system_files, | |||
| user_id=user.id, | |||
| stream=stream, | |||
| invoke_from=invoke_from, | |||
| @@ -1,20 +1,19 @@ | |||
| import logging | |||
| import os | |||
| from typing import Optional, cast | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfig | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import UserFrom | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from enums import UserFrom | |||
| from extensions.ext_database import db | |||
| from models.model import App, EndUser | |||
| from models.workflow import WorkflowType | |||
| @@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| db.session.close() | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if bool(os.environ.get("DEBUG", "False").lower() == "true"): | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # if only single iteration run is requested | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| @@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=json.dumps(event.outputs) | |||
| if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs | |||
| else None, | |||
| outputs=event.outputs, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| @@ -20,7 +20,6 @@ from core.app.entities.queue_entities import ( | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| @@ -45,6 +44,7 @@ from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData | |||
| from core.workflow.nodes.node_mapping import node_classes | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| @@ -1,4 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict | |||
| from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig | |||
| from core.entities.provider_configuration import ProviderModelBundle | |||
| from core.file.file_obj import FileVar | |||
| from core.file.models import File | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| @@ -22,7 +22,7 @@ class InvokeFrom(Enum): | |||
| DEBUGGER = "debugger" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "InvokeFrom": | |||
| def value_of(cls, value: str): | |||
| """ | |||
| Get value of given mode. | |||
| @@ -81,7 +81,7 @@ class AppGenerateEntity(BaseModel): | |||
| app_config: AppConfig | |||
| inputs: Mapping[str, Any] | |||
| files: list[FileVar] = [] | |||
| files: Sequence[File] | |||
| user_id: str | |||
| # extras | |||
| @@ -6,8 +6,9 @@ from pydantic import BaseModel, field_validator | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from enums import NodeType | |||
| class QueueEvent(str, Enum): | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| @@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse): | |||
| event: StreamEvent = StreamEvent.MESSAGE_END | |||
| id: str | |||
| metadata: dict = {} | |||
| files: Optional[Sequence[Mapping[str, Any]]] = None | |||
| class MessageFileStreamResponse(StreamResponse): | |||
| @@ -1,18 +0,0 @@ | |||
| import re | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from . import SegmentGroup, factory | |||
| VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") | |||
| def convert_template(*, template: str, variable_pool: VariablePool): | |||
| parts = re.split(VARIABLE_PATTERN, template) | |||
| segments = [] | |||
| for part in filter(lambda x: x, parts): | |||
| if "." in part and (value := variable_pool.get(part.split("."))): | |||
| segments.append(value) | |||
| else: | |||
| segments.append(factory.build_segment(part)) | |||
| return SegmentGroup(value=segments) | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import time | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime, timezone | |||
| from typing import Any, Optional, Union, cast | |||
| @@ -27,15 +28,15 @@ from core.app.entities.task_entities import ( | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.file.file_obj import FileVar | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from enums import NodeType, WorkflowRunTriggeredFrom | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| @@ -47,7 +48,6 @@ from models.workflow import ( | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| WorkflowRunTriggeredFrom, | |||
| ) | |||
| @@ -117,7 +117,7 @@ class WorkflowCycleManage: | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| outputs: Optional[str] = None, | |||
| outputs: Mapping[str, Any] | None = None, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| @@ -133,8 +133,10 @@ class WorkflowCycleManage: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| outputs = WorkflowEntry.handle_special_values(outputs) | |||
| workflow_run.status = WorkflowRunStatus.SUCCEEDED.value | |||
| workflow_run.outputs = outputs | |||
| workflow_run.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| @@ -286,10 +288,11 @@ class WorkflowCycleManage: | |||
| db.session.commit() | |||
| db.session.close() | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| workflow_node_execution.finished_at = finished_at | |||
| @@ -326,11 +329,12 @@ class WorkflowCycleManage: | |||
| db.session.commit() | |||
| db.session.close() | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| @@ -637,7 +641,7 @@ class WorkflowCycleManage: | |||
| ), | |||
| ) | |||
| def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: | |||
| def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: | |||
| """ | |||
| Fetch files from node outputs | |||
| :param outputs_dict: node outputs dict | |||
| @@ -646,15 +650,15 @@ class WorkflowCycleManage: | |||
| if not outputs_dict: | |||
| return [] | |||
| files = [] | |||
| for output_var, output_value in outputs_dict.items(): | |||
| file_vars = self._fetch_files_from_variable_value(output_value) | |||
| if file_vars: | |||
| files.extend(file_vars) | |||
| files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] | |||
| # Remove None | |||
| files = [file for file in files if file] | |||
| # Flatten list | |||
| files = [file for sublist in files for file in sublist] | |||
| return files | |||
| def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: | |||
| def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: | |||
| """ | |||
| Fetch files from variable value | |||
| :param value: variable value | |||
| @@ -666,17 +670,17 @@ class WorkflowCycleManage: | |||
| files = [] | |||
| if isinstance(value, list): | |||
| for item in value: | |||
| file_var = self._get_file_var_from_value(item) | |||
| if file_var: | |||
| files.append(file_var) | |||
| file = self._get_file_var_from_value(item) | |||
| if file: | |||
| files.append(file) | |||
| elif isinstance(value, dict): | |||
| file_var = self._get_file_var_from_value(value) | |||
| if file_var: | |||
| files.append(file_var) | |||
| file = self._get_file_var_from_value(value) | |||
| if file: | |||
| files.append(file) | |||
| return files | |||
| def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: | |||
| def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: | |||
| """ | |||
| Get file var from value | |||
| :param value: variable value | |||
| @@ -685,14 +689,11 @@ class WorkflowCycleManage: | |||
| if not value: | |||
| return None | |||
| if isinstance(value, dict): | |||
| if "__variant" in value and value["__variant"] == FileVar.__name__: | |||
| return value | |||
| elif isinstance(value, FileVar): | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| return value | |||
| elif isinstance(value, File): | |||
| return value.to_dict() | |||
| return None | |||
| def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: | |||
| """ | |||
| Refetch workflow run | |||
| @@ -1,29 +0,0 @@ | |||
| import enum | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| class PromptMessageFileType(enum.Enum): | |||
| IMAGE = "image" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in PromptMessageFileType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class PromptMessageFile(BaseModel): | |||
| type: PromptMessageFileType | |||
| data: Any = None | |||
| class ImagePromptMessageFile(PromptMessageFile): | |||
| class DETAIL(enum.Enum): | |||
| LOW = "low" | |||
| HIGH = "high" | |||
| type: PromptMessageFileType = PromptMessageFileType.IMAGE | |||
| detail: DETAIL = DETAIL.LOW | |||
| @@ -0,0 +1,19 @@ | |||
| from .constants import FILE_MODEL_IDENTITY | |||
| from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType | |||
| from .models import ( | |||
| File, | |||
| FileExtraConfig, | |||
| ImageConfig, | |||
| ) | |||
| __all__ = [ | |||
| "FileType", | |||
| "FileExtraConfig", | |||
| "FileTransferMethod", | |||
| "FileBelongsTo", | |||
| "File", | |||
| "ImageConfig", | |||
| "FileAttribute", | |||
| "ArrayFileAttribute", | |||
| "FILE_MODEL_IDENTITY", | |||
| ] | |||
| @@ -0,0 +1 @@ | |||
| FILE_MODEL_IDENTITY = "__dify__file__" | |||
| @@ -0,0 +1,55 @@ | |||
| from enum import Enum | |||
| class FileType(str, Enum): | |||
| IMAGE = "image" | |||
| DOCUMENT = "document" | |||
| AUDIO = "audio" | |||
| VIDEO = "video" | |||
| CUSTOM = "custom" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileTransferMethod(str, Enum): | |||
| REMOTE_URL = "remote_url" | |||
| LOCAL_FILE = "local_file" | |||
| TOOL_FILE = "tool_file" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileTransferMethod: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileBelongsTo(str, Enum): | |||
| USER = "user" | |||
| ASSISTANT = "assistant" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileBelongsTo: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileAttribute(str, Enum): | |||
| TYPE = "type" | |||
| SIZE = "size" | |||
| NAME = "name" | |||
| MIME_TYPE = "mime_type" | |||
| TRANSFER_METHOD = "transfer_method" | |||
| URL = "url" | |||
| EXTENSION = "extension" | |||
| class ArrayFileAttribute(str, Enum): | |||
| LENGTH = "length" | |||
| @@ -0,0 +1,136 @@ | |||
| import base64 | |||
| from configs import dify_config | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models import UploadFile | |||
| from . import helpers | |||
| from .enums import FileAttribute | |||
| from .models import File, FileTransferMethod, FileType | |||
| from .tool_file_parser import ToolFileParser | |||
| def get_attr(*, file: "File", attr: "FileAttribute"): | |||
| match attr: | |||
| case FileAttribute.TYPE: | |||
| return file.type.value | |||
| case FileAttribute.SIZE: | |||
| return file.size | |||
| case FileAttribute.NAME: | |||
| return file.filename | |||
| case FileAttribute.MIME_TYPE: | |||
| return file.mime_type | |||
| case FileAttribute.TRANSFER_METHOD: | |||
| return file.transfer_method.value | |||
| case FileAttribute.URL: | |||
| return file.remote_url | |||
| case FileAttribute.EXTENSION: | |||
| return file.extension | |||
| case _: | |||
| raise ValueError(f"Invalid file attribute: {attr}") | |||
| def to_prompt_message_content(file: "File", /): | |||
| """ | |||
| Convert a File object to an ImagePromptMessageContent object. | |||
| This function takes a File object and converts it to an ImagePromptMessageContent | |||
| object, which can be used as a prompt for image-based AI models. | |||
| Args: | |||
| file (File): The File object to convert. Must be of type FileType.IMAGE. | |||
| Returns: | |||
| ImagePromptMessageContent: An object containing the image data and detail level. | |||
| Raises: | |||
| ValueError: If the file is not an image or if the file data is missing. | |||
| Note: | |||
| The detail level of the image prompt is determined by the file's extra_config. | |||
| If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. | |||
| """ | |||
| if file.type != FileType.IMAGE: | |||
| raise ValueError("Only image file can convert to prompt message content") | |||
| url_or_b64_data = _get_url_or_b64_data(file=file) | |||
| if url_or_b64_data is None: | |||
| raise ValueError("Missing file data") | |||
| # decide the detail of image prompt message content | |||
| if file._extra_config and file._extra_config.image_config and file._extra_config.image_config.detail: | |||
| detail = file._extra_config.image_config.detail | |||
| else: | |||
| detail = ImagePromptMessageContent.DETAIL.LOW | |||
| return ImagePromptMessageContent(data=url_or_b64_data, detail=detail) | |||
| def download(*, upload_file_id: str, tenant_id: str): | |||
| upload_file = ( | |||
| db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first() | |||
| ) | |||
| if not upload_file: | |||
| raise ValueError("upload file not found") | |||
| return _download(upload_file.key) | |||
| def _download(path: str, /): | |||
| """ | |||
| Download and return the contents of a file as bytes. | |||
| This function loads the file from storage and ensures it's in bytes format. | |||
| Args: | |||
| path (str): The path to the file in storage. | |||
| Returns: | |||
| bytes: The contents of the file as a bytes object. | |||
| Raises: | |||
| ValueError: If the loaded file is not a bytes object. | |||
| """ | |||
| data = storage.load(path, stream=False) | |||
| if not isinstance(data, bytes): | |||
| raise ValueError(f"file {path} is not a bytes object") | |||
| return data | |||
| def _get_base64(*, upload_file_id: str, tenant_id: str) -> str | None: | |||
| upload_file = ( | |||
| db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first() | |||
| ) | |||
| if not upload_file: | |||
| return None | |||
| data = _download(upload_file.key) | |||
| if data is None: | |||
| return None | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return f"data:{upload_file.mime_type};base64,{encoded_string}" | |||
| def _get_url_or_b64_data(file: "File"): | |||
| if file.type == FileType.IMAGE: | |||
| if file.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| return file.remote_url | |||
| elif file.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| if file.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": | |||
| return helpers.get_signed_image_url(upload_file_id=file.related_id) | |||
| return _get_base64(upload_file_id=file.related_id, tenant_id=file.tenant_id) | |||
| elif file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| # add sign url | |||
| if file.related_id is None or file.extension is None: | |||
| raise ValueError("Missing file related_id or extension") | |||
| return ToolFileParser.get_tool_file_manager().sign_file( | |||
| tool_file_id=file.related_id, extension=file.extension | |||
| ) | |||
| @@ -1,145 +0,0 @@ | |||
| import enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.file.tool_file_parser import ToolFileParser | |||
| from core.file.upload_file_parser import UploadFileParser | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from extensions.ext_database import db | |||
| class FileExtraConfig(BaseModel): | |||
| """ | |||
| File Upload Entity. | |||
| """ | |||
| image_config: Optional[dict[str, Any]] = None | |||
| class FileType(enum.Enum): | |||
| IMAGE = "image" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileTransferMethod(enum.Enum): | |||
| REMOTE_URL = "remote_url" | |||
| LOCAL_FILE = "local_file" | |||
| TOOL_FILE = "tool_file" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileTransferMethod: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileBelongsTo(enum.Enum): | |||
| USER = "user" | |||
| ASSISTANT = "assistant" | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in FileBelongsTo: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class FileVar(BaseModel): | |||
| id: Optional[str] = None # message file id | |||
| tenant_id: str | |||
| type: FileType | |||
| transfer_method: FileTransferMethod | |||
| url: Optional[str] = None # remote url | |||
| related_id: Optional[str] = None | |||
| extra_config: Optional[FileExtraConfig] = None | |||
| filename: Optional[str] = None | |||
| extension: Optional[str] = None | |||
| mime_type: Optional[str] = None | |||
| def to_dict(self) -> dict: | |||
| return { | |||
| "__variant": self.__class__.__name__, | |||
| "tenant_id": self.tenant_id, | |||
| "type": self.type.value, | |||
| "transfer_method": self.transfer_method.value, | |||
| "url": self.preview_url, | |||
| "remote_url": self.url, | |||
| "related_id": self.related_id, | |||
| "filename": self.filename, | |||
| "extension": self.extension, | |||
| "mime_type": self.mime_type, | |||
| } | |||
| def to_markdown(self) -> str: | |||
| """ | |||
| Convert file to markdown | |||
| :return: | |||
| """ | |||
| preview_url = self.preview_url | |||
| if self.type == FileType.IMAGE: | |||
| text = f'' | |||
| else: | |||
| text = f"[{self.filename or preview_url}]({preview_url})" | |||
| return text | |||
| @property | |||
| def data(self) -> Optional[str]: | |||
| """ | |||
| Get image data, file signed url or base64 data | |||
| depending on config MULTIMODAL_SEND_IMAGE_FORMAT | |||
| :return: | |||
| """ | |||
| return self._get_data() | |||
| @property | |||
| def preview_url(self) -> Optional[str]: | |||
| """ | |||
| Get signed preview url | |||
| :return: | |||
| """ | |||
| return self._get_data(force_url=True) | |||
| @property | |||
| def prompt_message_content(self) -> ImagePromptMessageContent: | |||
| if self.type == FileType.IMAGE: | |||
| image_config = self.extra_config.image_config | |||
| return ImagePromptMessageContent( | |||
| data=self.data, | |||
| detail=ImagePromptMessageContent.DETAIL.HIGH | |||
| if image_config.get("detail") == "high" | |||
| else ImagePromptMessageContent.DETAIL.LOW, | |||
| ) | |||
| def _get_data(self, force_url: bool = False) -> Optional[str]: | |||
| from models.model import UploadFile | |||
| if self.type == FileType.IMAGE: | |||
| if self.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| return self.url | |||
| elif self.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| upload_file = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) | |||
| .first() | |||
| ) | |||
| return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) | |||
| elif self.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| extension = self.extension | |||
| # add sign url | |||
| return ToolFileParser.get_tool_file_manager().sign_file( | |||
| tool_file_id=self.related_id, extension=extension | |||
| ) | |||
| return None | |||
| @@ -0,0 +1,61 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import os | |||
| import time | |||
| from configs import dify_config | |||
| def get_signed_image_url(upload_file_id: str) -> str: | |||
| url = f"{dify_config.FILES_URL}/files/{upload_file_id}/image-preview" | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| key = dify_config.SECRET_KEY.encode() | |||
| msg = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |||
| def get_signed_file_url(upload_file_id: str) -> str: | |||
| url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| key = dify_config.SECRET_KEY.encode() | |||
| msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |||
| def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||
| # verify signature | |||
| if sign != recalculated_encoded_sign: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||
| # verify signature | |||
| if sign != recalculated_encoded_sign: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| @@ -1,243 +0,0 @@ | |||
| import re | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Union | |||
| from urllib.parse import parse_qs, urlparse | |||
| import requests | |||
| from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser, MessageFile, UploadFile | |||
| from services.file_service import IMAGE_EXTENSIONS | |||
| class MessageFileParser: | |||
| def __init__(self, tenant_id: str, app_id: str) -> None: | |||
| self.tenant_id = tenant_id | |||
| self.app_id = app_id | |||
| def validate_and_transform_files_arg( | |||
| self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] | |||
| ) -> list[FileVar]: | |||
| """ | |||
| validate and transform files arg | |||
| :param files: | |||
| :param file_extra_config: | |||
| :param user: | |||
| :return: | |||
| """ | |||
| for file in files: | |||
| if not isinstance(file, dict): | |||
| raise ValueError("Invalid file format, must be dict") | |||
| if not file.get("type"): | |||
| raise ValueError("Missing file type") | |||
| FileType.value_of(file.get("type")) | |||
| if not file.get("transfer_method"): | |||
| raise ValueError("Missing file transfer method") | |||
| FileTransferMethod.value_of(file.get("transfer_method")) | |||
| if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: | |||
| if not file.get("url"): | |||
| raise ValueError("Missing file url") | |||
| if not file.get("url").startswith("http"): | |||
| raise ValueError("Invalid file url") | |||
| if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): | |||
| raise ValueError("Missing file upload_file_id") | |||
| if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): | |||
| raise ValueError("Missing file tool_file_id") | |||
| # transform files to file objs | |||
| type_file_objs = self._to_file_objs(files, file_extra_config) | |||
| # validate files | |||
| new_files = [] | |||
| for file_type, file_objs in type_file_objs.items(): | |||
| if file_type == FileType.IMAGE: | |||
| # parse and validate files | |||
| image_config = file_extra_config.image_config | |||
| # check if image file feature is enabled | |||
| if not image_config: | |||
| continue | |||
| # Validate number of files | |||
| if len(files) > image_config["number_limits"]: | |||
| raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") | |||
| for file_obj in file_objs: | |||
| # Validate transfer method | |||
| if file_obj.transfer_method.value not in image_config["transfer_methods"]: | |||
| raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") | |||
| # Validate file type | |||
| if file_obj.type != FileType.IMAGE: | |||
| raise ValueError(f"Invalid file type: {file_obj.type}") | |||
| if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| # check remote url valid and is image | |||
| result, error = self._check_image_remote_url(file_obj.url) | |||
| if result is False: | |||
| raise ValueError(error) | |||
| elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| # get upload file from upload_file_id | |||
| upload_file = ( | |||
| db.session.query(UploadFile) | |||
| .filter( | |||
| UploadFile.id == file_obj.related_id, | |||
| UploadFile.tenant_id == self.tenant_id, | |||
| UploadFile.created_by == user.id, | |||
| UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| UploadFile.extension.in_(IMAGE_EXTENSIONS), | |||
| ) | |||
| .first() | |||
| ) | |||
| # check upload file is belong to tenant and user | |||
| if not upload_file: | |||
| raise ValueError("Invalid upload file") | |||
| new_files.append(file_obj) | |||
| # return all file objs | |||
| return new_files | |||
| def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): | |||
| """ | |||
| transform message files | |||
| :param files: | |||
| :param file_extra_config: | |||
| :return: | |||
| """ | |||
| # transform files to file objs | |||
| type_file_objs = self._to_file_objs(files, file_extra_config) | |||
| # return all file objs | |||
| return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] | |||
| def _to_file_objs( | |||
| self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig | |||
| ) -> dict[FileType, list[FileVar]]: | |||
| """ | |||
| transform files to file objs | |||
| :param files: | |||
| :param file_extra_config: | |||
| :return: | |||
| """ | |||
| type_file_objs: dict[FileType, list[FileVar]] = { | |||
| # Currently only support image | |||
| FileType.IMAGE: [] | |||
| } | |||
| if not files: | |||
| return type_file_objs | |||
| # group by file type and convert file args or message files to FileObj | |||
| for file in files: | |||
| if isinstance(file, MessageFile): | |||
| if file.belongs_to == FileBelongsTo.ASSISTANT.value: | |||
| continue | |||
| file_obj = self._to_file_obj(file, file_extra_config) | |||
| if file_obj.type not in type_file_objs: | |||
| continue | |||
| type_file_objs[file_obj.type].append(file_obj) | |||
| return type_file_objs | |||
| def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): | |||
| """ | |||
| transform file to file obj | |||
| :param file: | |||
| :return: | |||
| """ | |||
| if isinstance(file, dict): | |||
| transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) | |||
| if transfer_method != FileTransferMethod.TOOL_FILE: | |||
| return FileVar( | |||
| tenant_id=self.tenant_id, | |||
| type=FileType.value_of(file.get("type")), | |||
| transfer_method=transfer_method, | |||
| url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, | |||
| related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, | |||
| extra_config=file_extra_config, | |||
| ) | |||
| return FileVar( | |||
| tenant_id=self.tenant_id, | |||
| type=FileType.value_of(file.get("type")), | |||
| transfer_method=transfer_method, | |||
| url=None, | |||
| related_id=file.get("tool_file_id"), | |||
| extra_config=file_extra_config, | |||
| ) | |||
| else: | |||
| return FileVar( | |||
| id=file.id, | |||
| tenant_id=self.tenant_id, | |||
| type=FileType.value_of(file.type), | |||
| transfer_method=FileTransferMethod.value_of(file.transfer_method), | |||
| url=file.url, | |||
| related_id=file.upload_file_id or None, | |||
| extra_config=file_extra_config, | |||
| ) | |||
| def _check_image_remote_url(self, url): | |||
| try: | |||
| headers = { | |||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" | |||
| " Chrome/91.0.4472.124 Safari/537.36" | |||
| } | |||
| def is_s3_presigned_url(url): | |||
| try: | |||
| parsed_url = urlparse(url) | |||
| if "amazonaws.com" not in parsed_url.netloc: | |||
| return False | |||
| query_params = parse_qs(parsed_url.query) | |||
| def check_presign_v2(query_params): | |||
| required_params = ["Signature", "Expires"] | |||
| for param in required_params: | |||
| if param not in query_params: | |||
| return False | |||
| if not query_params["Expires"][0].isdigit(): | |||
| return False | |||
| signature = query_params["Signature"][0] | |||
| if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): | |||
| return False | |||
| return True | |||
| def check_presign_v4(query_params): | |||
| required_params = ["X-Amz-Signature", "X-Amz-Expires"] | |||
| for param in required_params: | |||
| if param not in query_params: | |||
| return False | |||
| if not query_params["X-Amz-Expires"][0].isdigit(): | |||
| return False | |||
| signature = query_params["X-Amz-Signature"][0] | |||
| if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): | |||
| return False | |||
| return True | |||
| return check_presign_v4(query_params) or check_presign_v2(query_params) | |||
| except Exception: | |||
| return False | |||
| if is_s3_presigned_url(url): | |||
| response = requests.get(url, headers=headers, allow_redirects=True) | |||
| if response.status_code in {200, 304}: | |||
| return True, "" | |||
| response = requests.head(url, headers=headers, allow_redirects=True) | |||
| if response.status_code in {200, 304}: | |||
| return True, "" | |||
| else: | |||
| return False, "URL does not exist." | |||
| except requests.RequestException as e: | |||
| return False, f"Error checking URL: {e}" | |||
| @@ -0,0 +1,140 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field, model_validator | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from . import helpers | |||
| from .constants import FILE_MODEL_IDENTITY | |||
| from .enums import FileTransferMethod, FileType | |||
| from .tool_file_parser import ToolFileParser | |||
| class ImageConfig(BaseModel): | |||
| """ | |||
| NOTE: This part of validation is deprecated, but still used in app features "Image Upload". | |||
| """ | |||
| number_limits: int = 0 | |||
| transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) | |||
| detail: ImagePromptMessageContent.DETAIL | None = None | |||
| class FileExtraConfig(BaseModel): | |||
| """ | |||
| File Upload Entity. | |||
| """ | |||
| image_config: Optional[ImageConfig] = None | |||
| allowed_file_types: Sequence[FileType] = Field(default_factory=list) | |||
| allowed_extensions: Sequence[str] = Field(default_factory=list) | |||
| allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) | |||
| number_limits: int = 0 | |||
| class File(BaseModel): | |||
| dify_model_identity: str = FILE_MODEL_IDENTITY | |||
| id: Optional[str] = None # message file id | |||
| tenant_id: str | |||
| type: FileType | |||
| transfer_method: FileTransferMethod | |||
| remote_url: Optional[str] = None # remote url | |||
| related_id: Optional[str] = None | |||
| filename: Optional[str] = None | |||
| extension: Optional[str] = Field(default=None, description="File extension, should contains dot") | |||
| mime_type: Optional[str] = None | |||
| size: int = -1 | |||
| _extra_config: FileExtraConfig | None = None | |||
| def to_dict(self) -> Mapping[str, str | int | None]: | |||
| data = self.model_dump(mode="json") | |||
| return { | |||
| **data, | |||
| "url": self.generate_url(), | |||
| } | |||
| @property | |||
| def markdown(self) -> str: | |||
| url = self.generate_url() | |||
| if self.type == FileType.IMAGE: | |||
| text = f'' | |||
| else: | |||
| text = f"[{self.filename or url}]({url})" | |||
| return text | |||
| def generate_url(self) -> Optional[str]: | |||
| if self.type == FileType.IMAGE: | |||
| if self.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| return self.remote_url | |||
| elif self.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| if self.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| return helpers.get_signed_image_url(upload_file_id=self.related_id) | |||
| elif self.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| assert self.related_id is not None | |||
| assert self.extension is not None | |||
| return ToolFileParser.get_tool_file_manager().sign_file( | |||
| tool_file_id=self.related_id, extension=self.extension | |||
| ) | |||
| else: | |||
| if self.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| return self.remote_url | |||
| elif self.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| if self.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| return helpers.get_signed_file_url(upload_file_id=self.related_id) | |||
| elif self.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| assert self.related_id is not None | |||
| assert self.extension is not None | |||
| return ToolFileParser.get_tool_file_manager().sign_file( | |||
| tool_file_id=self.related_id, extension=self.extension | |||
| ) | |||
| @model_validator(mode="after") | |||
| def validate_after(self): | |||
| match self.transfer_method: | |||
| case FileTransferMethod.REMOTE_URL: | |||
| if not self.remote_url: | |||
| raise ValueError("Missing file url") | |||
| if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): | |||
| raise ValueError("Invalid file url") | |||
| case FileTransferMethod.LOCAL_FILE: | |||
| if not self.related_id: | |||
| raise ValueError("Missing file related_id") | |||
| case FileTransferMethod.TOOL_FILE: | |||
| if not self.related_id: | |||
| raise ValueError("Missing file related_id") | |||
| # Validate the extra config. | |||
| if not self._extra_config: | |||
| return self | |||
| if self._extra_config.allowed_file_types: | |||
| if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: | |||
| raise ValueError(f"Invalid file type: {self.type}") | |||
| if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: | |||
| raise ValueError(f"Invalid file extension: {self.extension}") | |||
| if ( | |||
| self._extra_config.allowed_upload_methods | |||
| and self.transfer_method not in self._extra_config.allowed_upload_methods | |||
| ): | |||
| raise ValueError(f"Invalid transfer method: {self.transfer_method}") | |||
| match self.type: | |||
| case FileType.IMAGE: | |||
| # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". | |||
| if not self._extra_config.image_config: | |||
| return self | |||
| # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field | |||
| if ( | |||
| self._extra_config.image_config.transfer_methods | |||
| and self.transfer_method not in self._extra_config.image_config.transfer_methods | |||
| ): | |||
| raise ValueError(f"Invalid transfer method: {self.transfer_method}") | |||
| return self | |||
| @@ -1,4 +1,9 @@ | |||
| tool_file_manager = {"manager": None} | |||
| from typing import TYPE_CHECKING, Any | |||
| if TYPE_CHECKING: | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| tool_file_manager: dict[str, Any] = {"manager": None} | |||
| class ToolFileParser: | |||
| @@ -1,79 +0,0 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import logging | |||
| import os | |||
| import time | |||
| from typing import Optional | |||
| from configs import dify_config | |||
| from extensions.ext_storage import storage | |||
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] | |||
| IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) | |||
| class UploadFileParser: | |||
| @classmethod | |||
| def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: | |||
| if not upload_file: | |||
| return None | |||
| if upload_file.extension not in IMAGE_EXTENSIONS: | |||
| return None | |||
| if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: | |||
| return cls.get_signed_temp_image_url(upload_file.id) | |||
| else: | |||
| # get image file base64 | |||
| try: | |||
| data = storage.load(upload_file.key) | |||
| except FileNotFoundError: | |||
| logging.error(f"File not found: {upload_file.key}") | |||
| return None | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return f"data:{upload_file.mime_type};base64,{encoded_string}" | |||
| @classmethod | |||
| def get_signed_temp_image_url(cls, upload_file_id) -> str: | |||
| """ | |||
| get signed url from upload file | |||
| :param upload_file: UploadFile object | |||
| :return: | |||
| """ | |||
| base_url = dify_config.FILES_URL | |||
| image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" | |||
| timestamp = str(int(time.time())) | |||
| nonce = os.urandom(16).hex() | |||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |||
| @classmethod | |||
| def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: | |||
| """ | |||
| verify signature | |||
| :param upload_file_id: file id | |||
| :param timestamp: timestamp | |||
| :param nonce: nonce | |||
| :param sign: signature | |||
| :return: | |||
| """ | |||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() | |||
| recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() | |||
| # verify signature | |||
| if sign != recalculated_encoded_sign: | |||
| return False | |||
| current_time = int(time.time()) | |||
| return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT | |||
| @@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") | |||
| SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") | |||
| SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) | |||
| proxies = ( | |||
| {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} | |||
| proxy_mounts = ( | |||
| { | |||
| "http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL), | |||
| "https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL), | |||
| } | |||
| if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL | |||
| else None | |||
| ) | |||
| @@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| while retries <= max_retries: | |||
| try: | |||
| if SSRF_PROXY_ALL_URL: | |||
| response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) | |||
| elif proxies: | |||
| response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) | |||
| with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client: | |||
| response = client.request(method=method, url=url, **kwargs) | |||
| elif proxy_mounts: | |||
| with httpx.Client(mounts=proxy_mounts) as client: | |||
| response = client.request(method=method, url=url, **kwargs) | |||
| else: | |||
| response = httpx.request(method=method, url=url, **kwargs) | |||
| with httpx.Client() as client: | |||
| response = client.request(method=method, url=url, **kwargs) | |||
| if response.status_code not in STATUS_FORCELIST: | |||
| return response | |||
| @@ -1,18 +1,20 @@ | |||
| from typing import Optional | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.file import file_manager | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageRole, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.model import AppMode, Conversation, Message, MessageFile | |||
| from models.workflow import WorkflowRun | |||
| @@ -65,7 +67,6 @@ class TokenBufferMemory: | |||
| messages = list(reversed(thread_messages)) | |||
| message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| @@ -84,17 +85,20 @@ class TokenBufferMemory: | |||
| workflow_run.workflow.features_dict, is_vision=False | |||
| ) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.transform_message_files(files, file_extra_config) | |||
| if file_extra_config and app_record: | |||
| file_objs = file_factory.build_from_message_files( | |||
| message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| if not file_objs: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| else: | |||
| prompt_message_contents = [TextPromptMessageContent(data=message.query)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||
| for file_obj in file_objs: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| import os | |||
| from collections.abc import Callable, Generator, Sequence | |||
| from typing import IO, Optional, Union, cast | |||
| from collections.abc import Callable, Generator, Iterable, Sequence | |||
| from typing import IO, Any, Optional, Union, cast | |||
| from core.embedding.embedding_constant import EmbeddingInputType | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |||
| @@ -274,7 +274,7 @@ class ModelInstance: | |||
| user=user, | |||
| ) | |||
| def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: | |||
| def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: | |||
| """ | |||
| Invoke large language tts model | |||
| @@ -298,7 +298,7 @@ class ModelInstance: | |||
| voice=voice, | |||
| ) | |||
| def _round_robin_invoke(self, function: Callable, *args, **kwargs): | |||
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): | |||
| """ | |||
| Round-robin invoke | |||
| :param function: function to invoke | |||
| @@ -0,0 +1,36 @@ | |||
| from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| from .message_entities import ( | |||
| AssistantPromptMessage, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentType, | |||
| PromptMessageRole, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from .model_entities import ModelPropertyKey | |||
| __all__ = [ | |||
| "ImagePromptMessageContent", | |||
| "PromptMessage", | |||
| "PromptMessageRole", | |||
| "LLMUsage", | |||
| "ModelPropertyKey", | |||
| "AssistantPromptMessage", | |||
| "PromptMessage", | |||
| "PromptMessageContent", | |||
| "PromptMessageRole", | |||
| "SystemPromptMessage", | |||
| "TextPromptMessageContent", | |||
| "UserPromptMessage", | |||
| "PromptMessageTool", | |||
| "ToolPromptMessage", | |||
| "PromptMessageContentType", | |||
| "LLMResult", | |||
| "LLMResultChunk", | |||
| "LLMResultChunkDelta", | |||
| ] | |||
| @@ -79,7 +79,7 @@ class ImagePromptMessageContent(PromptMessageContent): | |||
| Model class for image prompt message content. | |||
| """ | |||
| class DETAIL(Enum): | |||
| class DETAIL(str, Enum): | |||
| LOW = "low" | |||
| HIGH = "high" | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| import os | |||
| import re | |||
| import time | |||
| from abc import abstractmethod | |||
| @@ -8,6 +7,7 @@ from typing import Optional, Union | |||
| from pydantic import ConfigDict | |||
| from configs import dify_config | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.callbacks.logging_callback import LoggingCallback | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| @@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel): | |||
| callbacks = callbacks or [] | |||
| if bool(os.environ.get("DEBUG", "False").lower() == "true"): | |||
| if dify_config.DEBUG: | |||
| callbacks.append(LoggingCallback()) | |||
| # trigger before invoke callbacks | |||
| @@ -1,6 +1,7 @@ | |||
| import logging | |||
| import re | |||
| from abc import abstractmethod | |||
| from collections.abc import Iterable | |||
| from typing import Any, Optional | |||
| from pydantic import ConfigDict | |||
| @@ -22,8 +23,14 @@ class TTSModel(AIModel): | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| def invoke( | |||
| self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None | |||
| ): | |||
| self, | |||
| model: str, | |||
| tenant_id: str, | |||
| credentials: dict, | |||
| content_text: str, | |||
| voice: str, | |||
| user: Optional[str] = None, | |||
| ) -> Iterable[bytes]: | |||
| """ | |||
| Invoke large language model | |||
| @@ -50,8 +57,14 @@ class TTSModel(AIModel): | |||
| @abstractmethod | |||
| def _invoke( | |||
| self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None | |||
| ): | |||
| self, | |||
| model: str, | |||
| tenant_id: str, | |||
| credentials: dict, | |||
| content_text: str, | |||
| voice: str, | |||
| user: Optional[str] = None, | |||
| ) -> Iterable[bytes]: | |||
| """ | |||
| Invoke large language model | |||
| @@ -68,25 +81,25 @@ class TTSModel(AIModel): | |||
| def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: | |||
| """ | |||
| Get voice for given tts model voices | |||
| Retrieves the list of voices supported by a given text-to-speech (TTS) model. | |||
| :param language: tts language | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: voices lists | |||
| :param language: The language for which the voices are requested. | |||
| :param model: The name of the TTS model. | |||
| :param credentials: The credentials required to access the TTS model. | |||
| :return: A list of voices supported by the TTS model. | |||
| """ | |||
| model_schema = self.get_model_schema(model, credentials) | |||
| if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: | |||
| voices = model_schema.model_properties[ModelPropertyKey.VOICES] | |||
| if language: | |||
| return [ | |||
| {"name": d["name"], "value": d["mode"]} | |||
| for d in voices | |||
| if language and language in d.get("language") | |||
| ] | |||
| else: | |||
| return [{"name": d["name"], "value": d["mode"]} for d in voices] | |||
| if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties: | |||
| raise ValueError("this model does not support voice") | |||
| voices = model_schema.model_properties[ModelPropertyKey.VOICES] | |||
| if language: | |||
| return [ | |||
| {"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language") | |||
| ] | |||
| else: | |||
| return [{"name": d["name"], "value": d["mode"]} for d in voices] | |||
| def _get_model_default_voice(self, model: str, credentials: dict) -> Any: | |||
| """ | |||
| @@ -111,8 +124,10 @@ class TTSModel(AIModel): | |||
| """ | |||
| model_schema = self.get_model_schema(model, credentials) | |||
| if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: | |||
| return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] | |||
| if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: | |||
| raise ValueError("this model does not support audio type") | |||
| return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] | |||
| def _get_model_word_limit(self, model: str, credentials: dict) -> int: | |||
| """ | |||
| @@ -121,8 +136,10 @@ class TTSModel(AIModel): | |||
| """ | |||
| model_schema = self.get_model_schema(model, credentials) | |||
| if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties: | |||
| return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] | |||
| if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: | |||
| raise ValueError("this model does not support word limit") | |||
| return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] | |||
| def _get_model_workers_limit(self, model: str, credentials: dict) -> int: | |||
| """ | |||
| @@ -131,8 +148,10 @@ class TTSModel(AIModel): | |||
| """ | |||
| model_schema = self.get_model_schema(model, credentials) | |||
| if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties: | |||
| return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] | |||
| if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: | |||
| raise ValueError("this model does not support max workers") | |||
| return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] | |||
| @staticmethod | |||
| def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): | |||
| @@ -1,12 +1,15 @@ | |||
| from typing import Optional, Union | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.file.file_obj import FileVar | |||
| from core.file import file_manager | |||
| from core.file.models import File | |||
| from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageRole, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| @@ -14,7 +17,6 @@ from core.model_runtime.entities.message_entities import ( | |||
| ) | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| @@ -28,22 +30,19 @@ class AdvancedPromptTransform(PromptTransform): | |||
| def get_prompt( | |||
| self, | |||
| prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], | |||
| inputs: dict, | |||
| *, | |||
| prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, | |||
| inputs: dict[str, str], | |||
| query: str, | |||
| files: list[FileVar], | |||
| files: Sequence[File], | |||
| context: Optional[str], | |||
| memory_config: Optional[MemoryConfig], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| query_prompt_template: Optional[str] = None, | |||
| ) -> list[PromptMessage]: | |||
| inputs = {key: str(value) for key, value in inputs.items()} | |||
| prompt_messages = [] | |||
| model_mode = ModelMode.value_of(model_config.mode) | |||
| if model_mode == ModelMode.COMPLETION: | |||
| if isinstance(prompt_template, CompletionModelPromptTemplate): | |||
| prompt_messages = self._get_completion_model_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| inputs=inputs, | |||
| @@ -54,12 +53,11 @@ class AdvancedPromptTransform(PromptTransform): | |||
| memory=memory, | |||
| model_config=model_config, | |||
| ) | |||
| elif model_mode == ModelMode.CHAT: | |||
| elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): | |||
| prompt_messages = self._get_chat_model_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| inputs=inputs, | |||
| query=query, | |||
| query_prompt_template=query_prompt_template, | |||
| files=files, | |||
| context=context, | |||
| memory_config=memory_config, | |||
| @@ -74,7 +72,7 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_template: CompletionModelPromptTemplate, | |||
| inputs: dict, | |||
| query: Optional[str], | |||
| files: list[FileVar], | |||
| files: Sequence[File], | |||
| context: Optional[str], | |||
| memory_config: Optional[MemoryConfig], | |||
| memory: Optional[TokenBufferMemory], | |||
| @@ -88,10 +86,10 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_messages = [] | |||
| if prompt_template.edition_type == "basic" or not prompt_template.edition_type: | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) | |||
| if memory and memory_config: | |||
| role_prefix = memory_config.role_prefix | |||
| @@ -100,15 +98,15 @@ class AdvancedPromptTransform(PromptTransform): | |||
| memory_config=memory_config, | |||
| raw_prompt=raw_prompt, | |||
| role_prefix=role_prefix, | |||
| prompt_template=prompt_template, | |||
| parser=parser, | |||
| prompt_inputs=prompt_inputs, | |||
| model_config=model_config, | |||
| ) | |||
| if query: | |||
| prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) | |||
| prompt_inputs = self._set_query_variable(query, parser, prompt_inputs) | |||
| prompt = prompt_template.format(prompt_inputs) | |||
| prompt = parser.format(prompt_inputs) | |||
| else: | |||
| prompt = raw_prompt | |||
| prompt_inputs = inputs | |||
| @@ -116,9 +114,10 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | |||
| if files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=prompt)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| @@ -131,35 +130,28 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_template: list[ChatModelMessage], | |||
| inputs: dict, | |||
| query: Optional[str], | |||
| files: list[FileVar], | |||
| files: Sequence[File], | |||
| context: Optional[str], | |||
| memory_config: Optional[MemoryConfig], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| query_prompt_template: Optional[str] = None, | |||
| ) -> list[PromptMessage]: | |||
| """ | |||
| Get chat model prompt messages. | |||
| """ | |||
| raw_prompt_list = prompt_template | |||
| prompt_messages = [] | |||
| for prompt_item in raw_prompt_list: | |||
| for prompt_item in prompt_template: | |||
| raw_prompt = prompt_item.text | |||
| if prompt_item.edition_type == "basic" or not prompt_item.edition_type: | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt = prompt_template.format(prompt_inputs) | |||
| parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} | |||
| prompt_inputs = self._set_context_variable(context=context, parser=parser, prompt_inputs=prompt_inputs) | |||
| prompt = parser.format(prompt_inputs) | |||
| elif prompt_item.edition_type == "jinja2": | |||
| prompt = raw_prompt | |||
| prompt_inputs = inputs | |||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | |||
| prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) | |||
| else: | |||
| raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") | |||
| @@ -170,25 +162,25 @@ class AdvancedPromptTransform(PromptTransform): | |||
| elif prompt_item.role == PromptMessageRole.ASSISTANT: | |||
| prompt_messages.append(AssistantPromptMessage(content=prompt)) | |||
| if query and query_prompt_template: | |||
| prompt_template = PromptTemplateParser( | |||
| template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl | |||
| if query and memory_config and memory_config.query_prompt_template: | |||
| parser = PromptTemplateParser( | |||
| template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl | |||
| ) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} | |||
| prompt_inputs["#sys.query#"] = query | |||
| prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) | |||
| prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) | |||
| query = prompt_template.format(prompt_inputs) | |||
| query = parser.format(prompt_inputs) | |||
| if memory and memory_config: | |||
| prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | |||
| if files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||
| if files and query is not None: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=query)) | |||
| @@ -200,19 +192,19 @@ class AdvancedPromptTransform(PromptTransform): | |||
| # get last user message content and add files | |||
| prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| last_message.content = prompt_message_contents | |||
| else: | |||
| prompt_message_contents = [TextPromptMessageContent(data="")] # not for query | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| elif query: | |||
| @@ -220,8 +212,8 @@ class AdvancedPromptTransform(PromptTransform): | |||
| return prompt_messages | |||
| def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: | |||
| if "#context#" in prompt_template.variable_keys: | |||
| def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: | |||
| if "#context#" in parser.variable_keys: | |||
| if context: | |||
| prompt_inputs["#context#"] = context | |||
| else: | |||
| @@ -229,8 +221,8 @@ class AdvancedPromptTransform(PromptTransform): | |||
| return prompt_inputs | |||
| def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: | |||
| if "#query#" in prompt_template.variable_keys: | |||
| def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: | |||
| if "#query#" in parser.variable_keys: | |||
| if query: | |||
| prompt_inputs["#query#"] = query | |||
| else: | |||
| @@ -244,16 +236,16 @@ class AdvancedPromptTransform(PromptTransform): | |||
| memory_config: MemoryConfig, | |||
| raw_prompt: str, | |||
| role_prefix: MemoryConfig.RolePrefix, | |||
| prompt_template: PromptTemplateParser, | |||
| parser: PromptTemplateParser, | |||
| prompt_inputs: dict, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| ) -> dict: | |||
| if "#histories#" in prompt_template.variable_keys: | |||
| if "#histories#" in parser.variable_keys: | |||
| if memory: | |||
| inputs = {"#histories#": "", **prompt_inputs} | |||
| prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) | |||
| parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) | |||
| prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} | |||
| tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) | |||
| rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) | |||
| @@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional | |||
| from core.app.app_config.entities import PromptTemplateEntity | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.file import file_manager | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| @@ -18,7 +20,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from models.model import AppMode | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| from core.file.models import File | |||
| class ModelMode(enum.Enum): | |||
| @@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform): | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict, | |||
| query: str, | |||
| files: list["FileVar"], | |||
| files: list["File"], | |||
| context: Optional[str], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| @@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform): | |||
| inputs: dict, | |||
| query: str, | |||
| context: Optional[str], | |||
| files: list["FileVar"], | |||
| files: list["File"], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| ) -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| @@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform): | |||
| inputs: dict, | |||
| query: str, | |||
| context: Optional[str], | |||
| files: list["FileVar"], | |||
| files: list["File"], | |||
| memory: Optional[TokenBufferMemory], | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| ) -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| @@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform): | |||
| return [self.get_last_user_message(prompt, files)], stops | |||
| def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: | |||
| def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: | |||
| if files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=prompt)] | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||
| for file in files: | |||
| prompt_message_contents.append(file.prompt_message_content) | |||
| prompt_message_contents.append(file_manager.to_prompt_message_content(file)) | |||
| prompt_message = UserPromptMessage(content=prompt_message_contents) | |||
| else: | |||
| @@ -32,8 +32,8 @@ class UserToolProvider(BaseModel): | |||
| original_credentials: Optional[dict] = None | |||
| is_team_authorization: bool = False | |||
| allow_delete: bool = True | |||
| tools: list[UserTool] = None | |||
| labels: list[str] = None | |||
| tools: list[UserTool] | None = None | |||
| labels: list[str] | None = None | |||
| def to_dict(self) -> dict: | |||
| # ------------- | |||
| @@ -42,7 +42,7 @@ class UserToolProvider(BaseModel): | |||
| for tool in tools: | |||
| if tool.get("parameters"): | |||
| for parameter in tool.get("parameters"): | |||
| if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: | |||
| if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: | |||
| parameter["type"] = "files" | |||
| # ------------- | |||
| @@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel): | |||
| BLOB = "blob" | |||
| JSON = "json" | |||
| IMAGE_LINK = "image_link" | |||
| FILE_VAR = "file_var" | |||
| FILE = "file" | |||
| type: MessageType = MessageType.TEXT | |||
| """ | |||
| plain text, image url or link url | |||
| """ | |||
| message: str | bytes | dict | None = None | |||
| meta: dict[str, Any] | None = None | |||
| # TODO: Use a BaseModel for meta | |||
| meta: dict[str, Any] = Field(default_factory=dict) | |||
| save_as: str = "" | |||
| @@ -143,6 +144,67 @@ class ToolParameter(BaseModel): | |||
| SELECT = "select" | |||
| SECRET_INPUT = "secret-input" | |||
| FILE = "file" | |||
| FILES = "files" | |||
| # deprecated, should not use. | |||
| SYSTEM_FILES = "systme-files" | |||
| def as_normal_type(self): | |||
| if self in { | |||
| ToolParameter.ToolParameterType.SECRET_INPUT, | |||
| ToolParameter.ToolParameterType.SELECT, | |||
| }: | |||
| return "string" | |||
| return self.value | |||
| def cast_value(self, value: Any, /): | |||
| try: | |||
| match self: | |||
| case ( | |||
| ToolParameter.ToolParameterType.STRING | |||
| | ToolParameter.ToolParameterType.SECRET_INPUT | |||
| | ToolParameter.ToolParameterType.SELECT | |||
| ): | |||
| if value is None: | |||
| return "" | |||
| else: | |||
| return value if isinstance(value, str) else str(value) | |||
| case ToolParameter.ToolParameterType.BOOLEAN: | |||
| if value is None: | |||
| return False | |||
| elif isinstance(value, str): | |||
| # Allowed YAML boolean value strings: https://yaml.org/type/bool.html | |||
| # and also '0' for False and '1' for True | |||
| match value.lower(): | |||
| case "true" | "yes" | "y" | "1": | |||
| return True | |||
| case "false" | "no" | "n" | "0": | |||
| return False | |||
| case _: | |||
| return bool(value) | |||
| else: | |||
| return value if isinstance(value, bool) else bool(value) | |||
| case ToolParameter.ToolParameterType.NUMBER: | |||
| if isinstance(value, int | float): | |||
| return value | |||
| elif isinstance(value, str) and value: | |||
| if "." in value: | |||
| return float(value) | |||
| else: | |||
| return int(value) | |||
| case ( | |||
| ToolParameter.ToolParameterType.SYSTEM_FILES | |||
| | ToolParameter.ToolParameterType.FILE | |||
| | ToolParameter.ToolParameterType.FILES | |||
| ): | |||
| return value | |||
| case _: | |||
| return str(value) | |||
| except Exception: | |||
| raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") | |||
| class ToolParameterForm(Enum): | |||
| SCHEMA = "schema" # should be set while adding tool | |||
| @@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool): | |||
| for image in response.data: | |||
| mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) | |||
| blob_message = self.create_blob_message( | |||
| blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value | |||
| blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE | |||
| ) | |||
| result.append(blob_message) | |||
| return result | |||
| @@ -2,7 +2,7 @@ from typing import Any | |||
| from duckduckgo_search import DDGS | |||
| from core.file.file_obj import FileTransferMethod | |||
| from core.file.models import FileTransferMethod | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| @@ -13,7 +13,6 @@ from core.tools.errors import ( | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| @@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| # the parameter is not set currently, set the default value if needed | |||
| if parameter_schema.default is not None: | |||
| default_value = ToolParameterConverter.cast_parameter_by_type( | |||
| parameter_schema.default, parameter_schema.type | |||
| ) | |||
| default_value = parameter_schema.type.cast_value(parameter_schema.default) | |||
| tool_parameters[parameter] = default_value | |||
| def validate_credentials(self, credentials: dict[str, Any]) -> None: | |||
| @@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import ( | |||
| ) | |||
| from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| class ToolProviderController(BaseModel, ABC): | |||
| @@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC): | |||
| # the parameter is not set currently, set the default value if needed | |||
| if parameter_schema.default is not None: | |||
| tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type( | |||
| parameter_schema.default, parameter_schema.type | |||
| ) | |||
| tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) | |||
| def validate_credentials_format(self, credentials: dict[str, Any]) -> None: | |||
| """ | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Optional | |||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||
| from core.app.app_config.entities import VariableEntityType | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ( | |||
| @@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { | |||
| VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, | |||
| VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, | |||
| VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, | |||
| VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, | |||
| VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, | |||
| } | |||
| @@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| if not app: | |||
| raise ValueError("app not found") | |||
| controller = WorkflowToolProviderController( | |||
| **{ | |||
| controller = WorkflowToolProviderController.model_validate( | |||
| { | |||
| "identity": { | |||
| "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", | |||
| "name": db_provider.label, | |||
| @@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| :param app: the app | |||
| :return: the tool | |||
| """ | |||
| workflow: Workflow = ( | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) | |||
| .first() | |||
| @@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| raise ValueError("workflow not found") | |||
| # fetch start node | |||
| graph: dict = workflow.graph_dict | |||
| features_dict: dict = workflow.features_dict | |||
| graph = workflow.graph_dict | |||
| features_dict = workflow.features_dict | |||
| features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) | |||
| parameters = db_provider.parameter_configurations | |||
| variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) | |||
| def fetch_workflow_variable(variable_name: str) -> VariableEntity: | |||
| def fetch_workflow_variable(variable_name: str): | |||
| return next(filter(lambda x: x.variable == variable_name, variables), None) | |||
| user = db_provider.user | |||
| @@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| llm_description=parameter.description, | |||
| required=variable.required, | |||
| options=options, | |||
| default=variable.default, | |||
| ) | |||
| ) | |||
| elif features.file_upload: | |||
| @@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| name=parameter.name, | |||
| label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), | |||
| human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), | |||
| type=ToolParameter.ToolParameterType.FILE, | |||
| type=ToolParameter.ToolParameterType.SYSTEM_FILES, | |||
| llm_description=parameter.description, | |||
| required=False, | |||
| form=parameter.form, | |||
| @@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import ( | |||
| ToolRuntimeVariablePool, | |||
| ) | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| if TYPE_CHECKING: | |||
| from core.file.file_obj import FileVar | |||
| from core.file.models import File | |||
| class Tool(BaseModel, ABC): | |||
| @@ -63,8 +62,12 @@ class Tool(BaseModel, ABC): | |||
| def __init__(self, **data: Any): | |||
| super().__init__(**data) | |||
| class VariableKey(Enum): | |||
| class VariableKey(str, Enum): | |||
| IMAGE = "image" | |||
| DOCUMENT = "document" | |||
| VIDEO = "video" | |||
| AUDIO = "audio" | |||
| CUSTOM = "custom" | |||
| def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": | |||
| """ | |||
| @@ -221,9 +224,7 @@ class Tool(BaseModel, ABC): | |||
| result = deepcopy(tool_parameters) | |||
| for parameter in self.parameters or []: | |||
| if parameter.name in tool_parameters: | |||
| result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( | |||
| tool_parameters[parameter.name], parameter.type | |||
| ) | |||
| result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) | |||
| return result | |||
| @@ -295,10 +296,8 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) | |||
| def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: | |||
| return ToolInvokeMessage( | |||
| type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as="" | |||
| ) | |||
| def create_file_message(self, file: "File") -> ToolInvokeMessage: | |||
| return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="") | |||
| def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: | |||
| """ | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from copy import deepcopy | |||
| from typing import Any, Optional, Union | |||
| from core.file.file_obj import FileTransferMethod, FileVar | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType | |||
| from core.tools.tool.tool import Tool | |||
| from extensions.ext_database import db | |||
| @@ -45,11 +45,13 @@ class WorkflowTool(Tool): | |||
| workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) | |||
| # transform the tool parameters | |||
| tool_parameters, files = self._transform_args(tool_parameters) | |||
| tool_parameters, files = self._transform_args(tool_parameters=tool_parameters) | |||
| from core.app.apps.workflow.app_generator import WorkflowAppGenerator | |||
| generator = WorkflowAppGenerator() | |||
| assert self.runtime is not None | |||
| assert self.runtime.invoke_from is not None | |||
| result = generator.generate( | |||
| app_model=app, | |||
| workflow=workflow, | |||
| @@ -74,7 +76,7 @@ class WorkflowTool(Tool): | |||
| else: | |||
| outputs, files = self._extract_files(outputs) | |||
| for file in files: | |||
| result.append(self.create_file_var_message(file)) | |||
| result.append(self.create_file_message(file)) | |||
| result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) | |||
| result.append(self.create_json_message(outputs)) | |||
| @@ -154,22 +156,22 @@ class WorkflowTool(Tool): | |||
| parameters_result = {} | |||
| files = [] | |||
| for parameter in parameter_rules: | |||
| if parameter.type == ToolParameter.ToolParameterType.FILE: | |||
| if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES: | |||
| file = tool_parameters.get(parameter.name) | |||
| if file: | |||
| try: | |||
| file_var_list = [FileVar(**f) for f in file] | |||
| for file_var in file_var_list: | |||
| file_dict = { | |||
| "transfer_method": file_var.transfer_method.value, | |||
| "type": file_var.type.value, | |||
| file_var_list = [File.model_validate(f) for f in file] | |||
| for file in file_var_list: | |||
| file_dict: dict[str, str | None] = { | |||
| "transfer_method": file.transfer_method.value, | |||
| "type": file.type.value, | |||
| } | |||
| if file_var.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| file_dict["tool_file_id"] = file_var.related_id | |||
| elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| file_dict["upload_file_id"] = file_var.related_id | |||
| elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| file_dict["url"] = file_var.preview_url | |||
| if file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| file_dict["tool_file_id"] = file.related_id | |||
| elif file.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| file_dict["upload_file_id"] = file.related_id | |||
| elif file.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| file_dict["url"] = file.generate_url() | |||
| files.append(file_dict) | |||
| except Exception as e: | |||
| @@ -179,7 +181,7 @@ class WorkflowTool(Tool): | |||
| return parameters_result, files | |||
| def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: | |||
| def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: | |||
| """ | |||
| extract files from the result | |||
| @@ -190,17 +192,13 @@ class WorkflowTool(Tool): | |||
| result = {} | |||
| for key, value in outputs.items(): | |||
| if isinstance(value, list): | |||
| has_file = False | |||
| for item in value: | |||
| if isinstance(item, dict) and item.get("__variant") == "FileVar": | |||
| try: | |||
| files.append(FileVar(**item)) | |||
| has_file = True | |||
| except Exception as e: | |||
| pass | |||
| if has_file: | |||
| continue | |||
| if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| file = File.model_validate(item) | |||
| files.append(file) | |||
| elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| file = File.model_validate(value) | |||
| files.append(file) | |||
| result[key] = value | |||
| return result, files | |||
| @@ -10,7 +10,8 @@ from yarl import URL | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||
| from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | |||
| from core.file.file_obj import FileTransferMethod | |||
| from core.file import FileType | |||
| from core.file.models import FileTransferMethod | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter | |||
| from core.tools.errors import ( | |||
| @@ -25,6 +26,7 @@ from core.tools.errors import ( | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool.workflow_tool import WorkflowTool | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from enums import CreatedByRole | |||
| from extensions.ext_database import db | |||
| from models.model import Message, MessageFile | |||
| @@ -128,6 +130,7 @@ class ToolEngine: | |||
| """ | |||
| try: | |||
| # hit the callback handler | |||
| assert tool.identity is not None | |||
| workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) | |||
| if isinstance(tool, WorkflowTool): | |||
| @@ -258,7 +261,10 @@ class ToolEngine: | |||
| @staticmethod | |||
| def _create_message_files( | |||
| tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str | |||
| tool_messages: list[ToolInvokeMessageBinary], | |||
| agent_message: Message, | |||
| invoke_from: InvokeFrom, | |||
| user_id: str, | |||
| ) -> list[tuple[Any, str]]: | |||
| """ | |||
| Create message file | |||
| @@ -269,29 +275,31 @@ class ToolEngine: | |||
| result = [] | |||
| for message in tool_messages: | |||
| file_type = "bin" | |||
| if "image" in message.mimetype: | |||
| file_type = "image" | |||
| file_type = FileType.IMAGE | |||
| elif "video" in message.mimetype: | |||
| file_type = "video" | |||
| file_type = FileType.VIDEO | |||
| elif "audio" in message.mimetype: | |||
| file_type = "audio" | |||
| elif "text" in message.mimetype: | |||
| file_type = "text" | |||
| elif "pdf" in message.mimetype: | |||
| file_type = "pdf" | |||
| elif "zip" in message.mimetype: | |||
| file_type = "archive" | |||
| # ... | |||
| file_type = FileType.AUDIO | |||
| elif "text" in message.mimetype or "pdf" in message.mimetype: | |||
| file_type = FileType.DOCUMENT | |||
| else: | |||
| file_type = FileType.CUSTOM | |||
| # extract tool file id from url | |||
| tool_file_id = message.url.split("/")[-1].split(".")[0] | |||
| message_file = MessageFile( | |||
| message_id=agent_message.id, | |||
| type=file_type, | |||
| transfer_method=FileTransferMethod.TOOL_FILE.value, | |||
| transfer_method=FileTransferMethod.TOOL_FILE, | |||
| belongs_to="assistant", | |||
| url=message.url, | |||
| upload_file_id=None, | |||
| created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), | |||
| upload_file_id=tool_file_id, | |||
| created_by_role=( | |||
| CreatedByRole.ACCOUNT | |||
| if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else CreatedByRole.END_USER | |||
| ), | |||
| created_by=user_id, | |||
| ) | |||
| @@ -57,22 +57,32 @@ class ToolFileManager: | |||
| @staticmethod | |||
| def create_file_by_raw( | |||
| user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str | |||
| *, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| conversation_id: Optional[str], | |||
| file_binary: bytes, | |||
| mimetype: str, | |||
| ) -> ToolFile: | |||
| """ | |||
| create file | |||
| """ | |||
| extension = guess_extension(mimetype) or ".bin" | |||
| unique_name = uuid4().hex | |||
| filename = f"tools/{tenant_id}/{unique_name}{extension}" | |||
| storage.save(filename, file_binary) | |||
| filename = f"{unique_name}{extension}" | |||
| filepath = f"tools/{tenant_id}/{filename}" | |||
| storage.save(filepath, file_binary) | |||
| tool_file = ToolFile( | |||
| user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| conversation_id=conversation_id, | |||
| file_key=filepath, | |||
| mimetype=mimetype, | |||
| name=filename, | |||
| size=len(file_binary), | |||
| ) | |||
| db.session.add(tool_file) | |||
| db.session.commit() | |||
| db.session.refresh(tool_file) | |||
| return tool_file | |||
| @@ -80,29 +90,34 @@ class ToolFileManager: | |||
| def create_file_by_url( | |||
| user_id: str, | |||
| tenant_id: str, | |||
| conversation_id: str, | |||
| conversation_id: str | None, | |||
| file_url: str, | |||
| ) -> ToolFile: | |||
| """ | |||
| create file | |||
| """ | |||
| # try to download image | |||
| response = get(file_url) | |||
| response.raise_for_status() | |||
| blob = response.content | |||
| try: | |||
| response = get(file_url) | |||
| response.raise_for_status() | |||
| blob = response.content | |||
| except Exception as e: | |||
| logger.error(f"Failed to download file from {file_url}: {e}") | |||
| raise | |||
| mimetype = guess_type(file_url)[0] or "octet/stream" | |||
| extension = guess_extension(mimetype) or ".bin" | |||
| unique_name = uuid4().hex | |||
| filename = f"tools/{tenant_id}/{unique_name}{extension}" | |||
| storage.save(filename, blob) | |||
| filename = f"{unique_name}{extension}" | |||
| filepath = f"tools/{tenant_id}/{filename}" | |||
| storage.save(filepath, blob) | |||
| tool_file = ToolFile( | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| conversation_id=conversation_id, | |||
| file_key=filename, | |||
| file_key=filepath, | |||
| mimetype=mimetype, | |||
| original_url=file_url, | |||
| name=filename, | |||
| size=len(blob), | |||
| ) | |||
| db.session.add(tool_file) | |||
| @@ -110,18 +125,6 @@ class ToolFileManager: | |||
| return tool_file | |||
| @staticmethod | |||
| def create_file_by_key( | |||
| user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str | |||
| ) -> ToolFile: | |||
| """ | |||
| create file | |||
| """ | |||
| tool_file = ToolFile( | |||
| user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype | |||
| ) | |||
| return tool_file | |||
| @staticmethod | |||
| def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: | |||
| """ | |||
| @@ -131,7 +134,7 @@ class ToolFileManager: | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| tool_file: ToolFile = ( | |||
| tool_file = ( | |||
| db.session.query(ToolFile) | |||
| .filter( | |||
| ToolFile.id == id, | |||
| @@ -155,7 +158,7 @@ class ToolFileManager: | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| message_file: MessageFile = ( | |||
| message_file = ( | |||
| db.session.query(MessageFile) | |||
| .filter( | |||
| MessageFile.id == id, | |||
| @@ -166,13 +169,16 @@ class ToolFileManager: | |||
| # Check if message_file is not None | |||
| if message_file is not None: | |||
| # get tool file id | |||
| tool_file_id = message_file.url.split("/")[-1] | |||
| # trim extension | |||
| tool_file_id = tool_file_id.split(".")[0] | |||
| if message_file.url is not None: | |||
| tool_file_id = message_file.url.split("/")[-1] | |||
| # trim extension | |||
| tool_file_id = tool_file_id.split(".")[0] | |||
| else: | |||
| tool_file_id = None | |||
| else: | |||
| tool_file_id = None | |||
| tool_file: ToolFile = ( | |||
| tool_file = ( | |||
| db.session.query(ToolFile) | |||
| .filter( | |||
| ToolFile.id == tool_file_id, | |||
| @@ -196,7 +202,7 @@ class ToolFileManager: | |||
| :return: the binary of the file, mime type | |||
| """ | |||
| tool_file: ToolFile = ( | |||
| tool_file = ( | |||
| db.session.query(ToolFile) | |||
| .filter( | |||
| ToolFile.id == tool_file_id, | |||
| @@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_label_manager import ToolLabelManager | |||
| from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager | |||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | |||
| from extensions.ext_database import db | |||
| from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider | |||
| from services.tools.tools_transform_service import ToolTransformService | |||
| @@ -203,7 +202,7 @@ class ToolManager: | |||
| raise ToolProviderNotFoundError(f"provider type {provider_type} not found") | |||
| @classmethod | |||
| def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: | |||
| def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): | |||
| """ | |||
| init runtime parameter | |||
| """ | |||
| @@ -222,7 +221,7 @@ class ToolManager: | |||
| f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" | |||
| ) | |||
| return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) | |||
| return parameter_rule.type.cast_value(parameter_value) | |||
| @classmethod | |||
| def get_agent_tool_runtime( | |||
| @@ -243,7 +242,11 @@ class ToolManager: | |||
| parameters = tool_entity.get_all_runtime_parameters() | |||
| for parameter in parameters: | |||
| # check file types | |||
| if parameter.type == ToolParameter.ToolParameterType.FILE: | |||
| if parameter.type in { | |||
| ToolParameter.ToolParameterType.SYSTEM_FILES, | |||
| ToolParameter.ToolParameterType.FILE, | |||
| ToolParameter.ToolParameterType.FILES, | |||
| }: | |||
| raise ValueError(f"file type parameter {parameter.name} not supported in agent") | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM: | |||
| @@ -1,7 +1,8 @@ | |||
| import logging | |||
| from mimetypes import guess_extension | |||
| from typing import Optional | |||
| from core.file.file_obj import FileTransferMethod, FileType | |||
| from core.file import File, FileTransferMethod, FileType | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| @@ -11,7 +12,7 @@ logger = logging.getLogger(__name__) | |||
| class ToolFileMessageTransformer: | |||
| @classmethod | |||
| def transform_tool_invoke_messages( | |||
| cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | |||
| cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None | |||
| ) -> list[ToolInvokeMessage]: | |||
| """ | |||
| Transform tool message and handle file download | |||
| @@ -21,7 +22,7 @@ class ToolFileMessageTransformer: | |||
| for message in messages: | |||
| if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: | |||
| result.append(message) | |||
| elif message.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str): | |||
| # try to download image | |||
| try: | |||
| file = ToolFileManager.create_file_by_url( | |||
| @@ -50,11 +51,14 @@ class ToolFileMessageTransformer: | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.BLOB: | |||
| # get mime type and save blob to storage | |||
| assert message.meta is not None | |||
| mimetype = message.meta.get("mime_type", "octet/stream") | |||
| # if message is str, encode it to bytes | |||
| if isinstance(message.message, str): | |||
| message.message = message.message.encode("utf-8") | |||
| # FIXME: should do a type check here. | |||
| assert isinstance(message.message, bytes) | |||
| file = ToolFileManager.create_file_by_raw( | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| @@ -63,7 +67,7 @@ class ToolFileMessageTransformer: | |||
| mimetype=mimetype, | |||
| ) | |||
| url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) | |||
| url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) | |||
| # check if file is image | |||
| if "image" in mimetype: | |||
| @@ -84,12 +88,14 @@ class ToolFileMessageTransformer: | |||
| meta=message.meta.copy() if message.meta is not None else {}, | |||
| ) | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: | |||
| file_var = message.meta.get("file_var") | |||
| if file_var: | |||
| if file_var.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| url = cls.get_tool_file_url(file_var.related_id, file_var.extension) | |||
| if file_var.type == FileType.IMAGE: | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||
| assert message.meta is not None | |||
| file = message.meta.get("file") | |||
| if isinstance(file, File): | |||
| if file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| assert file.related_id is not None | |||
| url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) | |||
| if file.type == FileType.IMAGE: | |||
| result.append( | |||
| ToolInvokeMessage( | |||
| type=ToolInvokeMessage.MessageType.IMAGE_LINK, | |||
| @@ -107,11 +113,13 @@ class ToolFileMessageTransformer: | |||
| meta=message.meta.copy() if message.meta is not None else {}, | |||
| ) | |||
| ) | |||
| else: | |||
| result.append(message) | |||
| else: | |||
| result.append(message) | |||
| return result | |||
| @classmethod | |||
| def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: | |||
| def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: | |||
| return f'/files/tools/{tool_file_id}{extension or ".bin"}' | |||
| @@ -1,71 +0,0 @@ | |||
| from typing import Any | |||
| from core.tools.entities.tool_entities import ToolParameter | |||
| class ToolParameterConverter: | |||
| @staticmethod | |||
| def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: | |||
| match parameter_type: | |||
| case ( | |||
| ToolParameter.ToolParameterType.STRING | |||
| | ToolParameter.ToolParameterType.SECRET_INPUT | |||
| | ToolParameter.ToolParameterType.SELECT | |||
| ): | |||
| return "string" | |||
| case ToolParameter.ToolParameterType.BOOLEAN: | |||
| return "boolean" | |||
| case ToolParameter.ToolParameterType.NUMBER: | |||
| return "number" | |||
| case _: | |||
| raise ValueError(f"Unsupported parameter type {parameter_type}") | |||
| @staticmethod | |||
| def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: | |||
| # convert tool parameter config to correct type | |||
| try: | |||
| match parameter_type: | |||
| case ( | |||
| ToolParameter.ToolParameterType.STRING | |||
| | ToolParameter.ToolParameterType.SECRET_INPUT | |||
| | ToolParameter.ToolParameterType.SELECT | |||
| ): | |||
| if value is None: | |||
| return "" | |||
| else: | |||
| return value if isinstance(value, str) else str(value) | |||
| case ToolParameter.ToolParameterType.BOOLEAN: | |||
| if value is None: | |||
| return False | |||
| elif isinstance(value, str): | |||
| # Allowed YAML boolean value strings: https://yaml.org/type/bool.html | |||
| # and also '0' for False and '1' for True | |||
| match value.lower(): | |||
| case "true" | "yes" | "y" | "1": | |||
| return True | |||
| case "false" | "no" | "n" | "0": | |||
| return False | |||
| case _: | |||
| return bool(value) | |||
| else: | |||
| return value if isinstance(value, bool) else bool(value) | |||
| case ToolParameter.ToolParameterType.NUMBER: | |||
| if isinstance(value, int) | isinstance(value, float): | |||
| return value | |||
| elif isinstance(value, str) and value != "": | |||
| if "." in value: | |||
| return float(value) | |||
| else: | |||
| return int(value) | |||
| case ToolParameter.ToolParameterType.FILE: | |||
| return value | |||
| case _: | |||
| return str(value) | |||
| except Exception: | |||
| raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") | |||
| @@ -1,19 +1,18 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from core.app.app_config.entities import VariableEntity | |||
| from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration | |||
| class WorkflowToolConfigurationUtils: | |||
| @classmethod | |||
| def check_parameter_configurations(cls, configurations: list[dict]): | |||
| """ | |||
| check parameter configurations | |||
| """ | |||
| def check_parameter_configurations(cls, configurations: Mapping[str, Any]): | |||
| for configuration in configurations: | |||
| if not WorkflowToolParameterConfiguration(**configuration): | |||
| raise ValueError("invalid parameter configuration") | |||
| WorkflowToolParameterConfiguration.model_validate(configuration) | |||
| @classmethod | |||
| def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: | |||
| def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: | |||
| """ | |||
| get workflow graph variables | |||
| """ | |||
| @@ -1,4 +1,5 @@ | |||
| import logging | |||
| from pathlib import Path | |||
| from typing import Any | |||
| import yaml | |||
| @@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any | |||
| :param default_value: the value returned when errors ignored | |||
| :return: an object of the YAML content | |||
| """ | |||
| try: | |||
| with open(file_path, encoding="utf-8") as yaml_file: | |||
| try: | |||
| yaml_content = yaml.safe_load(yaml_file) | |||
| return yaml_content or default_value | |||
| except Exception as e: | |||
| raise YAMLError(f"Failed to load YAML file {file_path}: {e}") | |||
| except Exception as e: | |||
| if not file_path or not Path(file_path).exists(): | |||
| if ignore_error: | |||
| return default_value | |||
| else: | |||
| raise e | |||
| raise FileNotFoundError(f"File not found: {file_path}") | |||
| with open(file_path, encoding="utf-8") as yaml_file: | |||
| try: | |||
| yaml_content = yaml.safe_load(yaml_file) | |||
| return yaml_content or default_value | |||
| except Exception as e: | |||
| if ignore_error: | |||
| return default_value | |||
| else: | |||
| raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e | |||
| @@ -1,7 +1,12 @@ | |||
| from .segment_group import SegmentGroup | |||
| from .segments import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| ArrayNumberSegment, | |||
| ArrayObjectSegment, | |||
| ArraySegment, | |||
| ArrayStringSegment, | |||
| FileSegment, | |||
| FloatSegment, | |||
| IntegerSegment, | |||
| NoneSegment, | |||
| @@ -15,6 +20,7 @@ from .variables import ( | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| FileVariable, | |||
| FloatVariable, | |||
| IntegerVariable, | |||
| NoneVariable, | |||
| @@ -46,4 +52,10 @@ __all__ = [ | |||
| "ArrayNumberVariable", | |||
| "ArrayObjectVariable", | |||
| "ArraySegment", | |||
| "ArrayFileSegment", | |||
| "ArrayNumberSegment", | |||
| "ArrayObjectSegment", | |||
| "ArrayStringSegment", | |||
| "FileSegment", | |||
| "FileVariable", | |||
| ] | |||
| @@ -5,6 +5,8 @@ from typing import Any | |||
| from pydantic import BaseModel, ConfigDict, field_validator | |||
| from core.file import File | |||
| from .types import SegmentType | |||
| @@ -39,6 +41,9 @@ class Segment(BaseModel): | |||
| @property | |||
| def size(self) -> int: | |||
| """ | |||
| Return the size of the value in bytes. | |||
| """ | |||
| return sys.getsizeof(self.value) | |||
| def to_object(self) -> Any: | |||
| @@ -99,13 +104,27 @@ class ArraySegment(Segment): | |||
| def markdown(self) -> str: | |||
| items = [] | |||
| for item in self.value: | |||
| if hasattr(item, "to_markdown"): | |||
| items.append(item.to_markdown()) | |||
| else: | |||
| items.append(str(item)) | |||
| items.append(str(item)) | |||
| return "\n".join(items) | |||
| class FileSegment(Segment): | |||
| value_type: SegmentType = SegmentType.FILE | |||
| value: File | |||
| @property | |||
| def markdown(self) -> str: | |||
| return self.value.markdown | |||
| @property | |||
| def log(self) -> str: | |||
| return str(self.value) | |||
| @property | |||
| def text(self) -> str: | |||
| return str(self.value) | |||
| class ArrayAnySegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_ANY | |||
| value: Sequence[Any] | |||
| @@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment): | |||
| class ArrayObjectSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_OBJECT | |||
| value: Sequence[Mapping[str, Any]] | |||
| class ArrayFileSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_FILE | |||
| value: Sequence[File] | |||
| @property | |||
| def markdown(self) -> str: | |||
| items = [] | |||
| for item in self.value: | |||
| items.append(item.markdown) | |||
| return "\n".join(items) | |||