| @@ -2,13 +2,6 @@ name: Run Pytest | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| paths: | |||
| - api/** | |||
| - docker/** | |||
| - .github/workflows/api-tests.yml | |||
| concurrency: | |||
| group: api-tests-${{ github.head_ref || github.run_id }} | |||
| @@ -1,10 +1,6 @@ | |||
| name: autofix.ci | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: [ "main" ] | |||
| push: | |||
| branches: [ "main" ] | |||
| permissions: | |||
| contents: read | |||
| @@ -2,13 +2,6 @@ name: DB Migration Test | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| - plugins/beta | |||
| paths: | |||
| - api/migrations/** | |||
| - .github/workflows/db-migration-test.yml | |||
| concurrency: | |||
| group: db-migration-test-${{ github.ref }} | |||
| @@ -34,6 +27,12 @@ jobs: | |||
| - name: Install dependencies | |||
| run: uv sync --project api | |||
| - name: Ensure Offline migration are supported | |||
| run: | | |||
| # upgrade | |||
| uv run --directory api flask db upgrade 'base:head' --sql | |||
| # downgrade | |||
| uv run --directory api flask db downgrade 'head:base' --sql | |||
| - name: Prepare middleware env | |||
| run: | | |||
| @@ -3,11 +3,14 @@ name: Main CI Pipeline | |||
| on: | |||
| pull_request: | |||
| branches: [ "main" ] | |||
| push: | |||
| branches: [ "main" ] | |||
| permissions: | |||
| contents: write | |||
| pull-requests: write | |||
| checks: write | |||
| statuses: write | |||
| concurrency: | |||
| group: main-ci-${{ github.head_ref || github.run_id }} | |||
| @@ -2,9 +2,6 @@ name: Style check | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| concurrency: | |||
| group: style-${{ github.head_ref || github.run_id }} | |||
| @@ -2,15 +2,6 @@ name: Run VDB Tests | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| paths: | |||
| - api/core/rag/datasource/** | |||
| - docker/** | |||
| - .github/workflows/vdb-tests.yml | |||
| - api/uv.lock | |||
| - api/pyproject.toml | |||
| concurrency: | |||
| group: vdb-tests-${{ github.head_ref || github.run_id }} | |||
| @@ -2,11 +2,6 @@ name: Web Tests | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| paths: | |||
| - web/** | |||
| concurrency: | |||
| group: web-tests-${{ github.head_ref || github.run_id }} | |||
| @@ -1,18 +1,27 @@ | |||
| from typing import Optional, Union | |||
| from flask import Response | |||
| from flask_restx import Resource, reqparse | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import Session | |||
| from controllers.console.app.mcp_server import AppMCPServerStatus | |||
| from controllers.mcp import mcp_ns | |||
| from core.app.app_config.entities import VariableEntity | |||
| from core.mcp import types | |||
| from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler | |||
| from core.mcp.types import ClientNotification, ClientRequest | |||
| from core.mcp.utils import create_mcp_error_response | |||
| from core.mcp import types as mcp_types | |||
| from core.mcp.server.streamable_http import handle_mcp_request | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.model import App, AppMCPServer, AppMode | |||
| from models.model import App, AppMCPServer, AppMode, EndUser | |||
| class MCPRequestError(Exception): | |||
| """Custom exception for MCP request processing errors""" | |||
| def __init__(self, error_code: int, message: str): | |||
| self.error_code = error_code | |||
| self.message = message | |||
| super().__init__(message) | |||
| def int_or_str(value): | |||
| @@ -63,77 +72,173 @@ class MCPAppApi(Resource): | |||
| Raises: | |||
| ValidationError: Invalid request format or parameters | |||
| """ | |||
| # Parse and validate all arguments | |||
| args = mcp_request_parser.parse_args() | |||
| request_id: Optional[Union[int, str]] = args.get("id") | |||
| mcp_request = self._parse_mcp_request(args) | |||
| server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() | |||
| if not server: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") | |||
| ) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get MCP server and app | |||
| mcp_server, app = self._get_mcp_server_and_app(server_code, session) | |||
| self._validate_server_status(mcp_server) | |||
| if server.status != AppMCPServerStatus.ACTIVE: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") | |||
| ) | |||
| # Get user input form | |||
| user_input_form = self._get_user_input_form(app) | |||
| app = db.session.query(App).where(App.id == server.app_id).first() | |||
| if not app: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") | |||
| ) | |||
| # Handle notification vs request differently | |||
| return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session) | |||
| if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| workflow = app.workflow | |||
| if workflow is None: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") | |||
| ) | |||
| def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]: | |||
| """Get and validate MCP server and app in one query session""" | |||
| mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() | |||
| if not mcp_server: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found") | |||
| user_input_form = workflow.user_input_form(to_old_structure=True) | |||
| app = session.query(App).where(App.id == mcp_server.app_id).first() | |||
| if not app: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found") | |||
| return mcp_server, app | |||
| def _validate_server_status(self, mcp_server: AppMCPServer) -> None: | |||
| """Validate MCP server status""" | |||
| if mcp_server.status != AppMCPServerStatus.ACTIVE: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") | |||
| def _process_mcp_message( | |||
| self, | |||
| mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification, | |||
| request_id: Optional[Union[int, str]], | |||
| app: App, | |||
| mcp_server: AppMCPServer, | |||
| user_input_form: list[VariableEntity], | |||
| session: Session, | |||
| ) -> Response: | |||
| """Process MCP message (notification or request)""" | |||
| if isinstance(mcp_request, mcp_types.ClientNotification): | |||
| return self._handle_notification(mcp_request) | |||
| else: | |||
| app_model_config = app.app_model_config | |||
| if app_model_config is None: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable") | |||
| ) | |||
| features_dict = app_model_config.to_dict() | |||
| user_input_form = features_dict.get("user_input_form", []) | |||
| converted_user_input_form: list[VariableEntity] = [] | |||
| try: | |||
| for item in user_input_form: | |||
| variable_type = item.get("type", "") or list(item.keys())[0] | |||
| variable = item[variable_type] | |||
| converted_user_input_form.append( | |||
| VariableEntity( | |||
| type=variable_type, | |||
| variable=variable.get("variable"), | |||
| description=variable.get("description") or "", | |||
| label=variable.get("label"), | |||
| required=variable.get("required", False), | |||
| max_length=variable.get("max_length"), | |||
| options=variable.get("options") or [], | |||
| ) | |||
| ) | |||
| except ValidationError as e: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") | |||
| ) | |||
| return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session) | |||
| def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response: | |||
| """Handle MCP notification""" | |||
| # For notifications, only support init notification | |||
| if mcp_request.root.method != "notifications/initialized": | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method") | |||
| # Return HTTP 202 Accepted for notifications (no response body) | |||
| return Response("", status=202, content_type="application/json") | |||
| def _handle_request( | |||
| self, | |||
| mcp_request: mcp_types.ClientRequest, | |||
| request_id: Optional[Union[int, str]], | |||
| app: App, | |||
| mcp_server: AppMCPServer, | |||
| user_input_form: list[VariableEntity], | |||
| session: Session, | |||
| ) -> Response: | |||
| """Handle MCP request""" | |||
| if request_id is None: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required") | |||
| result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id) | |||
| if result is None: | |||
| # This shouldn't happen for requests, but handle gracefully | |||
| raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request") | |||
| return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True)) | |||
| def _get_user_input_form(self, app: App) -> list[VariableEntity]: | |||
| """Get and convert user input form""" | |||
| # Get raw user input form based on app mode | |||
| if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| if not app.workflow: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") | |||
| raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) | |||
| else: | |||
| if not app.app_model_config: | |||
| raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") | |||
| features_dict = app.app_model_config.to_dict() | |||
| raw_user_input_form = features_dict.get("user_input_form", []) | |||
| # Convert to VariableEntity objects | |||
| try: | |||
| request: ClientRequest | ClientNotification = ClientRequest.model_validate(args) | |||
| return self._convert_user_input_form(raw_user_input_form) | |||
| except ValidationError as e: | |||
| raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") | |||
| def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: | |||
| """Convert raw user input form to VariableEntity objects""" | |||
| return [self._create_variable_entity(item) for item in raw_form] | |||
| def _create_variable_entity(self, item: dict) -> VariableEntity: | |||
| """Create a single VariableEntity from raw form item""" | |||
| variable_type = item.get("type", "") or list(item.keys())[0] | |||
| variable = item[variable_type] | |||
| return VariableEntity( | |||
| type=variable_type, | |||
| variable=variable.get("variable"), | |||
| description=variable.get("description") or "", | |||
| label=variable.get("label"), | |||
| required=variable.get("required", False), | |||
| max_length=variable.get("max_length"), | |||
| options=variable.get("options") or [], | |||
| ) | |||
| def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: | |||
| """Parse and validate MCP request""" | |||
| try: | |||
| return mcp_types.ClientRequest.model_validate(args) | |||
| except ValidationError: | |||
| try: | |||
| notification = ClientNotification.model_validate(args) | |||
| request = notification | |||
| return mcp_types.ClientNotification.model_validate(args) | |||
| except ValidationError as e: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") | |||
| ) | |||
| mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) | |||
| response = mcp_server_handler.handle() | |||
| return helper.compact_generate_response(response) | |||
| raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}") | |||
| def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None: | |||
| """Get end user from existing session - optimized query""" | |||
| return ( | |||
| session.query(EndUser) | |||
| .where(EndUser.tenant_id == tenant_id) | |||
| .where(EndUser.session_id == mcp_server_id) | |||
| .where(EndUser.type == "mcp") | |||
| .first() | |||
| ) | |||
| def _create_end_user( | |||
| self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session | |||
| ) -> EndUser: | |||
| """Create end user in existing session""" | |||
| end_user = EndUser( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| type="mcp", | |||
| name=client_name, | |||
| session_id=mcp_server_id, | |||
| ) | |||
| session.add(end_user) | |||
| session.flush() # Use flush instead of commit to keep transaction open | |||
| session.refresh(end_user) | |||
| return end_user | |||
| def _handle_mcp_request( | |||
| self, | |||
| app: App, | |||
| mcp_server: AppMCPServer, | |||
| mcp_request: mcp_types.ClientRequest, | |||
| user_input_form: list[VariableEntity], | |||
| session: Session, | |||
| request_id: Union[int, str], | |||
| ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None: | |||
| """Handle MCP request and return response""" | |||
| end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session) | |||
| if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest): | |||
| client_info = mcp_request.root.params.clientInfo | |||
| client_name = f"{client_info.name}@{client_info.version}" | |||
| # Commit the session before creating end user to avoid transaction conflicts | |||
| session.commit() | |||
| with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin(): | |||
| end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session) | |||
| return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id) | |||
| @@ -4,224 +4,259 @@ from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from configs import dify_config | |||
| from controllers.web.passport import generate_session_id | |||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | |||
| from core.mcp import types | |||
| from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND | |||
| from core.mcp.utils import create_mcp_error_response | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from extensions.ext_database import db | |||
| from core.mcp import types as mcp_types | |||
| from models.model import App, AppMCPServer, AppMode, EndUser | |||
| from services.app_generate_service import AppGenerateService | |||
| logger = logging.getLogger(__name__) | |||
| class MCPServerStreamableHTTPRequestHandler: | |||
| def handle_mcp_request( | |||
| app: App, | |||
| request: mcp_types.ClientRequest, | |||
| user_input_form: list[VariableEntity], | |||
| mcp_server: AppMCPServer, | |||
| end_user: EndUser | None = None, | |||
| request_id: int | str = 1, | |||
| ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError: | |||
| """ | |||
| Apply to MCP HTTP streamable server with stateless http | |||
| Handle MCP request and return JSON-RPC response | |||
| Args: | |||
| app: The Dify app instance | |||
| request: The JSON-RPC request message | |||
| user_input_form: List of variable entities for the app | |||
| mcp_server: The MCP server configuration | |||
| end_user: Optional end user | |||
| request_id: The request ID | |||
| Returns: | |||
| JSON-RPC response or error | |||
| """ | |||
| def __init__( | |||
| self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] | |||
| ): | |||
| self.app = app | |||
| self.request = request | |||
| mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() | |||
| if not mcp_server: | |||
| raise ValueError("MCP server not found") | |||
| self.mcp_server: AppMCPServer = mcp_server | |||
| self.end_user = self.retrieve_end_user() | |||
| self.user_input_form = user_input_form | |||
| @property | |||
| def request_type(self): | |||
| return type(self.request.root) | |||
| @property | |||
| def parameter_schema(self): | |||
| parameters, required = self._convert_input_form_to_parameters(self.user_input_form) | |||
| if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: | |||
| return { | |||
| "type": "object", | |||
| "properties": parameters, | |||
| "required": required, | |||
| } | |||
| return { | |||
| "type": "object", | |||
| "properties": { | |||
| "query": {"type": "string", "description": "User Input/Question content"}, | |||
| **parameters, | |||
| }, | |||
| "required": ["query", *required], | |||
| } | |||
| request_type = type(request.root) | |||
| @property | |||
| def capabilities(self): | |||
| return types.ServerCapabilities( | |||
| tools=types.ToolsCapability(listChanged=False), | |||
| def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: | |||
| """Create success response with business result data""" | |||
| return mcp_types.JSONRPCResponse( | |||
| jsonrpc="2.0", | |||
| id=request_id, | |||
| result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True), | |||
| ) | |||
| def response(self, response: types.Result | str): | |||
| if isinstance(response, str): | |||
| sse_content = f"event: ping\ndata: {response}\n\n".encode() | |||
| yield sse_content | |||
| return | |||
| json_response = types.JSONRPCResponse( | |||
| def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError: | |||
| """Create error response with error code and message""" | |||
| from core.mcp.types import ErrorData | |||
| error_data = ErrorData(code=code, message=message) | |||
| return mcp_types.JSONRPCError( | |||
| jsonrpc="2.0", | |||
| id=(self.request.root.model_extra or {}).get("id", 1), | |||
| result=response.model_dump(by_alias=True, mode="json", exclude_none=True), | |||
| id=request_id, | |||
| error=error_data, | |||
| ) | |||
| json_data = json.dumps(jsonable_encoder(json_response)) | |||
| sse_content = f"event: message\ndata: {json_data}\n\n".encode() | |||
| # Request handler mapping using functional approach | |||
| request_handlers = { | |||
| mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), | |||
| mcp_types.ListToolsRequest: lambda: handle_list_tools( | |||
| app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict | |||
| ), | |||
| mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), | |||
| mcp_types.PingRequest: lambda: handle_ping(), | |||
| } | |||
| yield sse_content | |||
| try: | |||
| # Dispatch request to appropriate handler | |||
| handler = request_handlers.get(request_type) | |||
| if handler: | |||
| return create_success_response(handler()) | |||
| else: | |||
| return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") | |||
| def error_response(self, code: int, message: str, data=None): | |||
| request_id = (self.request.root.model_extra or {}).get("id", 1) or 1 | |||
| return create_mcp_error_response(request_id, code, message, data) | |||
| except ValueError as e: | |||
| logger.exception("Invalid params") | |||
| return create_error_response(mcp_types.INVALID_PARAMS, str(e)) | |||
| except Exception as e: | |||
| logger.exception("Internal server error") | |||
| return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e)) | |||
| def handle(self): | |||
| handle_map = { | |||
| types.InitializeRequest: self.initialize, | |||
| types.ListToolsRequest: self.list_tools, | |||
| types.CallToolRequest: self.invoke_tool, | |||
| types.InitializedNotification: self.handle_notification, | |||
| types.PingRequest: self.handle_ping, | |||
| } | |||
| try: | |||
| if self.request_type in handle_map: | |||
| return self.response(handle_map[self.request_type]()) | |||
| else: | |||
| return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}") | |||
| except ValueError as e: | |||
| logger.exception("Invalid params") | |||
| return self.error_response(INVALID_PARAMS, str(e)) | |||
| except Exception as e: | |||
| logger.exception("Internal server error") | |||
| return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") | |||
| def handle_notification(self): | |||
| return "ping" | |||
| def handle_ping(self): | |||
| return types.EmptyResult() | |||
| def initialize(self): | |||
| request = cast(types.InitializeRequest, self.request.root) | |||
| client_info = request.params.clientInfo | |||
| client_name = f"{client_info.name}@{client_info.version}" | |||
| if not self.end_user: | |||
| end_user = EndUser( | |||
| tenant_id=self.app.tenant_id, | |||
| app_id=self.app.id, | |||
| type="mcp", | |||
| name=client_name, | |||
| session_id=generate_session_id(), | |||
| external_user_id=self.mcp_server.id, | |||
| def handle_ping() -> mcp_types.EmptyResult: | |||
| """Handle ping request""" | |||
| return mcp_types.EmptyResult() | |||
| def handle_initialize(description: str) -> mcp_types.InitializeResult: | |||
| """Handle initialize request""" | |||
| capabilities = mcp_types.ServerCapabilities( | |||
| tools=mcp_types.ToolsCapability(listChanged=False), | |||
| ) | |||
| return mcp_types.InitializeResult( | |||
| protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION, | |||
| capabilities=capabilities, | |||
| serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version), | |||
| instructions=description, | |||
| ) | |||
| def handle_list_tools( | |||
| app_name: str, | |||
| app_mode: str, | |||
| user_input_form: list[VariableEntity], | |||
| description: str, | |||
| parameters_dict: dict[str, str], | |||
| ) -> mcp_types.ListToolsResult: | |||
| """Handle list tools request""" | |||
| parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) | |||
| return mcp_types.ListToolsResult( | |||
| tools=[ | |||
| mcp_types.Tool( | |||
| name=app_name, | |||
| description=description, | |||
| inputSchema=parameter_schema, | |||
| ) | |||
| db.session.add(end_user) | |||
| db.session.commit() | |||
| return types.InitializeResult( | |||
| protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION, | |||
| capabilities=self.capabilities, | |||
| serverInfo=types.Implementation(name="Dify", version=dify_config.project.version), | |||
| instructions=self.mcp_server.description, | |||
| ) | |||
| ], | |||
| ) | |||
| def list_tools(self): | |||
| if not self.end_user: | |||
| raise ValueError("User not found") | |||
| return types.ListToolsResult( | |||
| tools=[ | |||
| types.Tool( | |||
| name=self.app.name, | |||
| description=self.mcp_server.description, | |||
| inputSchema=self.parameter_schema, | |||
| ) | |||
| ], | |||
| ) | |||
| def invoke_tool(self): | |||
| if not self.end_user: | |||
| raise ValueError("User not found") | |||
| request = cast(types.CallToolRequest, self.request.root) | |||
| args = request.params.arguments or {} | |||
| if self.app.mode in {AppMode.WORKFLOW.value}: | |||
| args = {"inputs": args} | |||
| elif self.app.mode in {AppMode.COMPLETION.value}: | |||
| args = {"query": "", "inputs": args} | |||
| else: | |||
| args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}} | |||
| response = AppGenerateService.generate( | |||
| self.app, | |||
| self.end_user, | |||
| args, | |||
| InvokeFrom.SERVICE_API, | |||
| streaming=self.app.mode == AppMode.AGENT_CHAT.value, | |||
| ) | |||
| answer = "" | |||
| if isinstance(response, RateLimitGenerator): | |||
| for item in response.generator: | |||
| data = item | |||
| if isinstance(data, str) and data.startswith("data: "): | |||
| try: | |||
| json_str = data[6:].strip() | |||
| parsed_data = json.loads(json_str) | |||
| if parsed_data.get("event") == "agent_thought": | |||
| answer += parsed_data.get("thought", "") | |||
| except json.JSONDecodeError: | |||
| continue | |||
| if isinstance(response, Mapping): | |||
| if self.app.mode in { | |||
| AppMode.ADVANCED_CHAT.value, | |||
| AppMode.COMPLETION.value, | |||
| AppMode.CHAT.value, | |||
| AppMode.AGENT_CHAT.value, | |||
| }: | |||
| answer = response["answer"] | |||
| elif self.app.mode in {AppMode.WORKFLOW.value}: | |||
| answer = json.dumps(response["data"]["outputs"], ensure_ascii=False) | |||
| else: | |||
| raise ValueError("Invalid app mode") | |||
| # Not support image yet | |||
| return types.CallToolResult(content=[types.TextContent(text=answer, type="text")]) | |||
| def retrieve_end_user(self): | |||
| return ( | |||
| db.session.query(EndUser) | |||
| .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") | |||
| .first() | |||
| ) | |||
| def handle_call_tool( | |||
| app: App, | |||
| request: mcp_types.ClientRequest, | |||
| user_input_form: list[VariableEntity], | |||
| end_user: EndUser | None, | |||
| ) -> mcp_types.CallToolResult: | |||
| """Handle call tool request""" | |||
| request_obj = cast(mcp_types.CallToolRequest, request.root) | |||
| args = prepare_tool_arguments(app, request_obj.params.arguments or {}) | |||
| def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): | |||
| parameters: dict[str, dict[str, Any]] = {} | |||
| required = [] | |||
| for item in user_input_form: | |||
| parameters[item.variable] = {} | |||
| if item.type in ( | |||
| VariableEntityType.FILE, | |||
| VariableEntityType.FILE_LIST, | |||
| VariableEntityType.EXTERNAL_DATA_TOOL, | |||
| ): | |||
| continue | |||
| if item.required: | |||
| required.append(item.variable) | |||
| # if the workflow republished, the parameters not changed | |||
| # we should not raise error here | |||
| if not end_user: | |||
| raise ValueError("End user not found") | |||
| response = AppGenerateService.generate( | |||
| app, | |||
| end_user, | |||
| args, | |||
| InvokeFrom.SERVICE_API, | |||
| streaming=app.mode == AppMode.AGENT_CHAT.value, | |||
| ) | |||
| answer = extract_answer_from_response(app, response) | |||
| return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")]) | |||
| def build_parameter_schema( | |||
| app_mode: str, | |||
| user_input_form: list[VariableEntity], | |||
| parameters_dict: dict[str, str], | |||
| ) -> dict[str, Any]: | |||
| """Build parameter schema for the tool""" | |||
| parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) | |||
| if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: | |||
| return { | |||
| "type": "object", | |||
| "properties": parameters, | |||
| "required": required, | |||
| } | |||
| return { | |||
| "type": "object", | |||
| "properties": { | |||
| "query": {"type": "string", "description": "User Input/Question content"}, | |||
| **parameters, | |||
| }, | |||
| "required": ["query", *required], | |||
| } | |||
| def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: | |||
| """Prepare arguments based on app mode""" | |||
| if app.mode == AppMode.WORKFLOW.value: | |||
| return {"inputs": arguments} | |||
| elif app.mode == AppMode.COMPLETION.value: | |||
| return {"query": "", "inputs": arguments} | |||
| else: | |||
| # Chat modes - create a copy to avoid modifying original dict | |||
| args_copy = arguments.copy() | |||
| query = args_copy.pop("query", "") | |||
| return {"query": query, "inputs": args_copy} | |||
| def extract_answer_from_response(app: App, response: Any) -> str: | |||
| """Extract answer from app generate response""" | |||
| answer = "" | |||
| if isinstance(response, RateLimitGenerator): | |||
| answer = process_streaming_response(response) | |||
| elif isinstance(response, Mapping): | |||
| answer = process_mapping_response(app, response) | |||
| else: | |||
| logger.warning("Unexpected response type: %s", type(response)) | |||
| return answer | |||
| def process_streaming_response(response: RateLimitGenerator) -> str: | |||
| """Process streaming response for agent chat mode""" | |||
| answer = "" | |||
| for item in response.generator: | |||
| if isinstance(item, str) and item.startswith("data: "): | |||
| try: | |||
| description = self.mcp_server.parameters_dict[item.variable] | |||
| except KeyError: | |||
| description = "" | |||
| parameters[item.variable]["description"] = description | |||
| if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): | |||
| parameters[item.variable]["type"] = "string" | |||
| elif item.type == VariableEntityType.SELECT: | |||
| parameters[item.variable]["type"] = "string" | |||
| parameters[item.variable]["enum"] = item.options | |||
| elif item.type == VariableEntityType.NUMBER: | |||
| parameters[item.variable]["type"] = "float" | |||
| return parameters, required | |||
| json_str = item[6:].strip() | |||
| parsed_data = json.loads(json_str) | |||
| if parsed_data.get("event") == "agent_thought": | |||
| answer += parsed_data.get("thought", "") | |||
| except json.JSONDecodeError: | |||
| continue | |||
| return answer | |||
| def process_mapping_response(app: App, response: Mapping) -> str: | |||
| """Process mapping response based on app mode""" | |||
| if app.mode in { | |||
| AppMode.ADVANCED_CHAT.value, | |||
| AppMode.COMPLETION.value, | |||
| AppMode.CHAT.value, | |||
| AppMode.AGENT_CHAT.value, | |||
| }: | |||
| return response.get("answer", "") | |||
| elif app.mode == AppMode.WORKFLOW.value: | |||
| return json.dumps(response["data"]["outputs"], ensure_ascii=False) | |||
| else: | |||
| raise ValueError("Invalid app mode: " + str(app.mode)) | |||
| def convert_input_form_to_parameters( | |||
| user_input_form: list[VariableEntity], | |||
| parameters_dict: dict[str, str], | |||
| ) -> tuple[dict[str, dict[str, Any]], list[str]]: | |||
| """Convert user input form to parameter schema""" | |||
| parameters: dict[str, dict[str, Any]] = {} | |||
| required = [] | |||
| for item in user_input_form: | |||
| if item.type in ( | |||
| VariableEntityType.FILE, | |||
| VariableEntityType.FILE_LIST, | |||
| VariableEntityType.EXTERNAL_DATA_TOOL, | |||
| ): | |||
| continue | |||
| parameters[item.variable] = {} | |||
| if item.required: | |||
| required.append(item.variable) | |||
| # if the workflow republished, the parameters not changed | |||
| # we should not raise error here | |||
| description = parameters_dict.get(item.variable, "") | |||
| parameters[item.variable]["description"] = description | |||
| if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): | |||
| parameters[item.variable]["type"] = "string" | |||
| elif item.type == VariableEntityType.SELECT: | |||
| parameters[item.variable]["type"] = "string" | |||
| parameters[item.variable]["enum"] = item.options | |||
| elif item.type == VariableEntityType.NUMBER: | |||
| parameters[item.variable]["type"] = "float" | |||
| return parameters, required | |||
| @@ -138,5 +138,5 @@ def create_mcp_error_response( | |||
| error=error_data, | |||
| ) | |||
| json_data = json.dumps(jsonable_encoder(json_response)) | |||
| sse_content = f"event: message\ndata: {json_data}\n\n".encode() | |||
| sse_content = json_data.encode() | |||
| yield sse_content | |||
| @@ -5,7 +5,7 @@ Revises: 8bcc02c9bd07 | |||
| Create Date: 2025-08-09 15:53:54.341341 | |||
| """ | |||
| from alembic import op | |||
| from alembic import op, context | |||
| from libs.uuid_utils import uuidv7 | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| @@ -43,7 +43,15 @@ def upgrade(): | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||
| migrate_existing_providers_data() | |||
| if not context.is_offline_mode(): | |||
| migrate_existing_providers_data() | |||
| else: | |||
| op.execute( | |||
| '-- [IMPORTANT] Data migration skipped!!!\n' | |||
| "-- You should manually run data migration function `migrate_existing_providers_data`\n" | |||
| f"-- inside file {__file__}\n" | |||
| "-- Please review the migration script carefully!" | |||
| ) | |||
| # Remove encrypted_config column from providers table after migration | |||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||
| @@ -119,7 +127,16 @@ def downgrade(): | |||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||
| # Migrate data back from provider_credentials to providers | |||
| migrate_data_back_to_providers() | |||
| if not context.is_offline_mode(): | |||
| migrate_data_back_to_providers() | |||
| else: | |||
| op.execute( | |||
| '-- [IMPORTANT] Data migration skipped!!!\n' | |||
| "-- You should manually run data migration function `migrate_data_back_to_providers`\n" | |||
| f"-- inside file {__file__}\n" | |||
| "-- Please review the migration script carefully!" | |||
| ) | |||
| # Remove credential_id columns | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| @@ -6,7 +6,7 @@ Create Date: 2025-08-13 16:05:42.657730 | |||
| """ | |||
| from alembic import op | |||
| from alembic import op, context | |||
| from libs.uuid_utils import uuidv7 | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| @@ -48,8 +48,16 @@ def upgrade(): | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) | |||
| # Migrate existing provider_models data | |||
| migrate_existing_provider_models_data() | |||
| if not context.is_offline_mode(): | |||
| # Migrate existing provider_models data | |||
| migrate_existing_provider_models_data() | |||
| else: | |||
| op.execute( | |||
| '-- [IMPORTANT] Data migration skipped!!!\n' | |||
| "-- You should manually run data migration function `migrate_existing_provider_models_data`\n" | |||
| f"-- inside file {__file__}\n" | |||
| "-- Please review the migration script carefully!" | |||
| ) | |||
| # Remove encrypted_config column from provider_models table after migration | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| @@ -132,8 +140,16 @@ def downgrade(): | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||
| # Migrate data back from provider_model_credentials to provider_models | |||
| migrate_data_back_to_provider_models() | |||
| if not context.is_offline_mode(): | |||
| # Migrate data back from provider_model_credentials to provider_models | |||
| migrate_data_back_to_provider_models() | |||
| else: | |||
| op.execute( | |||
| '-- [IMPORTANT] Data migration skipped!!!\n' | |||
| "-- You should manually run data migration function `migrate_data_back_to_provider_models`\n" | |||
| f"-- inside file {__file__}\n" | |||
| "-- Please review the migration script carefully!" | |||
| ) | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.drop_column('credential_id') | |||
| @@ -1,12 +1,12 @@ | |||
| import enum | |||
| import json | |||
| from datetime import datetime | |||
| from typing import Optional, cast | |||
| from typing import Optional | |||
| import sqlalchemy as sa | |||
| from flask_login import UserMixin | |||
| from sqlalchemy import DateTime, String, func, select | |||
| from sqlalchemy.orm import Mapped, mapped_column, reconstructor | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | |||
| from models.base import Base | |||
| @@ -118,10 +118,24 @@ class Account(UserMixin, Base): | |||
| @current_tenant.setter | |||
| def current_tenant(self, tenant: "Tenant"): | |||
| ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1)) | |||
| if ta: | |||
| self.role = TenantAccountRole(ta.role) | |||
| self._current_tenant = tenant | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| tenant_join_query = select(TenantAccountJoin).where( | |||
| TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id | |||
| ) | |||
| tenant_join = session.scalar(tenant_join_query) | |||
| tenant_query = select(Tenant).where(Tenant.id == tenant.id) | |||
| # TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing | |||
| # access to it after the session has been closed. | |||
| # This prevents `DetachedInstanceError` when accessing the tenant outside | |||
| # the session's lifecycle. | |||
| # (The `tenant` argument is typically loaded by `db.session` without the | |||
| # `expire_on_commit=False` flag, meaning its lifetime is tied to the web | |||
| # request's lifecycle.) | |||
| tenant_reloaded = session.scalars(tenant_query).one() | |||
| if tenant_join: | |||
| self.role = TenantAccountRole(tenant_join.role) | |||
| self._current_tenant = tenant_reloaded | |||
| return | |||
| self._current_tenant = None | |||
| @@ -130,23 +144,19 @@ class Account(UserMixin, Base): | |||
| return self._current_tenant.id if self._current_tenant else None | |||
| def set_tenant_id(self, tenant_id: str): | |||
| tenant_account_join = cast( | |||
| tuple[Tenant, TenantAccountJoin], | |||
| ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .where(Tenant.id == tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.account_id == self.id) | |||
| .one_or_none() | |||
| ), | |||
| query = ( | |||
| select(Tenant, TenantAccountJoin) | |||
| .where(Tenant.id == tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.account_id == self.id) | |||
| ) | |||
| if not tenant_account_join: | |||
| return | |||
| tenant, join = tenant_account_join | |||
| self.role = TenantAccountRole(join.role) | |||
| self._current_tenant = tenant | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| tenant_account_join = session.execute(query).first() | |||
| if not tenant_account_join: | |||
| return | |||
| tenant, join = tenant_account_join | |||
| self.role = TenantAccountRole(join.role) | |||
| self._current_tenant = tenant | |||
| @property | |||
| def current_role(self): | |||
| @@ -0,0 +1 @@ | |||
| # MCP server tests | |||
| @@ -0,0 +1,449 @@ | |||
| import json | |||
| from unittest.mock import Mock, patch | |||
| import pytest | |||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||
| from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | |||
| from core.mcp import types | |||
| from core.mcp.server.streamable_http import ( | |||
| build_parameter_schema, | |||
| convert_input_form_to_parameters, | |||
| extract_answer_from_response, | |||
| handle_call_tool, | |||
| handle_initialize, | |||
| handle_list_tools, | |||
| handle_mcp_request, | |||
| handle_ping, | |||
| prepare_tool_arguments, | |||
| process_mapping_response, | |||
| ) | |||
| from models.model import App, AppMCPServer, AppMode, EndUser | |||
| class TestHandleMCPRequest: | |||
| """Test handle_mcp_request function""" | |||
| def setup_method(self): | |||
| """Setup test fixtures""" | |||
| self.app = Mock(spec=App) | |||
| self.app.name = "test_app" | |||
| self.app.mode = AppMode.CHAT.value | |||
| self.mcp_server = Mock(spec=AppMCPServer) | |||
| self.mcp_server.description = "Test server" | |||
| self.mcp_server.parameters_dict = {} | |||
| self.end_user = Mock(spec=EndUser) | |||
| self.user_input_form = [] | |||
| # Create mock request | |||
| self.mock_request = Mock() | |||
| self.mock_request.root = Mock() | |||
| self.mock_request.root.id = 123 | |||
| def test_handle_ping_request(self): | |||
| """Test handling ping request""" | |||
| # Setup ping request | |||
| self.mock_request.root = Mock(spec=types.PingRequest) | |||
| self.mock_request.root.id = 123 | |||
| request_type = Mock(return_value=types.PingRequest) | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCResponse) | |||
| assert result.jsonrpc == "2.0" | |||
| assert result.id == 123 | |||
| def test_handle_initialize_request(self): | |||
| """Test handling initialize request""" | |||
| # Setup initialize request | |||
| self.mock_request.root = Mock(spec=types.InitializeRequest) | |||
| self.mock_request.root.id = 123 | |||
| request_type = Mock(return_value=types.InitializeRequest) | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCResponse) | |||
| assert result.jsonrpc == "2.0" | |||
| assert result.id == 123 | |||
| def test_handle_list_tools_request(self): | |||
| """Test handling list tools request""" | |||
| # Setup list tools request | |||
| self.mock_request.root = Mock(spec=types.ListToolsRequest) | |||
| self.mock_request.root.id = 123 | |||
| request_type = Mock(return_value=types.ListToolsRequest) | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCResponse) | |||
| assert result.jsonrpc == "2.0" | |||
| assert result.id == 123 | |||
| @patch("core.mcp.server.streamable_http.AppGenerateService") | |||
| def test_handle_call_tool_request(self, mock_app_generate): | |||
| """Test handling call tool request""" | |||
| # Setup call tool request | |||
| mock_call_request = Mock(spec=types.CallToolRequest) | |||
| mock_call_request.params = Mock() | |||
| mock_call_request.params.arguments = {"query": "test question"} | |||
| mock_call_request.id = 123 | |||
| self.mock_request.root = mock_call_request | |||
| request_type = Mock(return_value=types.CallToolRequest) | |||
| # Mock app generate service response | |||
| mock_response = {"answer": "test answer"} | |||
| mock_app_generate.generate.return_value = mock_response | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCResponse) | |||
| assert result.jsonrpc == "2.0" | |||
| assert result.id == 123 | |||
| # Verify AppGenerateService was called | |||
| mock_app_generate.generate.assert_called_once() | |||
| def test_handle_unknown_request_type(self): | |||
| """Test handling unknown request type""" | |||
| # Setup unknown request | |||
| class UnknownRequest: | |||
| pass | |||
| self.mock_request.root = Mock(spec=UnknownRequest) | |||
| self.mock_request.root.id = 123 | |||
| request_type = Mock(return_value=UnknownRequest) | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCError) | |||
| assert result.jsonrpc == "2.0" | |||
| assert result.id == 123 | |||
| assert result.error.code == types.METHOD_NOT_FOUND | |||
| def test_handle_value_error(self): | |||
| """Test handling ValueError""" | |||
| # Setup request that will cause ValueError | |||
| self.mock_request.root = Mock(spec=types.CallToolRequest) | |||
| self.mock_request.root.params = Mock() | |||
| self.mock_request.root.params.arguments = {} | |||
| request_type = Mock(return_value=types.CallToolRequest) | |||
| # Don't provide end_user to cause ValueError | |||
| with patch("core.mcp.server.streamable_http.type", request_type): | |||
| result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123) | |||
| assert isinstance(result, types.JSONRPCError) | |||
| assert result.error.code == types.INVALID_PARAMS | |||
| def test_handle_generic_exception(self): | |||
| """Test handling generic exception""" | |||
| # Setup request that will cause generic exception | |||
| self.mock_request.root = Mock(spec=types.PingRequest) | |||
| self.mock_request.root.id = 123 | |||
| # Patch handle_ping to raise exception instead of type | |||
| with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")): | |||
| with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest): | |||
| result = handle_mcp_request( | |||
| self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123 | |||
| ) | |||
| assert isinstance(result, types.JSONRPCError) | |||
| assert result.error.code == types.INTERNAL_ERROR | |||
| class TestIndividualHandlers: | |||
| """Test individual handler functions""" | |||
| def test_handle_ping(self): | |||
| """Test ping handler""" | |||
| result = handle_ping() | |||
| assert isinstance(result, types.EmptyResult) | |||
| def test_handle_initialize(self): | |||
| """Test initialize handler""" | |||
| description = "Test server" | |||
| with patch("core.mcp.server.streamable_http.dify_config") as mock_config: | |||
| mock_config.project.version = "1.0.0" | |||
| result = handle_initialize(description) | |||
| assert isinstance(result, types.InitializeResult) | |||
| assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION | |||
| assert result.instructions == "Test server" | |||
| def test_handle_list_tools(self): | |||
| """Test list tools handler""" | |||
| app_name = "test_app" | |||
| app_mode = AppMode.CHAT.value | |||
| description = "Test server" | |||
| parameters_dict: dict[str, str] = {} | |||
| user_input_form: list[VariableEntity] = [] | |||
| result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict) | |||
| assert isinstance(result, types.ListToolsResult) | |||
| assert len(result.tools) == 1 | |||
| assert result.tools[0].name == "test_app" | |||
| assert result.tools[0].description == "Test server" | |||
| @patch("core.mcp.server.streamable_http.AppGenerateService") | |||
| def test_handle_call_tool(self, mock_app_generate): | |||
| """Test call tool handler""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.CHAT.value | |||
| # Create mock request | |||
| mock_request = Mock() | |||
| mock_call_request = Mock(spec=types.CallToolRequest) | |||
| mock_call_request.params = Mock() | |||
| mock_call_request.params.arguments = {"query": "test question"} | |||
| mock_request.root = mock_call_request | |||
| user_input_form: list[VariableEntity] = [] | |||
| end_user = Mock(spec=EndUser) | |||
| # Mock app generate service response | |||
| mock_response = {"answer": "test answer"} | |||
| mock_app_generate.generate.return_value = mock_response | |||
| result = handle_call_tool(app, mock_request, user_input_form, end_user) | |||
| assert isinstance(result, types.CallToolResult) | |||
| assert len(result.content) == 1 | |||
| # Type assertion needed due to union type | |||
| text_content = result.content[0] | |||
| assert hasattr(text_content, "text") | |||
| assert text_content.text == "test answer" # type: ignore[attr-defined] | |||
| def test_handle_call_tool_no_end_user(self): | |||
| """Test call tool handler without end user""" | |||
| app = Mock(spec=App) | |||
| mock_request = Mock() | |||
| user_input_form: list[VariableEntity] = [] | |||
| with pytest.raises(ValueError, match="End user not found"): | |||
| handle_call_tool(app, mock_request, user_input_form, None) | |||
| class TestUtilityFunctions: | |||
| """Test utility functions""" | |||
| def test_build_parameter_schema_chat_mode(self): | |||
| """Test building parameter schema for chat mode""" | |||
| app_mode = AppMode.CHAT.value | |||
| parameters_dict: dict[str, str] = {"name": "Enter your name"} | |||
| user_input_form = [ | |||
| VariableEntity( | |||
| type=VariableEntityType.TEXT_INPUT, | |||
| variable="name", | |||
| description="User name", | |||
| label="Name", | |||
| required=True, | |||
| ) | |||
| ] | |||
| schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) | |||
| assert schema["type"] == "object" | |||
| assert "query" in schema["properties"] | |||
| assert "name" in schema["properties"] | |||
| assert "query" in schema["required"] | |||
| assert "name" in schema["required"] | |||
| def test_build_parameter_schema_workflow_mode(self): | |||
| """Test building parameter schema for workflow mode""" | |||
| app_mode = AppMode.WORKFLOW.value | |||
| parameters_dict: dict[str, str] = {"input_text": "Enter text"} | |||
| user_input_form = [ | |||
| VariableEntity( | |||
| type=VariableEntityType.TEXT_INPUT, | |||
| variable="input_text", | |||
| description="Input text", | |||
| label="Input", | |||
| required=True, | |||
| ) | |||
| ] | |||
| schema = build_parameter_schema(app_mode, user_input_form, parameters_dict) | |||
| assert schema["type"] == "object" | |||
| assert "query" not in schema["properties"] | |||
| assert "input_text" in schema["properties"] | |||
| assert "input_text" in schema["required"] | |||
| def test_prepare_tool_arguments_chat_mode(self): | |||
| """Test preparing tool arguments for chat mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.CHAT.value | |||
| arguments = {"query": "test question", "name": "John"} | |||
| result = prepare_tool_arguments(app, arguments) | |||
| assert result["query"] == "test question" | |||
| assert result["inputs"]["name"] == "John" | |||
| # Original arguments should not be modified | |||
| assert arguments["query"] == "test question" | |||
| def test_prepare_tool_arguments_workflow_mode(self): | |||
| """Test preparing tool arguments for workflow mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.WORKFLOW.value | |||
| arguments = {"input_text": "test input"} | |||
| result = prepare_tool_arguments(app, arguments) | |||
| assert "inputs" in result | |||
| assert result["inputs"]["input_text"] == "test input" | |||
| def test_prepare_tool_arguments_completion_mode(self): | |||
| """Test preparing tool arguments for completion mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.COMPLETION.value | |||
| arguments = {"name": "John"} | |||
| result = prepare_tool_arguments(app, arguments) | |||
| assert result["query"] == "" | |||
| assert result["inputs"]["name"] == "John" | |||
| def test_extract_answer_from_mapping_response_chat(self): | |||
| """Test extracting answer from mapping response for chat mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.CHAT.value | |||
| response = {"answer": "test answer", "other": "data"} | |||
| result = extract_answer_from_response(app, response) | |||
| assert result == "test answer" | |||
| def test_extract_answer_from_mapping_response_workflow(self): | |||
| """Test extracting answer from mapping response for workflow mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = AppMode.WORKFLOW.value | |||
| response = {"data": {"outputs": {"result": "test result"}}} | |||
| result = extract_answer_from_response(app, response) | |||
| expected = json.dumps({"result": "test result"}, ensure_ascii=False) | |||
| assert result == expected | |||
| def test_extract_answer_from_streaming_response(self): | |||
| """Test extracting answer from streaming response""" | |||
| app = Mock(spec=App) | |||
| # Mock RateLimitGenerator | |||
| mock_generator = Mock(spec=RateLimitGenerator) | |||
| mock_generator.generator = [ | |||
| 'data: {"event": "agent_thought", "thought": "thinking..."}', | |||
| 'data: {"event": "agent_thought", "thought": "more thinking"}', | |||
| 'data: {"event": "other", "content": "ignore this"}', | |||
| "not data format", | |||
| ] | |||
| result = extract_answer_from_response(app, mock_generator) | |||
| assert result == "thinking...more thinking" | |||
| def test_process_mapping_response_invalid_mode(self): | |||
| """Test processing mapping response with invalid app mode""" | |||
| app = Mock(spec=App) | |||
| app.mode = "invalid_mode" | |||
| response = {"answer": "test"} | |||
| with pytest.raises(ValueError, match="Invalid app mode"): | |||
| process_mapping_response(app, response) | |||
| def test_convert_input_form_to_parameters(self): | |||
| """Test converting input form to parameters""" | |||
| user_input_form = [ | |||
| VariableEntity( | |||
| type=VariableEntityType.TEXT_INPUT, | |||
| variable="name", | |||
| description="User name", | |||
| label="Name", | |||
| required=True, | |||
| ), | |||
| VariableEntity( | |||
| type=VariableEntityType.SELECT, | |||
| variable="category", | |||
| description="Category", | |||
| label="Category", | |||
| required=False, | |||
| options=["A", "B", "C"], | |||
| ), | |||
| VariableEntity( | |||
| type=VariableEntityType.NUMBER, | |||
| variable="count", | |||
| description="Count", | |||
| label="Count", | |||
| required=True, | |||
| ), | |||
| VariableEntity( | |||
| type=VariableEntityType.FILE, | |||
| variable="upload", | |||
| description="File upload", | |||
| label="Upload", | |||
| required=False, | |||
| ), | |||
| ] | |||
| parameters_dict: dict[str, str] = { | |||
| "name": "Enter your name", | |||
| "category": "Select category", | |||
| "count": "Enter count", | |||
| } | |||
| parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) | |||
| # Check parameters | |||
| assert "name" in parameters | |||
| assert parameters["name"]["type"] == "string" | |||
| assert parameters["name"]["description"] == "Enter your name" | |||
| assert "category" in parameters | |||
| assert parameters["category"]["type"] == "string" | |||
| assert parameters["category"]["enum"] == ["A", "B", "C"] | |||
| assert "count" in parameters | |||
| assert parameters["count"]["type"] == "float" | |||
| # FILE type should be skipped - it creates empty dict but gets filtered later | |||
| # Check that it doesn't have any meaningful content | |||
| if "upload" in parameters: | |||
| assert parameters["upload"] == {} | |||
| # Check required fields | |||
| assert "name" in required | |||
| assert "count" in required | |||
| assert "category" not in required | |||
| # Note: _get_request_id function has been removed as request_id is now passed as parameter | |||