You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

client_session.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import queue
  2. from datetime import timedelta
  3. from typing import Any, Protocol
  4. from pydantic import AnyUrl, TypeAdapter
  5. from configs import dify_config
  6. from core.mcp import types
  7. from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
  8. from core.mcp.session.base_session import BaseSession, RequestResponder
  9. DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
  10. class SamplingFnT(Protocol):
  11. def __call__(
  12. self,
  13. context: RequestContext["ClientSession", Any],
  14. params: types.CreateMessageRequestParams,
  15. ) -> types.CreateMessageResult | types.ErrorData: ...
  16. class ListRootsFnT(Protocol):
  17. def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
  18. class LoggingFnT(Protocol):
  19. def __call__(
  20. self,
  21. params: types.LoggingMessageNotificationParams,
  22. ) -> None: ...
  23. class MessageHandlerFnT(Protocol):
  24. def __call__(
  25. self,
  26. message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  27. ) -> None: ...
  28. def _default_message_handler(
  29. message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  30. ) -> None:
  31. if isinstance(message, Exception):
  32. raise ValueError(str(message))
  33. elif isinstance(message, (types.ServerNotification | RequestResponder)):
  34. pass
  35. def _default_sampling_callback(
  36. context: RequestContext["ClientSession", Any],
  37. params: types.CreateMessageRequestParams,
  38. ) -> types.CreateMessageResult | types.ErrorData:
  39. return types.ErrorData(
  40. code=types.INVALID_REQUEST,
  41. message="Sampling not supported",
  42. )
  43. def _default_list_roots_callback(
  44. context: RequestContext["ClientSession", Any],
  45. ) -> types.ListRootsResult | types.ErrorData:
  46. return types.ErrorData(
  47. code=types.INVALID_REQUEST,
  48. message="List roots not supported",
  49. )
  50. def _default_logging_callback(
  51. params: types.LoggingMessageNotificationParams,
  52. ) -> None:
  53. pass
  54. ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
  55. class ClientSession(
  56. BaseSession[
  57. types.ClientRequest,
  58. types.ClientNotification,
  59. types.ClientResult,
  60. types.ServerRequest,
  61. types.ServerNotification,
  62. ]
  63. ):
  64. def __init__(
  65. self,
  66. read_stream: queue.Queue,
  67. write_stream: queue.Queue,
  68. read_timeout_seconds: timedelta | None = None,
  69. sampling_callback: SamplingFnT | None = None,
  70. list_roots_callback: ListRootsFnT | None = None,
  71. logging_callback: LoggingFnT | None = None,
  72. message_handler: MessageHandlerFnT | None = None,
  73. client_info: types.Implementation | None = None,
  74. ) -> None:
  75. super().__init__(
  76. read_stream,
  77. write_stream,
  78. types.ServerRequest,
  79. types.ServerNotification,
  80. read_timeout_seconds=read_timeout_seconds,
  81. )
  82. self._client_info = client_info or DEFAULT_CLIENT_INFO
  83. self._sampling_callback = sampling_callback or _default_sampling_callback
  84. self._list_roots_callback = list_roots_callback or _default_list_roots_callback
  85. self._logging_callback = logging_callback or _default_logging_callback
  86. self._message_handler = message_handler or _default_message_handler
  87. def initialize(self) -> types.InitializeResult:
  88. sampling = types.SamplingCapability()
  89. roots = types.RootsCapability(
  90. # TODO: Should this be based on whether we
  91. # _will_ send notifications, or only whether
  92. # they're supported?
  93. listChanged=True,
  94. )
  95. result = self.send_request(
  96. types.ClientRequest(
  97. types.InitializeRequest(
  98. method="initialize",
  99. params=types.InitializeRequestParams(
  100. protocolVersion=types.LATEST_PROTOCOL_VERSION,
  101. capabilities=types.ClientCapabilities(
  102. sampling=sampling,
  103. experimental=None,
  104. roots=roots,
  105. ),
  106. clientInfo=self._client_info,
  107. ),
  108. )
  109. ),
  110. types.InitializeResult,
  111. )
  112. if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
  113. raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
  114. self.send_notification(
  115. types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
  116. )
  117. return result
  118. def send_ping(self) -> types.EmptyResult:
  119. """Send a ping request."""
  120. return self.send_request(
  121. types.ClientRequest(
  122. types.PingRequest(
  123. method="ping",
  124. )
  125. ),
  126. types.EmptyResult,
  127. )
  128. def send_progress_notification(
  129. self, progress_token: str | int, progress: float, total: float | None = None
  130. ) -> None:
  131. """Send a progress notification."""
  132. self.send_notification(
  133. types.ClientNotification(
  134. types.ProgressNotification(
  135. method="notifications/progress",
  136. params=types.ProgressNotificationParams(
  137. progressToken=progress_token,
  138. progress=progress,
  139. total=total,
  140. ),
  141. ),
  142. )
  143. )
  144. def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
  145. """Send a logging/setLevel request."""
  146. return self.send_request(
  147. types.ClientRequest(
  148. types.SetLevelRequest(
  149. method="logging/setLevel",
  150. params=types.SetLevelRequestParams(level=level),
  151. )
  152. ),
  153. types.EmptyResult,
  154. )
  155. def list_resources(self) -> types.ListResourcesResult:
  156. """Send a resources/list request."""
  157. return self.send_request(
  158. types.ClientRequest(
  159. types.ListResourcesRequest(
  160. method="resources/list",
  161. )
  162. ),
  163. types.ListResourcesResult,
  164. )
  165. def list_resource_templates(self) -> types.ListResourceTemplatesResult:
  166. """Send a resources/templates/list request."""
  167. return self.send_request(
  168. types.ClientRequest(
  169. types.ListResourceTemplatesRequest(
  170. method="resources/templates/list",
  171. )
  172. ),
  173. types.ListResourceTemplatesResult,
  174. )
  175. def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
  176. """Send a resources/read request."""
  177. return self.send_request(
  178. types.ClientRequest(
  179. types.ReadResourceRequest(
  180. method="resources/read",
  181. params=types.ReadResourceRequestParams(uri=uri),
  182. )
  183. ),
  184. types.ReadResourceResult,
  185. )
  186. def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
  187. """Send a resources/subscribe request."""
  188. return self.send_request(
  189. types.ClientRequest(
  190. types.SubscribeRequest(
  191. method="resources/subscribe",
  192. params=types.SubscribeRequestParams(uri=uri),
  193. )
  194. ),
  195. types.EmptyResult,
  196. )
  197. def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
  198. """Send a resources/unsubscribe request."""
  199. return self.send_request(
  200. types.ClientRequest(
  201. types.UnsubscribeRequest(
  202. method="resources/unsubscribe",
  203. params=types.UnsubscribeRequestParams(uri=uri),
  204. )
  205. ),
  206. types.EmptyResult,
  207. )
  208. def call_tool(
  209. self,
  210. name: str,
  211. arguments: dict[str, Any] | None = None,
  212. read_timeout_seconds: timedelta | None = None,
  213. ) -> types.CallToolResult:
  214. """Send a tools/call request."""
  215. return self.send_request(
  216. types.ClientRequest(
  217. types.CallToolRequest(
  218. method="tools/call",
  219. params=types.CallToolRequestParams(name=name, arguments=arguments),
  220. )
  221. ),
  222. types.CallToolResult,
  223. request_read_timeout_seconds=read_timeout_seconds,
  224. )
  225. def list_prompts(self) -> types.ListPromptsResult:
  226. """Send a prompts/list request."""
  227. return self.send_request(
  228. types.ClientRequest(
  229. types.ListPromptsRequest(
  230. method="prompts/list",
  231. )
  232. ),
  233. types.ListPromptsResult,
  234. )
  235. def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
  236. """Send a prompts/get request."""
  237. return self.send_request(
  238. types.ClientRequest(
  239. types.GetPromptRequest(
  240. method="prompts/get",
  241. params=types.GetPromptRequestParams(name=name, arguments=arguments),
  242. )
  243. ),
  244. types.GetPromptResult,
  245. )
  246. def complete(
  247. self,
  248. ref: types.ResourceReference | types.PromptReference,
  249. argument: dict[str, str],
  250. ) -> types.CompleteResult:
  251. """Send a completion/complete request."""
  252. return self.send_request(
  253. types.ClientRequest(
  254. types.CompleteRequest(
  255. method="completion/complete",
  256. params=types.CompleteRequestParams(
  257. ref=ref,
  258. argument=types.CompletionArgument(**argument),
  259. ),
  260. )
  261. ),
  262. types.CompleteResult,
  263. )
  264. def list_tools(self) -> types.ListToolsResult:
  265. """Send a tools/list request."""
  266. return self.send_request(
  267. types.ClientRequest(
  268. types.ListToolsRequest(
  269. method="tools/list",
  270. )
  271. ),
  272. types.ListToolsResult,
  273. )
  274. def send_roots_list_changed(self) -> None:
  275. """Send a roots/list_changed notification."""
  276. self.send_notification(
  277. types.ClientNotification(
  278. types.RootsListChangedNotification(
  279. method="notifications/roots/list_changed",
  280. )
  281. )
  282. )
  283. def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
  284. ctx = RequestContext[ClientSession, Any](
  285. request_id=responder.request_id,
  286. meta=responder.request_meta,
  287. session=self,
  288. lifespan_context=None,
  289. )
  290. match responder.request.root:
  291. case types.CreateMessageRequest(params=params):
  292. with responder:
  293. response = self._sampling_callback(ctx, params)
  294. client_response = ClientResponse.validate_python(response)
  295. responder.respond(client_response)
  296. case types.ListRootsRequest():
  297. with responder:
  298. list_roots_response = self._list_roots_callback(ctx)
  299. client_response = ClientResponse.validate_python(list_roots_response)
  300. responder.respond(client_response)
  301. case types.PingRequest():
  302. with responder:
  303. return responder.respond(types.ClientResult(root=types.EmptyResult()))
  304. def _handle_incoming(
  305. self,
  306. req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
  307. ) -> None:
  308. """Handle incoming messages by forwarding to the message handler."""
  309. self._message_handler(req)
  310. def _received_notification(self, notification: types.ServerNotification) -> None:
  311. """Handle notifications from the server."""
  312. # Process specific notification types
  313. match notification.root:
  314. case types.LoggingMessageNotification(params=params):
  315. self._logging_callback(params)
  316. case _:
  317. pass