Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

streamable_http.py 8.8KB


  1. import json
  2. import logging
  3. from collections.abc import Mapping
  4. from typing import Any, cast
  5. from configs import dify_config
  6. from controllers.web.passport import generate_session_id
  7. from core.app.app_config.entities import VariableEntity, VariableEntityType
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
  10. from core.mcp import types
  11. from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
  12. from core.mcp.utils import create_mcp_error_response
  13. from core.model_runtime.utils.encoders import jsonable_encoder
  14. from extensions.ext_database import db
  15. from models.model import App, AppMCPServer, AppMode, EndUser
  16. from services.app_generate_service import AppGenerateService
  17. """
  18. Apply to MCP HTTP streamable server with stateless http
  19. """
  20. logger = logging.getLogger(__name__)
  21. class MCPServerStreamableHTTPRequestHandler:
  22. def __init__(
  23. self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
  24. ):
  25. self.app = app
  26. self.request = request
  27. mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
  28. if not mcp_server:
  29. raise ValueError("MCP server not found")
  30. self.mcp_server: AppMCPServer = mcp_server
  31. self.end_user = self.retrieve_end_user()
  32. self.user_input_form = user_input_form
  33. @property
  34. def request_type(self):
  35. return type(self.request.root)
  36. @property
  37. def parameter_schema(self):
  38. parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
  39. if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
  40. return {
  41. "type": "object",
  42. "properties": parameters,
  43. "required": required,
  44. }
  45. return {
  46. "type": "object",
  47. "properties": {
  48. "query": {"type": "string", "description": "User Input/Question content"},
  49. **parameters,
  50. },
  51. "required": ["query", *required],
  52. }
  53. @property
  54. def capabilities(self):
  55. return types.ServerCapabilities(
  56. tools=types.ToolsCapability(listChanged=False),
  57. )
  58. def response(self, response: types.Result | str):
  59. if isinstance(response, str):
  60. sse_content = f"event: ping\ndata: {response}\n\n".encode()
  61. yield sse_content
  62. return
  63. json_response = types.JSONRPCResponse(
  64. jsonrpc="2.0",
  65. id=(self.request.root.model_extra or {}).get("id", 1),
  66. result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
  67. )
  68. json_data = json.dumps(jsonable_encoder(json_response))
  69. sse_content = f"event: message\ndata: {json_data}\n\n".encode()
  70. yield sse_content
  71. def error_response(self, code: int, message: str, data=None):
  72. request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
  73. return create_mcp_error_response(request_id, code, message, data)
  74. def handle(self):
  75. handle_map = {
  76. types.InitializeRequest: self.initialize,
  77. types.ListToolsRequest: self.list_tools,
  78. types.CallToolRequest: self.invoke_tool,
  79. types.InitializedNotification: self.handle_notification,
  80. types.PingRequest: self.handle_ping,
  81. }
  82. try:
  83. if self.request_type in handle_map:
  84. return self.response(handle_map[self.request_type]())
  85. else:
  86. return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
  87. except ValueError as e:
  88. logger.exception("Invalid params")
  89. return self.error_response(INVALID_PARAMS, str(e))
  90. except Exception as e:
  91. logger.exception("Internal server error")
  92. return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
  93. def handle_notification(self):
  94. return "ping"
  95. def handle_ping(self):
  96. return types.EmptyResult()
  97. def initialize(self):
  98. request = cast(types.InitializeRequest, self.request.root)
  99. client_info = request.params.clientInfo
  100. client_name = f"{client_info.name}@{client_info.version}"
  101. if not self.end_user:
  102. end_user = EndUser(
  103. tenant_id=self.app.tenant_id,
  104. app_id=self.app.id,
  105. type="mcp",
  106. name=client_name,
  107. session_id=generate_session_id(),
  108. external_user_id=self.mcp_server.id,
  109. )
  110. db.session.add(end_user)
  111. db.session.commit()
  112. return types.InitializeResult(
  113. protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
  114. capabilities=self.capabilities,
  115. serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
  116. instructions=self.mcp_server.description,
  117. )
  118. def list_tools(self):
  119. if not self.end_user:
  120. raise ValueError("User not found")
  121. return types.ListToolsResult(
  122. tools=[
  123. types.Tool(
  124. name=self.app.name,
  125. description=self.mcp_server.description,
  126. inputSchema=self.parameter_schema,
  127. )
  128. ],
  129. )
  130. def invoke_tool(self):
  131. if not self.end_user:
  132. raise ValueError("User not found")
  133. request = cast(types.CallToolRequest, self.request.root)
  134. args = request.params.arguments or {}
  135. if self.app.mode in {AppMode.WORKFLOW.value}:
  136. args = {"inputs": args}
  137. elif self.app.mode in {AppMode.COMPLETION.value}:
  138. args = {"query": "", "inputs": args}
  139. else:
  140. args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
  141. response = AppGenerateService.generate(
  142. self.app,
  143. self.end_user,
  144. args,
  145. InvokeFrom.SERVICE_API,
  146. streaming=self.app.mode == AppMode.AGENT_CHAT.value,
  147. )
  148. answer = ""
  149. if isinstance(response, RateLimitGenerator):
  150. for item in response.generator:
  151. data = item
  152. if isinstance(data, str) and data.startswith("data: "):
  153. try:
  154. json_str = data[6:].strip()
  155. parsed_data = json.loads(json_str)
  156. if parsed_data.get("event") == "agent_thought":
  157. answer += parsed_data.get("thought", "")
  158. except json.JSONDecodeError:
  159. continue
  160. if isinstance(response, Mapping):
  161. if self.app.mode in {
  162. AppMode.ADVANCED_CHAT.value,
  163. AppMode.COMPLETION.value,
  164. AppMode.CHAT.value,
  165. AppMode.AGENT_CHAT.value,
  166. }:
  167. answer = response["answer"]
  168. elif self.app.mode in {AppMode.WORKFLOW.value}:
  169. answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
  170. else:
  171. raise ValueError("Invalid app mode")
  172. # Not support image yet
  173. return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
  174. def retrieve_end_user(self):
  175. return (
  176. db.session.query(EndUser)
  177. .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
  178. .first()
  179. )
  180. def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
  181. parameters: dict[str, dict[str, Any]] = {}
  182. required = []
  183. for item in user_input_form:
  184. parameters[item.variable] = {}
  185. if item.type in (
  186. VariableEntityType.FILE,
  187. VariableEntityType.FILE_LIST,
  188. VariableEntityType.EXTERNAL_DATA_TOOL,
  189. ):
  190. continue
  191. if item.required:
  192. required.append(item.variable)
  193. # if the workflow republished, the parameters not changed
  194. # we should not raise error here
  195. try:
  196. description = self.mcp_server.parameters_dict[item.variable]
  197. except KeyError:
  198. description = ""
  199. parameters[item.variable]["description"] = description
  200. if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
  201. parameters[item.variable]["type"] = "string"
  202. elif item.type == VariableEntityType.SELECT:
  203. parameters[item.variable]["type"] = "string"
  204. parameters[item.variable]["enum"] = item.options
  205. elif item.type == VariableEntityType.NUMBER:
  206. parameters[item.variable]["type"] = "float"
  207. return parameters, required