Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

client_session.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from datetime import timedelta
  2. from typing import Any, Protocol
  3. from pydantic import AnyUrl, TypeAdapter
  4. from configs import dify_config
  5. from core.mcp import types
  6. from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
  7. from core.mcp.session.base_session import BaseSession, RequestResponder
  8. DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
  9. class SamplingFnT(Protocol):
  10. def __call__(
  11. self,
  12. context: RequestContext["ClientSession", Any],
  13. params: types.CreateMessageRequestParams,
  14. ) -> types.CreateMessageResult | types.ErrorData: ...
  15. class ListRootsFnT(Protocol):
  16. def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
  17. class LoggingFnT(Protocol):
  18. def __call__(
  19. self,
  20. params: types.LoggingMessageNotificationParams,
  21. ) -> None: ...
  22. class MessageHandlerFnT(Protocol):
  23. def __call__(
  24. self,
  25. message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  26. ) -> None: ...
  27. def _default_message_handler(
  28. message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  29. ) -> None:
  30. if isinstance(message, Exception):
  31. raise ValueError(str(message))
  32. elif isinstance(message, (types.ServerNotification | RequestResponder)):
  33. pass
  34. def _default_sampling_callback(
  35. context: RequestContext["ClientSession", Any],
  36. params: types.CreateMessageRequestParams,
  37. ) -> types.CreateMessageResult | types.ErrorData:
  38. return types.ErrorData(
  39. code=types.INVALID_REQUEST,
  40. message="Sampling not supported",
  41. )
  42. def _default_list_roots_callback(
  43. context: RequestContext["ClientSession", Any],
  44. ) -> types.ListRootsResult | types.ErrorData:
  45. return types.ErrorData(
  46. code=types.INVALID_REQUEST,
  47. message="List roots not supported",
  48. )
  49. def _default_logging_callback(
  50. params: types.LoggingMessageNotificationParams,
  51. ) -> None:
  52. pass
  53. ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
  54. class ClientSession(
  55. BaseSession[
  56. types.ClientRequest,
  57. types.ClientNotification,
  58. types.ClientResult,
  59. types.ServerRequest,
  60. types.ServerNotification,
  61. ]
  62. ):
  63. def __init__(
  64. self,
  65. read_stream,
  66. write_stream,
  67. read_timeout_seconds: timedelta | None = None,
  68. sampling_callback: SamplingFnT | None = None,
  69. list_roots_callback: ListRootsFnT | None = None,
  70. logging_callback: LoggingFnT | None = None,
  71. message_handler: MessageHandlerFnT | None = None,
  72. client_info: types.Implementation | None = None,
  73. ) -> None:
  74. super().__init__(
  75. read_stream,
  76. write_stream,
  77. types.ServerRequest,
  78. types.ServerNotification,
  79. read_timeout_seconds=read_timeout_seconds,
  80. )
  81. self._client_info = client_info or DEFAULT_CLIENT_INFO
  82. self._sampling_callback = sampling_callback or _default_sampling_callback
  83. self._list_roots_callback = list_roots_callback or _default_list_roots_callback
  84. self._logging_callback = logging_callback or _default_logging_callback
  85. self._message_handler = message_handler or _default_message_handler
  86. def initialize(self) -> types.InitializeResult:
  87. sampling = types.SamplingCapability()
  88. roots = types.RootsCapability(
  89. # TODO: Should this be based on whether we
  90. # _will_ send notifications, or only whether
  91. # they're supported?
  92. listChanged=True,
  93. )
  94. result = self.send_request(
  95. types.ClientRequest(
  96. types.InitializeRequest(
  97. method="initialize",
  98. params=types.InitializeRequestParams(
  99. protocolVersion=types.LATEST_PROTOCOL_VERSION,
  100. capabilities=types.ClientCapabilities(
  101. sampling=sampling,
  102. experimental=None,
  103. roots=roots,
  104. ),
  105. clientInfo=self._client_info,
  106. ),
  107. )
  108. ),
  109. types.InitializeResult,
  110. )
  111. if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
  112. raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
  113. self.send_notification(
  114. types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
  115. )
  116. return result
  117. def send_ping(self) -> types.EmptyResult:
  118. """Send a ping request."""
  119. return self.send_request(
  120. types.ClientRequest(
  121. types.PingRequest(
  122. method="ping",
  123. )
  124. ),
  125. types.EmptyResult,
  126. )
  127. def send_progress_notification(
  128. self, progress_token: str | int, progress: float, total: float | None = None
  129. ) -> None:
  130. """Send a progress notification."""
  131. self.send_notification(
  132. types.ClientNotification(
  133. types.ProgressNotification(
  134. method="notifications/progress",
  135. params=types.ProgressNotificationParams(
  136. progressToken=progress_token,
  137. progress=progress,
  138. total=total,
  139. ),
  140. ),
  141. )
  142. )
  143. def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
  144. """Send a logging/setLevel request."""
  145. return self.send_request(
  146. types.ClientRequest(
  147. types.SetLevelRequest(
  148. method="logging/setLevel",
  149. params=types.SetLevelRequestParams(level=level),
  150. )
  151. ),
  152. types.EmptyResult,
  153. )
  154. def list_resources(self) -> types.ListResourcesResult:
  155. """Send a resources/list request."""
  156. return self.send_request(
  157. types.ClientRequest(
  158. types.ListResourcesRequest(
  159. method="resources/list",
  160. )
  161. ),
  162. types.ListResourcesResult,
  163. )
  164. def list_resource_templates(self) -> types.ListResourceTemplatesResult:
  165. """Send a resources/templates/list request."""
  166. return self.send_request(
  167. types.ClientRequest(
  168. types.ListResourceTemplatesRequest(
  169. method="resources/templates/list",
  170. )
  171. ),
  172. types.ListResourceTemplatesResult,
  173. )
  174. def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
  175. """Send a resources/read request."""
  176. return self.send_request(
  177. types.ClientRequest(
  178. types.ReadResourceRequest(
  179. method="resources/read",
  180. params=types.ReadResourceRequestParams(uri=uri),
  181. )
  182. ),
  183. types.ReadResourceResult,
  184. )
  185. def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
  186. """Send a resources/subscribe request."""
  187. return self.send_request(
  188. types.ClientRequest(
  189. types.SubscribeRequest(
  190. method="resources/subscribe",
  191. params=types.SubscribeRequestParams(uri=uri),
  192. )
  193. ),
  194. types.EmptyResult,
  195. )
  196. def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
  197. """Send a resources/unsubscribe request."""
  198. return self.send_request(
  199. types.ClientRequest(
  200. types.UnsubscribeRequest(
  201. method="resources/unsubscribe",
  202. params=types.UnsubscribeRequestParams(uri=uri),
  203. )
  204. ),
  205. types.EmptyResult,
  206. )
  207. def call_tool(
  208. self,
  209. name: str,
  210. arguments: dict[str, Any] | None = None,
  211. read_timeout_seconds: timedelta | None = None,
  212. ) -> types.CallToolResult:
  213. """Send a tools/call request."""
  214. return self.send_request(
  215. types.ClientRequest(
  216. types.CallToolRequest(
  217. method="tools/call",
  218. params=types.CallToolRequestParams(name=name, arguments=arguments),
  219. )
  220. ),
  221. types.CallToolResult,
  222. request_read_timeout_seconds=read_timeout_seconds,
  223. )
  224. def list_prompts(self) -> types.ListPromptsResult:
  225. """Send a prompts/list request."""
  226. return self.send_request(
  227. types.ClientRequest(
  228. types.ListPromptsRequest(
  229. method="prompts/list",
  230. )
  231. ),
  232. types.ListPromptsResult,
  233. )
  234. def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
  235. """Send a prompts/get request."""
  236. return self.send_request(
  237. types.ClientRequest(
  238. types.GetPromptRequest(
  239. method="prompts/get",
  240. params=types.GetPromptRequestParams(name=name, arguments=arguments),
  241. )
  242. ),
  243. types.GetPromptResult,
  244. )
  245. def complete(
  246. self,
  247. ref: types.ResourceReference | types.PromptReference,
  248. argument: dict[str, str],
  249. ) -> types.CompleteResult:
  250. """Send a completion/complete request."""
  251. return self.send_request(
  252. types.ClientRequest(
  253. types.CompleteRequest(
  254. method="completion/complete",
  255. params=types.CompleteRequestParams(
  256. ref=ref,
  257. argument=types.CompletionArgument(**argument),
  258. ),
  259. )
  260. ),
  261. types.CompleteResult,
  262. )
  263. def list_tools(self) -> types.ListToolsResult:
  264. """Send a tools/list request."""
  265. return self.send_request(
  266. types.ClientRequest(
  267. types.ListToolsRequest(
  268. method="tools/list",
  269. )
  270. ),
  271. types.ListToolsResult,
  272. )
  273. def send_roots_list_changed(self) -> None:
  274. """Send a roots/list_changed notification."""
  275. self.send_notification(
  276. types.ClientNotification(
  277. types.RootsListChangedNotification(
  278. method="notifications/roots/list_changed",
  279. )
  280. )
  281. )
  282. def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
  283. ctx = RequestContext[ClientSession, Any](
  284. request_id=responder.request_id,
  285. meta=responder.request_meta,
  286. session=self,
  287. lifespan_context=None,
  288. )
  289. match responder.request.root:
  290. case types.CreateMessageRequest(params=params):
  291. with responder:
  292. response = self._sampling_callback(ctx, params)
  293. client_response = ClientResponse.validate_python(response)
  294. responder.respond(client_response)
  295. case types.ListRootsRequest():
  296. with responder:
  297. list_roots_response = self._list_roots_callback(ctx)
  298. client_response = ClientResponse.validate_python(list_roots_response)
  299. responder.respond(client_response)
  300. case types.PingRequest():
  301. with responder:
  302. return responder.respond(types.ClientResult(root=types.EmptyResult()))
  303. def _handle_incoming(
  304. self,
  305. req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  306. ) -> None:
  307. """Handle incoming messages by forwarding to the message handler."""
  308. self._message_handler(req)
  309. def _received_notification(self, notification: types.ServerNotification) -> None:
  310. """Handle notifications from the server."""
  311. # Process specific notification types
  312. match notification.root:
  313. case types.LoggingMessageNotification(params=params):
  314. self._logging_callback(params)
  315. case _:
  316. pass