diff --git a/docs/migration.md b/docs/migration.md index 850e052550..49dee8f96c 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -634,11 +634,11 @@ server = Server("my-server", on_call_tool=handle_call_tool) The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). -The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. +The split separates shared fields from server-specific fields. There is no shared `RequestContext` generic anymore — each concrete class fixes its session type. **`RequestContext` changes:** -- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` +- The `RequestContext[SessionT, LifespanContextT, RequestT]` generic no longer exists; use `ClientRequestContext` or `ServerRequestContext[LifespanContextT, RequestT]` - Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` **Before (v1):** @@ -1164,7 +1164,33 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr `BaseSession._in_flight` and the `RequestResponder` members that supported it (`cancel()`, the `cancelled` and `in_flight` properties, the `on_complete` constructor argument, and the internal `CancelScope`) have been removed. These existed to let `ServerSession` cancel a handler when a `CancelledNotification` arrived; `ServerSession` no longer drives a receive loop, so they were dead code. Inbound-cancellation handling for the server now lives in `JSONRPCDispatcher`. -`BaseSession` is still used by `ClientSession`, which never relied on these members. `RequestResponder.respond()` is unchanged. +`BaseSession` itself has since been removed entirely; see the next section. + +### `ClientSession` now runs on `JSONRPCDispatcher`; `BaseSession` removed + +`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new keyword-only `dispatcher=` constructor argument accepts a pre-built dispatcher instead of the stream pair (for example a `DirectDispatcher` for in-process embedding). + +Code that imported or subclassed `BaseSession` directly has no shim — the class is removed outright. The receive-loop engine it implemented now lives in `JSONRPCDispatcher` (`mcp.shared.jsonrpc_dispatcher`); to customize client behavior, use the `ClientSession` constructor callbacks, or supply your own engine through the `dispatcher=` keyword. + +Behavior changes: + +- **Request ids count from 1** (previously 0). Progress tokens, which reuse the request id, shift the same way. Ids are opaque per JSON-RPC; do not assign meaning to them. +- **Timeouts**: the error message is now `Request 'tools/call' timed out` (previously `Timed out while waiting for response to CallToolRequest. Waited N seconds.`), and a timed-out or abandoned request is followed by `notifications/cancelled` on the wire, so the server stops the handler instead of leaving it running. The `initialize` request is never cancelled this way, and requests sent with resumption metadata are also exempt so they stay resumable. +- **No cancellation for requests that never reached the wire.** A timed-out or caller-cancelled request whose initial write never completed is failed locally without `notifications/cancelled` — the peer never saw the id, so there is nothing to cancel. +- **The resumption exemption applies only when the hints reach the transport.** A request sent from inside a request callback carries stream-routing metadata that takes precedence, so its resumption hints are dropped — and an abandoned one gets the courtesy `notifications/cancelled` like any other request. +- **Server-initiated requests run concurrently.** Sampling/elicitation/roots callbacks no longer serialize the receive loop: a slow callback does not block other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now actually interrupts the callback (the request is then answered with an error response). +- **Session shutdown answers in-flight server-initiated requests with `CONNECTION_CLOSED`** (-32000, `Connection closed`) instead of -32002. The write is bounded (about one second), so closing a session stays fast even when the transport has stopped accepting writes. +- **The `REQUEST_CANCELLED` constant is removed from `mcp.types`.** Its value (-32002) collided with the spec's resource-not-found error code, and the shutdown response above was its only use. +- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` start in arrival order, but each delivery runs as its own task with no completion-before-response guarantee (matching the TypeScript, C#, and Go SDKs): deliveries may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves. +- **Transport-level `Exception` items are delivered concurrently too.** An `Exception` the transport places on the read stream is dispatched to `message_handler` as its own task, like notification callbacks, instead of blocking the receive loop — and a `message_handler` that raises on it is logged, not fatal to the session. +- **Unknown-id responses are ignored**, as the spec asks. v1 surfaced them to `message_handler` as a `RuntimeError`; nothing is surfaced now. +- **Error responses with a null `id`** — the JSON-RPC shape for a peer reporting a parse error — are now dropped with a debug log. v1 surfaced them to `message_handler` as an `MCPError`. +- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. +- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed, `send_request` instead raises `MCPError` (`CONNECTION_CLOSED`), matching what an in-flight request receives — `RuntimeError` remains only for calls before entry. `send_notification` before entry still works. +- **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected. +- **The private `mcp.shared._context.RequestContext` generic is deleted.** Client callbacks now receive the concrete `mcp.client.ClientRequestContext`, whose `request_id` is always populated (the client only builds a context for inbound requests). Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. + +`mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. ### Experimental Tasks support removed diff --git a/src/mcp/client/context.py b/src/mcp/client/context.py index 2f4404e008..aecd29527f 100644 --- a/src/mcp/client/context.py +++ b/src/mcp/client/context.py @@ -1,16 +1,5 @@ """Request context for MCP client handlers.""" -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext +from mcp.client.session import ClientRequestContext -ClientRequestContext = RequestContext[ClientSession] -"""Context for handling incoming requests in a client session. - -This context is passed to client-side callbacks (sampling, elicitation, list_roots) when the server sends requests -to the client. - -Attributes: - request_id: The unique identifier for this request. - meta: Optional metadata associated with the request. - session: The client session handling this request. -""" +__all__ = ["ClientRequestContext"] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3a0485649f..e7dd1291ad 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,28 +1,49 @@ from __future__ import annotations import logging +from collections.abc import Mapping +from dataclasses import dataclass +from types import TracebackType from typing import Any, Protocol, cast, get_args +import anyio +import anyio.abc import anyio.lowlevel -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError +from typing_extensions import Self, TypeVar from mcp import types from mcp.client._transport import ReadStream, WriteStream -from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared._compat import resync_tracer +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.session import ProgressFnT, RequestResponder +from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types._types import RequestParamsMeta +from mcp.types import RequestId, RequestParamsMeta DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") logger = logging.getLogger("client") +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +@dataclass(kw_only=True) +class ClientRequestContext: + """Context for a server-initiated request, passed to the sampling/elicitation/list-roots callbacks.""" + + session: ClientSession + request_id: RequestId + meta: RequestParamsMeta | None = None + class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch @@ -30,14 +51,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext[ClientSession] + self, context: ClientRequestContext ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch @@ -59,7 +80,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( @@ -69,7 +90,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( @@ -79,7 +100,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -104,19 +125,21 @@ async def _default_logging_callback( answered with METHOD_NOT_FOUND instead of failing union validation.""" -class ClientSession( - BaseSession[ - types.ClientRequest, - types.ClientNotification, - types.ClientResult, - types.ServerRequest, - types.ServerNotification, - ] -): +class ClientSession: + """Client half of an MCP connection, running on a `Dispatcher`. + + Construct it over a transport's stream pair (or pass a pre-built + `dispatcher=`), enter as an async context manager, then call + `initialize()`. The dispatcher owns the receive loop and request + correlation; this class owns the typed MCP layer and the constructor + callbacks. Transport `Exception` items reach `message_handler` only when + the session builds its own dispatcher from a stream pair. + """ + def __init__( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception] | None = None, + write_stream: WriteStream[SessionMessage] | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, @@ -126,8 +149,9 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, + dispatcher: Dispatcher[Any] | None = None, ) -> None: - super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) + self._session_read_timeout_seconds = read_timeout_seconds self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities @@ -137,18 +161,91 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None + self._task_group: anyio.abc.TaskGroup | None = None + if dispatcher is not None: + if read_stream is not None or write_stream is not None: + raise ValueError("pass read_stream/write_stream or dispatcher, not both") + self._dispatcher: Dispatcher[Any] = dispatcher + else: + if read_stream is None or write_stream is None: + raise ValueError("read_stream and write_stream are required when no dispatcher is given") + # Built eagerly so notifications can be sent before entering the context manager. + self._dispatcher = JSONRPCDispatcher( + read_stream, write_stream, on_stream_exception=self._on_stream_exception + ) - @property - def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: - return types.server_request_adapter + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + try: + await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + except BaseException: + # Unwind the entered task group before propagating: a cancellation + # landing here (e.g. `move_on_after` around connect) would abandon + # it and anyio would later raise "exited non-innermost cancel scope". + task_group = self._task_group + self._task_group = None + task_group.cancel_scope.cancel() + # Shield the group's own scope (a new one would break LIFO exit) + # so a pending outer cancellation cannot re-fire inside __aexit__. + task_group.cancel_scope.shield = True + await task_group.__aexit__(None, None, None) + raise + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + # Exit must not block: cancel the dispatcher and in-flight callbacks. + assert self._task_group is not None + self._task_group.cancel_scope.cancel() + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + await resync_tracer() + return result - @property - def _receive_request_methods(self) -> frozenset[str]: - return _SERVER_REQUEST_METHODS + async def send_request( + self, + request: types.ClientRequest, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: float | None = None, + metadata: ClientMessageMetadata | None = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT: + """Send a request and wait for its typed result. - @property - def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]: - return types.server_notification_adapter + Args: + metadata: Streamable HTTP resumption hints. + + Raises: + MCPError: Error response, read timeout, or connection closed. + RuntimeError: Called before entering the context manager. + """ + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + method: str = data["method"] + opts: CallOptions = {} + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + if timeout is not None: + opts["timeout"] = timeout + if progress_callback is not None: + opts["on_progress"] = progress_callback + if metadata is not None: + if metadata.resumption_token is not None: + opts["resumption_token"] = metadata.resumption_token + if metadata.on_resumption_token_update is not None: + opts["on_resumption_token"] = metadata.on_resumption_token_update + if method == "initialize": + # The spec forbids cancelling initialize. + opts["cancel_on_abandon"] = False + raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) + return result_type.model_validate(raw, by_name=False) + + async def send_notification(self, notification: types.ClientNotification) -> None: + """Send a one-way notification. Usable before entering the context manager.""" + data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) + await self._dispatcher.notify(data["method"], data.get("params")) async def initialize(self) -> types.InitializeResult: sampling = ( @@ -397,49 +494,68 @@ async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) - async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) - - match responder.request: - case types.CreateMessageRequest(params=params): - with responder: - response = await self._sampling_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.ElicitRequest(params=params): - with responder: - response = await self._elicitation_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.ListRootsRequest(): - with responder: + async def _on_request( + self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + """Answer a server-initiated request via the registered callbacks.""" + if method not in _SERVER_REQUEST_METHODS: + raise MCPError(code=types.METHOD_NOT_FOUND, message="Method not found", data=method) + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + request = types.server_request_adapter.validate_python(payload, by_name=False) + + response: types.ClientResult | types.ErrorData + if isinstance(request, types.PingRequest): + # Answered without a context: ping has no callback that would need one. + response = types.EmptyResult() + else: + assert dctx.request_id is not None # the callback-driving dispatchers always assign ids + ctx = ClientRequestContext( + session=self, request_id=dctx.request_id, meta=request.params.meta if request.params else None + ) + match request: + case types.CreateMessageRequest(params=sampling_params): + response = await self._sampling_callback(ctx, sampling_params) + case types.ElicitRequest(params=elicit_params): + response = await self._elicitation_callback(ctx, elicit_params) + case types.ListRootsRequest(): # pragma: no branch response = await self._list_roots_callback(ctx) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) + client_response = ClientResponse.validate_python(response) + if isinstance(client_response, types.ErrorData): + raise MCPError.from_error_data(client_response) + return client_response.model_dump(by_alias=True, mode="json", exclude_none=True) - case types.PingRequest(): # pragma: no branch - with responder: - await responder.respond(types.EmptyResult()) - - async def _handle_incoming( - self, - req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + async def _on_notify( + self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> None: - """Handle incoming messages by forwarding to the message handler.""" - await self._message_handler(req) - - async def _received_notification(self, notification: types.ServerNotification) -> None: - """Handle notifications from the server.""" - # Process specific notification types - match notification: - case types.LoggingMessageNotification(params=params): - await self._logging_callback(params) - case types.ElicitCompleteNotification(params=params): - # Handle elicitation completion notification - # Clients MAY use this to retry requests or update UI - # The notification contains the elicitationId of the completed elicitation - pass - case _: - pass + """Route a server notification: validate, run the typed callback, tee to message_handler.""" + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + try: + notification = types.server_notification_adapter.validate_python(payload, by_name=False) + except ValidationError: + logger.warning("Failed to validate notification: %s", payload, exc_info=True) + return + if isinstance(notification, types.CancelledNotification): + # The dispatcher already applied the cancellation; not surfaced to message_handler. + return + if isinstance(notification, types.LoggingMessageNotification): + await self._logging_callback(notification.params) + await self._message_handler(notification) + + async def _on_stream_exception(self, exc: Exception) -> None: + """Deliver a transport-level fault to message_handler via a spawned task. + + Running the handler inline would park the dispatcher's read loop and + deadlock handlers that await session I/O. + """ + assert self._task_group is not None + self._task_group.start_soon(self._deliver_stream_exception, exc) + + async def _deliver_stream_exception(self, exc: Exception) -> None: + try: + await self._message_handler(exc) + except Exception: + logger.exception("message_handler raised on transport exception") diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 220d46f9a3..93904d6cc1 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -717,6 +717,9 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) + except anyio.ClosedResourceError: + # Session teardown can close the stream while the writer is between dequeues. + pass except Exception: logger.exception("Error in standalone SSE writer") # pragma: no cover finally: diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py deleted file mode 100644 index bbcee2d02c..0000000000 --- a/src/mcp/shared/_context.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Request context for MCP handlers.""" - -from dataclasses import dataclass -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParamsMeta - -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) - - -@dataclass(kw_only=True) -class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests. - - For request handlers, request_id is always populated. - For notification handlers, request_id is None. - """ - - session: SessionT - request_id: RequestId | None = None - meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 5b3d29c8d8..6bba749879 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -50,7 +50,7 @@ class _DirectDispatchContext: _back_request: _Request _back_notify: _Notify request_id: RequestId | None = None - """Always `None`: direct dispatch has no wire-level request id.""" + """A dispatcher-synthesized id for requests; `None` for notifications.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None @@ -91,6 +91,7 @@ def __init__(self, transport_ctx: TransportContext): self._peer: DirectDispatcher | None = None self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None + self._next_id = 0 self._ready = anyio.Event() self._closed = anyio.Event() @@ -128,13 +129,16 @@ async def run( def close(self) -> None: self._closed.set() - def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + def _make_context( + self, on_progress: ProgressFnT | None = None, request_id: RequestId | None = None + ) -> _DirectDispatchContext: assert self._peer is not None peer = self._peer return _DirectDispatchContext( transport=self._transport_ctx, _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), _back_notify=lambda m, p: peer._dispatch_notify(m, p), + request_id=request_id, _on_progress=on_progress, ) @@ -147,7 +151,9 @@ async def _dispatch_request( await self._ready.wait() assert self._on_request is not None opts = opts or {} - dctx = self._make_context(on_progress=opts.get("on_progress")) + # Synthesize an id: the DispatchContext contract reserves None for notifications. + self._next_id += 1 + dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id) try: with anyio.fail_after(opts.get("timeout")): try: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index cffdfd22f8..888e55ba33 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -55,6 +55,13 @@ class CallOptions(TypedDict, total=False): timeout: float """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" + cancel_on_abandon: bool + """Whether abandoning this request (timeout or caller cancellation) sends `notifications/cancelled`. + + Defaults to `True`. Set `False` for requests the protocol forbids cancelling, such as `initialize`. + Also suppressed when resumption hints reach the transport, or when the request was never written. + """ + on_progress: ProgressFnT """Receive `notifications/progress` updates for this request.""" @@ -97,9 +104,6 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a request and await its raw result dict. - `opts` carries per-call `timeout` / `on_progress` / resumption hints; - see `CallOptions`. - Raises: MCPError: If the peer responded with an error, or the handler raised. Implementations normalize all handler exceptions to @@ -187,6 +191,8 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. + + The lifecycle surface is provisional; `run()` may change before v2 stable. """ async def run( diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 457e6b6f77..2ca08954fe 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -1,21 +1,8 @@ -"""JSON-RPC `Dispatcher` implementation. - -Consumes the existing `SessionMessage`-based stream contract that all current -transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, -the receive loop, per-request task isolation, cancellation/progress wiring, and -the single exception-to-wire boundary. - -The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and -sees only `(ctx, method, params) -> dict`. Transports sit below and see only -`SessionMessage` reads/writes. - -The dispatcher is *mostly* MCP-agnostic - methods/params are opaque strings and -dicts - but it intercepts `notifications/cancelled` and -`notifications/progress` because request correlation, cancellation and -progress are exactly the wiring this layer exists to provide. Those few wire -shapes are extracted with structural `match` patterns (no casts, no -`mcp.types` model coupling); a malformed payload simply fails to match and -the correlation is skipped. +"""JSON-RPC `Dispatcher` over the `SessionMessage` stream contract all transports speak. + +Owns request-id correlation, the receive loop, per-request task isolation, +cancellation/progress wiring, and the single exception-to-wire boundary; +methods and params are otherwise opaque strings and dicts. """ from __future__ import annotations @@ -24,17 +11,19 @@ import logging from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, cast, overload +from functools import partial +from typing import Any, Generic, Literal, cast import anyio import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from opentelemetry.trace import SpanKind from pydantic import ValidationError +from typing_extensions import TypeVar from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -47,7 +36,6 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, - REQUEST_CANCELLED, REQUEST_TIMEOUT, ErrorData, JSONRPCError, @@ -63,23 +51,22 @@ logger = logging.getLogger(__name__) -TransportT = TypeVar("TransportT", bound=TransportContext) +_SHIELDED_WRITE_TIMEOUT: float = 5 +"""Bound for courtesy abandon-path writes; without it a wedged transport +would turn the shielded write into an uncancellable hang.""" -PeerCancelMode = Literal["interrupt", "signal"] -"""How inbound `notifications/cancelled` is applied to a running handler. +_SHUTDOWN_WRITE_TIMEOUT: float = 1 +"""Tighter bound for the shutdown-arm error write so a wedged transport can't hold session close.""" -`"interrupt"` (default) cancels the handler's scope. `"signal"` only sets -`ctx.cancel_requested` and lets the handler observe it cooperatively. -""" +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) +PeerCancelMode = Literal["interrupt", "signal"] +"""How `notifications/cancelled` is applied: `"interrupt"` (default) cancels +the handler's scope; `"signal"` only sets `ctx.cancel_requested`.""" -def _coerce_id(request_id: RequestId) -> RequestId: - """Coerce a string request ID to int when it's a valid int literal. - `_allocate_id` only ever produces `int` keys for `_pending`, but a peer - may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` - both perform this coercion at lookup time so the response still correlates. - """ +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" if isinstance(request_id, str): try: return int(request_id) @@ -113,12 +100,7 @@ class _JSONRPCDispatchContext(Generic[TransportT]): _dispatcher: JSONRPCDispatcher[TransportT] _request_id: RequestId | None message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework - """The transport-attached `SessionMessage.metadata` for this inbound message. - - Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks) - that the server lifts onto its request context. `None` for transports - that attach nothing. - """ + """Transport-attached `SessionMessage.metadata` that the server lifts onto its request context.""" _progress_token: ProgressToken | None = None _closed: bool = False cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -166,13 +148,7 @@ def _default_transport_builder(_meta: MessageMetadata) -> TransportContext: def _shielded_progress(fn: ProgressFnT) -> ProgressFnT: - """Wrap a user progress callback so it can't crash the dispatcher. - - The callback runs as a bare task in the dispatcher's task group; an - uncaught exception would cancel every sibling (the read loop and all - in-flight requests). Swallow and log instead, matching the previous - receive-loop's behavior. - """ + """Wrap a user progress callback so an exception can't cancel the dispatcher's task group.""" async def _wrapped(progress: float, total: float | None, message: str | None) -> None: try: @@ -183,61 +159,57 @@ async def _wrapped(progress: float, total: float | None, message: str | None) -> return _wrapped -def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: - """Choose the `SessionMessage.metadata` for an outgoing request/notification. +def _contained_notify(fn: OnNotify) -> OnNotify: + """Wrap a notification handler so it can't crash the dispatcher (same boundary as `_shielded_progress`).""" - `ServerMessageMetadata` tags a server-to-client message with the inbound - request it belongs to (so streamable-HTTP can route it onto that request's - SSE stream). `ClientMessageMetadata` carries resumption hints to the - client transport. `None` is the common case. + async def _wrapped(dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + try: + await fn(dctx, method, params) + except Exception: + logger.exception("notification handler for %r raised", method) - `SessionMessage.metadata` carries exactly one of these, so when - `related_request_id` is set it takes precedence and any resumption hints - in `opts` are dropped (with a debug log): requests made from a dispatch - context are routed onto the inbound request's stream, not resumed. + return _wrapped + + +@dataclass(slots=True, frozen=True) +class _OutboundPlan: + """Outbound metadata plus whether abandoning the request sends a courtesy `notifications/cancelled`.""" + + metadata: MessageMetadata + cancel_on_abandon: bool + + +def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | None) -> _OutboundPlan: + """Choose the outbound `SessionMessage.metadata` and the abandon-cancellation policy. + + `related_request_id` wins over resumption hints (they are dropped). Only + hints that actually reach the transport suppress the courtesy cancel - a + request that is neither resumable nor cancelled would leak the peer's work. """ + opts = opts or {} + cancel_on_abandon = opts.get("cancel_on_abandon", True) + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") if related_request_id is not None: - if opts and (opts.get("resumption_token") is not None or opts.get("on_resumption_token") is not None): + if token is not None or on_token is not None: logger.debug( "dropping resumption hints: related_request_id %r takes precedence on metadata", related_request_id ) - return ServerMessageMetadata(related_request_id=related_request_id) - if opts: - token = opts.get("resumption_token") - on_token = opts.get("on_resumption_token") - if token is not None or on_token is not None: - return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) - return None + return _OutboundPlan(ServerMessageMetadata(related_request_id=related_request_id), cancel_on_abandon) + if token is not None or on_token is not None: + return _OutboundPlan( + ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token), + cancel_on_abandon=False, + ) + return _OutboundPlan(None, cancel_on_abandon) class JSONRPCDispatcher(Dispatcher[TransportT]): - """`Dispatcher` over the existing `SessionMessage` stream contract. + """`Dispatcher` over the `SessionMessage` stream contract. - Inherits the `Dispatcher` Protocol explicitly so pyright checks - conformance at the class definition rather than at first use. + Explicit Protocol base so pyright checks conformance at the class definition. """ - @overload - def __init__( - self: JSONRPCDispatcher[TransportContext], - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - *, - peer_cancel_mode: PeerCancelMode = "interrupt", - raise_handler_exceptions: bool = False, - inline_methods: frozenset[str] = frozenset(), - ) -> None: ... - @overload - def __init__( - self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - *, - transport_builder: Callable[[MessageMetadata], TransportT], - peer_cancel_mode: PeerCancelMode = "interrupt", - raise_handler_exceptions: bool = False, - inline_methods: frozenset[str] = frozenset(), - ) -> None: ... def __init__( self, read_stream: ReadStream[SessionMessage | Exception], @@ -247,33 +219,40 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: + """Wire a dispatcher over a transport's `SessionMessage` stream pair. + + Args: + transport_builder: Builds each message's `TransportContext` from + its `SessionMessage.metadata`. + raise_handler_exceptions: Re-raise handler exceptions out of + `run()` after the error response is written. + inline_methods: Methods awaited in the read loop before the next + message is dequeued (e.g. `initialize`); an inline handler + that awaits the peer deadlocks the parked loop. + on_stream_exception: Observer for `Exception` items on the read + stream; without it they are debug-logged and dropped. + """ self._read_stream = read_stream self._write_stream = write_stream - # The overloads guarantee that when `transport_builder` is omitted, - # `TransportT` is `TransportContext`, so the default is type-correct; - # pyright can't see across overloads, hence the cast. + # With transport_builder omitted, TransportT defaults to + # TransportContext; pyright can't connect the two, hence the cast. self._transport_builder = cast( "Callable[[MessageMetadata], TransportT]", transport_builder or _default_transport_builder, ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions - # Request methods handled inline in the read loop (awaited before the - # next message is dequeued) instead of spawned concurrently. Use for - # methods whose side effects must be observable to the next message, - # e.g. `initialize`, so a pipelined follow-up sees the initialized state. - # Only suitable for handlers that complete quickly, since inline handling - # blocks dequeuing; a handler that awaits the peer (`send_raw_request`) - # while inline will deadlock because the parked read loop cannot dequeue - # the response. self._inline_methods = inline_methods + self._on_stream_exception = on_stream_exception self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} self._tg: anyio.abc.TaskGroup | None = None self._running = False + self._closed = False async def send_raw_request( self, @@ -285,82 +264,91 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a JSON-RPC request and await its response. - `_related_request_id` is set only by `_JSONRPCDispatchContext` when a - handler makes a server-to-client request mid-flight; it routes the - outgoing message onto the correct per-request SSE stream (SHTTP) via - `ServerMessageMetadata`. Top-level callers leave it `None`. + `_related_request_id` is set only by `_JSONRPCDispatchContext` so that + mid-handler requests route onto the inbound request's SSE stream. Raises: - MCPError: The peer responded with a JSON-RPC error; or - `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or - `CONNECTION_CLOSED` if the dispatcher shut down while - awaiting the response. - RuntimeError: Called before `run()` has started or after it has - finished. + MCPError: Peer error response; `REQUEST_TIMEOUT` if + `opts["timeout"]` elapsed; `CONNECTION_CLOSED` if the + transport closed or the dispatcher shut down. + RuntimeError: Called before `run()`. """ + # Post-close sends get the same CONNECTION_CLOSED contract as in-flight waiters. + if self._closed: + raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") if not self._running: - raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run()") opts = opts or {} request_id = self._allocate_id() out_params = dict(params) if params is not None else {} out_meta = dict(out_params.get("_meta") or {}) on_progress = opts.get("on_progress") if on_progress is not None: - # The caller wants progress updates. The spec mechanism is: include - # `_meta.progressToken` on the request; the peer echoes that token on - # any `notifications/progress` it sends. We use the request id as the - # token so the receive loop can find this `_Pending.on_progress` by - # `_pending[token]` without a second lookup table. + # The request id doubles as the progress token, so `_pending[token]` finds `on_progress` directly. out_meta["progressToken"] = request_id out_params["_meta"] = out_meta - # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from - # `_resolve_pending`/`_fan_out_closed` means the waiter already has an - # outcome and dropping the late/redundant signal is correct. buffer=0 - # is unsafe - there's a window between registering `_pending[id]` and - # parking in `receive()` where a close signal would be lost. + # buffer=1: a close signal can arrive before the waiter parks in receive(); + # a WouldBlock later just means the waiter already has its one outcome. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending - metadata = _outbound_metadata(_related_request_id, opts) + plan = _plan_outbound(_related_request_id, opts) + # Spec MUST: only previously-issued requests may be cancelled, so the + # courtesy cancel arms only once the request write completes. + request_written = False + target = out_params.get("name") span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" - # TODO(maxisbey): the otel span + inject below mirror - # BaseSession.send_request for parity. They belong in an outbound - # middleware (symmetric with otel_middleware on the inbound side) once - # that seam exists; the dispatcher should not own otel. + # TODO(maxisbey): move the otel span + inject into an outbound + # middleware once that seam exists; the dispatcher should not own otel. try: with otel_span( span_name, kind=SpanKind.CLIENT, attributes={"mcp.method.name": method, "jsonrpc.request.id": str(request_id)}, ): - # Inject W3C trace context into _meta (SEP-414). With a no-op - # tracer this writes nothing, but `_meta` itself is still - # present on the wire (and the interaction suite pins that). + # SEP-414: inject W3C trace context; `_meta` stays on the wire even with a no-op tracer. inject_trace_context(out_meta) msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) - await self._write(msg, metadata) + try: + await self._write(msg, plan.metadata) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # Transport tore down before run() noticed EOF; surface the documented contract. + raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") from None + request_written = True with anyio.fail_after(opts.get("timeout")): outcome = await receive.receive() except TimeoutError: - # Spec-recommended courtesy: tell the peer we've given up so it can - # stop work and free resources. v1's BaseSession.send_request does - # NOT do this; it's new behaviour. - await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) + # Courtesy cancel (spec-recommended, new vs v1) so the peer stops work; + # unshielded so an outer caller cancellation can still interrupt the write. + if plan.cancel_on_abandon and request_written: + await self._final_write( + partial( + self._cancel_outbound, + request_id, + f"timed out after {opts.get('timeout')}s", + _related_request_id, + ), + shield=False, + timeout=_SHIELDED_WRITE_TIMEOUT, + describe=f"courtesy cancel for timed-out request {request_id!r}", + ) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): - # Our caller's scope was cancelled. We're already inside a cancelled - # scope, so any bare `await` here re-raises immediately - shield to - # let the courtesy cancel notification go out before we propagate. - with anyio.CancelScope(shield=True): - await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) + # Caller cancelled: bare awaits re-raise here, so the shielded helper + # lets the courtesy cancel go out before we propagate. + if plan.cancel_on_abandon and request_written: + await self._final_write( + partial(self._cancel_outbound, request_id, "caller cancelled", _related_request_id), + shield=True, + timeout=_SHIELDED_WRITE_TIMEOUT, + describe=f"courtesy cancel for caller-cancelled request {request_id!r}", + ) raise finally: - # Always remove the waiter, even on cancel/timeout, so a late - # response from the peer (race) hits a closed stream and is dropped - # in `_dispatch` rather than leaking. + # Remove the waiter on every path so a late response is dropped, not leaked. self._pending.pop(request_id, None) send.close() receive.close() @@ -376,15 +364,13 @@ async def notify( *, _related_request_id: RequestId | None = None, ) -> None: - # Leave `params` unset (not explicitly None) when there are none: - # transports serialize with `exclude_unset=True`, and an explicit None - # would survive as `"params": null`, which JSON-RPC 2.0 forbids and - # strict peers (e.g. the TypeScript SDK's zod schemas) reject. + # Leave `params` unset when None: with `exclude_unset=True` an explicit + # None would serialize as `"params": null`, which JSON-RPC 2.0 forbids. if params is not None: msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params)) else: msg = JSONRPCNotification(jsonrpc="2.0", method=method) - await self._write(msg, _outbound_metadata(_related_request_id, None)) + await self._write(msg, _plan_outbound(_related_request_id, None).metadata) async def run( self, @@ -395,10 +381,7 @@ async def run( ) -> None: """Drive the receive loop until the read stream closes. - Each inbound request is handled in its own task in an internal task - group; `task_status.started()` fires once that group is open, so - `await tg.start(dispatcher.run, ...)` resumes when `send_raw_request` - is usable. + `task_status.started()` fires once `send_raw_request` is usable. """ try: async with anyio.create_task_group() as tg: @@ -409,34 +392,27 @@ async def run( async with self._read_stream, self._write_stream: try: async for item in self._read_stream: - # Duck-typed: `_context_streams.ContextReceiveStream` - # exposes `.last_context` (the sender's contextvars - # snapshot per message). Plain memory streams don't. + # Duck-typed: only `ContextReceiveStream` carries the + # sender's per-message contextvars snapshot. sender_ctx: contextvars.Context | None = getattr( self._read_stream, "last_context", None ) await self._dispatch(item, on_request, on_notify, sender_ctx) except anyio.ClosedResourceError: - # The transport closed our receive end and we looped - # back to `__anext__` on the now-closed stream - # (stateless SHTTP teardown). Same as EOF. + # Receive end closed under us (stateless SHTTP teardown); same as EOF. logger.debug("read stream closed by transport; treating as EOF") - # Read stream EOF: wake any blocked `send_raw_request` waiters - # (callers outside this task group) with CONNECTION_CLOSED. + # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. self._running = False + self._closed = True self._fan_out_closed() finally: - # Transport closed: cancel in-flight handlers. Without this - # the task-group join waits for them, and a handler that - # outlives its caller (its request timed out client-side, or - # the client disconnected mid-call) would keep `run()` from - # returning forever. Same behaviour as `Server.run()` before - # the dispatcher rework. + # Cancel in-flight handlers; otherwise the task-group join + # waits on handlers whose callers are already gone. tg.cancel_scope.cancel() finally: - # Covers the cancel/crash paths where the inline fan-out above is - # never reached. Idempotent. + # Covers cancel/crash paths that skip the inline fan-out; idempotent. self._running = False + self._closed = True self._tg = None self._fan_out_closed() @@ -449,13 +425,17 @@ async def _dispatch( ) -> None: """Route one inbound item. - Everything here is `send_nowait` or `_spawn`; the only `await` is for - `inline_methods` requests, which deliberately block dequeuing until - handled. Any other `await` would let one slow message head-of-line - block the entire read loop. + Only `inline_methods` requests and the `on_stream_exception` observer + are awaited; any other `await` would head-of-line block the read loop. """ if isinstance(item, Exception): - logger.debug("transport yielded exception: %r", item) + if self._on_stream_exception is None: + logger.debug("transport yielded exception: %r", item) + return + try: + await self._on_stream_exception(item) + except Exception: + logger.exception("on_stream_exception observer raised") return metadata = item.metadata msg = item.message @@ -467,9 +447,7 @@ async def _dispatch( case JSONRPCResponse(): self._resolve_pending(msg.id, msg.result) case JSONRPCError(): # pragma: no branch - # `id` may be None per JSON-RPC (parse error before id known). - # The match is exhaustive over JSONRPCMessage; the no-match arc - # on this final case is unreachable. + # Exhaustive over JSONRPCMessage, so the no-match arc is unreachable. self._resolve_pending(msg.id, msg.error) async def _dispatch_request( @@ -481,8 +459,7 @@ async def _dispatch_request( ) -> None: progress_token: ProgressToken | None match req.params: - # The bool guard matters: `int()` patterns match bool (a subclass), - # and `True == 1` would alias dict lookups to request id 1. + # bool subclasses int: without the guard True would alias request id 1. case {"_meta": {"progressToken": str() | int() as progress_token}} if not isinstance(progress_token, bool): pass case _: @@ -490,9 +467,7 @@ async def _dispatch_request( try: transport_ctx = self._transport_builder(metadata) except Exception: - # Containment boundary for the user-supplied builder: a raising - # builder must cost only this message, not the whole connection - # (the exception would otherwise escape into run()'s read loop). + # A raising builder must cost only this message, not the connection. logger.exception("transport_builder raised; rejecting request %r", req.id) self._spawn( self._write_error, @@ -509,20 +484,13 @@ async def _dispatch_request( _progress_token=progress_token, ) scope = anyio.CancelScope() - # TODO(maxisbey): the spec puts request-id uniqueness on the sender; - # neither v1 nor the TS SDK guards a duplicate id here, so for now we - # blind-overwrite (parity). Revisit rejecting with INVALID_REQUEST. - # Coerced key so `notifications/cancelled` correlates regardless of - # whether the peer stringifies the id between request and cancel - # (`_dispatch_notification` coerces at lookup; responses still echo - # `req.id` verbatim). + # TODO(maxisbey): duplicate ids blind-overwrite (v1/TS parity); revisit + # rejecting with INVALID_REQUEST. Key coerced so a stringified + # `notifications/cancelled` id still correlates. self._in_flight[_coerce_id(req.id)] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: - # Spawn (so `sender_ctx` applies, matching the concurrent path) but - # park the read loop until the handler returns; that's the inline - # ordering guarantee. Because the read loop is parked, a handler - # that awaits the peer here (e.g. `dctx.send_raw_request`) will - # deadlock: the response can never be dequeued. + # Spawn so `sender_ctx` applies, but park the read loop until the + # handler returns - that's the inline ordering guarantee. done = anyio.Event() async def _run_inline() -> None: @@ -546,17 +514,12 @@ def _dispatch_notification( """Route one inbound notification. `notifications/cancelled` and `notifications/progress` are intercepted - here because they correlate against JSON-RPC request IDs - the - `_in_flight` / `_pending` tables this layer owns - so no higher layer - can act on them. Both are still teed to `on_notify` afterwards, so - middleware and registered notification handlers observe every inbound - notification. See the module docstring for the design rationale. + here (they correlate against the `_in_flight`/`_pending` tables this + layer owns) and still teed to `on_notify` afterwards. """ if msg.method == "notifications/cancelled": match msg.params: - # The bool guards here and below matter: `int()` patterns match - # bool (a subclass), and `True == 1` would alias the dict lookup - # to the entry keyed by request id 1. + # bool subclasses int: the guards keep True from aliasing request id 1. case {"requestId": str() | int() as rid} if ( not isinstance(rid, bool) and (in_flight := self._in_flight.get(_coerce_id(rid))) is not None ): @@ -565,9 +528,6 @@ def _dispatch_notification( in_flight.scope.cancel() case _: pass - # fall through: cancelled is also teed to on_notify so middleware - # and registered handlers can observe it (matches DirectDispatcher, - # which forwards every notification). elif msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( @@ -587,18 +547,16 @@ def _dispatch_notification( ) case _: pass - # fall through: progress is also teed to on_notify try: transport_ctx = self._transport_builder(metadata) except Exception: - # Same containment boundary as `_dispatch_request`: a raising - # builder drops this notification instead of killing the read loop. + # Same containment as `_dispatch_request`: drop the notification, keep the loop. logger.exception("transport_builder raised; dropping notification %r", msg.method) return dctx = _JSONRPCDispatchContext( transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata ) - self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + self._spawn(_contained_notify(on_notify), dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None @@ -618,10 +576,8 @@ def _spawn( ) -> None: """Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars. - ASGI middleware (auth, OTel) sets contextvars on the request task that - wrote into the read stream. `Context.run(tg.start_soon, ...)` makes - the spawned handler inherit *that* context instead of the receive - loop's, so `auth_context_var` and OTel spans survive. + ASGI middleware (auth, OTel) sets contextvars on the task that wrote the + message; `Context.run` makes the spawned handler inherit that context. """ assert self._tg is not None if sender_ctx is not None: @@ -632,10 +588,9 @@ def _spawn( def _fan_out_closed(self) -> None: """Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`. - Synchronous (uses `send_nowait`) because it's called from `finally` - which may be inside a cancelled scope. Idempotent. + Synchronous: callers may be inside a cancelled scope. Idempotent. """ - closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + closed = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") for pending in self._pending.values(): try: pending.send.send_nowait(closed) @@ -652,61 +607,54 @@ async def _handle_request( ) -> None: """Run `on_request` for one inbound request and write its response. - This is the single exception-to-wire boundary: handler exceptions are - caught here and serialized to `JSONRPCError`. Nothing above this in - the stack constructs wire errors. + The single exception-to-wire boundary: handler exceptions become `JSONRPCError` here. """ try: with scope: try: result = await on_request(dctx, req.method, req.params) finally: - # Handler done: close the back-channel (detached work that - # later calls `dctx.send_raw_request()` should see - # `NoBackChannelError`) and drop from `_in_flight` so a - # late `notifications/cancelled` is a no-op rather than - # racing the result write below. No checkpoint between - # handler return and the pop, so the cancel can't - # interleave there. + # Close the back-channel and drop from `_in_flight`; no checkpoint + # since handler return, so a peer cancel can't interleave. + # Identity guard: don't evict a duplicate id's newer entry. dctx.close() - self._in_flight.pop(_coerce_id(req.id), None) + key = _coerce_id(req.id) + if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: + del self._in_flight[key] await self._write_result(req.id, result) - if scope.cancel_called: - # Peer-cancel: `_dispatch_notification` cancelled this scope - # while the handler was running. anyio swallows a scope's *own* - # cancel at __exit__, so execution lands here rather than the - # `except cancelled` arm below. - # TODO(maxisbey): spec says SHOULD NOT respond after cancel. - # The existing server always has, so match that for now. + if scope.cancelled_caught: + # anyio absorbs the scope's own cancel at __exit__, and + # `cancelled_caught` (unlike `cancel_called`) guarantees the + # result write above did not happen - no double response. + # TODO(maxisbey): spec says SHOULD NOT respond after cancel; + # the existing server always has, so match that for now. await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): - # Outer-cancel: run()'s task group is shutting down. Any bare - # `await` here re-raises immediately, so shield the courtesy write. - with anyio.CancelScope(shield=True): - await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + # Shutdown: answer the request so the peer isn't left waiting; the + # shielded helper is needed because bare awaits re-raise here. + await self._final_write( + partial(self._write_error, req.id, ErrorData(code=CONNECTION_CLOSED, message="Connection closed")), + shield=True, + timeout=_SHUTDOWN_WRITE_TIMEOUT, + describe=f"shutdown error response for request {req.id!r}", + ) raise except MCPError as e: await self._write_error(req.id, e.error) except ValidationError: - # TODO(maxisbey): data="" is pinned compat with the existing - # server (which never leaked pydantic error text onto the wire). - # Consider putting the validation detail in `data` once the - # interaction suite's divergence entry is resolved. + # TODO(maxisbey): data="" pins existing-server compat (no pydantic + # text on the wire); revisit per the suite's divergence entry. await self._write_error( req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") ) except Exception as e: logger.exception("handler for %r raised", req.method) - # TODO(maxisbey): code=0 is pinned compat with the existing - # server's `_handle_request`. JSON-RPC says INTERNAL_ERROR - # (-32603); revisit once the suite's divergence entry is resolved. + # TODO(maxisbey): code=0 pins existing-server compat; JSON-RPC says + # INTERNAL_ERROR. Revisit per the suite's divergence entry. await self._write_error(req.id, ErrorData(code=0, message=str(e))) if self._raise_handler_exceptions: raise - # No outer `_in_flight` pop here: the inner `finally` above already - # removes the entry on every path out of the handler, and a second - # pop after the awaited response writes could evict a newer request - # that reused the id during that window. + # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. def _allocate_id(self) -> int: self._next_id += 1 @@ -727,11 +675,28 @@ async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped error for %r: write stream closed", request_id) + async def _final_write( + self, + write: Callable[[], Awaitable[None]], + *, + shield: bool, + timeout: float, + describe: str, + ) -> None: + """Attempt one last write under the shared abandon/teardown policy. + + `shield=True` is for arms already inside a cancelled scope (a bare + `await` would re-raise); the bound keeps a wedged transport write + from becoming an uncancellable hang. + """ + with anyio.move_on_after(timeout, shield=shield) as scope: + await write() + if scope.cancelled_caught: + logger.warning("%s gave up: transport write blocked", describe) + async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: - # Thread `related_request_id` so streamable-HTTP routes the cancel onto - # the same per-request SSE stream as the request it cancels; without it - # the notification falls through to the standalone GET stream and is - # dropped when no GET stream is open. + # Thread `related_request_id` so streamable HTTP routes the cancel onto + # the request's own SSE stream instead of a possibly-absent GET stream. try: await self.notify( "notifications/cancelled", diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 61279ad8b8..b4f0beedf1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,487 +1,21 @@ -from __future__ import annotations +"""Compatibility names that outlived the removed v1 session layer (`BaseSession`).""" -import contextvars -import logging -from contextlib import AsyncExitStack -from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import Generic, TypeVar -import anyio -from anyio.streams.memory import MemoryObjectSendStream -from opentelemetry.trace import SpanKind -from pydantic import BaseModel, TypeAdapter -from typing_extensions import Self - -from mcp.shared._compat import resync_tracer -from mcp.shared._otel import inject_trace_context, otel_span -from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.exceptions import MCPError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.types import ( - CONNECTION_CLOSED, - INVALID_PARAMS, - METHOD_NOT_FOUND, - REQUEST_TIMEOUT, - CancelledNotification, - ClientNotification, - ClientRequest, - ClientResult, - ErrorData, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - ProgressNotification, - ProgressToken, - RequestParamsMeta, - ServerNotification, - ServerRequest, - ServerResult, -) - -SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) -SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) -SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) -ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) -ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) -ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) +from mcp.shared.dispatcher import ProgressFnT as ProgressFnT +from mcp.shared.message import MessageMetadata +from mcp.types import RequestParamsMeta RequestId = str | int - -class ProgressFnT(Protocol): - """Protocol for progress notification callbacks.""" - - async def __call__( - self, progress: float, total: float | None, message: str | None - ) -> None: ... # pragma: no branch +ReceiveRequestT = TypeVar("ReceiveRequestT") +SendResultT = TypeVar("SendResultT") class RequestResponder(Generic[ReceiveRequestT, SendResultT]): - """Handles responding to MCP requests and manages request lifecycle. - - This class MUST be used as a context manager to ensure proper cleanup and - cancellation handling: - - Example: - ```python - with request_responder as resp: - await resp.respond(result) - ``` - - The context manager ensures: - 1. Proper cancellation scope setup and cleanup - 2. Request completion tracking - 3. Cleanup of in-flight requests - """ - - def __init__( - self, - request_id: RequestId, - request_meta: RequestParamsMeta | None, - request: ReceiveRequestT, - session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - message_metadata: MessageMetadata = None, - context: contextvars.Context | None = None, - ) -> None: - self.request_id = request_id - self.request_meta = request_meta - self.request = request - self.message_metadata = message_metadata - self.context = context - self._session = session - self._completed = False - self._entered = False # Track if we're in a context manager - - def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: - self._entered = True - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self._entered = False - - async def respond(self, response: SendResultT | ErrorData) -> None: - """Send a response for this request. - - Must be called within a context manager block. - - Raises: - RuntimeError: If not used within a context manager - AssertionError: If request was already responded to - """ - if not self._entered: # pragma: no cover - raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" - self._completed = True - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, response=response - ) - - -class BaseSession( - Generic[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], -): - """Implements an MCP "session" on top of read/write streams, including features - like request/response linking, notifications, and progress. - - This class is an async context manager that automatically starts processing - messages when entered. - """ - - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] - _request_id: int - _progress_callbacks: dict[RequestId, ProgressFnT] - - def __init__( - self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - # If none, reading will never time out - read_timeout_seconds: float | None = None, - ) -> None: - self._read_stream = read_stream - self._write_stream = write_stream - self._response_streams = {} - self._request_id = 0 - self._session_read_timeout_seconds = read_timeout_seconds - self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() - - async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - await resync_tracer() - return result - - async def send_request( - self, - request: SendRequestT, - result_type: type[ReceiveResultT], - request_read_timeout_seconds: float | None = None, - metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: - """Sends a request and waits for a response. - - Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take - precedence over the session read timeout. - - Do not use this method to emit notifications! Use send_notification() instead. - """ - request_id = self._request_id - self._request_id = request_id + 1 - - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream - - # Set up progress token if progress callback is provided - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token - if "params" not in request_data: # pragma: lax no cover - request_data["params"] = {} - if "_meta" not in request_data["params"]: # pragma: lax no cover - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback - - try: - target = request_data.get("params", {}).get("name") - span_name = f"MCP send {request.method} {target}" if target else f"MCP send {request.method}" - - with otel_span( - span_name, - kind=SpanKind.CLIENT, - attributes={"mcp.method.name": request.method, "jsonrpc.request.id": str(request_id)}, - ): - # Inject W3C trace context into _meta (SEP-414). - meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {}) - inject_trace_context(meta) - - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) - - finally: - self._response_streams.pop(request_id, None) - self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() - - async def send_notification( - self, - notification: SendNotificationT, - related_request_id: RequestId | None = None, - ) -> None: - """Emits a notification, which is a one-way message that does not expect a response.""" - # Some transport implementations may need to set the related_request_id - # to attribute to the notifications to the request that triggered them. - jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", - **notification.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage( - message=jsonrpc_notification, - metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, - ) - await self._write_stream.send(session_message) - - async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: - if isinstance(response, ErrorData): - jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=jsonrpc_error) - await self._write_stream.send(session_message) - else: - jsonrpc_response = JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage(message=jsonrpc_response) - await self._write_stream.send(session_message) - - @property - def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: - """Each subclass must provide its own request adapter.""" - raise NotImplementedError - - @property - def _receive_request_methods(self) -> frozenset[str]: - """Method names in the receive-request union; anything else is - answered with METHOD_NOT_FOUND before validation is attempted.""" - raise NotImplementedError - - @property - def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: - raise NotImplementedError - - async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: - try: - - async def _handle_session_message(message: SessionMessage) -> None: - sender_context: contextvars.Context | None = getattr(self._read_stream, "last_context", None) - if isinstance(message.message, JSONRPCRequest): - if message.message.method not in self._receive_request_methods: - # Unknown methods are METHOD_NOT_FOUND (-32601) per - # JSON-RPC 2.0, not validation failures (-32602). - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.id, - error=ErrorData( - code=METHOD_NOT_FOUND, message="Method not found", data=message.message.method - ), - ) - await self._write_stream.send(SessionMessage(message=error_response)) - return - try: - validated_request = self._receive_request_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, - ) - responder = RequestResponder( - request_id=message.message.id, - request_meta=validated_request.params.meta if validated_request.params else None, - request=validated_request, - session=self, - message_metadata=message.metadata, - context=sender_context, - ) - await self._received_request(responder) - except Exception: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning("Failed to validate request", exc_info=True) - logging.debug(f"Message that failed validation: {message.message}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.id, - error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), - ) - session_message = SessionMessage(message=error_response) - await self._write_stream.send(session_message) - - elif isinstance(message.message, JSONRPCNotification): - try: - notification = self._receive_notification_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, - ) - if isinstance(notification, CancelledNotification): - # ClientSession runs server-initiated requests - # inline in this loop, so by the time a peer - # cancellation is read there is nothing left to - # cancel. Consume it here so message_handler - # keeps the contract it had before the - # dispatcher swap removed _in_flight. - return - # Handle progress notifications callback - if isinstance(notification, ProgressNotification): - progress_token = notification.params.progress_token - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.params.progress, - notification.params.total, - notification.params.message, - ) - except Exception: - logging.exception("Progress callback raised an exception") - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception: - # For other validation errors, log and continue - logging.warning( - "Failed to validate notification: %s", - message.message, - exc_info=True, - ) - else: # Response or error - await self._handle_response(message) - - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - continue - - await _handle_session_message(message) - - except anyio.ClosedResourceError: - # This is expected when the client disconnects abruptly. - # Without this handler, the exception would propagate up and - # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: - # Other exceptions are not expected and should be logged. We purposefully - # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover - finally: - # after the read stream is closed, we need to send errors - # to any pending requests - # Snapshot: stream.send() wakes the waiter, whose finally pops - # from _response_streams before the next __next__() call. - for id, stream in list(self._response_streams.items()): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: # pragma: lax no cover - # Stream might already be closed - pass - self._response_streams.clear() - - def _normalize_request_id(self, response_id: RequestId) -> RequestId: - """Normalize a response ID to match how request IDs are stored. - - Since the client always sends integer IDs, we normalize string IDs - to integers when possible. This matches the TypeScript SDK approach: - https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 - - Args: - response_id: The response ID from the incoming message. - - Returns: - The normalized ID (int if possible, otherwise original value). - """ - if isinstance(response_id, str): - try: - return int(response_id) - except ValueError: - logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") - return response_id - - async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message.""" - # This check is always true at runtime: the caller (_receive_loop) only invokes - # this method in the else branch after checking for JSONRPCRequest and - # JSONRPCNotification. However, the type checker can't infer this from the - # method signature, so we need this guard for type narrowing. - if not isinstance(message.message, JSONRPCResponse | JSONRPCError): - return # pragma: no cover - - if message.message.id is None: - # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId - error = message.message.error - logging.warning(f"Received error with null ID: {error.message}") - await self._handle_incoming(MCPError(error.code, error.message, error.data)) - return - # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(message.message.id) - - stream = self._response_streams.pop(response_id, None) - if stream: - await stream.send(message.message) - else: - await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) - - async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: - """Can be overridden by subclasses to handle a request without needing to - listen on the message stream. - - If the request is responded to within this method, it will not be - forwarded on to the message stream. - """ - - async def _received_notification(self, notification: ReceiveNotificationT) -> None: - """Can be overridden by subclasses to handle a notification without needing - to listen on the message stream. - """ - - async def send_progress_notification( - self, - progress_token: ProgressToken, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """Sends a progress notification for a request that is currently being processed.""" + """Typing stub for the v1 responder; the SDK never instantiates it.""" - async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception - ) -> None: - """A generic handler for incoming messages. Overridden by subclasses.""" + request_id: RequestId + request_meta: RequestParamsMeta | None + request: ReceiveRequestT + message_metadata: MessageMetadata diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index cb49ff29db..b2d537fb70 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -152,7 +152,6 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, - REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -320,7 +319,6 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", - "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 14743c33b0..84304a37c1 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,7 +43,6 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 -REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index be4b9a97b9..a26ef45b27 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -2,9 +2,8 @@ from pydantic import FileUrl from mcp import Client -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent @@ -20,7 +19,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py deleted file mode 100644 index c7bf8fafa4..0000000000 --- a/tests/client/test_resource_cleanup.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any -from unittest.mock import patch - -import anyio -import pytest -from pydantic import TypeAdapter - -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestId, SendResultT -from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest - - -@pytest.mark.anyio -async def test_send_request_stream_cleanup(): - """Test that send_request properly cleans up streams when an exception occurs. - - This test mocks out most of the session functionality to focus on stream cleanup. - """ - - # Create a mock session with the minimal required functionality - class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]): - async def _send_response( - self, request_id: RequestId, response: SendResultT | ErrorData - ) -> None: # pragma: no cover - pass - - @property - def _receive_request_adapter(self) -> TypeAdapter[Any]: - return TypeAdapter(object) # pragma: no cover - - @property - def _receive_notification_adapter(self) -> TypeAdapter[Any]: - return TypeAdapter(object) # pragma: no cover - - # Create streams - write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) - read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) - - # Create the session - session = TestSession(read_stream_receive, write_stream_send) - - # Create a test request - request = PingRequest() - - # Patch the _write_stream.send method to raise an exception - async def mock_send(*args: Any, **kwargs: Any): - raise RuntimeError("Simulated network error") - - # Record the response streams before the test - initial_stream_count = len(session._response_streams) - - # Run the test with the patched method - with patch.object(session._write_stream, "send", mock_send): - with pytest.raises(RuntimeError): - await session.send_request(request, EmptyResult) - - # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" - ) - - # Clean up - await write_stream_send.aclose() - await write_stream_receive.aclose() - await read_stream_send.aclose() - await read_stream_receive.aclose() diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 6efcac0a52..2b90b00afa 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,9 +1,8 @@ import pytest from mcp import Client -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, @@ -26,7 +25,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return @@ -71,7 +70,7 @@ async def test_create_message_backwards_compat_single_content(): ) async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 28d212d007..48ef5bab78 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,22 +1,28 @@ from __future__ import annotations -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Mapping +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any import anyio +import anyio.abc import anyio.streams.memory import pytest +from pydantic import FileUrl from mcp import types +from mcp.client import ClientRequestContext from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession -from mcp.shared._context import RequestContext +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( INVALID_PARAMS, LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, CallToolResult, Implementation, InitializedNotification, @@ -420,7 +426,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -430,7 +436,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) @@ -504,7 +510,7 @@ async def test_client_capabilities_with_sampling_tools(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -751,8 +757,32 @@ async def test_receive_loop_answers_malformed_inbound_request_with_invalid_param @pytest.mark.anyio -async def test_receive_loop_answers_invalid_params_when_sampling_callback_raises(): - """Same boundary catches exceptions from the request handler itself.""" +async def test_receive_loop_answers_unknown_request_method_with_method_not_found(): + """An unknown request method is answered with METHOD_NOT_FOUND, not INVALID_PARAMS (spec-mandated).""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.id == 7 + assert out.message.error == types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown") + + +@pytest.mark.anyio +async def test_receive_loop_drops_unknown_notification_method_without_response(): + """An unknown notification method is dropped silently: JSON-RPC forbids responses to notifications.""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="x/unknown"))) + # The answered follow-up ping proves no response was emitted and the loop survived. + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCResponse) + assert out.message.id == 1 + + +@pytest.mark.anyio +async def test_raising_sampling_callback_answers_with_code_zero(): + """A raising sampling callback is answered with code 0 and `str(exc)` (SDK-defined). + Raw streams because the assertion is the outbound `JSONRPCError` envelope itself.""" async def boom(ctx: object, params: object) -> types.CreateMessageResult: raise RuntimeError("sampling boom") @@ -767,12 +797,13 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: ) out = await from_client.receive() assert isinstance(out.message, JSONRPCError) - assert out.message.error.code == INVALID_PARAMS + assert out.message.error == types.ErrorData(code=0, message="sampling boom") @pytest.mark.anyio async def test_receive_loop_logs_and_drops_malformed_notification(caplog: pytest.LogCaptureFixture): - """A notification that fails ServerNotification validation is logged and dropped.""" + """A malformed notification is logged and dropped without reaching `message_handler` (SDK-defined). + Scripted peer: the typed API cannot emit a method outside the spec's notification union.""" seen: list[object] = [] delivered = anyio.Event() @@ -792,19 +823,54 @@ async def handler(msg: object) -> None: @pytest.mark.anyio -async def test_receive_loop_forwards_transport_exception_to_message_handler(): +async def test_raising_message_handler_on_transport_exception_costs_the_delivery_not_the_connection( + caplog: pytest.LogCaptureFixture, +): + """A `message_handler` that raises on a transport-level `Exception` item is contained: the + failure is logged and the receive loop keeps serving (SDK-defined). Raw streams because + only a transport can put an `Exception` item on the read stream.""" seen: list[object] = [] delivered = anyio.Event() async def handler(msg: object) -> None: seen.append(msg) delivered.set() + # No checkpoint between set() and the containment log, so after wait() the log entry exists. + raise RuntimeError("handler boom") - async with raw_client_session(message_handler=handler) as (_session, to_client, _): + async with raw_client_session(message_handler=handler) as (_session, to_client, from_client): exc = ValueError("bad bytes") await to_client.send(exc) await delivered.wait() + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=9, method="ping"))) + out = await from_client.receive() assert seen == [exc] + assert isinstance(out.message, JSONRPCResponse) + assert out.message.id == 9 + assert "message_handler raised on transport exception" in caplog.text + + +@pytest.mark.anyio +async def test_message_handler_awaiting_session_traffic_on_transport_exception_completes(): + """A `message_handler` that awaits session traffic on a transport `Exception` item completes: + fault deliveries are spawned into the task group, not run inline in the read loop (SDK-defined). + Raw streams because only a transport can put an `Exception` item on the read stream.""" + ponged = anyio.Event() + + # `session` resolves at call time, after the `as` clause binds it. + async def handler(msg: object) -> None: + assert isinstance(msg, Exception) + await session.send_ping() + ponged.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, from_client): + await to_client.send(ValueError("bad bytes")) + # Serve the handler's ping like a transport would; inline delivery would deadlock here. + out = await from_client.receive() + assert isinstance(out.message, JSONRPCRequest) + assert out.message.method == "ping" + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={}))) + await ponged.wait() @pytest.mark.anyio @@ -814,6 +880,7 @@ async def test_receive_loop_consumes_server_cancelled_without_reaching_message_h The server dispatcher now emits this on sampling/elicitation timeout, but ClientSession has no in-flight tracking to act on it, so surfacing it would only break user handlers that exhaustively match ServerNotification. + Scripted peer: the typed server API cannot emit a bare `notifications/cancelled`. """ seen: list[object] = [] delivered = anyio.Event() @@ -841,23 +908,214 @@ async def handler(msg: object) -> None: @pytest.mark.anyio -async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture): +async def test_progress_notification_reaches_request_callback_and_message_handler(): + """A `notifications/progress` for an in-flight request reaches both the `progress_callback` and + `message_handler` (SDK-defined). Scripted peer: the progress token must echo the wire request id.""" + updates: list[tuple[float, float | None, str | None]] = [] + teed: list[types.ProgressNotification] = [] + request_id: types.RequestId | None = None + progressed = anyio.Event() delivered = anyio.Event() - async def boom(progress: float, total: float | None, message: str | None) -> None: - raise RuntimeError("progress boom") + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + updates.append((progress, total, message)) + progressed.set() async def handler(msg: object) -> None: + # Only the progress notification is teed to the message handler here. + assert isinstance(msg, types.ProgressNotification) + teed.append(msg) delivered.set() - async with raw_client_session(message_handler=handler) as (session, to_client, _): - # Register the callback under a known token without sending a request. - session._progress_callbacks[42] = boom # pyright: ignore[reportPrivateUsage] - params = {"progressToken": 42, "progress": 0.5} - await to_client.send( - SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) - ) - # The progress notification also reaches the message handler after the - # callback runs, so this fires once the callback's exception is handled. - await delivered.wait() - assert "Progress callback raised an exception" in caplog.text + async with raw_client_session(message_handler=handler) as (session, to_client, from_client): + async with anyio.create_task_group() as tg: + + async def call() -> None: + await session.send_request(types.PingRequest(), types.EmptyResult, progress_callback=on_progress) + + tg.start_soon(call) + request = await from_client.receive() + assert isinstance(request.message, JSONRPCRequest) + request_id = request.message.id + # The request id doubles as the progress token. + params = {"progressToken": request_id, "progress": 0.5, "total": 1.0, "message": "halfway"} + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) + ) + await progressed.wait() + await delivered.wait() + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + assert updates == [(0.5, 1.0, "halfway")] + assert request_id is not None + assert len(teed) == 1 + assert teed[0].params == types.ProgressNotificationParams( + progress_token=request_id, progress=0.5, total=1.0, message="halfway" + ) + + +@pytest.mark.anyio +async def test_dispatcher_keyword_runs_over_direct_dispatch(): + """A session built with dispatcher= works without a stream pair (in-process embedding).""" + client_side, server_side = create_direct_dispatcher_pair() + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> dict[str, object]: + assert method == "ping" + return {} + + notified: list[str] = [] + + async def server_on_notify( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> None: + notified.append(method) + + session = ClientSession(dispatcher=client_side) + results: list[types.EmptyResult] = [] + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, server_on_request, server_on_notify) + async with session: + results.append(await session.send_ping(meta=None)) + # Server-to-client: direct dispatch delivers ping with no params member (no _meta injection). + assert await server_side.send_raw_request("ping", None) == {} + await session.send_notification(types.RootsListChangedNotification()) + server_side.close() + assert results == [types.EmptyResult()] + assert notified == ["notifications/roots/list_changed"] + + +@pytest.mark.anyio +async def test_direct_dispatch_roots_list_reaches_callback_with_synthesized_request_id(): + """A server-initiated roots/list over dispatcher= reaches the registered callback and round-trips + the result; the callback context carries an int request_id (SDK-defined: DirectDispatcher + synthesizes ids).""" + client_side, server_side = create_direct_dispatcher_pair() + contexts: list[ClientRequestContext] = [] + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + contexts.append(context) + return types.ListRootsResult(roots=[types.Root(uri=FileUrl("file:///workspace"))]) + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> dict[str, object]: + raise NotImplementedError + + async def server_on_notify( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> None: + raise NotImplementedError + + session = ClientSession(dispatcher=client_side, list_roots_callback=list_roots) + result: dict[str, Any] | None = None + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, server_on_request, server_on_notify) + async with session: + result = await server_side.send_raw_request("roots/list", None) + server_side.close() + assert result == {"roots": [{"uri": "file:///workspace"}]} + assert len(contexts) == 1 + assert isinstance(contexts[0].request_id, int) + + +@pytest.mark.anyio +async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): + """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids + cancelling it — and leaves the option unset for every other method.""" + + class RecordingDispatcher: + """Records `send_raw_request` opts and answers with canned results.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, CallOptions]] = [] + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, opts or {})) + if method == "initialize": + return InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + return {} + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + pass + + dispatcher = RecordingDispatcher() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.initialize() + await session.send_ping() + opts_by_method = dict(dispatcher.calls) + assert opts_by_method["initialize"].get("cancel_on_abandon") is False + assert "cancel_on_abandon" not in opts_by_method["ping"] + + +def test_constructor_rejects_streams_and_dispatcher_together(): + client_side, _server_side = create_direct_dispatcher_pair() + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + with pytest.raises(ValueError, match="not both"): + ClientSession(s2c_recv, dispatcher=client_side) + s2c_send.close() + s2c_recv.close() + + +def test_constructor_requires_both_streams_without_dispatcher(): + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + with pytest.raises(ValueError, match="read_stream and write_stream are required"): + ClientSession(s2c_recv) + with pytest.raises(ValueError, match="read_stream and write_stream are required"): + ClientSession() + s2c_send.close() + s2c_recv.close() + + +@pytest.mark.anyio +async def test_aenter_cancelled_while_dispatcher_starts_unwinds_cleanly(): + """Cancellation while `__aenter__` waits for the dispatcher to start unwinds the half-entered + task group cleanly, not via anyio's "exited non-innermost cancel scope" RuntimeError (SDK-defined).""" + + class NeverStartsDispatcher: + """`run()` parks without ever signalling `task_status.started()`.""" + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + session = ClientSession(dispatcher=NeverStartsDispatcher()) + async with AsyncExitStack() as stack: + # `start()` is parked forever, so the deadline only ends the wait — any duration is non-racy. + with anyio.move_on_after(0.01) as scope: + await stack.enter_async_context(session) + assert scope.cancelled_caught + # The failed enter must not leave the session half-entered. + assert session._task_group is None diff --git a/tests/interaction/README.md b/tests/interaction/README.md index be68c3b0f1..473e79c83b 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -193,11 +193,13 @@ many requirements at once; if the assertions would be separate, write separate t ### Notifications and concurrency -The client's receive loop dispatches each incoming message to completion before reading the next, -and the in-memory transport delivers everything on one ordered stream. Together these guarantee -that every notification a server handler emits before its response reaches the client callback -before the originating request returns — so tests collect notifications into a plain list and -assert after the call, with no synchronisation. The exceptions: +The client's dispatcher starts a task per incoming notification in arrival order but does not +await it before reading the next message, so completion order is not structural. What still +holds: the in-memory transport delivers everything on one ordered stream, and a callback that +records synchronously (no `await` before the append) finishes its scheduling slice before the +awaited request's waiter — woken strictly later — resumes. So tests whose callbacks are plain +appends may still collect into a list and assert after the call. A callback that awaits before +recording loses that ordering and must synchronise. The other exceptions: - a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in the receiving handler and awaited under `anyio.fail_after(5)`; @@ -220,9 +222,8 @@ but still inside an outer `async with`, and no restructure can avoid it. A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose execution is timing-dependent under the in-process HTTP bridge — the POST-stream and -stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` -check in `message_router`, and the response-stream double-close guard in -`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to -strict `no cover` without first making the teardown ordering deterministic. The suite also relies -on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the -SSE leg would otherwise leak. +stateless-session `except Exception` handlers in `server/streamable_http*.py` and the +`_terminated` check in `message_router`. `strict-no-cover` does not check `lax` lines; do not +promote them to strict `no cover` without first making the teardown ordering deterministic. The +suite also relies on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that +closes a stream the SSE leg would otherwise leak. diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 25833b0ca5..54d41e1e7b 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -67,8 +67,9 @@ def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage] self._log = log async def send(self, item: SessionMessage, /) -> None: - self._log.append(item) + # Record only after the inner send returns: a failed or cancelled send never reached the transport. await self._inner.send(item) + self._log.append(item) async def aclose(self) -> None: await self._inner.aclose() diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index caed8905d0..acaef072c1 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -268,18 +268,15 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " - "request; the server sends an error response (code 0, 'Request cancelled'), which is what " - "unblocks the SDK client's pending call." + "request; both seats send an error response (code 0, 'Request cancelled') instead — the " + "server for cancelled client requests, and the client for cancelled server-initiated " + "requests — which is what unblocks the sender's pending call." ), ), ), "protocol:cancel:initialize-not-cancellable": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior="The client never sends notifications/cancelled for the initialize request.", - deferred=( - "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " - "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." - ), ), "protocol:cancel:late-response-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", @@ -287,14 +284,6 @@ def __post_init__(self) -> None: "A response that arrives after the sender issued notifications/cancelled is ignored; the " "request stays failed and no error is raised." ), - divergence=Divergence( - note=( - "A response whose id matches no in-flight request is delivered to the message handler " - "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " - "same code path; tested in its unknown-id form because that is deterministic without the " - "client-side cancellation API the SDK does not yet provide." - ), - ), ), "protocol:cancel:server-survives": Requirement( source="sdk", @@ -306,19 +295,6 @@ def __post_init__(self) -> None: "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " "cancels it, and the client stops processing the cancelled request." ), - divergence=Divergence( - note=( - "Abandoning a server-side send_request emits no cancellation notification, and the client " - "could not act on one anyway: client callbacks run inline in the receive loop, so a " - "cancellation is not even read until the callback has finished." - ), - ), - deferred=( - "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " - "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " - "the client could not act on one anyway because client callbacks run inline in the receive " - "loop, so a cancellation would not even be read until the callback had already finished." - ), ), "protocol:cancel:unknown-id-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", @@ -363,6 +339,26 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic#responses", behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", ), + "protocol:error:null-id": Requirement( + source="sdk", + behavior=( + "An error response carrying a null id — the JSON-RPC shape for a peer reporting a failure it " + "could not attribute to a request, such as a parse error — is surfaced to the application " + "rather than silently discarded." + ), + divergence=Divergence( + note=( + "The dispatcher drops null-id error responses with a debug log; v1 surfaced them to " + "message_handler as an MCPError. A typed fault channel restoring visibility is planned " + "before v2 stable." + ), + ), + deferred=( + "Not yet covered here: the current drop is pinned at the dispatcher level by " + "tests/shared/test_jsonrpc_dispatcher.py; an interaction-level test waits on the planned " + "fault channel." + ), + ), "protocol:meta:related-task": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", behavior="Messages may carry related-task _meta associating them with a task.", @@ -466,11 +462,6 @@ def __post_init__(self) -> None: "When a request times out, the sender issues notifications/cancelled for that request before " "failing the local call." ), - divergence=Divergence( - note=( - "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." - ), - ), ), "protocol:timeout:session-survives": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 6f1454e58a..6e6c2b6f60 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -11,11 +11,12 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client import ClientSession +from mcp.client import ClientRequestContext, ClientSession from mcp.server import Server, ServerRequestContext from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( + REQUEST_TIMEOUT, CallToolResult, EmptyResult, ErrorData, @@ -155,14 +156,71 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) +@requirement("protocol:cancel:server-to-client") +async def test_abandoned_server_request_cancels_the_client_callback(connect: Connect) -> None: + """A server that abandons a sampling request cancels it, interrupting the client's callback mid-await.""" + callback_started = anyio.Event() + callback_cancelled = anyio.Event() + + async def sampling_callback( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + callback_started.set() + try: + await anyio.Event().wait() # blocks until the cancellation interrupts it + except anyio.get_cancelled_exc_class(): + callback_cancelled.set() + raise + raise NotImplementedError # unreachable + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "impatient" + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=8, + ) + ) + async with anyio.create_task_group() as abandon_scope: + + async def sample() -> None: + await ctx.session.send_request(request, types.CreateMessageResult) + raise NotImplementedError # unreachable: the scope is cancelled + + abandon_scope.start_soon(sample) + with anyio.fail_after(5): + await callback_started.wait() + abandon_scope.cancel_scope.cancel() + with anyio.fail_after(5): + await callback_cancelled.wait() + return CallToolResult(content=[TextContent(text="abandoned")]) + + server = Server("abandoner", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("impatient", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="abandoned")])) + assert callback_cancelled.is_set() + + @requirement("protocol:cancel:late-response-ignored") -async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: - """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. +async def test_a_response_for_an_unknown_request_id_is_ignored() -> None: + """A response whose id matches no in-flight request is ignored, as the spec asks. The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; that is the same client-side code path as any response with an unknown id, and that form is - deterministic to test without depending on the cancellation API the SDK does not yet provide. - See the divergence note on the requirement. + deterministic to test without a client-side cancellation API. + + "Ignored" is proved in two halves: the pong round-trip proves the read loop survived the + fabricated response (the ordered in-memory stream routed it first), and `surfaced` holding + only the control notification proves the fabricated response was never delivered to + `message_handler` (v1 surfaced it there as a RuntimeError). A real Server cannot be made to answer with a fabricated id, so the test plays the server's side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The @@ -208,14 +266,18 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage assert isinstance(ping, SessionMessage) assert isinstance(ping.message, JSONRPCRequest) assert ping.message.method == "ping" - # First answer with a fabricated id that matches nothing in flight, then the real id. + # First a fabricated id that matches nothing in flight, then a control notification that + # is surfaced to message_handler (proving the handler is live), then the real id. await server_write.send(respond(9999, EmptyResult())) + await server_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) await server_write.send(respond(ping.message.id, EmptyResult())) - incoming: list[IncomingMessage] = [] + surfaced: list[IncomingMessage] = [] async def message_handler(message: IncomingMessage) -> None: - incoming.append(message) + surfaced.append(message) async with ( create_client_server_memory_streams() as ((client_read, client_write), server_streams), @@ -228,7 +290,56 @@ async def message_handler(message: IncomingMessage) -> None: pong = await session.send_request(PingRequest(), EmptyResult) assert pong == snapshot(EmptyResult()) - assert len(incoming) == 1 - assert isinstance(incoming[0], RuntimeError) - # The full message embeds the response object's repr; only the prefix is stable. - assert str(incoming[0]).startswith("Received response with an unknown request ID:") + # The stream is ordered, so the fabricated response was routed before the control + # notification: only the control surfaced, so the unknown-id response was dropped. + assert surfaced == snapshot([types.ToolListChangedNotification()]) + + +@requirement("protocol:cancel:initialize-not-cancellable") +async def test_timed_out_initialize_sends_no_cancellation() -> None: + """An abandoned initialize is not followed by notifications/cancelled on the wire (spec-mandated). + + A real Server always answers initialize, so the test plays a stalling server by hand. + """ + received_methods: list[str] = [] + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + # Hold the initialize request unanswered until the client's read timeout fires. + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + received_methods.append(init.message.method) + + follow_up = await server_read.receive() + assert isinstance(follow_up, SessionMessage) + assert isinstance(follow_up.message, JSONRPCRequest) + received_methods.append(follow_up.message.method) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=follow_up.message.id, + result=EmptyResult().model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + # The session-level read timeout is the only public pathway that abandons initialize. + ClientSession(client_read, client_write, read_timeout_seconds=0.000001) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.initialize() + assert exc_info.value.error.code == REQUEST_TIMEOUT + # Override the session-level timeout: this ping must round-trip normally. + pong = await session.send_request(PingRequest(), EmptyResult, request_read_timeout_seconds=5) + + assert pong == snapshot(EmptyResult()) + # The stream is ordered, so a courtesy cancel would have arrived ahead of the ping. + assert received_methods == snapshot(["initialize", "ping"]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index fba632ef4d..b8f9d3d776 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -1,12 +1,8 @@ """Logging interactions against the low-level Server, driven through the public Client API. -Notification ordering: the in-memory transport delivers every server-to-client message on one -ordered stream, and the client's receive loop dispatches each incoming message to completion -before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds -only for messages that carry a ``related_request_id`` (they ride the originating request's POST -stream); without it the message routes to the standalone GET stream and may arrive after the -response. These tests pass ``related_request_id`` so they can collect into a plain list and -assert after the request completes on every transport leg -- no events, no waiting. +Notification ordering: await-free callbacks finish in arrival order, and passing +``related_request_id`` keeps each notification on the originating request's POST stream over +streamable HTTP, so plain-list collection is deterministic on every transport leg. """ import pytest diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 6350c33a33..a89039b99e 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -87,8 +87,8 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N async with connect(server) as client: result = await client.call_tool("inspect", {}, progress_callback=ignore) - # The token is the request id of the tools/call request itself (initialize is request 0). - assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + # The token is the request id of the tools/call request itself (initialize is request 1). + assert result == snapshot(CallToolResult(content=[TextContent(text="2")])) @requirement("protocol:progress:no-token") diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index b440f32106..903829845e 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -13,9 +13,13 @@ from trio.testing import MockClock from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext -from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from mcp.shared.message import SessionMessage +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, JSONRPCNotification, TextContent +from tests.interaction._helpers import RecordingTransport from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -26,16 +30,19 @@ async def test_request_timeout_fails_the_pending_call() -> None: """A request whose response does not arrive within its read timeout fails with a timeout error. - No cancellation is sent to the server (see the divergence note on the requirement): the handler - starts and is still running after the caller has already given up. The test waits for the - handler to have started only after the timeout has fired, so the timeout itself races nothing. + The timeout is followed by notifications/cancelled, which interrupts the server's handler. """ handler_started = anyio.Event() + handler_cancelled = anyio.Event() async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" handler_started.set() - await anyio.Event().wait() # blocks until the session is torn down + try: + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise raise NotImplementedError # unreachable server = Server("blocker", on_call_tool=call_tool) @@ -44,18 +51,79 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}, read_timeout_seconds=0.000001) - # The request was already on the wire: the handler still runs even though the caller gave up. + # The request was already on the wire: the handler started and was then cancelled. with anyio.fail_after(5): await handler_started.wait() + await handler_cancelled.wait() assert exc_info.value.error == snapshot( ErrorData( code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + message="Request 'tools/call' timed out", ) ) +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_server_request_timeout_sends_cancellation_to_the_client() -> None: + """A server-initiated request that times out fails server-side and cancels the client's work. + + The sampling callback answers only after the server gave up; the late response is discarded. + """ + release = anyio.Event() + callback_started = anyio.Event() + errors: list[ErrorData] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "impatient" + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=8, + ) + ) + with pytest.raises(MCPError) as exc_info: + await ctx.session.send_request(request, types.CreateMessageResult, request_read_timeout_seconds=0.000001) + errors.append(exc_info.value.error) + release.set() + return CallToolResult(content=[TextContent(text="gave up")]) + + server = Server("impatient", on_list_tools=list_tools, on_call_tool=call_tool) + recording = RecordingTransport(InMemoryTransport(server)) + + async def sampling_callback( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + callback_started.set() + with anyio.fail_after(5): + await release.wait() + return types.CreateMessageResult(role="assistant", content=TextContent(text="too late"), model="test-model") + + async with Client(recording, sampling_callback=sampling_callback) as client: + result = await client.call_tool("impatient", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="gave up")])) + assert callback_started.is_set() + assert errors == snapshot([ErrorData(code=REQUEST_TIMEOUT, message="Request 'sampling/createMessage' timed out")]) + cancellations = [ + item.message + for item in recording.received + if isinstance(item, SessionMessage) + and isinstance(item.message, JSONRPCNotification) + and item.message.method == "notifications/cancelled" + ] + # requestId 1 is the sampling request, the server's first outbound request. + assert [notification.params for notification in cancellations] == snapshot( + [{"requestId": 1, "reason": "timed out after 1e-06s"}] + ) + + @requirement("protocol:timeout:session-survives") async def test_session_serves_requests_after_timeout() -> None: """A timed-out request does not poison the session: the next request succeeds.""" @@ -73,7 +141,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: if params.name == "echo": return CallToolResult(content=[TextContent(text="still alive")]) - await anyio.Event().wait() # blocks until the session is torn down + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it raise NotImplementedError # unreachable server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) @@ -105,7 +173,7 @@ async def test_session_level_timeout_applies_to_every_request() -> None: async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" - await anyio.Event().wait() # blocks until the session is torn down + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it raise NotImplementedError # unreachable server = Server("blocker", on_call_tool=call_tool) @@ -117,6 +185,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert exc_info.value.error == snapshot( ErrorData( code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + message="Request 'tools/call' timed out", ) ) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 0f9c58aa7a..178c2c1c38 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -61,7 +61,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_request_ids_are_unique_and_never_null() -> None: """Every request the client sends carries a distinct, non-null id. - The id sequence is pinned: sequential integers from zero, in send order. + The id sequence is pinned: sequential integers from one, in send order. """ recording = RecordingTransport(InMemoryTransport(_echo_server())) @@ -77,7 +77,7 @@ async def test_request_ids_are_unique_and_never_null() -> None: assert len(request_ids) == len(set(request_ids)) # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a # schema-cache refresh here because the explicit tools/list already populated the cache. - assert request_ids == snapshot([0, 1, 2, 3, 4]) + assert request_ids == snapshot([1, 2, 3, 4, 5]) @requirement("protocol:notifications:no-response") diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 6b593d2a54..b1c6a4f709 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -12,7 +12,14 @@ from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent +from mcp.types import ( + REQUEST_TIMEOUT, + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, +) @pytest.mark.anyio @@ -55,7 +62,8 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" if params.name == "slow": - await slow_request_lock.wait() # it should timeout here + # The client's timeout fires during this wait; the courtesy cancellation then interrupts it. + await slow_request_lock.wait() text = f"slow {request_count}" else: text = f"fast {request_count}" @@ -95,9 +103,9 @@ async def client( # Use very small timeout to trigger quickly without waiting with pytest.raises(MCPError) as exc_info: await session.call_tool("slow", read_timeout_seconds=0.000001) # artificial timeout that always fails - assert "Timed out while waiting" in str(exc_info.value) + assert exc_info.value.error.code == REQUEST_TIMEOUT - # release the slow request not to have hanging process + # No-op if the courtesy cancellation already interrupted the handler. slow_request_lock.set() # Third call should work (fast operation, no timeout), diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 9292586b32..26908ed16e 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, Field from mcp import Client, types -from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client import ClientRequestContext +from mcp.client.session import ElicitationFnT from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -64,7 +64,7 @@ async def test_elicitation_accept_returns_the_users_answer_to_the_tool(): create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: # pragma: no cover @@ -81,7 +81,7 @@ async def test_elicitation_decline_reaches_the_tool_without_content(): mcp = MCPServer(name="ElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -119,9 +119,7 @@ class InvalidNestedSchema(BaseModel): create_validation_tool("nested_model", InvalidNestedSchema) # Dummy callback (won't be called due to validation failure) - async def elicitation_callback( - context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover return ElicitResult(action="accept", content={}) async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -176,7 +174,7 @@ async def optional_tool(ctx: Context) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -194,9 +192,7 @@ async def invalid_optional_tool(ctx: Context) -> str: except TypeError as e: return f"Validation failed: {str(e)}" - async def elicitation_callback( - context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover return ElicitResult(action="accept", content={}) await call_tool_and_assert( @@ -219,7 +215,7 @@ async def valid_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" return f"User {result.action}" # pragma: no cover - async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams): if "Please provide tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -239,7 +235,7 @@ async def optional_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {tags_str}" return f"User {result.action}" # pragma: no cover - async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def optional_multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams): if "Please provide optional tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -273,7 +269,7 @@ async def defaults_tool(ctx: Context) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_schema_verify(context: ClientRequestContext, params: ElicitRequestParams): # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -295,7 +291,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession], params: ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_override(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) @@ -371,7 +367,7 @@ async def select_color_legacy(ctx: Context) -> str: return f"User: {result.data.user_name}, Color: {result.data.color}" return f"User {result.action}" # pragma: no cover - async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def enum_callback(context: ClientRequestContext, params: ElicitRequestParams): if "colors" in params.message and "legacy" not in params.message: return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) elif "color" in params.message: diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index f71c0574cd..5bac39dfee 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -26,9 +26,7 @@ structured_output, tool_progress, ) -from mcp.client import Client -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext +from mcp.client import Client, ClientRequestContext from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, @@ -80,9 +78,7 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -async def sampling_callback( - context: RequestContext[ClientSession], params: CreateMessageRequestParams -) -> CreateMessageResult: +async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( role="assistant", @@ -94,7 +90,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): +async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 21352b5f2f..60d30342c4 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1098,10 +1098,10 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert "Logged messages for test" in content.text assert mock_log.call_count == 4 - mock_log.assert_any_call(level="debug", data="Debug message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="info", data="Info message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="1") + mock_log.assert_any_call(level="debug", data="Debug message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="info", data="Info message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="2") async def test_optional_context(self): """Test that context is optional.""" diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index af90dc208b..9ab03fcdab 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -5,10 +5,9 @@ from pydantic import BaseModel, Field from mcp import Client, types -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -28,7 +27,7 @@ async def request_api_key(ctx: Context) -> str: return f"User {result.action}" # Create elicitation callback that accepts URL mode - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" assert params.url == "https://example.com/api_key_setup" assert params.elicitation_id == "test-elicitation-001" @@ -57,7 +56,7 @@ async def oauth_flow(ctx: Context) -> str: # Test only checks decline path return f"User {result.action} authorization" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="decline") @@ -83,7 +82,7 @@ async def payment_flow(ctx: Context) -> str: # Test only checks cancel path return f"User {result.action} payment" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="cancel") @@ -110,7 +109,7 @@ async def setup_credentials(ctx: Context) -> str: # Test only checks accept path - return the type name return type(result).__name__ - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -137,7 +136,7 @@ async def check_url_response(ctx: Context) -> str: assert result.content is None return f"Action: {result.action}, Content: {result.content}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify that this is URL mode assert params.mode == "url" assert isinstance(params, types.ElicitRequestURLParams) @@ -170,7 +169,7 @@ async def ask_name(ctx: Context) -> str: assert result.data is not None return f"Hello, {result.data.name}!" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify form mode parameters assert params.mode == "form" assert isinstance(params, types.ElicitRequestFormParams) @@ -206,7 +205,7 @@ async def trigger_elicitation(ctx: Context) -> str: return "Elicitation completed" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -263,7 +262,7 @@ async def test_cancel(ctx: Context) -> str: return "Not cancelled" # pragma: no cover # Test declined result - async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def decline_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="decline") async with Client(mcp, elicitation_callback=decline_callback) as client: @@ -273,7 +272,7 @@ async def decline_callback(context: RequestContext[ClientSession], params: Elici assert result.content[0].text == "Declined" # Test cancelled result - async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def cancel_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="cancel") async with Client(mcp, elicitation_callback=cancel_callback) as client: @@ -303,7 +302,7 @@ async def use_deprecated_elicit(ctx: Context) -> str: return f"Email: {result.content.get('email', 'none')}" return "No email provided" # pragma: no cover - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify this is form mode assert params.mode == "form" assert params.requested_schema is not None @@ -331,7 +330,7 @@ async def direct_elicit_url(ctx: Context) -> str: ) return f"Result: {result.action}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") diff --git a/tests/shared/test_context_streams.py b/tests/shared/test_context_streams.py new file mode 100644 index 0000000000..b035892303 --- /dev/null +++ b/tests/shared/test_context_streams.py @@ -0,0 +1,20 @@ +"""Tests for the contextvars-carrying memory-stream wrappers.""" + +import anyio +import pytest + +from mcp.shared._context_streams import create_context_streams + +pytestmark = pytest.mark.anyio + + +async def test_sync_close_closes_the_underlying_streams() -> None: + """The wrappers mirror anyio's memory streams: close() is the sync form of aclose().""" + send, receive = create_context_streams[str](1) + await send.send("queued") + send.close() + receive.close() + with pytest.raises(anyio.ClosedResourceError): + await send.send("after close") + with pytest.raises(anyio.ClosedResourceError): + await receive.receive() diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 745f4b3875..01150a21ca 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -228,13 +228,15 @@ async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair @pytest.mark.anyio async def test_ctx_request_id_exposes_inbound_id(pair_factory: PairFactory): - """JSON-RPC carries the wire id through; direct dispatch has none.""" + """Every dispatcher assigns each inbound request a distinct int id; JSON-RPC carries + the wire id through, DirectDispatcher synthesizes one (SDK-defined).""" async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await client.send_raw_request("tools/call", None) a, b = (ctx.request_id for ctx in srec.contexts) - assert (a is None and b is None) or (isinstance(a, int) and isinstance(b, int) and a != b) + assert isinstance(a, int) and isinstance(b, int) + assert a != b @pytest.mark.anyio diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264d..c6b5750928 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -3,7 +3,7 @@ import pytest from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError def test_url_elicitation_required_error_create_with_single_elicitation() -> None: @@ -162,3 +162,14 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +def test_from_jsonrpc_error_preserves_code_message_and_data() -> None: + """Building an MCPError from a wire JSONRPCError keeps every error field.""" + wire = JSONRPCError( + jsonrpc="2.0", + id=3, + error=ErrorData(code=URL_ELICITATION_REQUIRED, message="go elsewhere", data={"hint": "y"}), + ) + error = MCPError.from_jsonrpc_error(wire) + assert error.error == ErrorData(code=URL_ELICITATION_REQUIRED, message="go elsewhere", data={"hint": "y"}) diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index b2a24c87dc..f6b51dd5c2 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -1,20 +1,16 @@ -"""JSON-RPC-specific Dispatcher tests. - -Behaviors with no `DirectDispatcher` analog: request-id correlation, the -exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. -The contract tests shared with `DirectDispatcher` live in -`test_dispatcher.py`. -""" +"""JSON-RPC-specific dispatcher tests; contract tests shared with `DirectDispatcher` live in `test_dispatcher.py`.""" import contextvars import json import logging from collections.abc import Mapping +from types import TracebackType from typing import Any import anyio import anyio.lowlevel import pytest +from trio.testing import MockClock from mcp import Client from mcp.server import Server, ServerRequestContext @@ -24,8 +20,9 @@ from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, _coerce_id, - _outbound_metadata, + _OutboundPlan, _Pending, + _plan_outbound, ) from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext @@ -33,6 +30,7 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, + REQUEST_TIMEOUT, CallToolRequest, CallToolRequestParams, CallToolResult, @@ -53,6 +51,30 @@ DCtx = DispatchContext[TransportContext] +class RecordingWriteStream: + """Records sends without a checkpoint, so a pending cancellation cannot interrupt the write or mask it.""" + + def __init__(self) -> None: + self.sent: list[SessionMessage] = [] + + async def send(self, item: SessionMessage) -> None: + self.sent.append(item) + + async def aclose(self) -> None: + raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + + async def __aenter__(self) -> "RecordingWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + @pytest.mark.anyio async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): release_first = anyio.Event() @@ -129,6 +151,49 @@ async def call_then_record() -> None: assert seen_error == [ErrorData(code=0, message="Request cancelled")] +@pytest.mark.anyio +async def test_peer_cancel_landing_after_handlers_last_checkpoint_writes_only_the_result(): + """A peer cancel that fails to interrupt the handler writes only the result: one answer per + id goes on the wire (SDK-defined). The recording stream is needed because a memory stream's + `send` checkpoints, letting the deferred cancellation land mid-write and hide a double answer.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + recording = RecordingWriteStream() + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + handler_started = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + return {"completed": "after-cancel"} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + # The cancel is also the handler's wakeup, so anyio defers it and the handler completes. + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ) + ) + ) + # Quiesce: the handler has resumed, completed, and exited its scope. + await anyio.wait_all_tasks_blocked() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + assert [m.message for m in recording.sent] == [ + JSONRPCResponse(jsonrpc="2.0", id=1, result={"completed": "after-cancel"}) + ] + + @pytest.mark.anyio async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): handler_started = anyio.Event() @@ -142,7 +207,6 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | def factory(*, can_send_request: bool = True): client, server, close = jsonrpc_pair(can_send_request=can_send_request) - # Reach in to set signal mode on the server side. assert isinstance(server, JSONRPCDispatcher) server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] return client, server, close @@ -189,17 +253,12 @@ async def caller() -> None: @pytest.mark.anyio async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): - """Iterating a closed receive end raises ClosedResourceError; run() treats it as EOF. - - Stateless SHTTP teardown closes the dispatcher's receive end after the - request is handled; the next loop iteration must not surface as a crash. - """ + """Iterating a closed receive end is EOF, not a crash (stateless SHTTP closes it during teardown).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) on_request, on_notify = echo_handlers(Recorder()) - # Close the dispatcher's own receive end (not the send end) before run() - # iterates it: __anext__ on a closed stream raises ClosedResourceError. + # Close the receive end itself (not the send end): __anext__ then raises ClosedResourceError. c2s_recv.close() with anyio.fail_after(5): await server.run(on_request, on_notify) @@ -209,12 +268,8 @@ async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): @pytest.mark.anyio async def test_run_cancels_in_flight_handlers_when_read_stream_eofs(): - """A handler that outlives its caller must not keep run() from returning. - - Without the cancel-at-EOF, the task-group join would wait on this handler - forever (over SSE that leaks the handler task and the GET request hosting - the session). - """ + """run() cancels still-running handlers at read-stream EOF; otherwise its join waits forever + (over SSE, leaking the handler and the GET request hosting the session).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -260,7 +315,6 @@ async def test_run_closes_write_stream_on_exit(): await tg.start(server.run, on_request, on_notify) c2s_send.close() # EOF the read side; run() exits with anyio.fail_after(5), pytest.raises(anyio.EndOfStream): # pragma: no branch - # Write end was entered and released by run(); peer's receive sees EOF. await s2c_recv.receive() s2c_recv.close() @@ -279,8 +333,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): with pytest.raises(MCPError): # REQUEST_TIMEOUT await client.send_raw_request("slow", None, {"timeout": 0}) - # The server handler is still running; let it finish and write a - # response for an id the client has already discarded. + # Let the parked handler respond to an id the client has already discarded. await handler_started.wait() proceed.set() # One more round-trip proves the dispatcher is still healthy. @@ -309,7 +362,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> with pytest.raises(BaseException) as exc: async with anyio.create_task_group() as tg: await tg.start(server.run, boom, on_notify) - # Inject a request directly onto the server's read stream. await c2s_send.send( SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) ) @@ -340,7 +392,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> try: async with anyio.create_task_group() as tg: await tg.start(server.run, server_on_request, on_notify) - # Kick the server with an inbound request id=7. await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) with anyio.fail_after(5): outbound = await s2c_recv.receive() @@ -348,7 +399,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> assert isinstance(outbound.message, JSONRPCRequest) assert isinstance(outbound.metadata, ServerMessageMetadata) assert outbound.metadata.related_request_id == 7 - # Reply so the handler completes cleanly. await c2s_send.send( SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) ) @@ -365,12 +415,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_courtesy_cancel_on_timeout_tags_outbound_with_server_message_metadata(): - """The timeout-path `notifications/cancelled` carries the originating request id. - - Streamable-HTTP's `message_router` keys on `ServerMessageMetadata.related_request_id`; - a cancel without it would fall through to the standalone GET stream and be dropped - when no GET stream is open, so the client never learns to stop work. - """ + """The timeout-path `notifications/cancelled` carries the originating request id: SHTTP's + `message_router` keys on `related_request_id`; without it the cancel would be dropped.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -413,6 +459,549 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() +@pytest.mark.anyio +async def test_dispatch_context_request_with_dropped_resumption_hints_still_sends_courtesy_cancel(): + """Resumption hints that never reach the transport must not suppress the abandon cancel: + `related_request_id` takes metadata precedence and drops the hints, so the request is not resumable.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await ctx.send_raw_request("sampling/createMessage", None, {"timeout": 0, "resumption_token": "tok"}) + return {"gave_up": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + # The hints were dropped: dispatch-context routing won the metadata. + assert isinstance(outbound.metadata, ServerMessageMetadata) + sampling_id = outbound.message.id + # Don't respond; let the timeout fire. Next on the wire must be the courtesy cancel. + with anyio.fail_after(5): + cancel = await s2c_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": sampling_id, "reason": "timed out after 0s"} + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.result == {"gave_up": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_caller_cancel_sends_courtesy_cancellation_on_the_wire(): + """Cancelling the scope around send_raw_request emits notifications/cancelled by default.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + cancel = await c2s_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": request.message.id, "reason": "caller cancelled"} + assert cancel.metadata is None + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert scopes[0].cancelled_caught + + +@pytest.mark.anyio +async def test_caller_cancel_during_blocked_request_write_sends_no_cancelled_notification(): + """A caller cancelled mid-request-write must not emit `notifications/cancelled` (the spec only + allows cancelling issued requests). The fake stream wedges only the first write, so a later + courtesy cancel - the bug - would still be captured.""" + + class FirstWriteWedgedStream: + def __init__(self) -> None: + self.sent: list[SessionMessage] = [] + self.first_write_started = anyio.Event() + + async def send(self, item: SessionMessage) -> None: + if not self.first_write_started.is_set(): + self.first_write_started.set() + await anyio.sleep_forever() # the request write wedges until the caller is cancelled + self.sent.append(item) + + async def aclose(self) -> None: + raise NotImplementedError + + async def __aenter__(self) -> "FirstWriteWedgedStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + wedged = FirstWriteWedgedStream() + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, wedged) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + await wedged.first_write_started.wait() # the caller is parked in the request write + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + await client.notify("notifications/marker", None) + tg.cancel_scope.cancel() + finally: + s2c_send.close() + s2c_recv.close() + assert scopes[0].cancelled_caught + # Only the marker went out post-wedge: no cancel for a request the peer never received. + assert [m.message for m in wedged.sent] == [JSONRPCNotification(jsonrpc="2.0", method="notifications/marker")] + + +@pytest.mark.anyio +async def test_caller_cancel_with_resumption_hints_suppresses_the_courtesy_cancellation(): + """A request sent with resumption hints is meant to be resumed; abandoning it must not stop the peer's work.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + async def on_token(token: str) -> None: + raise NotImplementedError + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None, {"on_resumption_token": on_token}) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + # A courtesy cancel would have to precede the marker on the ordered stream. + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_timeout_with_resumption_hints_suppresses_the_courtesy_cancellation(): + """A timed-out request that carries resumption hints stays resumable: no cancellation is sent.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0, "resumption_token": "tok"}) + assert exc.value.error.code == REQUEST_TIMEOUT + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_on_abandon_false_suppresses_the_courtesy_cancellation_on_timeout(): + """Callers opt out per call for requests the protocol forbids cancelling (initialize).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0, "cancel_on_abandon": False}) + assert exc.value.error.code == REQUEST_TIMEOUT + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): + """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang: + `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy-cancel write (SDK-defined bound). On regression + the test hangs rather than failing fast - fail_after cannot cancel through the shield.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + # Both bounds exceed the in-loop _SHIELDED_WRITE_TIMEOUT (5s); the virtual clock makes them instant. + with anyio.fail_after(30): + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + # Consume only the request; the later courtesy cancel finds no reader and wedges. + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + scopes[0].cancel() + with anyio.fail_after(20): + await gave_up.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert scopes[0].cancelled_caught + # The warning proves it was the bound (not a completed write) that released the shield. + assert "courtesy cancel for caller-cancelled request" in caplog.text + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): + """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined + bound): `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy cancel so the error still surfaces.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + errors: list[MCPError] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 1}) + errors.append(exc.value) + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + # Consume only the request; the later courtesy cancel finds no reader and wedges. + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + # Exceeds the request timeout (1s) plus _SHIELDED_WRITE_TIMEOUT (5s); virtual clock, no wall time. + with anyio.fail_after(10): + await gave_up.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert errors[0].error.code == REQUEST_TIMEOUT + assert "courtesy cancel for timed-out request" in caplog.text + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): + """Cancelling the task group hosting run() completes even when the shutdown error write wedges: + only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). The fake stream is needed + because run()'s teardown closes a memory stream, which would wake the blocked send.""" + + class WedgedWriteStream: + async def send(self, item: SessionMessage) -> None: + await anyio.sleep_forever() + + async def aclose(self) -> None: + raise NotImplementedError + + async def __aenter__(self) -> "WedgedWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, WedgedWriteStream()) + handler_started = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await anyio.sleep_forever() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + # 3s sits between _SHUTDOWN_WRITE_TIMEOUT (1s) and _SHIELDED_WRITE_TIMEOUT (5s): pins the tighter bound. + with anyio.fail_after(3): + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(server.run, park, on_notify) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + await handler_started.wait() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + # The warning proves the bound (not a completed write) released the join. + assert "shutdown error response for request" in caplog.text + + +@pytest.mark.anyio +async def test_shutdown_answers_in_flight_request_with_connection_closed(): + """Cancelling run() answers a still-running request with CONNECTION_CLOSED (SDK-defined). The + recording stream is needed because run()'s exit would close a memory stream before the shielded write lands.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + recording = RecordingWriteStream() + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + handler_started = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await anyio.sleep_forever() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, park, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + assert [m.message for m in recording.sent] == [ + JSONRPCError(jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed")) + ] + + +@pytest.mark.anyio +async def test_request_write_failure_propagates_and_leaves_no_pending_entry(): + """A request whose transport write raises must not leak its `_pending` entry (v1 regression cover).""" + boom = RuntimeError("write failed") + + class RaisingWriteStream: + async def send(self, item: SessionMessage) -> None: + raise boom + + async def aclose(self) -> None: + raise NotImplementedError + + async def __aenter__(self) -> "RaisingWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, RaisingWriteStream()) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5), pytest.raises(RuntimeError) as exc: + await client.send_raw_request("ping", None) + assert exc.value is boom + assert client._pending == {} # pyright: ignore[reportPrivateUsage] + tg.cancel_scope.cancel() + finally: + s2c_send.close() + s2c_recv.close() + + +@pytest.mark.anyio +async def test_request_write_on_torn_down_transport_raises_connection_closed(): + """A write onto a torn-down transport surfaces as MCPError(CONNECTION_CLOSED), not a raw `BrokenResourceError`.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + # Close only the peer's receive end, so run() has not observed EOF when the write fails. + c2s_recv.close() + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_notification_handler_exception_is_contained(caplog: pytest.LogCaptureFixture): + """A raising notification handler costs only that notification, never the connection (parity with TS/C#/Go).""" + + async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise RuntimeError("notify boom") + + async with running_pair(jsonrpc_pair, server_on_notify=server_on_notify) as (client, *_): + with anyio.fail_after(5): + await client.notify("boom", None) + # The connection survived: a full round-trip still works. + result = await client.send_raw_request("ping", None) + assert result == {"echoed": "ping", "params": {}} + assert "notification handler for 'boom' raised" in caplog.text + + +@pytest.mark.anyio +async def test_spawned_notification_handlers_run_concurrently(): + """Notification handlers are spawned, not serialized (parity with TS/C#): the first handler + waits for the second to start, so serialized dispatch would deadlock here.""" + second_started = anyio.Event() + completed: list[str] = [] + done = anyio.Event() + + async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + if method == "first": + await second_started.wait() + else: + second_started.set() + completed.append(method) + if len(completed) == 2: + done.set() + + async with running_pair(jsonrpc_pair, server_on_notify=server_on_notify) as (client, *_): + with anyio.fail_after(5): + await client.notify("first", None) + await client.notify("second", None) + await done.wait() + assert completed == ["second", "first"] + + @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_request_metadata(): """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" @@ -503,12 +1092,8 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio async def test_ctx_after_handler_return_reports_closed_and_drops_backchannel_traffic(): - """Once `_handle_request` closes the dctx, the back-channel guard and ops agree. - - Detached work that outlives the handler must see `can_send_request == False`, - get `NoBackChannelError` from `send_raw_request`, and have `notify`/`progress` - silently dropped rather than emitted with a stale `related_request_id`. - """ + """After `_handle_request` closes the dctx, `can_send_request` is False, `send_raw_request` raises + NoBackChannelError, and `notify`/`progress` are dropped rather than sent with a stale `related_request_id`.""" captured: list[DCtx] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -528,8 +1113,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) await dctx.send_raw_request("sampling/createMessage", None) await dctx.notify("notifications/message", {"level": "info"}) await dctx.progress(0.9) - # A second round-trip flushes any notification the server might have - # written, so an empty client recorder afterwards proves the drop. + # A second round-trip flushes any server write; an empty recorder then proves the drop. await client.send_raw_request("ping", None) assert crec.notifications == [] @@ -549,15 +1133,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): result = await client.send_raw_request("t", None, opts) - # Request still completes; the callback's crash was swallowed. assert result == {"ok": True} assert "progress callback raised" in caplog.text @pytest.mark.anyio async def test_inline_methods_are_handled_before_next_message_is_dequeued(): - """A method in `inline_methods` runs to completion before subsequent - messages are dispatched, so its side effects are visible to them.""" + """An `inline_methods` method runs to completion before the next message is dispatched.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( @@ -589,10 +1171,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_send_raw_request_always_carries_meta_on_the_wire(): - """Outbound requests always include `params._meta` (otel injection per SEP-414). - - Caller-supplied `_meta` keys are preserved; the progress token is merged in. - """ + """Outbound requests always carry `params._meta` (otel injection per SEP-414); caller-supplied + keys are preserved and the progress token is merged in.""" seen: list[Mapping[str, Any] | None] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -607,9 +1187,7 @@ async def noop_progress(progress: float, total: float | None, message: str | Non with anyio.fail_after(5): await client.send_raw_request("a", None) await client.send_raw_request("b", {"x": 1, "_meta": {"k": "v"}}, opts) - # `_meta` is always present. Its contents depend on the active otel - # tracer (traceparent/tracestate may be injected), so assert presence - # and that anything beyond W3C keys is exactly what we expect. + # `_meta` contents depend on the active otel tracer, so pin only what sits beyond the W3C keys. w3c = {"traceparent", "tracestate"} assert seen[0] is not None and seen[0].keys() == {"_meta"} assert set(seen[0]["_meta"].keys()) <= w3c @@ -643,6 +1221,26 @@ async def test_send_raw_request_before_run_raises_runtimeerror(): s.close() +@pytest.mark.anyio +async def test_send_raw_request_after_connection_close_raises_connection_closed(): + """Sending after run() saw EOF raises MCPError(CONNECTION_CLOSED) — the same contract + in-flight waiters get — not RuntimeError (SDK-defined).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + s2c_send.close() # peer drops: run() sees immediate EOF and returns + with anyio.fail_after(5): + await client.run(on_request, on_notify) + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_transport_exception_in_read_stream_is_logged_and_dropped(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) @@ -665,6 +1263,63 @@ async def test_transport_exception_in_read_stream_is_logged_and_dropped(): s.close() +@pytest.mark.anyio +async def test_on_stream_exception_observes_transport_exceptions(): + """With an observer set, Exception items reach it instead of being dropped; the loop stays healthy.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + + seen: list[Exception] = [] + + async def observe(exc: Exception) -> None: + seen.append(exc) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe) + on_request, on_notify = echo_handlers(Recorder()) + hiccup = ValueError("transport hiccup") + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(hiccup) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [hiccup] + + +@pytest.mark.anyio +async def test_on_stream_exception_observer_raising_is_contained(caplog: pytest.LogCaptureFixture): + """A raising observer costs the item, not the connection: it runs in the read loop itself.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + + async def observe(exc: Exception) -> None: + raise RuntimeError("observer boom") + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert "on_stream_exception observer raised" in caplog.text + + @pytest.mark.anyio async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): @@ -722,8 +1377,7 @@ async def call() -> None: @pytest.mark.anyio @pytest.mark.parametrize("inline", [frozenset[str](), frozenset({"t"})], ids=["spawned", "inline"]) async def test_handler_inherits_sender_contextvars(inline: frozenset[str]): - """The handler task sees contextvars set by the task that wrote into the - read stream, on both the spawned and the inline-method dispatch paths.""" + """The handler sees the sender's contextvars on both the spawned and the inline-method dispatch paths.""" raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) write_send = ContextSendStream[SessionMessage | Exception](raw_send) @@ -823,9 +1477,7 @@ async def caller() -> None: sent = await c2s_recv.receive() assert isinstance(sent, SessionMessage) assert isinstance(sent.message, JSONRPCRequest) - # Now safe: close the client's write end, then cancel the caller. The - # shielded `_cancel_outbound` write hits ClosedResourceError and is - # swallowed; cancellation propagates cleanly. + # The shielded `_cancel_outbound` write now hits ClosedResourceError and is swallowed. c2s_send.close() caller_scope.cancel() with anyio.fail_after(5): @@ -855,7 +1507,6 @@ def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) - # Register a fake pending and pre-fill its single buffer slot. d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] send.send_nowait({"real": "result"}) d._fan_out_closed() # pyright: ignore[reportPrivateUsage] @@ -866,12 +1517,14 @@ def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): s.close() -def test_outbound_metadata_with_resumption_token_returns_client_metadata(): - md = _outbound_metadata(None, {"resumption_token": "abc"}) - assert isinstance(md, ClientMessageMetadata) - assert md.resumption_token == "abc" - assert _outbound_metadata(None, None) is None - assert _outbound_metadata(None, {}) is None +def test_plan_outbound_with_resumption_token_returns_client_metadata_and_suppresses_abandon_cancel(): + """Hints that reach the transport make the request resumable, so abandoning it must not cancel the peer's work.""" + plan = _plan_outbound(None, {"resumption_token": "abc"}) + assert isinstance(plan.metadata, ClientMessageMetadata) + assert plan.metadata.resumption_token == "abc" + assert plan.cancel_on_abandon is False + assert _plan_outbound(None, None) == _OutboundPlan(metadata=None, cancel_on_abandon=True) + assert _plan_outbound(None, {}) == _OutboundPlan(metadata=None, cancel_on_abandon=True) @pytest.mark.anyio @@ -905,6 +1558,43 @@ async def respond_stringly() -> None: s.close() +@pytest.mark.anyio +async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (same `_coerce_id` path).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def reject_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCError( + jsonrpc="2.0", id=str(rid), error=ErrorData(code=INVALID_PARAMS, message="bad cursor") + ) + ) + ) + + tg.start_soon(reject_stringly) + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" # the peer's error, passed through + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) @@ -957,7 +1647,7 @@ def test_coerce_id_passes_through_non_numeric_string_and_int(): @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): - """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + """Parse-error responses (id=null) have no waiter; they're dropped and the read loop stays healthy.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -968,7 +1658,18 @@ async def test_jsonrpc_error_response_with_null_id_is_dropped(): await s2c_send.send( SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) ) - await anyio.sleep(0) + with anyio.fail_after(5): + # Ordered stream: this round-trip completing proves the null-id error was consumed. + async def respond() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={"ok": True})) + ) + + tg.start_soon(respond) + assert await client.send_raw_request("ping", None) == {"ok": True} tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -977,12 +1678,7 @@ async def test_jsonrpc_error_response_with_null_id_is_dropped(): @pytest.mark.anyio async def test_notify_without_params_omits_params_key_on_the_wire(): - """JSON-RPC 2.0 forbids `params: null`; the member must be absent. - - Transports serialize with `exclude_unset=True`, so `notify` must leave - `params` unset on the model rather than passing an explicit None (strict - peers like the TypeScript SDK reject `"params": null`). - """ + """JSON-RPC 2.0 forbids `params: null`: `notify` leaves `params` unset (transports use `exclude_unset=True`).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1102,9 +1798,7 @@ async def call() -> None: tg.start_soon(call) await handler_started.wait() await client.notify("notifications/cancelled", {"requestId": True}) - # The malformed cancel is teed to on_notify; once observed, the - # correlation arm has already run - and must not have cancelled - # the request keyed by id 1. + # Once the teed notification is observed, the correlation arm has already run. await srec.notified.wait() assert not handler_exited.is_set() await client.notify("notifications/cancelled", {"requestId": 1}) @@ -1161,8 +1855,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) @pytest.mark.anyio async def test_request_with_bool_meta_progress_token_is_not_adopted(): - """A bool `_meta.progressToken` is malformed; `ctx.progress()` must be a no-op - instead of emitting `progressToken: true` on the wire.""" + """A bool `_meta.progressToken` is malformed: `ctx.progress()` must be a no-op, not emit `progressToken: true`.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -1200,8 +1893,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ids=["string-cancel-for-int-request", "int-cancel-for-string-request"], ) async def test_cancelled_correlates_across_string_and_int_request_id_forms(request_id: RequestId, cancel_id: object): - """A peer that stringifies the id between request and cancel still cancels - (same `_coerce_id` treatment as the response path).""" + """A peer that stringifies the id between request and cancel still cancels (same `_coerce_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -1240,9 +1932,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_completed_handler_does_not_evict_reused_request_id_from_in_flight(): - """The awaited response write sits after the `_in_flight` pop; a second - request reusing the id during that window must keep its own entry (a - second, post-write pop would evict it and break its peer-cancellation).""" + """A second request reusing an id while the first handler is parked in its response write + keeps its own `_in_flight` entry (a post-write pop would evict it and break peer-cancellation).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) # buffer=0: the first handler's response write parks until the test receives. s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) @@ -1300,28 +1991,90 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() -def test_outbound_metadata_with_related_request_id_drops_resumption_hints_with_debug_log( +@pytest.mark.anyio +async def test_duplicate_request_id_completion_of_first_handler_keeps_second_cancellable(): + """A duplicate inbound id overwrites `_in_flight` (parity with v1/TS); the identity-guarded pop + keeps the first handler's completion from evicting the second's entry and breaking its cancellation.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + first_started = anyio.Event() + release_first = anyio.Event() + second_started = anyio.Event() + second_exited = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + first_started.set() + await release_first.wait() + return {"first": True} + second_started.set() + try: + await anyio.sleep_forever() + finally: + second_exited.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="first"))) + await first_started.wait() + # Duplicate id: the table entry now belongs to the second request. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="second"))) + await second_started.wait() + release_first.set() + resp1 = await s2c_recv.receive() + assert isinstance(resp1, SessionMessage) + assert isinstance(resp1.message, JSONRPCResponse) + assert resp1.message.result == {"first": True} + # Let the first handler task run past its pop entirely. + await anyio.wait_all_tasks_blocked() + assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] + # The surviving entry must still be cancellable by the peer. + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 7} + ) + ) + ) + resp2 = await s2c_recv.receive() + assert isinstance(resp2, SessionMessage) + assert isinstance(resp2.message, JSONRPCError) + assert resp2.message.error == ErrorData(code=0, message="Request cancelled") + assert second_exited.is_set() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_plan_outbound_with_related_request_id_drops_resumption_hints_but_keeps_abandon_cancel( caplog: pytest.LogCaptureFixture, ): - """`SessionMessage.metadata` carries one object; `related_request_id` wins - and resumption hints are dropped observably (debug log).""" + """`related_request_id` wins the metadata slot; dropped hints don't suppress the abandon cancel.""" with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): - md = _outbound_metadata(7, {"resumption_token": "abc"}) - assert isinstance(md, ServerMessageMetadata) - assert md.related_request_id == 7 + plan = _plan_outbound(7, {"resumption_token": "abc"}) + assert isinstance(plan.metadata, ServerMessageMetadata) + assert plan.metadata.related_request_id == 7 + assert plan.cancel_on_abandon is True assert "dropping resumption hints" in caplog.text caplog.clear() with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): - md = _outbound_metadata(7, {"timeout": 1.0}) - assert isinstance(md, ServerMessageMetadata) + plan = _plan_outbound(7, {"timeout": 1.0}) + assert isinstance(plan.metadata, ServerMessageMetadata) assert "dropping resumption hints" not in caplog.text @pytest.mark.anyio async def test_server_middleware_observes_cancelled_notification(): - """End-to-end over the JSON-RPC path: `Server.middleware` wraps every inbound - notification, including `notifications/cancelled` (the dispatcher applies - the cancellation itself, then forwards the notification).""" + """`Server.middleware` wraps every inbound notification, including `notifications/cancelled` + (the dispatcher applies the cancellation itself, then forwards the notification).""" handler_started = anyio.Event() cancel_observed = anyio.Event() observed: list[tuple[str, dict[str, Any]]] = [] diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py deleted file mode 100644 index 38f36d82cc..0000000000 --- a/tests/shared/test_session.py +++ /dev/null @@ -1,447 +0,0 @@ -import anyio -import pytest - -from mcp import Client, types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.shared.exceptions import MCPError -from mcp.shared.memory import create_client_server_memory_streams -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - METHOD_NOT_FOUND, - PARSE_ERROR, - CancelledNotification, - CancelledNotificationParams, - ClientResult, - EmptyResult, - ErrorData, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - ServerNotification, - ServerRequest, -) - - -@pytest.mark.anyio -async def test_request_cancellation(): - """Test that requests can be cancelled while in-flight.""" - ev_tool_called = anyio.Event() - ev_cancelled = anyio.Event() - request_id = None - - # Create a server with a slow tool - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - nonlocal request_id, ev_tool_called - if params.name == "slow_tool": - request_id = ctx.request_id - ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return types.CallToolResult(content=[]) # pragma: no cover - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - raise NotImplementedError - - server = Server( - name="TestSessionServer", - on_call_tool=handle_call_tool, - on_list_tools=handle_list_tools, - ) - - async def make_request(client: Client): - nonlocal ev_cancelled - try: - await client.session.send_request( - types.CallToolRequest( - params=types.CallToolRequestParams(name="slow_tool", arguments={}), - ), - types.CallToolResult, - ) - pytest.fail("Request should have been cancelled") # pragma: no cover - except MCPError as e: - # Expected - request was cancelled - assert "Request cancelled" in str(e) - ev_cancelled.set() - - async with Client(server) as client: - async with anyio.create_task_group() as tg: # pragma: no branch - tg.start_soon(make_request, client) - - # Wait for the request to be in-flight - with anyio.fail_after(1): # Timeout after 1 second - await ev_tool_called.wait() - - # Send cancellation notification - assert request_id is not None - await client.session.send_notification( - CancelledNotification(params=CancelledNotificationParams(request_id=request_id)) - ) - - # Give cancellation time to process - with anyio.fail_after(1): # pragma: no branch - await ev_cancelled.wait() - - -@pytest.mark.anyio -async def test_response_id_type_mismatch_string_to_int(): - """Test that responses with string IDs are correctly matched to requests sent with - integer IDs. - - This handles the case where a server returns "id": "0" (string) but the client - sent "id": 0 (integer). Without ID type normalization, this would cause a timeout. - """ - ev_response_received = anyio.Event() - result_holder: list[types.EmptyResult] = [] - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with a string ID instead of integer.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - root = message.message - assert isinstance(root, JSONRPCRequest) - # Get the original request ID (which is an integer) - request_id = root.id - assert isinstance(request_id, int), f"Expected int, got {type(request_id)}" - - # Respond with the ID as a string (simulating a buggy server) - response = JSONRPCResponse( - jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch - result={}, - ) - await server_write.send(SessionMessage(message=response)) - - async def make_request(client_session: ClientSession): - nonlocal result_holder - # Send a ping request (uses integer ID internally) - result = await client_session.send_ping() - result_holder.append(result) - ev_response_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_response_received.wait() - - assert len(result_holder) == 1 - assert isinstance(result_holder[0], EmptyResult) - - -@pytest.mark.anyio -async def test_error_response_id_type_mismatch_string_to_int(): - """Test that error responses with string IDs are correctly matched to requests - sent with integer IDs. - - This handles the case where a server returns an error with "id": "0" (string) - but the client sent "id": 0 (integer). - """ - ev_error_received = anyio.Event() - error_holder: list[MCPError | Exception] = [] - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with an error using a string ID.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - root = message.message - assert isinstance(root, JSONRPCRequest) - request_id = root.id - assert isinstance(request_id, int) - - # Respond with an error, using the ID as a string - error_response = JSONRPCError( - jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch - error=ErrorData(code=-32600, message="Test error"), - ) - await server_write.send(SessionMessage(message=error_response)) - - async def make_request(client_session: ClientSession): - nonlocal error_holder - try: - await client_session.send_ping() - pytest.fail("Expected MCPError to be raised") # pragma: no cover - except MCPError as e: - error_holder.append(e) - ev_error_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - - assert len(error_holder) == 1 - assert "Test error" in str(error_holder[0]) - - -@pytest.mark.anyio -async def test_response_id_non_numeric_string_no_match(): - """Test that responses with non-numeric string IDs don't incorrectly match - integer request IDs. - - If a server returns "id": "abc" (non-numeric string), it should not match - a request sent with "id": 0 (integer). - """ - ev_timeout = anyio.Event() - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with a non-numeric string ID.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - - # Respond with a non-numeric string ID (should not match) - response = JSONRPCResponse( - jsonrpc="2.0", - id="not_a_number", # Non-numeric string - result={}, - ) - await server_write.send(SessionMessage(message=response)) - - async def make_request(client_session: ClientSession): - try: - # Use a short timeout since we expect this to fail - await client_session.send_request( - types.PingRequest(), - types.EmptyResult, - request_read_timeout_seconds=0.5, - ) - pytest.fail("Expected timeout") # pragma: no cover - except MCPError as e: - assert "Timed out" in str(e) - ev_timeout.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_timeout.wait() - - -@pytest.mark.anyio -async def test_connection_closed(): - """Test that pending requests are cancelled when the connection is closed remotely.""" - - ev_closed = anyio.Event() - ev_response = anyio.Event() - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def make_request(client_session: ClientSession): - """Send a request in a separate task""" - nonlocal ev_response - try: - # any request will do - await client_session.initialize() - pytest.fail("Request should have errored") # pragma: no cover - except MCPError as e: - # Expected - request errored - assert "Connection closed" in str(e) - ev_response.set() - - async def mock_server(): - """Wait for a request, then close the connection""" - nonlocal ev_closed - # Wait for a request - await server_read.receive() - # Close the connection, as if the server exited - server_write.close() - server_read.close() - ev_closed.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(make_request, client_session) - tg.start_soon(mock_server) - - with anyio.fail_after(1): - await ev_closed.wait() - with anyio.fail_after(1): # pragma: no branch - await ev_response.wait() - - -@pytest.mark.anyio -async def test_null_id_error_surfaced_via_message_handler(): - """Test that a JSONRPCError with id=None is surfaced to the message handler. - - Per JSON-RPC 2.0, error responses use id=null when the request id could not - be determined (e.g., parse errors). These cannot be correlated to any pending - request, so they are forwarded to the message handler as MCPError. - """ - ev_error_received = anyio.Event() - error_holder: list[MCPError] = [] - - async def capture_errors( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - assert isinstance(message, MCPError) - error_holder.append(message) - ev_error_received.set() - - sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - _server_read, server_write = server_streams - - async def mock_server(): - """Send a null-id error (simulating a parse error).""" - error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) - await server_write.send(SessionMessage(message=error_response)) - - async with ( - anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - message_handler=capture_errors, - ) as _client_session, - ): - tg.start_soon(mock_server) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - - assert len(error_holder) == 1 - assert error_holder[0].error == sent_error - - -@pytest.mark.anyio -async def test_null_id_error_does_not_affect_pending_request(): - """Test that a null-id error doesn't interfere with an in-flight request. - - When a null-id error arrives while a request is pending, the error should - go to the message handler and the pending request should still complete - normally with its own response. - """ - ev_error_received = anyio.Event() - ev_response_received = anyio.Event() - error_holder: list[MCPError] = [] - result_holder: list[EmptyResult] = [] - - async def capture_errors( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - assert isinstance(message, MCPError) - error_holder.append(message) - ev_error_received.set() - - sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Read a request, inject a null-id error, then respond normally.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - assert isinstance(message.message, JSONRPCRequest) - request_id = message.message.id - - # First, send a null-id error (should go to message handler) - await server_write.send(SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=sent_error))) - - # Then, respond normally to the pending request - await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) - - async def make_request(client_session: ClientSession): - result = await client_session.send_ping() - result_holder.append(result) - ev_response_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - message_handler=capture_errors, - ) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - await ev_response_received.wait() - - # Null-id error reached the message handler - assert len(error_holder) == 1 - assert error_holder[0].error == sent_error - - # Pending request completed successfully - assert len(result_holder) == 1 - assert isinstance(result_holder[0], EmptyResult) - - -@pytest.mark.anyio -async def test_receive_loop_answers_unknown_request_method_with_method_not_found(): - """A peer request whose method is not in the receive union gets -32601 - (METHOD_NOT_FOUND) on the wire, not a validation failure (-32602).""" - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async with ClientSession(read_stream=client_read, write_stream=client_write): - await server_write.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) - with anyio.fail_after(5): # pragma: no branch - out = await server_read.receive() - - assert isinstance(out, SessionMessage) - assert isinstance(out.message, JSONRPCError) - assert out.message.id == 7 - assert out.message.error == ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown") - - -@pytest.mark.anyio -async def test_receive_loop_drops_unknown_notification_method_without_response(): - """An unknown notification method is dropped silently: JSON-RPC forbids - responses to notifications, and the receive loop keeps serving.""" - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async with ClientSession(read_stream=client_read, write_stream=client_write): - await server_write.send(SessionMessage(message=JSONRPCNotification(jsonrpc="2.0", method="x/unknown"))) - # The next wire output must be the answer to this follow-up ping, - # proving the notification produced no response and the loop survived. - await server_write.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) - with anyio.fail_after(5): # pragma: no branch - out = await server_read.receive() - - assert isinstance(out, SessionMessage) - assert isinstance(out.message, JSONRPCResponse) - assert out.message.id == 1 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7db7e68fb2..02976656e8 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -18,16 +18,20 @@ import anyio import httpx import pytest +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount +from starlette.types import Message, Scope from mcp import MCPError, types +from mcp.client import ClientRequestContext from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( + GET_STREAM_KEY, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, @@ -41,7 +45,6 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer -from mcp.shared._context import RequestContext from mcp.shared._context_streams import create_context_streams from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -1232,7 +1235,7 @@ async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params @@ -2224,3 +2227,115 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_standalone_stream_teardown_mid_listen_is_not_an_error(caplog: pytest.LogCaptureFixture) -> None: + """Standalone-stream teardown while the writer is parked in receive() logs no error (SDK-defined).""" + session_manager = StreamableHTTPSessionManager( + app=_create_server(), + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + notified = anyio.Event() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + # Only the standalone-stream notification is teed to the handler here. + assert isinstance(message, types.ResourceUpdatedNotification) + notified.set() + + async with session_manager.run(): + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + # A notification with no related request rides the GET stream, proving the writer is live. + await session.call_tool("test_tool_with_standalone_notification", {}) + with anyio.fail_after(5): + await notified.wait() + # Tear the standalone stream down while the writer is parked on it. + (transport,) = session_manager._server_instances.values() # pyright: ignore[reportPrivateUsage] + await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] + assert "Error in standalone SSE writer" not in caplog.text + + +@pytest.mark.anyio +async def test_standalone_stream_teardown_between_dequeues_is_not_an_error( + caplog: pytest.LogCaptureFixture, +) -> None: + """Teardown landing while the standalone writer is between dequeues logs no error. + + SDK-defined: after teardown the writer's next dequeue hits its own closed stream — expected + disconnect noise. The public surface cannot force this window (the in-process client consumes + SSE without backpressure), so the test drives the transport's ASGI entry point with a gated `send`. + """ + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + # The GET handler only checks that a read-stream writer exists; it is never written to. + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + transport._read_stream_writer = read_stream_writer # pyright: ignore[reportPrivateUsage] + + stream_registered = anyio.Event() + + class SignalingStreams( + dict[types.RequestId, tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]]] + ): + # Only the GET handler inserts here, so any insert is the standalone stream registration. + def __setitem__( + self, + key: types.RequestId, + value: tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]], + ) -> None: + super().__setitem__(key, value) + stream_registered.set() + + transport._request_streams = SignalingStreams() # pyright: ignore[reportPrivateUsage] + + gate = anyio.Event() + sent: list[Message] = [] + + async def asgi_send(message: Message) -> None: + sent.append(message) + await gate.wait() + + # Never delivers anything, parking the response's disconnect listener. + disconnect_send, disconnect_receive = anyio.create_memory_object_stream[Message](0) + + async def asgi_receive() -> Message: + return await disconnect_receive.receive() + + scope: Scope = { + "type": "http", + "method": "GET", + "path": "/mcp", + "query_string": b"", + "headers": [(b"accept", b"text/event-stream")], + } + notification = types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + + async with read_stream_writer, read_stream, disconnect_send, disconnect_receive: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(transport.handle_request, scope, asgi_receive, asgi_send) + await stream_registered.wait() + standalone_send = transport._request_streams[GET_STREAM_KEY][0] # pyright: ignore[reportPrivateUsage] + # Zero-buffer rendezvous: once send() returns, the writer has dequeued the event + # and is blocked forwarding it past the closed gate — the between-dequeues window. + await standalone_send.send(EventMessage(notification)) + await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] + # Unblock the response; the writer's next dequeue hits its closed stream. + gate.set() + + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 200 + body_chunks = [message for message in sent if message["type"] == "http.response.body"] + assert b"notifications/initialized" in body_chunks[0]["body"] + assert body_chunks[-1] == {"type": "http.response.body", "body": b"", "more_body": False} + assert "Error in standalone SSE writer" not in caplog.text + assert "Error in standalone SSE response" not in caplog.text