ソースを参照

Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

tags/2.0.0-beta.1
-LAN- 2ヶ月前
コミット
5415d0c6d1
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 0
- 7
.github/workflows/api-tests.yml ファイルの表示

@@ -2,13 +2,6 @@ name: Run Pytest

on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
- .github/workflows/api-tests.yml

concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}

+ 0
- 4
.github/workflows/autofix.yml ファイルの表示

@@ -1,10 +1,6 @@
name: autofix.ci
on:
workflow_call:
pull_request:
branches: [ "main" ]
push:
branches: [ "main" ]
permissions:
contents: read


+ 6
- 7
.github/workflows/db-migration-test.yml ファイルの表示

@@ -2,13 +2,6 @@ name: DB Migration Test

on:
workflow_call:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml

concurrency:
group: db-migration-test-${{ github.ref }}
@@ -34,6 +27,12 @@ jobs:

- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql

- name: Prepare middleware env
run: |

+ 3
- 0
.github/workflows/main-ci.yml ファイルの表示

@@ -3,11 +3,14 @@ name: Main CI Pipeline
on:
pull_request:
branches: [ "main" ]
push:
branches: [ "main" ]

permissions:
contents: write
pull-requests: write
checks: write
statuses: write

concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}

+ 0
- 3
.github/workflows/style.yml ファイルの表示

@@ -2,9 +2,6 @@ name: Style check

on:
workflow_call:
pull_request:
branches:
- main

concurrency:
group: style-${{ github.head_ref || github.run_id }}

+ 0
- 9
.github/workflows/vdb-tests.yml ファイルの表示

@@ -2,15 +2,6 @@ name: Run VDB Tests

