You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

streamable_http.py 8.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import json
  2. import logging
  3. from collections.abc import Mapping
  4. from typing import Any, cast
  5. from configs import dify_config
  6. from controllers.web.passport import generate_session_id
  7. from core.app.app_config.entities import VariableEntity, VariableEntityType
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
  10. from core.mcp import types
  11. from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
  12. from core.mcp.utils import create_mcp_error_response
  13. from core.model_runtime.utils.encoders import jsonable_encoder
  14. from extensions.ext_database import db
  15. from models.model import App, AppMCPServer, AppMode, EndUser
  16. from services.app_generate_service import AppGenerateService
  17. """
  18. Apply to MCP HTTP streamable server with stateless http
  19. """
  20. logger = logging.getLogger(__name__)
  21. class MCPServerStreamableHTTPRequestHandler:
  22. def __init__(
  23. self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
  24. ):
  25. self.app = app
  26. self.request = request
  27. mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
  28. if not mcp_server:
  29. raise ValueError("MCP server not found")
  30. self.mcp_server: AppMCPServer = mcp_server
  31. self.end_user = self.retrieve_end_user()
  32. self.user_input_form = user_input_form
  33. @property
  34. def request_type(self):
  35. return type(self.request.root)
  36. @property
  37. def parameter_schema(self):
  38. parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
  39. if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
  40. return {
  41. "type": "object",
  42. "properties": parameters,
  43. "required": required,
  44. }
  45. return {
  46. "type": "object",
  47. "properties": {
  48. "query": {"type": "string", "description": "User Input/Question content"},
  49. **parameters,
  50. },
  51. "required": ["query", *required],
  52. }
  53. @property
  54. def capabilities(self):
  55. return types.ServerCapabilities(
  56. tools=types.ToolsCapability(listChanged=False),
  57. )
  58. def response(self, response: types.Result | str):
  59. if isinstance(response, str):
  60. sse_content = f"event: ping\ndata: {response}\n\n".encode()
  61. yield sse_content
  62. return
  63. json_response = types.JSONRPCResponse(
  64. jsonrpc="2.0",
  65. id=(self.request.root.model_extra or {}).get("id", 1),
  66. result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
  67. )
  68. json_data = json.dumps(jsonable_encoder(json_response))
  69. sse_content = f"event: message\ndata: {json_data}\n\n".encode()
  70. yield sse_content
  71. def error_response(self, code: int, message: str, data=None):
  72. request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
  73. return create_mcp_error_response(request_id, code, message, data)
  74. def handle(self):
  75. handle_map = {
  76. types.InitializeRequest: self.initialize,
  77. types.ListToolsRequest: self.list_tools,
  78. types.CallToolRequest: self.invoke_tool,
  79. types.InitializedNotification: self.handle_notification,
  80. }
  81. try:
  82. if self.request_type in handle_map:
  83. return self.response(handle_map[self.request_type]())
  84. else:
  85. return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
  86. except ValueError as e:
  87. logger.exception("Invalid params")
  88. return self.error_response(INVALID_PARAMS, str(e))
  89. except Exception as e:
  90. logger.exception("Internal server error")
  91. return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
  92. def handle_notification(self):
  93. return "ping"
  94. def initialize(self):
  95. request = cast(types.InitializeRequest, self.request.root)
  96. client_info = request.params.clientInfo
  97. clinet_name = f"{client_info.name}@{client_info.version}"
  98. if not self.end_user:
  99. end_user = EndUser(
  100. tenant_id=self.app.tenant_id,
  101. app_id=self.app.id,
  102. type="mcp",
  103. name=clinet_name,
  104. session_id=generate_session_id(),
  105. external_user_id=self.mcp_server.id,
  106. )
  107. db.session.add(end_user)
  108. db.session.commit()
  109. return types.InitializeResult(
  110. protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
  111. capabilities=self.capabilities,
  112. serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
  113. instructions=self.mcp_server.description,
  114. )
  115. def list_tools(self):
  116. if not self.end_user:
  117. raise ValueError("User not found")
  118. return types.ListToolsResult(
  119. tools=[
  120. types.Tool(
  121. name=self.app.name,
  122. description=self.mcp_server.description,
  123. inputSchema=self.parameter_schema,
  124. )
  125. ],
  126. )
  127. def invoke_tool(self):
  128. if not self.end_user:
  129. raise ValueError("User not found")
  130. request = cast(types.CallToolRequest, self.request.root)
  131. args = request.params.arguments
  132. if not args:
  133. raise ValueError("No arguments provided")
  134. if self.app.mode in {AppMode.WORKFLOW.value}:
  135. args = {"inputs": args}
  136. elif self.app.mode in {AppMode.COMPLETION.value}:
  137. args = {"query": "", "inputs": args}
  138. else:
  139. args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
  140. response = AppGenerateService.generate(
  141. self.app,
  142. self.end_user,
  143. args,
  144. InvokeFrom.SERVICE_API,
  145. streaming=self.app.mode == AppMode.AGENT_CHAT.value,
  146. )
  147. answer = ""
  148. if isinstance(response, RateLimitGenerator):
  149. for item in response.generator:
  150. data = item
  151. if isinstance(data, str) and data.startswith("data: "):
  152. try:
  153. json_str = data[6:].strip()
  154. parsed_data = json.loads(json_str)
  155. if parsed_data.get("event") == "agent_thought":
  156. answer += parsed_data.get("thought", "")
  157. except json.JSONDecodeError:
  158. continue
  159. if isinstance(response, Mapping):
  160. if self.app.mode in {
  161. AppMode.ADVANCED_CHAT.value,
  162. AppMode.COMPLETION.value,
  163. AppMode.CHAT.value,
  164. AppMode.AGENT_CHAT.value,
  165. }:
  166. answer = response["answer"]
  167. elif self.app.mode in {AppMode.WORKFLOW.value}:
  168. answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
  169. else:
  170. raise ValueError("Invalid app mode")
  171. # Not support image yet
  172. return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
  173. def retrieve_end_user(self):
  174. return (
  175. db.session.query(EndUser)
  176. .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
  177. .first()
  178. )
  179. def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
  180. parameters: dict[str, dict[str, Any]] = {}
  181. required = []
  182. for item in user_input_form:
  183. parameters[item.variable] = {}
  184. if item.type in (
  185. VariableEntityType.FILE,
  186. VariableEntityType.FILE_LIST,
  187. VariableEntityType.EXTERNAL_DATA_TOOL,
  188. ):
  189. continue
  190. if item.required:
  191. required.append(item.variable)
  192. # if the workflow republished, the parameters not changed
  193. # we should not raise error here
  194. try:
  195. description = self.mcp_server.parameters_dict[item.variable]
  196. except KeyError:
  197. description = ""
  198. parameters[item.variable]["description"] = description
  199. if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
  200. parameters[item.variable]["type"] = "string"
  201. elif item.type == VariableEntityType.SELECT:
  202. parameters[item.variable]["type"] = "string"
  203. parameters[item.variable]["enum"] = item.options
  204. elif item.type == VariableEntityType.NUMBER:
  205. parameters[item.variable]["type"] = "float"
  206. return parameters, required