|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- from typing import Optional, Union
-
- from flask import Response
- from flask_restx import Resource, reqparse
- from pydantic import ValidationError
- from sqlalchemy.orm import Session
-
- from controllers.console.app.mcp_server import AppMCPServerStatus
- from controllers.mcp import mcp_ns
- from core.app.app_config.entities import VariableEntity
- from core.mcp import types as mcp_types
- from core.mcp.server.streamable_http import handle_mcp_request
- from extensions.ext_database import db
- from libs import helper
- from models.model import App, AppMCPServer, AppMode, EndUser
-
-
- class MCPRequestError(Exception):
- """Custom exception for MCP request processing errors"""
-
- def __init__(self, error_code: int, message: str):
- self.error_code = error_code
- self.message = message
- super().__init__(message)
-
-
- def int_or_str(value):
- """Validate that a value is either an integer or string."""
- if isinstance(value, (int, str)):
- return value
- else:
- return None
-
-
- # Define parser for both documentation and validation
- mcp_request_parser = reqparse.RequestParser()
- mcp_request_parser.add_argument(
- "jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')"
- )
- mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
- mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
- mcp_request_parser.add_argument(
- "id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses"
- )
-
-
- @mcp_ns.route("/server/<string:server_code>/mcp")
- class MCPAppApi(Resource):
- @mcp_ns.expect(mcp_request_parser)
- @mcp_ns.doc("handle_mcp_request")
- @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
- @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
- @mcp_ns.doc(
- responses={
- 200: "MCP response successfully processed",
- 400: "Invalid MCP request or parameters",
- 404: "Server or app not found",
- }
- )
- def post(self, server_code: str):
- """Handle MCP requests for a specific server.
-
- Processes JSON-RPC formatted requests according to the Model Context Protocol specification.
- Validates the server status and associated app before processing the request.
-
- Args:
- server_code: Unique identifier for the MCP server
-
- Returns:
- dict: JSON-RPC response from the MCP handler
-
- Raises:
- ValidationError: Invalid request format or parameters
- """
- args = mcp_request_parser.parse_args()
- request_id: Optional[Union[int, str]] = args.get("id")
- mcp_request = self._parse_mcp_request(args)
-
- with Session(db.engine, expire_on_commit=False) as session:
- # Get MCP server and app
- mcp_server, app = self._get_mcp_server_and_app(server_code, session)
- self._validate_server_status(mcp_server)
-
- # Get user input form
- user_input_form = self._get_user_input_form(app)
-
- # Handle notification vs request differently
- return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)
-
- def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
- """Get and validate MCP server and app in one query session"""
- mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
- if not mcp_server:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
-
- app = session.query(App).where(App.id == mcp_server.app_id).first()
- if not app:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
-
- return mcp_server, app
-
- def _validate_server_status(self, mcp_server: AppMCPServer):
- """Validate MCP server status"""
- if mcp_server.status != AppMCPServerStatus.ACTIVE:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
-
- def _process_mcp_message(
- self,
- mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
- request_id: Optional[Union[int, str]],
- app: App,
- mcp_server: AppMCPServer,
- user_input_form: list[VariableEntity],
- session: Session,
- ) -> Response:
- """Process MCP message (notification or request)"""
- if isinstance(mcp_request, mcp_types.ClientNotification):
- return self._handle_notification(mcp_request)
- else:
- return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)
-
- def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
- """Handle MCP notification"""
- # For notifications, only support init notification
- if mcp_request.root.method != "notifications/initialized":
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
- # Return HTTP 202 Accepted for notifications (no response body)
- return Response("", status=202, content_type="application/json")
-
- def _handle_request(
- self,
- mcp_request: mcp_types.ClientRequest,
- request_id: Optional[Union[int, str]],
- app: App,
- mcp_server: AppMCPServer,
- user_input_form: list[VariableEntity],
- session: Session,
- ) -> Response:
- """Handle MCP request"""
- if request_id is None:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")
-
- result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
- if result is None:
- # This shouldn't happen for requests, but handle gracefully
- raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")
-
- return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))
-
- def _get_user_input_form(self, app: App) -> list[VariableEntity]:
- """Get and convert user input form"""
- # Get raw user input form based on app mode
- if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
- if not app.workflow:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
- raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
- else:
- if not app.app_model_config:
- raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
- features_dict = app.app_model_config.to_dict()
- raw_user_input_form = features_dict.get("user_input_form", [])
-
- # Convert to VariableEntity objects
- try:
- return self._convert_user_input_form(raw_user_input_form)
- except ValidationError as e:
- raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
-
- def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
- """Convert raw user input form to VariableEntity objects"""
- return [self._create_variable_entity(item) for item in raw_form]
-
- def _create_variable_entity(self, item: dict) -> VariableEntity:
- """Create a single VariableEntity from raw form item"""
- variable_type = item.get("type", "") or list(item.keys())[0]
- variable = item[variable_type]
-
- return VariableEntity(
- type=variable_type,
- variable=variable.get("variable"),
- description=variable.get("description") or "",
- label=variable.get("label"),
- required=variable.get("required", False),
- max_length=variable.get("max_length"),
- options=variable.get("options") or [],
- )
-
- def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
- """Parse and validate MCP request"""
- try:
- return mcp_types.ClientRequest.model_validate(args)
- except ValidationError:
- try:
- return mcp_types.ClientNotification.model_validate(args)
- except ValidationError as e:
- raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
-
- def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
- """Get end user from existing session - optimized query"""
- return (
- session.query(EndUser)
- .where(EndUser.tenant_id == tenant_id)
- .where(EndUser.session_id == mcp_server_id)
- .where(EndUser.type == "mcp")
- .first()
- )
-
- def _create_end_user(
- self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
- ) -> EndUser:
- """Create end user in existing session"""
- end_user = EndUser(
- tenant_id=tenant_id,
- app_id=app_id,
- type="mcp",
- name=client_name,
- session_id=mcp_server_id,
- )
- session.add(end_user)
- session.flush() # Use flush instead of commit to keep transaction open
- session.refresh(end_user)
- return end_user
-
- def _handle_mcp_request(
- self,
- app: App,
- mcp_server: AppMCPServer,
- mcp_request: mcp_types.ClientRequest,
- user_input_form: list[VariableEntity],
- session: Session,
- request_id: Union[int, str],
- ) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
- """Handle MCP request and return response"""
- end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
-
- if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
- client_info = mcp_request.root.params.clientInfo
- client_name = f"{client_info.name}@{client_info.version}"
- # Commit the session before creating end user to avoid transaction conflicts
- session.commit()
- with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
- end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
-
- return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|