on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/core/rag/datasource/**
- docker/**
- .github/workflows/vdb-tests.yml
- api/uv.lock
- api/pyproject.toml

concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}

+ 0
- 5
.github/workflows/web-tests.yml ファイルの表示

@@ -2,11 +2,6 @@ name: Web Tests

on:
workflow_call:
pull_request:
branches:
- main
paths:
- web/**

concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}

+ 171
- 66
api/controllers/mcp/mcp.py ファイルの表示

@@ -1,18 +1,27 @@
from typing import Optional, Union

from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session

from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
from models.model import App, AppMCPServer, AppMode, EndUser


class MCPRequestError(Exception):
"""Custom exception for MCP request processing errors"""

def __init__(self, error_code: int, message: str):
self.error_code = error_code
self.message = message
super().__init__(message)


def int_or_str(value):
@@ -63,77 +72,173 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
# Parse and validate all arguments
args = mcp_request_parser.parse_args()

request_id: Optional[Union[int, str]] = args.get("id")
mcp_request = self._parse_mcp_request(args)

server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
)
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)

if server.status != AppMCPServerStatus.ACTIVE:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
# Get user input form
user_input_form = self._get_user_input_form(app)

app = db.session.query(App).where(App.id == server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
)
# Handle notification vs request differently
return self._process_mcp_message(mcp_request, request_id, app, mcp_server, user_input_form, session)

if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")

user_input_form = workflow.user_input_form(to_old_structure=True)
app = session.query(App).where(App.id == mcp_server.app_id).first()
if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")

return mcp_server, app

def _validate_server_status(self, mcp_server: AppMCPServer) -> None:
"""Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")

def _process_mcp_message(
self,
mcp_request: mcp_types.ClientRequest | mcp_types.ClientNotification,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Process MCP message (notification or request)"""
if isinstance(mcp_request, mcp_types.ClientNotification):
return self._handle_notification(mcp_request)
else:
app_model_config = app.app_model_config
if app_model_config is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)

features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
)
return self._handle_request(mcp_request, request_id, app, mcp_server, user_input_form, session)

def _handle_notification(self, mcp_request: mcp_types.ClientNotification) -> Response:
"""Handle MCP notification"""
# For notifications, only support init notification
if mcp_request.root.method != "notifications/initialized":
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Invalid notification method")
# Return HTTP 202 Accepted for notifications (no response body)
return Response("", status=202, content_type="application/json")

def _handle_request(
self,
mcp_request: mcp_types.ClientRequest,
request_id: Optional[Union[int, str]],
app: App,
mcp_server: AppMCPServer,
user_input_form: list[VariableEntity],
session: Session,
) -> Response:
"""Handle MCP request"""
if request_id is None:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Request ID is required")

result = self._handle_mcp_request(app, mcp_server, mcp_request, user_input_form, session, request_id)
if result is None:
# This shouldn't happen for requests, but handle gracefully
raise MCPRequestError(mcp_types.INTERNAL_ERROR, "No response generated for request")

return helper.compact_generate_response(result.model_dump(by_alias=True, mode="json", exclude_none=True))

def _get_user_input_form(self, app: App) -> list[VariableEntity]:
"""Get and convert user input form"""
# Get raw user input form based on app mode
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if not app.workflow:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
raw_user_input_form = app.workflow.user_input_form(to_old_structure=True)
else:
if not app.app_model_config:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable")
features_dict = app.app_model_config.to_dict()
raw_user_input_form = features_dict.get("user_input_form", [])

# Convert to VariableEntity objects
try:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
return self._convert_user_input_form(raw_user_input_form)
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")

def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]

def _create_variable_entity(self, item: dict) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]

return VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)

def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)
except ValidationError:
try:
notification = ClientNotification.model_validate(args)
request = notification
return mcp_types.ClientNotification.model_validate(args)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
)

mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")

def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
"""Get end user from existing session - optimized query"""
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.first()
)

def _create_end_user(
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
) -> EndUser:
"""Create end user in existing session"""
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type="mcp",
name=client_name,
session_id=mcp_server_id,
)
session.add(end_user)
session.flush() # Use flush instead of commit to keep transaction open
session.refresh(end_user)
return end_user

def _handle_mcp_request(
self,
app: App,
mcp_server: AppMCPServer,
mcp_request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
session: Session,
request_id: Union[int, str],
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
"""Handle MCP request and return response"""
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)

if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)

return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

+ 230
- 195
api/core/mcp/server/streamable_http.py ファイルの表示

@@ -4,224 +4,259 @@ from collections.abc import Mapping
from typing import Any, cast

from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from core.mcp import types as mcp_types
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

logger = logging.getLogger(__name__)


class MCPServerStreamableHTTPRequestHandler:
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
mcp_server: AppMCPServer,
end_user: EndUser | None = None,
request_id: int | str = 1,
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError:
"""
Apply to MCP HTTP streamable server with stateless http
Handle MCP request and return JSON-RPC response

Args:
app: The Dify app instance
request: The JSON-RPC request message
user_input_form: List of variable entities for the app
mcp_server: The MCP server configuration
end_user: Optional end user
request_id: The request ID

Returns:
JSON-RPC response or error
"""

def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form

@property
def request_type(self):
return type(self.request.root)

@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
request_type = type(request.root)

@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse:
"""Create success response with business result data"""
return mcp_types.JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=result_data.model_dump(by_alias=True, mode="json", exclude_none=True),
)

def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
def create_error_response(code: int, message: str) -> mcp_types.JSONRPCError:
"""Create error response with error code and message"""
from core.mcp.types import ErrorData

error_data = ErrorData(code=code, message=message)
return mcp_types.JSONRPCError(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
id=request_id,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))

sse_content = f"event: message\ndata: {json_data}\n\n".encode()
# Request handler mapping using functional approach
request_handlers = {
mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description),
mcp_types.ListToolsRequest: lambda: handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
),
mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user),
mcp_types.PingRequest: lambda: handle_ping(),
}

yield sse_content
try:
# Dispatch request to appropriate handler
handler = request_handlers.get(request_type)
if handler:
return create_success_response(handler())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")

def error_response(self, code: int, message: str, data=None):
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
except ValueError as e:
logger.exception("Invalid params")
return create_error_response(mcp_types.INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return create_error_response(mcp_types.INTERNAL_ERROR, "Internal server error: " + str(e))

def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
types.PingRequest: self.handle_ping,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
logger.exception("Invalid params")
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")

def handle_notification(self):
return "ping"

def handle_ping(self):
return types.EmptyResult()

def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=client_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,

def handle_ping() -> mcp_types.EmptyResult:
"""Handle ping request"""
return mcp_types.EmptyResult()


def handle_initialize(description: str) -> mcp_types.InitializeResult:
"""Handle initialize request"""
capabilities = mcp_types.ServerCapabilities(
tools=mcp_types.ToolsCapability(listChanged=False),
)

return mcp_types.InitializeResult(
protocolVersion=mcp_types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=mcp_types.Implementation(name="Dify", version=dify_config.project.version),
instructions=description,
)


def handle_list_tools(
app_name: str,
app_mode: str,
user_input_form: list[VariableEntity],
description: str,
parameters_dict: dict[str, str],
) -> mcp_types.ListToolsResult:
"""Handle list tools request"""
parameter_schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)

return mcp_types.ListToolsResult(
tools=[
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=parameter_schema,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
instructions=self.mcp_server.description,
)
],
)

def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.app.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)

def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments or {}
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
args = {"query": "", "inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(
self.app,
self.end_user,
args,
InvokeFrom.SERVICE_API,
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
)
answer = ""
if isinstance(response, RateLimitGenerator):
for item in response.generator:
data = item
if isinstance(data, str) and data.startswith("data: "):
try:
json_str = data[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
if isinstance(response, Mapping):
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])

def retrieve_end_user(self):
return (
db.session.query(EndUser)
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def handle_call_tool(
app: App,
request: mcp_types.ClientRequest,
user_input_form: list[VariableEntity],
end_user: EndUser | None,
) -> mcp_types.CallToolResult:
"""Handle call tool request"""
request_obj = cast(mcp_types.CallToolRequest, request.root)
args = prepare_tool_arguments(app, request_obj.params.arguments or {})

def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
if not end_user:
raise ValueError("End user not found")

response = AppGenerateService.generate(
app,
end_user,
args,
InvokeFrom.SERVICE_API,
streaming=app.mode == AppMode.AGENT_CHAT.value,
)

answer = extract_answer_from_response(app, response)
return mcp_types.CallToolResult(content=[mcp_types.TextContent(text=answer, type="text")])


def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)

if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}


def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW.value:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION.value:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}


def extract_answer_from_response(app: App, response: Any) -> str:
"""Extract answer from app generate response"""
answer = ""

if isinstance(response, RateLimitGenerator):
answer = process_streaming_response(response)
elif isinstance(response, Mapping):
answer = process_mapping_response(app, response)
else:
logger.warning("Unexpected response type: %s", type(response))

return answer


def process_streaming_response(response: RateLimitGenerator) -> str:
"""Process streaming response for agent chat mode"""
answer = ""
for item in response.generator:
if isinstance(item, str) and item.startswith("data: "):
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required
json_str = item[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
return answer


def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
if app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW.value:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))


def convert_input_form_to_parameters(
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> tuple[dict[str, dict[str, Any]], list[str]]:
"""Convert user input form to parameter schema"""
parameters: dict[str, dict[str, Any]] = {}
required = []

for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
parameters[item.variable] = {}
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
description = parameters_dict.get(item.variable, "")
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required

+ 1
- 1
api/core/mcp/utils.py ファイルの表示

@@ -138,5 +138,5 @@ def create_mcp_error_response(
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
sse_content = json_data.encode()
yield sse_content

+ 20
- 3
api/migrations/versions/2025_08_09_1553-e8446f481c1e_add_provider_credential_pool_support.py ファイルの表示

@@ -5,7 +5,7 @@ Revises: 8bcc02c9bd07
Create Date: 2025-08-09 15:53:54.341341

"""
from alembic import op
from alembic import op, context
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
@@ -43,7 +43,15 @@ def upgrade():
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True))

migrate_existing_providers_data()
if not context.is_offline_mode():
migrate_existing_providers_data()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_existing_providers_data`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)

# Remove encrypted_config column from providers table after migration
with op.batch_alter_table('providers', schema=None) as batch_op:
@@ -119,7 +127,16 @@ def downgrade():
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))

# Migrate data back from provider_credentials to providers
migrate_data_back_to_providers()

if not context.is_offline_mode():
migrate_data_back_to_providers()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_data_back_to_providers`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)

# Remove credential_id columns
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:

+ 21
- 5
api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py ファイルの表示

@@ -6,7 +6,7 @@ Create Date: 2025-08-13 16:05:42.657730

"""

from alembic import op
from alembic import op, context
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
@@ -48,8 +48,16 @@ def upgrade():
with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True))

# Migrate existing provider_models data
migrate_existing_provider_models_data()
if not context.is_offline_mode():
# Migrate existing provider_models data
migrate_existing_provider_models_data()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_existing_provider_models_data`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)

# Remove encrypted_config column from provider_models table after migration
with op.batch_alter_table('provider_models', schema=None) as batch_op:
@@ -132,8 +140,16 @@ def downgrade():
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))

# Migrate data back from provider_model_credentials to provider_models
migrate_data_back_to_provider_models()
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models
migrate_data_back_to_provider_models()
else:
op.execute(
'-- [IMPORTANT] Data migration skipped!!!\n'
"-- You should manually run data migration function `migrate_data_back_to_provider_models`\n"
f"-- inside file {__file__}\n"
"-- Please review the migration script carefully!"
)

with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.drop_column('credential_id')

+ 32
- 22
api/models/account.py ファイルの表示

@@ -1,12 +1,12 @@
import enum
import json
from datetime import datetime
from typing import Optional, cast
from typing import Optional

import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor

from models.base import Base

@@ -118,10 +118,24 @@ class Account(UserMixin, Base):

@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_join_query = select(TenantAccountJoin).where(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id
)
tenant_join = session.scalar(tenant_join_query)
tenant_query = select(Tenant).where(Tenant.id == tenant.id)
# TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing
# access to it after the session has been closed.
# This prevents `DetachedInstanceError` when accessing the tenant outside
# the session's lifecycle.
# (The `tenant` argument is typically loaded by `db.session` without the
# `expire_on_commit=False` flag, meaning its lifetime is tied to the web
# request's lifecycle.)
tenant_reloaded = session.scalars(tenant_query).one()

if tenant_join:
self.role = TenantAccountRole(tenant_join.role)
self._current_tenant = tenant_reloaded
return
self._current_tenant = None

@@ -130,23 +144,19 @@ class Account(UserMixin, Base):
return self._current_tenant.id if self._current_tenant else None

def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
.one_or_none()
),
query = (
select(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
)
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_account_join = session.execute(query).first()
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant

@property
def current_role(self):

+ 1
- 0
api/tests/unit_tests/core/mcp/server/__init__.py ファイルの表示

@@ -0,0 +1 @@
# MCP server tests

+ 449
- 0
api/tests/unit_tests/core/mcp/server/test_streamable_http.py ファイルの表示

@@ -0,0 +1,449 @@
import json
from unittest.mock import Mock, patch

import pytest

from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.server.streamable_http import (
build_parameter_schema,
convert_input_form_to_parameters,
extract_answer_from_response,
handle_call_tool,
handle_initialize,
handle_list_tools,
handle_mcp_request,
handle_ping,
prepare_tool_arguments,
process_mapping_response,
)
from models.model import App, AppMCPServer, AppMode, EndUser


class TestHandleMCPRequest:
"""Test handle_mcp_request function"""

def setup_method(self):
"""Setup test fixtures"""
self.app = Mock(spec=App)
self.app.name = "test_app"
self.app.mode = AppMode.CHAT.value

self.mcp_server = Mock(spec=AppMCPServer)
self.mcp_server.description = "Test server"
self.mcp_server.parameters_dict = {}

self.end_user = Mock(spec=EndUser)
self.user_input_form = []

# Create mock request
self.mock_request = Mock()
self.mock_request.root = Mock()
self.mock_request.root.id = 123

def test_handle_ping_request(self):
"""Test handling ping request"""
# Setup ping request
self.mock_request.root = Mock(spec=types.PingRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.PingRequest)

with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123

def test_handle_initialize_request(self):
"""Test handling initialize request"""
# Setup initialize request
self.mock_request.root = Mock(spec=types.InitializeRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.InitializeRequest)

with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123

def test_handle_list_tools_request(self):
"""Test handling list tools request"""
# Setup list tools request
self.mock_request.root = Mock(spec=types.ListToolsRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.ListToolsRequest)

with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123

@patch("core.mcp.server.streamable_http.AppGenerateService")
def test_handle_call_tool_request(self, mock_app_generate):
"""Test handling call tool request"""
# Setup call tool request
mock_call_request = Mock(spec=types.CallToolRequest)
mock_call_request.params = Mock()
mock_call_request.params.arguments = {"query": "test question"}
mock_call_request.id = 123

self.mock_request.root = mock_call_request
request_type = Mock(return_value=types.CallToolRequest)

# Mock app generate service response
mock_response = {"answer": "test answer"}
mock_app_generate.generate.return_value = mock_response

with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123

# Verify AppGenerateService was called
mock_app_generate.generate.assert_called_once()

def test_handle_unknown_request_type(self):
"""Test handling unknown request type"""

# Setup unknown request
class UnknownRequest:
pass

self.mock_request.root = Mock(spec=UnknownRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=UnknownRequest)

with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCError)
assert result.jsonrpc == "2.0"
assert result.id == 123
assert result.error.code == types.METHOD_NOT_FOUND

def test_handle_value_error(self):
"""Test handling ValueError"""
# Setup request that will cause ValueError
self.mock_request.root = Mock(spec=types.CallToolRequest)
self.mock_request.root.params = Mock()
self.mock_request.root.params.arguments = {}

request_type = Mock(return_value=types.CallToolRequest)

# Don't provide end_user to cause ValueError
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123)

assert isinstance(result, types.JSONRPCError)
assert result.error.code == types.INVALID_PARAMS

def test_handle_generic_exception(self):
"""Test handling generic exception"""
# Setup request that will cause generic exception
self.mock_request.root = Mock(spec=types.PingRequest)
self.mock_request.root.id = 123

# Patch handle_ping to raise exception instead of type
with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")):
with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)

assert isinstance(result, types.JSONRPCError)
assert result.error.code == types.INTERNAL_ERROR


class TestIndividualHandlers:
"""Test individual handler functions"""

def test_handle_ping(self):
"""Test ping handler"""
result = handle_ping()
assert isinstance(result, types.EmptyResult)

def test_handle_initialize(self):
"""Test initialize handler"""
description = "Test server"

with patch("core.mcp.server.streamable_http.dify_config") as mock_config:
mock_config.project.version = "1.0.0"
result = handle_initialize(description)

assert isinstance(result, types.InitializeResult)
assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION
assert result.instructions == "Test server"

def test_handle_list_tools(self):
"""Test list tools handler"""
app_name = "test_app"
app_mode = AppMode.CHAT.value
description = "Test server"
parameters_dict: dict[str, str] = {}
user_input_form: list[VariableEntity] = []

result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict)

assert isinstance(result, types.ListToolsResult)
assert len(result.tools) == 1
assert result.tools[0].name == "test_app"
assert result.tools[0].description == "Test server"

@patch("core.mcp.server.streamable_http.AppGenerateService")
def test_handle_call_tool(self, mock_app_generate):
"""Test call tool handler"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value

# Create mock request
mock_request = Mock()
mock_call_request = Mock(spec=types.CallToolRequest)
mock_call_request.params = Mock()
mock_call_request.params.arguments = {"query": "test question"}
mock_request.root = mock_call_request

user_input_form: list[VariableEntity] = []
end_user = Mock(spec=EndUser)

# Mock app generate service response
mock_response = {"answer": "test answer"}
mock_app_generate.generate.return_value = mock_response

result = handle_call_tool(app, mock_request, user_input_form, end_user)

assert isinstance(result, types.CallToolResult)
assert len(result.content) == 1
# Type assertion needed due to union type
text_content = result.content[0]
assert hasattr(text_content, "text")
assert text_content.text == "test answer" # type: ignore[attr-defined]

def test_handle_call_tool_no_end_user(self):
"""Test call tool handler without end user"""
app = Mock(spec=App)
mock_request = Mock()
user_input_form: list[VariableEntity] = []

with pytest.raises(ValueError, match="End user not found"):
handle_call_tool(app, mock_request, user_input_form, None)


class TestUtilityFunctions:
"""Test utility functions"""

def test_build_parameter_schema_chat_mode(self):
"""Test building parameter schema for chat mode"""
app_mode = AppMode.CHAT.value
parameters_dict: dict[str, str] = {"name": "Enter your name"}

user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="name",
description="User name",
label="Name",
required=True,
)
]

schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)

assert schema["type"] == "object"
assert "query" in schema["properties"]
assert "name" in schema["properties"]
assert "query" in schema["required"]
assert "name" in schema["required"]

def test_build_parameter_schema_workflow_mode(self):
"""Test building parameter schema for workflow mode"""
app_mode = AppMode.WORKFLOW.value
parameters_dict: dict[str, str] = {"input_text": "Enter text"}

user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="input_text",
description="Input text",
label="Input",
required=True,
)
]

schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)

assert schema["type"] == "object"
assert "query" not in schema["properties"]
assert "input_text" in schema["properties"]
assert "input_text" in schema["required"]

def test_prepare_tool_arguments_chat_mode(self):
"""Test preparing tool arguments for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value

arguments = {"query": "test question", "name": "John"}

result = prepare_tool_arguments(app, arguments)

assert result["query"] == "test question"
assert result["inputs"]["name"] == "John"
# Original arguments should not be modified
assert arguments["query"] == "test question"

def test_prepare_tool_arguments_workflow_mode(self):
"""Test preparing tool arguments for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value

arguments = {"input_text": "test input"}

result = prepare_tool_arguments(app, arguments)

assert "inputs" in result
assert result["inputs"]["input_text"] == "test input"

def test_prepare_tool_arguments_completion_mode(self):
"""Test preparing tool arguments for completion mode"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value

arguments = {"name": "John"}

result = prepare_tool_arguments(app, arguments)

assert result["query"] == ""
assert result["inputs"]["name"] == "John"

def test_extract_answer_from_mapping_response_chat(self):
"""Test extracting answer from mapping response for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT.value

response = {"answer": "test answer", "other": "data"}

result = extract_answer_from_response(app, response)

assert result == "test answer"

def test_extract_answer_from_mapping_response_workflow(self):
"""Test extracting answer from mapping response for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value

response = {"data": {"outputs": {"result": "test result"}}}

result = extract_answer_from_response(app, response)

expected = json.dumps({"result": "test result"}, ensure_ascii=False)
assert result == expected

def test_extract_answer_from_streaming_response(self):
"""Test extracting answer from streaming response"""
app = Mock(spec=App)

# Mock RateLimitGenerator
mock_generator = Mock(spec=RateLimitGenerator)
mock_generator.generator = [
'data: {"event": "agent_thought", "thought": "thinking..."}',
'data: {"event": "agent_thought", "thought": "more thinking"}',
'data: {"event": "other", "content": "ignore this"}',
"not data format",
]

result = extract_answer_from_response(app, mock_generator)

assert result == "thinking...more thinking"

def test_process_mapping_response_invalid_mode(self):
"""Test processing mapping response with invalid app mode"""
app = Mock(spec=App)
app.mode = "invalid_mode"

response = {"answer": "test"}

with pytest.raises(ValueError, match="Invalid app mode"):
process_mapping_response(app, response)

def test_convert_input_form_to_parameters(self):
"""Test converting input form to parameters"""
user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="name",
description="User name",
label="Name",
required=True,
),
VariableEntity(
type=VariableEntityType.SELECT,
variable="category",
description="Category",
label="Category",
required=False,
options=["A", "B", "C"],
),
VariableEntity(
type=VariableEntityType.NUMBER,
variable="count",
description="Count",
label="Count",
required=True,
),
VariableEntity(
type=VariableEntityType.FILE,
variable="upload",
description="File upload",
label="Upload",
required=False,
),
]

parameters_dict: dict[str, str] = {
"name": "Enter your name",
"category": "Select category",
"count": "Enter count",
}

parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)

# Check parameters
assert "name" in parameters
assert parameters["name"]["type"] == "string"
assert parameters["name"]["description"] == "Enter your name"

assert "category" in parameters
assert parameters["category"]["type"] == "string"
assert parameters["category"]["enum"] == ["A", "B", "C"]

assert "count" in parameters
assert parameters["count"]["type"] == "float"

# FILE type should be skipped - it creates empty dict but gets filtered later
# Check that it doesn't have any meaningful content
if "upload" in parameters:
assert parameters["upload"] == {}

# Check required fields
assert "name" in required
assert "count" in required
assert "category" not in required

# Note: _get_request_id function has been removed as request_id is now passed as parameter

読み込み中…
キャンセル
保存