| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- import json
- import logging
- from collections.abc import Mapping
- from typing import Any, cast
-
- from configs import dify_config
- from controllers.web.passport import generate_session_id
- from core.app.app_config.entities import VariableEntity, VariableEntityType
- from core.app.entities.app_invoke_entities import InvokeFrom
- from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
- from core.mcp import types
- from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
- from core.mcp.utils import create_mcp_error_response
- from core.model_runtime.utils.encoders import jsonable_encoder
- from extensions.ext_database import db
- from models.model import App, AppMCPServer, AppMode, EndUser
- from services.app_generate_service import AppGenerateService
-
- """
- Apply to MCP HTTP streamable server with stateless http
- """
- logger = logging.getLogger(__name__)
-
-
- class MCPServerStreamableHTTPRequestHandler:
- def __init__(
- self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
- ):
- self.app = app
- self.request = request
- mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
- if not mcp_server:
- raise ValueError("MCP server not found")
- self.mcp_server: AppMCPServer = mcp_server
- self.end_user = self.retrieve_end_user()
- self.user_input_form = user_input_form
-
- @property
- def request_type(self):
- return type(self.request.root)
-
- @property
- def parameter_schema(self):
- parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
- if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
- return {
- "type": "object",
- "properties": parameters,
- "required": required,
- }
- return {
- "type": "object",
- "properties": {
- "query": {"type": "string", "description": "User Input/Question content"},
- **parameters,
- },
- "required": ["query", *required],
- }
-
- @property
- def capabilities(self):
- return types.ServerCapabilities(
- tools=types.ToolsCapability(listChanged=False),
- )
-
- def response(self, response: types.Result | str):
- if isinstance(response, str):
- sse_content = f"event: ping\ndata: {response}\n\n".encode()
- yield sse_content
- return
- json_response = types.JSONRPCResponse(
- jsonrpc="2.0",
- id=(self.request.root.model_extra or {}).get("id", 1),
- result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
- )
- json_data = json.dumps(jsonable_encoder(json_response))
-
- sse_content = f"event: message\ndata: {json_data}\n\n".encode()
-
- yield sse_content
-
- def error_response(self, code: int, message: str, data=None):
- request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
- return create_mcp_error_response(request_id, code, message, data)
-
- def handle(self):
- handle_map = {
- types.InitializeRequest: self.initialize,
- types.ListToolsRequest: self.list_tools,
- types.CallToolRequest: self.invoke_tool,
- types.InitializedNotification: self.handle_notification,
- types.PingRequest: self.handle_ping,
- }
- try:
- if self.request_type in handle_map:
- return self.response(handle_map[self.request_type]())
- else:
- return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
- except ValueError as e:
- logger.exception("Invalid params")
- return self.error_response(INVALID_PARAMS, str(e))
- except Exception as e:
- logger.exception("Internal server error")
- return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
-
- def handle_notification(self):
- return "ping"
-
- def handle_ping(self):
- return types.EmptyResult()
-
- def initialize(self):
- request = cast(types.InitializeRequest, self.request.root)
- client_info = request.params.clientInfo
- client_name = f"{client_info.name}@{client_info.version}"
- if not self.end_user:
- end_user = EndUser(
- tenant_id=self.app.tenant_id,
- app_id=self.app.id,
- type="mcp",
- name=client_name,
- session_id=generate_session_id(),
- external_user_id=self.mcp_server.id,
- )
- db.session.add(end_user)
- db.session.commit()
- return types.InitializeResult(
- protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
- capabilities=self.capabilities,
- serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
- instructions=self.mcp_server.description,
- )
-
- def list_tools(self):
- if not self.end_user:
- raise ValueError("User not found")
- return types.ListToolsResult(
- tools=[
- types.Tool(
- name=self.app.name,
- description=self.mcp_server.description,
- inputSchema=self.parameter_schema,
- )
- ],
- )
-
- def invoke_tool(self):
- if not self.end_user:
- raise ValueError("User not found")
- request = cast(types.CallToolRequest, self.request.root)
- args = request.params.arguments or {}
- if self.app.mode in {AppMode.WORKFLOW.value}:
- args = {"inputs": args}
- elif self.app.mode in {AppMode.COMPLETION.value}:
- args = {"query": "", "inputs": args}
- else:
- args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
- response = AppGenerateService.generate(
- self.app,
- self.end_user,
- args,
- InvokeFrom.SERVICE_API,
- streaming=self.app.mode == AppMode.AGENT_CHAT.value,
- )
- answer = ""
- if isinstance(response, RateLimitGenerator):
- for item in response.generator:
- data = item
- if isinstance(data, str) and data.startswith("data: "):
- try:
- json_str = data[6:].strip()
- parsed_data = json.loads(json_str)
- if parsed_data.get("event") == "agent_thought":
- answer += parsed_data.get("thought", "")
- except json.JSONDecodeError:
- continue
- if isinstance(response, Mapping):
- if self.app.mode in {
- AppMode.ADVANCED_CHAT.value,
- AppMode.COMPLETION.value,
- AppMode.CHAT.value,
- AppMode.AGENT_CHAT.value,
- }:
- answer = response["answer"]
- elif self.app.mode in {AppMode.WORKFLOW.value}:
- answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
- else:
- raise ValueError("Invalid app mode")
- # Not support image yet
- return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
-
- def retrieve_end_user(self):
- return (
- db.session.query(EndUser)
- .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
- .first()
- )
-
- def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
- parameters: dict[str, dict[str, Any]] = {}
- required = []
- for item in user_input_form:
- parameters[item.variable] = {}
- if item.type in (
- VariableEntityType.FILE,
- VariableEntityType.FILE_LIST,
- VariableEntityType.EXTERNAL_DATA_TOOL,
- ):
- continue
- if item.required:
- required.append(item.variable)
- # if the workflow republished, the parameters not changed
- # we should not raise error here
- try:
- description = self.mcp_server.parameters_dict[item.variable]
- except KeyError:
- description = ""
- parameters[item.variable]["description"] = description
- if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
- parameters[item.variable]["type"] = "string"
- elif item.type == VariableEntityType.SELECT:
- parameters[item.variable]["type"] = "string"
- parameters[item.variable]["enum"] = item.options
- elif item.type == VariableEntityType.NUMBER:
- parameters[item.variable]["type"] = "float"
- return parameters, required
|