diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index ec38f35838..5be19ac726 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -5,15 +5,80 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import Generator +import gc +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager +from typing import Any +import httpx import pytest +from sse_starlette.sse import AppStatus +from starlette.applications import Starlette +from starlette.routing import Mount +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from tests.test_helpers import wait_for_server +from mcp.server import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import CallToolResult, TextContent, Tool +from tests.interaction.transports import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown when +# run in process; the old subprocess harness never observed them. The interaction suite registers +# the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# themselves passing in isolated runs. Markers are item-scoped, so the autouse +# `_collect_leaked_streams` fixture below garbage-collects each test's leaks inside its own +# teardown, where these filters apply; without it, leaks GC'd at session cleanup escape the +# scoped ignores. The filters are scoped to anyio's MemoryObject*Stream leak signature so an +# unrelated leak still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + + +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response (and `json_response=False` below means every request + in this module is served as one). sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -35,28 +100,12 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - # Import inside the function since this runs in a separate process - from collections.abc import AsyncGenerator - from contextlib import asynccontextmanager - from typing import Any - - import uvicorn - from starlette.applications import Starlette - from starlette.routing import Mount - - import mcp.types as types - from mcp.server import Server - from mcp.server.streamable_http_manager import StreamableHTTPSessionManager - from mcp.types import TextContent, Tool - - # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") +def make_unicode_server() -> Server[object, object]: + """The Unicode echo server: tool and prompt contents that exercise non-ASCII round trips.""" + server: Server[object, object] = Server(name="unicode_test_server") @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" + async def handle_list_tools() -> list[Tool]: return [ Tool( name="echo_unicode", @@ -72,22 +121,12 @@ async def list_tools() -> list[Tool]: ] @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] - else: - raise ValueError(f"Unknown tool: {name}") + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + assert name == "echo_unicode" + return CallToolResult(content=[TextContent(type="text", text=f"Echo: {arguments['text']}")]) @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" + async def handle_list_prompts() -> list[types.Prompt]: return [ types.Prompt( name="unicode_prompt", @@ -97,137 +136,90 @@ async def list_prompts() -> list[types.Prompt]: ] @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] - ) - raise ValueError(f"Unknown prompt: {name}") - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing - ) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - # Create an ASGI application - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lifespan, - ) - - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: + assert name == "unicode_prompt" + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), + ) + ] + ) + + return server + + +@asynccontextmanager +async def unicode_session() -> AsyncIterator[ClientSession]: + """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the + Unicode test server, entirely in process.""" + # SSE response mode, so Unicode rides the SSE event encoding rather than a plain JSON body. + session_manager = StreamableHTTPSessionManager(app=make_unicode_server(), json_response=False) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + + async with ( + session_manager.run(), + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects + # the bare /mcp path to /mcp/. + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True + ) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as ( + read_stream, + write_stream, + _get_session_id, + ), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call() -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 + async with unicode_session() as session: + # Test 1: List tools (server→client Unicode in descriptions) + tools = await session.list_tools() + assert len(tools.tools) == 1 - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Check Unicode in tool descriptions + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Verify server correctly received and echoed back Unicode + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts() -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + async with unicode_session() as session: + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index ba58da7321..4a93f998b9 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,4 +1,5 @@ import errno +import gc import os import shutil import sys @@ -10,7 +11,12 @@ import pytest from mcp.client.session import ClientSession -from mcp.client.stdio import StdioServerParameters, _create_platform_compatible_process, stdio_client +from mcp.client.stdio import ( + StdioServerParameters, + _create_platform_compatible_process, + _terminate_process_tree, + stdio_client, +) from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -219,6 +225,46 @@ def sigint_handler(signum, frame): raise +async def _wait_for_first_write(path: str) -> None: + """Poll until the file at *path* exists and has grown beyond its initial empty state. + + The marker files below are created empty before the writer is spawned, so any + growth proves the writing process booted and reached its write loop. Polling + replaces fixed startup sleeps, which flake on loaded machines where interpreter + startup can exceed any fixed window. Bounded so a writer that never starts + fails the test instead of hanging it. + """ + with anyio.fail_after(15): + while not os.path.exists(path) or os.path.getsize(path) == 0: + await anyio.sleep(0.05) + + +async def _wait_for_writes_to_stop(path: str) -> None: + """Poll until the file at *path* stops growing. + + Returns once the size is unchanged across three successive 0.3 second gaps + (each three times the writers' 0.1 second write interval), so a writer that + is merely starved of CPU for a single gap is not mistaken for a terminated + one. Any observed growth resets the consecutive-stable counter. The sentinel + forces at least one non-stable iteration before counting starts. If the file + never stops growing, the timeout fails the test: a writer that survives + _terminate_process_tree is a genuine cleanup failure that must not be masked. + """ + last_size = -1 + stable_pairs = 0 + with anyio.fail_after(15): + while True: + current_size = os.path.getsize(path) + if current_size == last_size: + stable_pairs += 1 + else: + stable_pairs = 0 + last_size = current_size + if stable_pairs == 3: + return + await anyio.sleep(0.3) + + class TestChildProcessCleanup: """ Tests for child process cleanup functionality using _terminate_process_tree. @@ -259,84 +305,67 @@ async def test_basic_child_process_cleanup(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: parent_marker = f.name - try: - # Parent script that spawns a child process - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Mark that parent started - with open({escape_path_for_python(parent_marker)}, 'w') as f: - f.write('parent started\\n') - - # Child script that writes continuously - child_script = f''' - import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"{time.time()}") - f.flush() - time.sleep(0.1) - ''' - - # Start the child process - child = subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent just sleeps + # Parent script that spawns a child process + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import os + + # Mark that parent started + with open({escape_path_for_python(parent_marker)}, 'w') as f: + f.write('parent started\\n') + + # Child script that writes continuously + child_script = f''' + import time + with open({escape_path_for_python(marker_file)}, 'a') as f: while True: + f.write(f"{time.time()}") + f.flush() time.sleep(0.1) - """ - ) + ''' - print("\nStarting child process termination test...") + # Start the child process + child = subprocess.Popen([sys.executable, '-c', child_script]) - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Parent just sleeps + while True: + time.sleep(0.1) + """ + ) - # Wait for processes to start - await anyio.sleep(0.5) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False - # Verify parent started - assert os.path.exists(parent_marker), "Parent process didn't start" + try: + # Wait for the parent to start and the child to reach its write loop + await _wait_for_first_write(parent_marker) + assert os.path.getsize(parent_marker) > 0, "Parent process didn't start" - # Verify child is writing - if os.path.exists(marker_file): # pragma: no branch - initial_size = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size_after_wait = os.path.getsize(marker_file) - assert size_after_wait > initial_size, "Child process should be writing" - print(f"Child is writing (file grew from {initial_size} to {size_after_wait} bytes)") + await _wait_for_first_write(marker_file) + assert os.path.getsize(marker_file) > 0, "Child process should be writing" # Terminate using our function - print("Terminating process and children...") - from mcp.client.stdio import _terminate_process_tree - await _terminate_process_tree(proc) + tree_killed = True - # Verify processes stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size_after_cleanup = os.path.getsize(marker_file) - await anyio.sleep(0.5) - final_size = os.path.getsize(marker_file) - - print(f"After cleanup: file size {size_after_cleanup} -> {final_size}") - assert final_size == size_after_cleanup, ( - f"Child process still running! File grew by {final_size - size_after_cleanup} bytes" - ) - - print("SUCCESS: Child process was properly terminated") - + # Verify the child stopped writing; a survivor times out and fails the test + await _wait_for_writes_to_stop(marker_file) finally: + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up files for f in [marker_file, parent_marker]: try: os.unlink(f) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") @@ -353,88 +382,79 @@ async def test_nested_process_tree(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f3: grandchild_file = f3.name - try: - # Simple nested process tree test - # We create parent -> child -> grandchild, each writing to a file - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import os - - # Child will spawn grandchild and write to child file - child_script = f'''import subprocess - import sys - import time - - # Grandchild just writes to file - grandchild_script = \"\"\"import time - with open({escape_path_for_python(grandchild_file)}, 'a') as f: - while True: - f.write(f"gc {{time.time()}}") - f.flush() - time.sleep(0.1)\"\"\" - - # Spawn grandchild - subprocess.Popen([sys.executable, '-c', grandchild_script]) - - # Child writes to its file - with open({escape_path_for_python(child_file)}, 'a') as f: - while True: - f.write(f"c {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Spawn child process - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent writes to its file - with open({escape_path_for_python(parent_file)}, 'a') as f: - while True: - f.write(f"p {time.time()}") - f.flush() - time.sleep(0.1) - """ - ) + # Simple nested process tree test + # We create parent -> child -> grandchild, each writing to a file + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import os + + # Child will spawn grandchild and write to child file + child_script = f'''import subprocess + import sys + import time + + # Grandchild just writes to file + grandchild_script = \"\"\"import time + with open({escape_path_for_python(grandchild_file)}, 'a') as f: + while True: + f.write(f"gc {{time.time()}}") + f.flush() + time.sleep(0.1)\"\"\" - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Spawn grandchild + subprocess.Popen([sys.executable, '-c', grandchild_script]) - # Let all processes start - await anyio.sleep(1.0) + # Child writes to its file + with open({escape_path_for_python(child_file)}, 'a') as f: + while True: + f.write(f"c {time.time()}") + f.flush() + time.sleep(0.1)''' - # Verify all are writing - for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - initial_size = os.path.getsize(file_path) - await anyio.sleep(0.3) - new_size = os.path.getsize(file_path) - assert new_size > initial_size, f"{name} process should be writing" + # Spawn child process + subprocess.Popen([sys.executable, '-c', child_script]) - # Terminate the whole tree - from mcp.client.stdio import _terminate_process_tree + # Parent writes to its file + with open({escape_path_for_python(parent_file)}, 'a') as f: + while True: + f.write(f"p {time.time()}") + f.flush() + time.sleep(0.1) + """ + ) - await _terminate_process_tree(proc) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False - # Verify all stopped - await anyio.sleep(0.5) + try: + # Wait for every level of the tree to reach its write loop for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): # pragma: no branch - size1 = os.path.getsize(file_path) - await anyio.sleep(0.3) - size2 = os.path.getsize(file_path) - assert size1 == size2, f"{name} still writing after cleanup!" + await _wait_for_first_write(file_path) + assert os.path.getsize(file_path) > 0, f"{name} process should be writing" - print("SUCCESS: All processes in tree terminated") + # Terminate the whole tree + await _terminate_process_tree(proc) + tree_killed = True + # Verify every level stopped writing; a survivor times out and fails the test + for file_path in (parent_file, child_file, grandchild_file): + await _wait_for_writes_to_stop(file_path) finally: + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up all marker files for f in [parent_file, child_file, grandchild_file]: try: os.unlink(f) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") @@ -448,72 +468,63 @@ async def test_early_parent_exit(self): with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: marker_file = f.name - try: - # Parent that spawns child and waits briefly - parent_script = textwrap.dedent( - f""" - import subprocess - import sys - import time - import signal - - # Child that continues running - child_script = f'''import time - with open({escape_path_for_python(marker_file)}, 'a') as f: - while True: - f.write(f"child {time.time()}") - f.flush() - time.sleep(0.1)''' - - # Start child in same process group - subprocess.Popen([sys.executable, '-c', child_script]) - - # Parent waits a bit then exits on SIGTERM - def handle_term(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGTERM, handle_term) - - # Wait + # Parent that spawns child and waits briefly + parent_script = textwrap.dedent( + f""" + import subprocess + import sys + import time + import signal + + # Child that continues running + child_script = f'''import time + with open({escape_path_for_python(marker_file)}, 'a') as f: while True: - time.sleep(0.1) - """ - ) + f.write(f"child {time.time()}") + f.flush() + time.sleep(0.1)''' - # Start the parent process - proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + # Start child in same process group + subprocess.Popen([sys.executable, '-c', child_script]) - # Let child start writing - await anyio.sleep(0.5) + # Parent waits a bit then exits on SIGTERM + def handle_term(sig, frame): + sys.exit(0) - # Verify child is writing - if os.path.exists(marker_file): # pragma: no cover - size1 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size2 = os.path.getsize(marker_file) - assert size2 > size1, "Child should be writing" + signal.signal(signal.SIGTERM, handle_term) - # Terminate - this will kill the process group even if parent exits first - from mcp.client.stdio import _terminate_process_tree + # Wait + while True: + time.sleep(0.1) + """ + ) - await _terminate_process_tree(proc) + # Start the parent process + proc = await _create_platform_compatible_process(sys.executable, ["-c", parent_script]) + tree_killed = False - # Verify child stopped - await anyio.sleep(0.5) - if os.path.exists(marker_file): # pragma: no branch - size3 = os.path.getsize(marker_file) - await anyio.sleep(0.3) - size4 = os.path.getsize(marker_file) - assert size3 == size4, "Child should be terminated" + try: + # Wait for the child to reach its write loop + await _wait_for_first_write(marker_file) + assert os.path.getsize(marker_file) > 0, "Child should be writing" - print("SUCCESS: Child terminated even with parent exit during cleanup") + # Terminate - this will kill the process group even if parent exits first + await _terminate_process_tree(proc) + tree_killed = True + # Verify the child stopped writing; a survivor times out and fails the test + await _wait_for_writes_to_stop(marker_file) finally: + if not tree_killed: # pragma: no cover - cleanup only reached when the test failed mid-flight + await _terminate_process_tree(proc) # Clean up marker file try: os.unlink(marker_file) except OSError: # pragma: no cover pass + # Collect subprocess transports now, while this test's warning filters + # are active, so GC-time ResourceWarnings cannot hit a later test + gc.collect() @pytest.mark.anyio diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index 597a87082c..92119cb1aa 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -23,6 +23,22 @@ def pytest_configure(config: pytest.Config) -> None: "filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning" ) config.addinivalue_line("filterwarnings", "ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning") + # The trio-mockclock leg of the session-level timeout test (test_timeouts.py) is the suite's + # only test on the trio backend. v1's streamable-HTTP client abandons its httpx/httpx-sse + # response generators when the session task group is cancelled at teardown; asyncio finalizes + # abandoned async generators silently at loop shutdown, but trio's finalizer warns about each + # one (`Async generator ... was garbage collected before it had been exhausted`). Abandoning + # `EventSource.aiter_sse` abandons the whole generator chain nested under it (`aiter_lines` -> + # `aiter_text` -> `aiter_bytes` -> `aiter_raw`), and which links the finalizer reports depends + # on GC timing and Python version. The fixes live in `src/` on `main` and are out of scope for + # this tests-only backport. The filters are scoped to the httpx/httpx-sse generator signatures + # (every generator in that chain lives on `Response` or `EventSource`) so an unrelated leak + # still fails the suite. + config.addinivalue_line("filterwarnings", "ignore:Async generator 'httpx:ResourceWarning") + config.addinivalue_line( + "filterwarnings", + "ignore:.*async_generator object (Response|EventSource).aiter_:pytest.PytestUnraisableExceptionWarning", + ) _FACTORIES: dict[str, Connect] = { diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index 2a3b885a6d..c80e98405a 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -1,8 +1,9 @@ """Request timeouts against the low-level Server, driven through the public client API. The handler blocks on an event that is never set, so the awaited response can never arrive and -any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore -set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +any positive timeout fires deterministically on the next event-loop pass. Per-request timeouts are +set to an effectively-zero duration; the session-level test runs on trio's virtual clock instead +(see the comment there). Either way the tests add no wall-clock time to the suite. (Zero itself cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) """ @@ -12,6 +13,7 @@ import anyio import pytest from inline_snapshot import snapshot +from trio.testing import MockClock from mcp import McpError, types from mcp.server.lowlevel import Server @@ -82,7 +84,19 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB assert result == snapshot(CallToolResult(content=[TextContent(type="text", text="still alive")])) +# A session-level timeout cannot use the effectively-zero pattern above: it also governs the +# initialize handshake, which must complete before the blocked tool call can wait the timeout +# out in full. Any real-clock margin is a bet against CI scheduler stalls (a 50ms value lost +# that bet in CI; the in-process handshake tail reaches ~190ms on a loaded windows runner), so +# this test runs on trio's virtual clock instead. With autojump, time advances only when every +# task is blocked: the handshake always has a runnable task and therefore cannot time out no +# matter how slow the runner, and once the tool call blocks on the never-answered request the +# run goes idle and the clock jumps straight to the deadline — deterministic, with no real wait. @requirement("protocol:timeout:session-default") +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) async def test_session_level_timeout_applies_to_every_request(connect: Connect) -> None: """A read timeout configured on the client applies to requests that do not set their own.""" server: Server[Any] = Server("blocker") @@ -93,12 +107,6 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB await anyio.Event().wait() # blocks until the session is torn down raise NotImplementedError # unreachable - # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the - # per-request timeouts: a session-level timeout also governs the initialize handshake, so the - # value must be long enough for the in-process handshake to complete before the blocked tool - # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual - # latency; lowering it only erodes the margin against CI scheduler jitter without saving - # anything perceptible. async with connect(server, read_timeout_seconds=timedelta(seconds=0.05)) as client: with pytest.raises(McpError) as exc_info: await client.call_tool("block", {}) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 7821c1eed5..3abb7bf048 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -27,6 +27,8 @@ _HARNESS_SELF_TESTS = { "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_a_second_response_after_the_first_completes_is_invisible_to_the_client", + "tests.interaction.transports.test_bridge.test_body_chunks_after_the_final_chunk_are_ignored", "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py index e69de29bb2..b5bbb633c2 100644 --- a/tests/interaction/transports/__init__.py +++ b/tests/interaction/transports/__init__.py @@ -0,0 +1,9 @@ +"""Transport-specific interaction tests, and the in-process streaming bridge they are built on. + +`StreamingASGITransport` is re-exported here as the sanctioned import point for test code +outside this suite (the bridge module itself is suite-private). +""" + +from tests.interaction.transports._bridge import StreamingASGITransport + +__all__ = ["StreamingASGITransport"] diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py index f78c6d14b5..17c2432ae1 100644 --- a/tests/interaction/transports/_bridge.py +++ b/tests/interaction/transports/_bridge.py @@ -12,6 +12,8 @@ - The request body is buffered before the application is invoked (MCP requests are small JSON documents); the response streams chunk by chunk. +- The response ends at the first `http.response.body` whose `more_body` is falsy; anything the + application sends after that is ignored, exactly as a real server's client never observes it. - Closing the response — or the whole client — delivers `http.disconnect` to the application, exactly as a real server sees when its peer goes away. - An exception the application raises before sending `http.response.start` fails the originating @@ -47,9 +49,14 @@ def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected self._chunks = chunks self._client_disconnected = client_disconnected - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self._chunks: - yield chunk + def __aiter__(self) -> AsyncIterator[bytes]: + # Delegate to the memory stream's own async iterator instead of wrapping it in an async + # generator. httpx abandons the iterator without closing it when a streamed response is + # closed mid-stream; trio's asyncgen finalizer warns about abandoned generators (asyncio + # finalizes them silently at loop shutdown), which would fail the suite's one trio-backend + # test. The memory stream is a plain async iterator with the same EndOfStream -> + # StopAsyncIteration semantics and is not tracked by that machinery. + return self._chunks async def aclose(self) -> None: self._client_disconnected.set() @@ -111,6 +118,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: request_delivered = False client_disconnected = anyio.Event() response_started = anyio.Event() + response_complete = False response_status = 0 response_headers: list[tuple[bytes, bytes]] = [] application_error: Exception | None = None @@ -125,7 +133,14 @@ async def receive_request() -> Message: return {"type": "http.disconnect"} async def send_response(message: Message) -> None: - nonlocal response_status, response_headers + nonlocal response_complete, response_status, response_headers + if response_complete: + # The response ended with the final body chunk below; a real server's client never + # observes anything sent after that, so drop it. Starlette's `request_response` + # makes this path real: an endpoint whose sub-application already sent a complete + # rejection response (the legacy SSE transport's request validation) still returns + # a `Response`, which sends a trailing second start/body pair. + return if message["type"] == "http.response.start": response_status = message["status"] response_headers = list(message.get("headers", [])) @@ -136,6 +151,7 @@ async def send_response(message: Message) -> None: if body: await chunk_writer.send(body) if not message.get("more_body", False): + response_complete = True await chunk_writer.aclose() async def run_application() -> None: diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py index 7420b9d902..d51fbd88d2 100644 --- a/tests/interaction/transports/test_bridge.py +++ b/tests/interaction/transports/test_bridge.py @@ -40,6 +40,53 @@ async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: assert chunks == [b"first", b"second"] +async def test_a_second_response_after_the_first_completes_is_invisible_to_the_client() -> None: + """Only the first complete response reaches the client; a trailing start/body pair is dropped. + + Starlette's `request_response` produces exactly this sequence when an endpoint's + sub-application has already sent a complete rejection response (the legacy SSE transport's + request validation): the endpoint still returns a `Response`, which sends a second response. + """ + + async def double_responding_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 421, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"rejected", "more_body": False}) + await send({"type": "http.response.start", "status": 200, "headers": [(b"x-late", b"yes")]}) + await send({"type": "http.response.body", "body": b"too late", "more_body": False}) + + transport = StreamingASGITransport(double_responding_app) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + response = await http.get("/double") + + assert response.status_code == 421 + assert response.text == "rejected" + assert "x-late" not in response.headers + + +async def test_body_chunks_after_the_final_chunk_are_ignored() -> None: + """Extra body chunks after `more_body: False` neither reach the client nor fail the application.""" + application_finished = anyio.Event() + + async def overflowing_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b"complete", "more_body": False}) + await send({"type": "http.response.body", "body": b"overflow", "more_body": True}) + application_finished.set() + + transport = StreamingASGITransport(overflowing_app) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + response = await http.get("/overflow") + with anyio.fail_after(5): + await application_finished.wait() + + assert response.status_code == 200 + assert response.text == "complete" + + async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: """A client that closes the response early is seen by the application as an http.disconnect.""" seen_after_request: list[Message] = [] diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 716a308a53..deb36e5f28 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,16 +1,13 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket from collections.abc import Iterator from typing import Any import anyio import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -22,12 +19,15 @@ from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + @pytest.fixture(autouse=True) def reset_sse_starlette_exit_event() -> Iterator[None]: @@ -46,275 +46,161 @@ def clear() -> None: clear() -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" - app = SecurityTestServer() +def sse_security_client(security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: + """An httpx client whose requests are served in process by an SSE app with the given settings.""" + server = Server(SERVER_NAME) sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): + async def handle_sse(request: Request) -> Response: try: - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: - await app.run(streams[0], streams[1], app.create_initialization_options()) + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await server.run(read, write, server.create_initialization_options()) except ValueError as e: - # Validation error was already handled inside connect_sse + # Validation error was already handled inside connect_sse, which sent the rejection + # response itself; its non-empty body checkpoints, so the test reads the rejection + # status before the trailing Response() below sends a second response start. logger.debug(f"SSE connection failed validation: {e}") return Response() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] - - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) + # The SSE GET runs until it observes a disconnect, so the bridge must let the application + # drain on close rather than cancelling it. + transport = StreamingASGITransport(app, cancel_on_close=False) + return httpx.AsyncClient(transport=transport, base_url=BASE_URL) @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +async def test_sse_security_default_settings() -> None: + """With default security settings (protection disabled), any Host and Origin connect.""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with sse_security_client() as client: + async with client.stream("GET", "/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): - """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list +async def test_sse_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): - """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins +async def test_sse_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + async with sse_security_client(security_settings) as client: + response = await client.get("/sse", headers={"Origin": "http://evil.com"}) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): - """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host +async def test_sse_security_post_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": "text/plain"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + fake_session_id = "12345678123456781234567812345678" - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" + async with sse_security_client(security_settings) as client: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() + response = await client.post(f"/messages/?session_id={fake_session_id}", content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): - """Test SSE with security disabled.""" +async def test_sse_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host still connects.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "evil.com"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): - """Test SSE with custom allowed hosts.""" +async def test_sse_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts connects; hosts outside the list are still rejected.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = {"Host": "custom.host"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host - assert response.status_code == 200 - # Test with non-allowed host - headers = {"Host": "evil.com"} + async with sse_security_client(settings) as client: + async with client.stream("GET", "/sse", headers={"Host": "custom.host"}) as response: + assert response.status_code == 200 - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + response = await client.get("/sse", headers={"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): - """Test SSE with wildcard port patterns.""" +async def test_sse_security_wildcard_ports() -> None: + """A `host:*` pattern accepts that host with any port, for Host and Origin alike.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - try: - # Test with various port numbers + async with sse_security_client(settings) as client: for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 - - headers = {"Origin": f"http://localhost:{test_port}"} - - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port - assert response.status_code == 200 + async with client.stream("GET", "/sse", headers={"Host": f"localhost:{test_port}"}) as response: + assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with client.stream("GET", "/sse", headers={"Origin": f"http://localhost:{test_port}"}) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): - """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host +async def test_sse_security_post_valid_content_type() -> None: + """Every application/json Content-Type variant passes validation (reaching the session lookup).""" security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: - async with httpx.AsyncClient() as client: - # Test with various valid content types - valid_content_types = [ - "application/json", - "application/json; charset=utf-8", - "application/json;charset=utf-8", - "APPLICATION/JSON", # Case insensitive - ] - - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) - fake_session_id = "12345678123456781234567812345678" - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", - headers={"Content-Type": content_type}, - json={"test": "data"}, - ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check - assert response.status_code == 404 - assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + # A well-formed session ID that no live session owns. + fake_session_id = "12345678123456781234567812345678" + + async with sse_security_client(security_settings) as client: + for content_type in valid_content_types: + response = await client.post( + f"/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # 404 proves the request passed the content-type check and reached the session lookup. + assert response.status_code == 404 + assert response.text == "Could not find session" def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 0ae07c43ad..e5b9710873 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,6 +1,7 @@ """Tests for StreamableHTTPSessionManager.""" import json +import logging from typing import Any from unittest.mock import AsyncMock, patch @@ -317,12 +318,33 @@ async def mock_receive(): assert error_data["error"]["message"] == "Session not found" +class _IdleTimeoutObserver(logging.Handler): + """Resolves `reaped` when the manager logs that a session's idle timeout fired.""" + + def __init__(self) -> None: + super().__init__() + self.reaped = anyio.Event() + + def emit(self, record: logging.LogRecord) -> None: + if "idle timeout" in record.getMessage(): + self.reaped.set() + + @pytest.mark.anyio -async def test_idle_session_is_reaped(): +async def test_idle_session_is_reaped(caplog: pytest.LogCaptureFixture, request: pytest.FixtureRequest): """After idle timeout fires, the session returns 404.""" app = Server("test-idle-reap") manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + # The reap is observed through the manager's own "idle timeout" log record: the manager pops + # the session synchronously after emitting it, before its next await, so a waiter woken by + # the record always finds the session gone. caplog.set_level enables INFO so it is created. + observer = _IdleTimeoutObserver() + manager_logger = logging.getLogger(streamable_http_manager.__name__) + manager_logger.addHandler(observer) + request.addfinalizer(lambda: manager_logger.removeHandler(observer)) + caplog.set_level(logging.INFO, logger=streamable_http_manager.__name__) + async with manager.run(): sent_messages: list[Message] = [] @@ -353,8 +375,10 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Wait for the 50ms idle timeout to fire and cleanup to complete - await anyio.sleep(0.1) + # Wait for the 50ms idle timeout to fire and the session to be unregistered. Re-requesting + # the session to poll for the 404 would push its idle deadline forward and keep it alive. + with anyio.fail_after(5): + await observer.reaped.wait() # Verify via public API: old session ID now returns 404 response_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index a637b1dce0..9f4117dff4 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,293 +1,180 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import logging -import multiprocessing -import socket -from collections.abc import AsyncGenerator +import gc +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.routing import Mount -from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport -logger = logging.getLogger(__name__) SERVER_NAME = "test_streamable_http_security_server" - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - async def on_list_tools(self) -> list[Tool]: - return [] - - -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" - app = SecurityTestServer() - - # Create session manager with security settings - session_manager = StreamableHTTPSessionManager( - app=app, - json_response=False, - stateless=False, - security_settings=security_settings, - ) - - # Create the ASGI handler - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - await session_manager.handle_request(scope, receive, send) - - # Create Starlette app with lifespan - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - routes = [ - Mount("/", app=handle_streamable_http), - ] - - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") - - -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +# The in-process app is mounted at this origin purely so URLs are well-formed and the default +# Host header is a localhost form; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown when +# run in process; the old subprocess harness never observed them. The interaction suite registers +# the same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# that complete the initialize handshake passing in isolated runs. The filters are scoped to +# anyio's MemoryObject*Stream leak signature so an unrelated leak still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + + +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + +@asynccontextmanager +async def streamable_http_security_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncIterator[httpx.AsyncClient]: + """Yield an httpx client served in process by a StreamableHTTP app with the given settings.""" + session_manager = StreamableHTTPSessionManager(app=Server(SERVER_NAME), security_settings=security_settings) + app = Starlette(routes=[Mount("/", app=session_manager.handle_request)]) + + async with session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as client: + yield client + + +def _base_headers() -> dict[str, str]: + """Headers every well-formed request carries, so each test varies only the header under test.""" + return {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} + + +def _initialize_body() -> dict[str, object]: + """A minimal initialize POST body; these tests assert header validation, not the handshake.""" + return {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): - """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_default_settings() -> None: + """With default security settings, a request with localhost headers is served.""" + async with streamable_http_security_client() as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers()) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): - """Test StreamableHTTP with invalid Host header.""" +async def test_streamable_http_security_invalid_host_header() -> None: + """A Host header outside allowed_hosts is rejected with 421.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): - """Test StreamableHTTP with invalid Origin header.""" +async def test_streamable_http_security_invalid_origin_header() -> None: + """An Origin header outside allowed_origins is rejected with 403.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.post( + "/", json=_initialize_body(), headers=_base_headers() | {"Origin": "http://evil.com"} + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): - """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() +async def test_streamable_http_security_invalid_content_type() -> None: + """A POST whose Content-Type is not application/json (or is missing) is rejected with 400.""" + async with streamable_http_security_client() as client: + response = await client.post("/", headers=_base_headers() | {"Content-Type": "text/plain"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + response = await client.post("/", headers={"Accept": "application/json, text/event-stream"}, content="test") + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): - """Test StreamableHTTP with security disabled.""" +async def test_streamable_http_security_disabled() -> None: + """With protection explicitly disabled, a disallowed Host is still served.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): - """Test StreamableHTTP with custom allowed hosts.""" +async def test_streamable_http_security_custom_allowed_hosts() -> None: + """A custom entry in allowed_hosts is served.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(settings) as client: + response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "custom.host"}) + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): - """Test StreamableHTTP GET request with security.""" +async def test_streamable_http_security_get_request() -> None: + """GET requests pass the same Host validation before any session handling.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: - # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + + async with streamable_http_security_client(security_settings) as client: + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) + # An allowed host passes security and fails on session validation instead. + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7604450f81..856606488e 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,18 +1,18 @@ +"""Tests for the SSE client and server transports, driven entirely in process.""" + +import gc import json -import multiprocessing -import socket -import time -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Iterable, Iterator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from pydantic import AnyUrl +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -23,188 +23,191 @@ from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client from mcp.server import Server +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import McpError from mcp.types import ( + CallToolResult, EmptyResult, ErrorData, Implementation, InitializeResult, JSONRPCResponse, - ReadResourceResult, ServerCapabilities, TextContent, TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_server_for_SSE" +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + +# v1's HTTP server transports leak a handful of anyio memory streams on teardown when run in +# process; the old subprocess harness never observed them. The interaction suite registers the +# same two scoped filters globally from tests/interaction/conftest.py (see the comment there), +# but they only take effect when that package's conftest is loaded; these markers keep the tests +# themselves passing in isolated runs. Markers are item-scoped, so the autouse +# `_collect_leaked_streams` fixture below garbage-collects each test's leaks inside its own +# teardown, where these filters apply; without it, leaks GC'd at session cleanup escape the +# scoped ignores. The filters are scoped to anyio's MemoryObject*Stream leak signature so an +# unrelated leak still fails the suite. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + + +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: AnyUrl) -> str | bytes: - if uri.scheme == "foobar": - return f"Read {uri.host}" - elif uri.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {uri.host}" - - raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ) - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixture in tests/interaction/conftest.py, which guards the + interaction suite the same way. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + + +def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: + """An httpx_client_factory for sse_client whose clients are served in process by `app`.""" + + def factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE GET runs until it observes a disconnect, so the bridge must let the + # application drain on close rather than cancelling it. follow_redirects matches + # create_mcp_http_client, the factory this one stands in for. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + + return factory + + +def make_test_server() -> Server[object, Request]: + """A server whose read_resource handler answers foobar:// URIs and 404s everything else.""" + server: Server[object, Request] = Server(SERVER_NAME) + + @server.read_resource() + async def handle_read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: + if uri.scheme == "foobar": + return [ReadResourceContents(content=f"Read {uri.host}", mime_type="text/plain")] + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) + + return server + + +def make_app(server: Server[Any, Any]) -> Starlette: + """Mount `server` on a Starlette app exposing the SSE transport at /sse and /messages/.""" + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the transport security behaviour itself is pinned by + # tests/server/test_sse_security.py. + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) + async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) +def make_server_app() -> Starlette: + return make_app(make_test_server()) -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +@pytest.mark.anyio +async def test_raw_sse_connection() -> None: + """The SSE GET responds 200 with an event-stream content type, announcing the session + endpoint as its first event.""" + http_client = httpx.AsyncClient( + transport=StreamingASGITransport(make_server_app(), cancel_on_close=False), base_url=BASE_URL + ) + with anyio.fail_after(5): + async with http_client, http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -# Tests -@pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: - """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): # pragma: no branch - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + lines = response.aiter_lines() + assert await anext(lines) == "event: endpoint" + assert (await anext(lines)).startswith("data: /messages/?session_id=") @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection() -> None: + """A client initializes against, and pings, a server over the SSE transport.""" + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.serverInfo.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None - - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id - - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: +async def test_sse_client_on_session_created() -> None: + """The session-created callback receives the new session ID before sse_client yields.""" + factory = in_process_client_factory(make_server_app()) + captured: list[str] = [] + + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None - assert len(captured_session_id) > 0 + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -219,13 +222,14 @@ def on_session_created(session_id: str) -> None: ], ) def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + """The session ID is read from the endpoint URL's sessionId/session_id query parameters.""" assert _extract_session_id_from_endpoint(endpoint_url) == expected @pytest.mark.anyio -async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + """No session-created callback fires when the endpoint URL carries no session ID.""" + factory = in_process_client_factory(make_server_app()) callback_mock = Mock() def mock_extract(url: str) -> None: @@ -233,17 +237,19 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session() -> AsyncGenerator[ClientSession, None]: + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -253,6 +259,7 @@ async def initialized_sse_client_session(server: None, server_url: str) -> Async async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, ) -> None: + """A resource read round-trips its arguments and the handler's content over SSE.""" session = initialized_sse_client_session response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 @@ -264,234 +271,132 @@ async def test_sse_client_happy_request_and_response( async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, ) -> None: + """A server-side McpError reaches the client with its message intact.""" session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) @pytest.mark.anyio -@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( # pragma: no cover - initialized_sse_client_session: ClientSession, -) -> None: - session = initialized_sse_client_session - - # sanity check that normal, fast responses are working - response = await session.read_resource(uri=AnyUrl("foobar://1")) - assert isinstance(response, ReadResourceResult) - - with anyio.move_on_after(3): - with pytest.raises(McpError, match="Read timed out"): - response = await session.read_resource(uri=AnyUrl("slow://2")) - # we should receive an error here - return - - pytest.fail("the client should have timed out and returned an error already") - - -def run_mounted_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() +async def test_sse_client_basic_connection_mounted_app() -> None: + """The SSE transport works unchanged when its app is mounted under a sub-path.""" + main_app = Starlette(routes=[Mount("/mounted_app", app=make_server_app())]) + factory = in_process_client_factory(main_app) - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: + async with sse_client(f"{BASE_URL}/mounted_app/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.serverInfo.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - inputSchema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] - - -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), +def make_context_server() -> Server[object, Request]: + """A server whose tools echo back the request headers seen via the request context.""" + server: Server[object, Request] = Server("request_context_server") + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> CallToolResult: + assert name in ("echo_headers", "echo_context") + ctx = server.request_context + assert ctx.request is not None + headers_info = dict(ctx.request.headers) + + if name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + context_data = { + "request_id": args.get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echoes request headers", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + inputSchema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), ] - ) - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - -@pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() + return server - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") +def make_context_server_app() -> Starlette: + return make_app(make_context_server()) @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: - """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers +async def test_request_context_propagation() -> None: + """Custom HTTP headers on the SSE connection are visible to server handlers via the request context.""" + factory = in_process_client_factory(make_context_server_app()) + custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=custom_headers) as streams: + async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content = tool_result.content[0] + assert isinstance(content, TextContent) + headers_data = json.loads(content.text) - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: - """Test that request contexts are isolated between different SSE clients.""" - contexts: list[dict[str, Any]] = [] - - # Create multiple clients with different headers - for i in range(3): +async def test_request_context_isolation() -> None: + """Each SSE connection's handlers see only that connection's request headers.""" + factory = in_process_client_factory(make_context_server_app()) + + # Connect three clients in turn, each with its own headers. Each connection is + # verified inside its own block: on Python 3.11 the line tracer is lost once an + # async-with teardown throws (python/cpython#106749), so statements placed after + # this loop would be reported uncovered on some matrix cells. The loop's exit + # arc fires after the final teardown and sits in the same shadow, hence the + # branch exclusion. + for i in range(3): # pragma: no branch headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: + async with ClientSession(*streams) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + content = tool_result.content[0] + assert isinstance(content, TextContent) + ctx = json.loads(content.text) + assert ctx["request_id"] == f"request-{i}" + assert ctx["headers"].get("x-request-id") == f"request-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" - # Verify each request had its own context - assert len(contexts) == 3 - for i, ctx in enumerate(contexts): - assert ctx["request_id"] == f"request-{i}" - assert ctx["headers"].get("x-request-id") == f"request-{i}" - assert ctx["headers"].get("x-custom-value") == f"value-{i}" - -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -525,7 +430,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception @@ -602,3 +507,33 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: assert not isinstance(msg, Exception) assert isinstance(msg.message.root, types.JSONRPCResponse) assert msg.message.root.id == 1 + + +@pytest.mark.anyio +async def test_sse_session_cleanup_on_disconnect() -> None: + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 + + When a client disconnects, the server should remove the session from + _read_stream_writers. Without this cleanup, stale sessions accumulate and + POST requests to disconnected sessions return 202 Accepted followed by a + ClosedResourceError when the server tries to write to the dead stream. + """ + factory = in_process_client_factory(make_server_app()) + captured: list[str] = [] + + # Connect a client session, then disconnect + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # After disconnect, POST to the stale session should return 404 + # (not 202 as it did before the fix) + async with factory() as client: + response = await client.post( + f"/messages/?session_id={captured[0]}", + json={"jsonrpc": "2.0", "method": "ping", "id": 99}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 404 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 731dd20dd3..61d7793240 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1,14 +1,15 @@ -""" -Tests for the StreamableHTTP server and client transport. +"""Tests for the StreamableHTTP server and client transport. -Contains tests for both server and client sides of the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport, driven +entirely in process. """ +import gc import json -import multiprocessing -import socket import time -from collections.abc import Generator +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from datetime import timedelta from typing import Any from unittest.mock import MagicMock @@ -16,10 +17,9 @@ import anyio import httpx import pytest -import requests -import uvicorn from httpx_sse import ServerSentEvent from pydantic import AnyUrl +from sse_starlette.sse import AppStatus from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount @@ -45,7 +45,6 @@ ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage @@ -58,11 +57,61 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports._bridge import StreamingASGITransport + +# v1's streamable-HTTP server transport leaks a handful of anyio memory streams on teardown +# (e.g. `_handle_get_request` only closes `sse_stream_reader` on the exception path; the +# session manager's per-session task-group cancel can race the per-request cleanup). The old +# socket-based version of this file ran the transport in a separate process and so never +# observed these `__del__`-time ResourceWarnings; running in-process via the streaming bridge +# does. The fixes live in `src/` on `main` and are out of scope for this tests-only change. +# The filters are scoped to anyio's `MemoryObject*Stream` leak signature so an unrelated leak +# still fails the suite; tests/interaction/conftest.py applies the same pair for the same reason. +pytestmark = [ + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:pytest.PytestUnraisableExceptionWarning"), + pytest.mark.filterwarnings("ignore:.*MemoryObject(Send|Receive)Stream:ResourceWarning"), +] + + +@pytest.fixture(autouse=True) +def _collect_leaked_streams() -> Iterator[None]: + """Garbage-collect each test's leaked memory streams inside its own teardown. + + The filterwarnings marks above only apply while a test in this file is the + active warning context. The leaked streams sit in reference cycles, so without + a forced collection their deallocator warnings fire wherever the garbage + collector happens to run next: during an unrelated test (failing it, since the + global ``filterwarnings = ["error"]`` has no ignore there) or at pytest's + session-unconfigure unraisable sweep (exit code 1 after all tests passed when + running without xdist, e.g. ``-n 0`` for ``--pdb`` debugging). + """ + yield + gc.collect() + + +@pytest.fixture(autouse=True) +def _reset_sse_starlette_exit_event() -> Iterator[None]: + """Reset sse-starlette's module-global exit Event around each test. + + sse-starlette <3.0 (allowed by this branch's dependency floor; CI's lowest-direct leg + installs it) stores an `anyio.Event` on the `AppStatus` class the first time an + `EventSourceResponse` runs; that Event is bound to the test's event loop and breaks every + subsequent in-process SSE response. sse-starlette 3.x switched to a ContextVar and has no + such attribute. Resetting on both sides of the test keeps this module immune to a stale + Event left behind by an earlier test on the same worker as well as cleaning up after its + own. This mirrors the autouse fixtures in tests/shared/test_sse.py and + tests/interaction/conftest.py. + """ + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + # setattr keeps pyright happy: the locked sse-starlette 3.x has no such attribute. + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + yield + if hasattr(AppStatus, "should_exit_event"): # pragma: no branch + setattr(AppStatus, "should_exit_event", None) # pragma: lax no cover + # Test constants SERVER_NAME = "test_streamable_http_server" -TEST_SESSION_ID = "test-session-id-12345" INIT_REQUEST = { "jsonrpc": "2.0", "method": "initialize", @@ -74,16 +123,19 @@ "id": "init-1", } +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: # pragma: no cover +def extract_protocol_version_from_sse(response: httpx.Response) -> str: """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): if line.startswith("data: "): init_data = json.loads(line[6:]) return init_data["result"]["protocolVersion"] - raise ValueError("Could not extract protocol version from SSE response") + raise ValueError("Could not extract protocol version from SSE response") # pragma: no cover # Simple in-memory event store for testing @@ -94,412 +146,263 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event( # pragma: no cover - self, stream_id: StreamId, message: types.JSONRPCMessage | None - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, ) -> StreamId | None: """Replay events after the specified ID.""" - # Find the stream ID of the last event - target_stream_id = None - for stream_id, event_id, _ in self._events: - if event_id == last_event_id: - target_stream_id = stream_id - break - - if target_stream_id is None: - # If event ID not found, return None - return None + # Find the stream ID of the last event; clients always resume from a stored event. + target_stream_id = next(stream_id for stream_id, event_id, _ in self._events if event_id == last_event_id) # Convert last_event_id to int for comparison last_event_id_int = int(last_event_id) - # Replay only events from the same stream with ID > last_event_id + # Replay only events from the same stream with ID > last_event_id, skipping priming + # events (None message). for stream_id, event_id, message in self._events: - if stream_id == target_stream_id and int(event_id) > last_event_id_int: - # Skip priming events (None message) - if message is not None: - await send_callback(EventMessage(message, event_id)) + if stream_id == target_stream_id and message is not None and int(event_id) > last_event_id_int: + await send_callback(EventMessage(message, event_id)) return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: AnyUrl) -> str | bytes: - if uri.scheme == "foobar": - return f"Read {uri.host}" - elif uri.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {uri.host}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - inputSchema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, - }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - inputSchema={"type": "object", "properties": {}}, - ), - ] +@dataclass +class ServerState: + lock: anyio.Event = field(default_factory=anyio.Event) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) - return [TextContent(type="text", text=f"Called {name}")] +@asynccontextmanager +async def _server_lifespan(_server: Server[ServerState, Request]) -> AsyncIterator[ServerState]: + yield ServerState() - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) +def _create_server() -> Server[ServerState, Request]: + server: Server[ServerState, Request] = Server(SERVER_NAME, lifespan=_server_lifespan) - await anyio.sleep(0.1) + @server.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + raise ValueError(f"Unknown resource: {uri}") - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + ] - return [TextContent(type="text", text="Completed!")] + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, - ) + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) + return [TextContent(type="text", text=f"Called {name}")] - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", + elif name == "test_sampling_tool": + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, - ) - - # Now wait for the lock to be released - await self._lock.wait() + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, + assert sampling_result.content.type == "text" + return [ + TextContent( + type="text", + text=f"Response from sampling: {sampling_result.content.text}", ) + ] - return [TextContent(type="text", text="Completed")] - - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" - - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + elif name == "wait_for_lock_with_notification": + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] - - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) - - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + await ctx.lifespan_context.lock.wait() - if ctx.close_sse_stream: - await ctx.close_sse_stream() + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - await anyio.sleep(sleep_time) + return [TextContent(type="text", text="Completed")] - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + elif name == "release_lock": + ctx.lifespan_context.lock.set() + return [TextContent(type="text", text="Lock released")] - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1")) + elif name == "tool_with_stream_close": + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Done")] + + elif name == "tool_with_multiple_notifications_and_close": + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="All notifications sent")] - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + elif name == "tool_with_standalone_stream_close": + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1")) + await anyio.sleep(0.1) - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + assert ctx.close_standalone_sse_stream is not None + await ctx.close_standalone_sse_stream() - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + await anyio.sleep(1.5) + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2")) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2")) + return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text="Standalone stream close test done")] + return [TextContent(type="text", text=f"Called {name}")] - return [TextContent(type="text", text=f"Called {name}")] + return server -def create_app( +@asynccontextmanager +async def running_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover - """Create a Starlette application for testing using the session manager. + server: Server[Any, Request] | None = None, +) -> AsyncIterator[Starlette]: + """Serve the test server's streamable HTTP app in process for the duration. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. retry_interval: Retry interval in milliseconds for SSE polling. + server: Server to mount; defaults to the file's shared test server. """ - # Create server instance - server = ServerTest() - - # Create the session manager - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the protection itself is pinned by + # tests/server/test_streamable_http_security.py. session_manager = StreamableHTTPSessionManager( - app=server, + app=server if server is not None else _create_server(), event_store=event_store, json_response=is_json_response_enabled, - security_settings=security_settings, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), retry_interval=retry_interval, ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app - # Create an ASGI application that uses the session manager - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - return app +def make_client(app: Starlette, headers: dict[str, str] | None = None) -> httpx.AsyncClient: + """An httpx client served in process by `app`, with create_mcp_http_client's redirect default. - -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. + (Starlette's Mount 307-redirects the bare /mcp path to /mcp/, which the SDK's own client + factory follows.) """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, + return httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, headers=headers, follow_redirects=True ) - # Start the server - server = uvicorn.Server(config=config) - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - import traceback - - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests +# Test fixtures @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def basic_app() -> AsyncIterator[Starlette]: + """The test server's app with SSE response mode.""" + async with running_app() as app: + yield app @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +async def json_app() -> AsyncIterator[Starlette]: + """The test server's app with JSON response mode.""" + async with running_app(is_json_response_enabled=True) as app: + yield app @pytest.fixture @@ -509,160 +412,138 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +async def event_app(event_store: SimpleEventStore) -> AsyncIterator[tuple[SimpleEventStore, Starlette]]: + """The test server's app with an event store and retry_interval enabled.""" + async with running_app(event_store=event_store, retry_interval=500) as app: + yield event_store, app -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) +# Basic request validation tests +@pytest.mark.anyio +async def test_accept_header_validation(basic_app: Starlette) -> None: + """A POST without an Accept header is rejected with 406.""" + async with make_client(basic_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() +@pytest.mark.anyio +@pytest.mark.parametrize( + "accept_header", + [ + "text/html", + "application/*", + "text/*", + ], +) +async def test_accept_header_incompatible(basic_app: Starlette, accept_header: str) -> None: + """Accept headers that do not literally include both required media types are rejected for SSE mode. - # Wait for server to be running - wait_for_server(json_server_port) + (v1 matches Accept media types literally; wildcard support is a main-only change, #2152.) + """ + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": accept_header, + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - yield - # Clean up - proc.kill() - proc.join(timeout=2) +@pytest.mark.anyio +async def test_content_type_validation(basic_app: Starlette) -> None: + """A POST whose Content-Type is not application/json is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + content="This is not JSON", + ) + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +@pytest.mark.anyio +async def test_json_validation(basic_app: Starlette) -> None: + """A POST body that is not valid JSON is rejected with a parse error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + content="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text -@pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +@pytest.mark.anyio +async def test_json_parsing(basic_app: Starlette) -> None: + """Valid JSON that is not a JSON-RPC message is rejected with a validation error.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text -# Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): - """Test that Accept header is properly validated.""" - # Test without Accept header - response = requests.post( - f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_content_type_validation(basic_server: None, basic_server_url: str): - """Test that Content-Type header is properly validated.""" - # Test with incorrect Content-Type - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "text/plain", - }, - data="This is not JSON", - ) - assert response.status_code == 400 - assert "Invalid Content-Type" in response.text +@pytest.mark.anyio +async def test_method_not_allowed(basic_app: Starlette) -> None: + """Unsupported HTTP methods are rejected with 405.""" + async with make_client(basic_app) as client: + response = await client.put( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): - """Test that JSON content is properly validated.""" - # Test with invalid JSON - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - data="this is not valid json", - ) - assert response.status_code == 400 - assert "Parse error" in response.text - - -def test_json_parsing(basic_server: None, basic_server_url: str): - """Test that JSON content is properly parse.""" - # Test with valid JSON but invalid JSON-RPC - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"foo": "bar"}, - ) - assert response.status_code == 400 - assert "Validation error" in response.text - - -def test_method_not_allowed(basic_server: None, basic_server_url: str): - """Test that unsupported HTTP methods are rejected.""" - # Test with unsupported method (PUT) - response = requests.put( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, - ) - assert response.status_code == 405 - assert "Method Not Allowed" in response.text - - -def test_session_validation(basic_server: None, basic_server_url: str): - """Test session ID validation.""" - # session_id not used directly in this test - - # Test without session ID - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, - ) - assert response.status_code == 400 - assert "Missing session ID" in response.text +@pytest.mark.anyio +async def test_session_validation(basic_app: Starlette) -> None: + """A non-initialize request without a session ID is rejected with 400.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text -def test_session_id_pattern(): - """Test that SESSION_ID_PATTERN correctly validates session IDs.""" +def test_session_id_pattern() -> None: + """SESSION_ID_PATTERN accepts visible ASCII (0x21-0x7E) and rejects everything else.""" # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) valid_session_ids = [ "test-session-id", @@ -696,8 +577,8 @@ def test_session_id_pattern(): assert SESSION_ID_PATTERN.fullmatch(session_id) is None -def test_streamable_http_transport_init_validation(): - """Test that StreamableHTTPServerTransport validates session ID on init.""" +def test_streamable_http_transport_init_validation() -> None: + """StreamableHTTPServerTransport accepts valid or absent session IDs and rejects invalid ones.""" # Valid session ID should initialize without errors valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") assert valid_transport.mcp_session_id == "valid-id" @@ -719,299 +600,265 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): - """Test session termination via DELETE and subsequent request handling.""" - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_session_termination(basic_app: Starlette) -> None: + """DELETE terminates the session, after which requests for it return 404.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + response = await client.delete( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now terminate the session - session_id = response.headers.get(MCP_SESSION_ID_HEADER) - response = requests.delete( - f"{basic_server_url}/mcp", - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 200 - - # Try to use the terminated session - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "ping", "id": 2}, - ) - assert response.status_code == 404 - assert "Session has been terminated" in response.text - - -def test_response(basic_server: None, basic_server_url: str): - """Test response handling for a valid request.""" - mcp_url = f"{basic_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 +@pytest.mark.anyio +async def test_response(basic_app: Starlette) -> None: + """A request on an initialized session is answered on a text/event-stream response.""" + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now get the session ID + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Try to use the session with proper headers + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + ) as tools_response: + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" - # Extract negotiated protocol version from SSE response - negotiated_version = extract_protocol_version_from_sse(response) - # Now get the session ID - session_id = response.headers.get(MCP_SESSION_ID_HEADER) +@pytest.mark.anyio +async def test_json_response(json_app: Starlette) -> None: + """With JSON response mode enabled, requests are answered with application/json bodies.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" - # Try to use the session with proper headers - tools_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, - stream=True, - ) - assert tools_response.status_code == 200 - assert tools_response.headers.get("Content-Type") == "text/event-stream" - - -def test_json_response(json_response_server: None, json_server_url: str): - """Test response handling when is_json_response_enabled is True.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): - """Test that json_response servers only require application/json in Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 200 - assert response.headers.get("Content-Type") == "application/json" - - -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" - mcp_url = f"{json_server_url}/mcp" - response = requests.post( - mcp_url, - headers={ - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text +@pytest.mark.anyio +async def test_json_response_accept_json_only(json_app: Starlette) -> None: + """JSON response mode only requires application/json in the Accept header.""" + async with make_client(json_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests with incorrect Accept header.""" - mcp_url = f"{json_server_url}/mcp" - # Test with only text/event-stream (wrong for JSON server) - response = requests.post( - mcp_url, - headers={ - "Accept": "text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - -def test_get_sse_stream(basic_server: None, basic_server_url: str): - """Test establishing an SSE stream via GET request.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None +@pytest.mark.anyio +async def test_json_response_missing_accept_header(json_app: Starlette) -> None: + """JSON response mode still rejects requests without an Accept header.""" + async with make_client(json_app) as client: + # Suppress the httpx client default Accept: */* header + del client.headers["accept"] + response = await client.post( + "/mcp", + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Now attempt to establish an SSE stream via GET - get_response = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - # Verify we got a successful response with the right content type - assert get_response.status_code == 200 - assert get_response.headers.get("Content-Type") == "text/event-stream" +@pytest.mark.anyio +async def test_json_response_incorrect_accept_header(json_app: Starlette) -> None: + """JSON response mode rejects an Accept header that does not cover application/json.""" + async with make_client(json_app) as client: + # Test with only text/event-stream (wrong for JSON server) + response = await client.post( + "/mcp", + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text - # Test that a second GET request gets rejected (only one stream allowed) - second_get = requests.get( - mcp_url, - headers={ - "Accept": "text/event-stream", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - # Should get CONFLICT (409) since there's already a stream - # Note: This might fail if the first stream fully closed before this runs, - # but generally it should work in the test environment where it runs quickly - assert second_get.status_code == 409 - - -def test_get_validation(basic_server: None, basic_server_url: str): - """Test validation for GET requests.""" - # First, we need to initialize a session - mcp_url = f"{basic_server_url}/mcp" - init_response = requests.post( - mcp_url, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 +@pytest.mark.anyio +async def test_get_sse_stream(basic_app: Starlette) -> None: + """GET establishes the standalone SSE stream, and a second GET is rejected with 409.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 - # Get the session ID - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - assert session_id is not None + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) - # Extract negotiated protocol version from SSE response - init_data = None - assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): # pragma: no branch - if line.startswith("data: "): # pragma: no cover - init_data = json.loads(line[6:]) - break - assert init_data is not None - negotiated_version = init_data["result"]["protocolVersion"] - - # Test without Accept header - response = requests.get( - mcp_url, - headers={ - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - stream=True, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text - - # Test with wrong Accept header - response = requests.get( - mcp_url, - headers={ - "Accept": "application/json", + # Now attempt to establish an SSE stream via GET + get_headers = { + "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - ) - assert response.status_code == 406 - assert "Not Acceptable" in response.text + } + # The streams enter in order, so the second GET arrives while the first is held open. + async with ( + client.stream("GET", "/mcp", headers=get_headers) as get_response, + client.stream("GET", "/mcp", headers=get_headers) as second_get, + ): + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + # The second GET gets CONFLICT (409): only one standalone stream is allowed per session. + assert second_get.status_code == 409 -# Client-specific fixtures -@pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client + +@pytest.mark.anyio +async def test_get_validation(basic_app: Starlette) -> None: + """A GET without an Accept header covering text/event-stream is rejected with 406.""" + async with make_client(basic_app) as client: + # First, we need to initialize a session + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + negotiated_version = extract_protocol_version_from_sse(init_response) + + # Test without Accept header (suppress the httpx client default Accept: */*) + del client.headers["accept"] + response = await client.get( + "/mcp", + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = await client.get( + "/mcp", + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text +# Client-specific fixtures @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_app: Starlette) -> AsyncIterator[ClientSession]: """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - await session.initialize() - yield session + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): - """Test basic client connection with initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_basic_connection(basic_app: Starlette) -> None: + """A client initializes against a server over the StreamableHTTP transport.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME @pytest.mark.anyio -async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): - """Test client resource read functionality.""" +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None: + """A resource read round-trips its arguments and the handler's content.""" response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 assert response.contents[0].uri == AnyUrl("foobar://test-resource") @@ -1020,11 +867,11 @@ async def test_streamable_http_client_resource_read(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): - """Test client tool invocation.""" +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession) -> None: + """A tool call reaches the handler and returns its content.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 10 + assert len(tools.tools) == 8 assert tools.tools[0].name == "test_tool" # Call the tool @@ -1035,8 +882,8 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio -async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): - """Test error handling in client.""" +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: + """A server-side error reaches the client as an McpError with the handler's message.""" with pytest.raises(McpError) as exc_info: await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) assert exc_info.value.error.code == 0 @@ -1044,66 +891,56 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): - """Test that session ID persists across requests.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_session_persistence(basic_app: Starlette) -> None: + """The session persists across multiple requests on one connection.""" + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Make multiple requests to verify session persistence - tools = await session.list_tools() - assert len(tools.tools) == 10 + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 8 - # Read a resource - resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) - assert isinstance(resource.contents[0], TextResourceContents) is True - content = resource.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Read test-persist" + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): - """Test client with JSON response mode.""" - async with streamable_http_client(f"{json_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_streamable_http_client_json_response(json_app: Starlette) -> None: + """The client works identically against a server in JSON response mode.""" + async with ( + make_client(json_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession( - read_stream, - write_stream, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == SERVER_NAME + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME - # Check tool listing - tools = await session.list_tools() - assert len(tools.tools) == 10 + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 8 - # Call a tool and verify JSON response handling - result = await session.call_tool("test_tool", {}) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Called test_tool" + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): - """Test GET stream functionality for server-initiated messages.""" - import mcp.types as types - +async def test_streamable_http_client_get_stream(basic_app: Starlette) -> None: + """A server-initiated notification reaches the client on the standalone GET stream.""" notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications @@ -1113,79 +950,91 @@ async def message_handler( # pragma: no branch if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - this triggers the GET stream setup - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) + + # Verify we received the notification + assert len(notifications_received) > 0 - # Call the special tool that sends a notification - await session.call_tool("test_tool_with_standalone_notification", {}) + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True - # Verify we received the notification - assert len(notifications_received) > 0 + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" - # Verify the notification is a ResourceUpdatedNotification - resource_update_found = False - for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch - assert str(notif.root.params.uri) == "http://test_resource/" - resource_update_found = True - assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" +def create_session_id_capturing_client(app: Starlette) -> tuple[httpx.AsyncClient, list[str]]: + """Create an in-process httpx client that captures the session ID from responses.""" + captured_ids: list[str] = [] + + async def capture_session_id(response: httpx.Response) -> None: + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if session_id: + captured_ids.append(session_id) + + client = httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url=BASE_URL, + follow_redirects=True, + event_hooks={"response": [capture_session_id]}, + ) + return client, captured_ids @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): - """Test client session termination functionality.""" +async def test_streamable_http_client_session_termination(basic_app: Starlette) -> None: + """After the client terminates its session on close, a new connection with that session ID fails.""" + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) - captured_session_id = None + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} - # Create the streamable_http_client with a custom httpx client to capture headers - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - - # Make a request to confirm session is working - tools = await session.list_tools() - assert len(tools.tools) == 10 - - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 8 + + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): - """Test client session termination functionality with a 204 response. + basic_app: Starlette, monkeypatch: pytest.MonkeyPatch +) -> None: + """Session termination also succeeds when the server answers the DELETE with 204. This test patches the httpx client to return a 204 response for DELETEs. """ @@ -1210,55 +1059,50 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Apply the patch to the httpx client monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) - captured_session_id = None + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(basic_app) - # Create the streamable_http_client with a custom httpx client to capture headers - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - - # Make a request to confirm session is working - tools = await session.list_tools() - assert len(tools.tools) == 10 - - headers: dict[str, str] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 8 + + async with make_client(basic_app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises( # pragma: no branch - McpError, - match="Session terminated", - ): + with pytest.raises(McpError, match="Session terminated"): # pragma: no branch await session.list_tools() @pytest.mark.anyio -async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): - """Test client session resumption using sync primitives for reliable coordination.""" - _, server_url = event_server +async def test_streamable_http_client_resumption(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """A second client resumes an interrupted request with a resumption token and receives the rest.""" + _, app = event_app # Variables to track the state - captured_session_id = None - captured_resumption_token = None + captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version = None - first_notification_received = False + first_notification_received = anyio.Event() + resumption_token_received = anyio.Event() async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -1268,83 +1112,88 @@ async def message_handler( # pragma: no branch # Look for our first notification if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch if message.root.params.data == "First notification before lock": - nonlocal first_notification_received - first_notification_received = True + first_notification_received.set() async def on_resumption_token_update(token: str) -> None: nonlocal captured_resumption_token captured_resumption_token = token + resumption_token_received.set() + + # Use httpx client with event hooks to capture session ID + httpx_client, captured_ids = create_session_id_capturing_client(app) # First, start the client session and begin the tool that waits on lock - async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False) as ( - read_stream, - write_stream, - get_session_id, - ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - captured_session_id = get_session_id() - assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocolVersion - - # Start the tool that will wait on lock in a task - async with anyio.create_task_group() as tg: - - async def run_tool(): - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams( - name="wait_for_lock_with_notification", arguments={} - ), - ) - ), - types.CallToolResult, - metadata=metadata, - ) + async with httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", terminate_on_close=False, http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert len(captured_ids) > 0 + captured_session_id = captured_ids[0] + assert captured_session_id is not None + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocolVersion, + } - tg.start_soon(run_tool) + # Start the tool that will wait on lock in a task + async with anyio.create_task_group() as tg: # pragma: no branch - # Wait for the first notification and resumption token - while not first_notification_received or not captured_resumption_token: - await anyio.sleep(0.1) + async def run_tool(): + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams( + name="wait_for_lock_with_notification", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) - # Kill the client session while tool is waiting on lock - tg.cancel_scope.cancel() + tg.start_soon(run_tool) - # Verify we received exactly one notification - assert len(captured_notifications) == 1 # pragma: no cover - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover + # Wait for the first notification and resumption token + with anyio.fail_after(5): + await first_notification_received.wait() + await resumption_token_received.wait() - # Clear notifications for the second phase - captured_notifications = [] # pragma: no cover + # first_notification_received is set by message_handler immediately + # after appending to captured_notifications. The server tool is + # blocked on its lock, so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() - # Now resume the session with the same mcp-session-id and protocol version - headers: dict[str, Any] = {} # pragma: no cover - if captured_session_id: # pragma: no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - if captured_protocol_version: # pragma: no cover - headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + # Kill the client session while tool is waiting on lock + tg.cancel_scope.cancel() - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client) as ( + async with make_client(app, headers=headers) as httpx_client2: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as ( read_stream, write_stream, _, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: # pragma: no branch result = await session.send_request( types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="release_lock", arguments={}), - ) + types.CallToolRequest(params=types.CallToolRequestParams(name="release_lock", arguments={})) ), types.CallToolResult, ) @@ -1367,14 +1216,13 @@ async def run_tool(): # We should have received the remaining notifications assert len(captured_notifications) == 1 - - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover - assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) + assert captured_notifications[0].root.params.data == "Second notification after lock" @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): - """Test server-initiated sampling request through streamable HTTP transport.""" +async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: + """A server-initiated sampling request reaches the client callback and its result the tool.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False captured_message_params = None @@ -1401,153 +1249,99 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session, ): - async with ClientSession( - read_stream, - write_stream, - sampling_callback=sampling_callback, - ) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) - # Call the tool that triggers server-side sampling - tool_result = await session.call_tool("test_sampling_tool", {}) + # Call the tool that triggers server-side sampling + tool_result = await session.call_tool("test_sampling_tool", {}) - # Verify the tool result contains the expected content - assert len(tool_result.content) == 1 - assert tool_result.content[0].type == "text" - assert "Response from sampling: Received message from server" in tool_result.content[0].text + # Verify the tool result contains the expected content + assert len(tool_result.content) == 1 + assert tool_result.content[0].type == "text" + assert "Response from sampling: Received message from server" in tool_result.content[0].text - # Verify sampling callback was invoked - assert sampling_callback_invoked - assert captured_message_params is not None - assert len(captured_message_params.messages) == 1 - assert captured_message_params.messages[0].content.text == "Server needs client sampling" + # Verify sampling callback was invoked + assert sampling_callback_invoked + assert captured_message_params is not None + assert len(captured_message_params.messages) == 1 + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - inputSchema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +def _create_context_server() -> Server[dict[str, Any], Request]: + server: Server[dict[str, Any], Request] = Server("ContextAwareServer") + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echo request headers from context", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + inputSchema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - - return [TextContent(type="text", text=f"Unknown tool: {name}")] + "required": ["request_id"], + }, + ), + ] + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + assert name in ("echo_headers", "echo_context") + assert isinstance(ctx.request, Request) + + if name == "echo_headers": + return [TextContent(type="text", text=json.dumps(dict(ctx.request.headers)))] + + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": dict(ctx.request.headers), + "method": ctx.request.method, + "path": ctx.request.url.path, + } + return [TextContent(type="text", text=json.dumps(context_data))] + return server -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" - server = ContextAwareServerTest() +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" + server = _create_context_server() session_manager = StreamableHTTPSessionManager( app=server, - event_store=None, - json_response=False, - ) - - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lambda app: session_manager.run(), - ) - - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - server_instance.run() - - -@pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + async with session_manager.run(): + yield app @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request context is properly propagated through StreamableHTTP.""" +async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: + """Custom HTTP headers on the connection are visible to server handlers via ctx.request.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, _, @@ -1572,11 +1366,11 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: - """Test that request contexts are isolated between StreamableHTTP clients.""" +async def test_streamablehttp_request_context_isolation(context_app: Starlette) -> None: + """Each connection's handlers see only that connection's request headers.""" contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = { "X-Request-Id": f"request-{i}", @@ -1584,8 +1378,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with make_client(context_app, headers=headers) as httpx_client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client) as ( read_stream, write_stream, _, @@ -1602,8 +1396,8 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No contexts.append(context_data) # Verify each request had its own context - assert len(contexts) == 3 # pragma: no cover - for i, ctx in enumerate(contexts): # pragma: no cover + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" @@ -1611,157 +1405,160 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): - """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, +async def test_client_includes_protocol_version_header_after_init(context_app: Starlette) -> None: + """After initialization, every client request carries the negotiated protocol version header.""" + async with ( + make_client(context_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocolVersion - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version - - -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): - """Test that server returns 400 Bad Request version if header unsupported or invalid.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request with invalid protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "invalid-version", - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocolVersion - # Test request with valid protocol version (should succeed) - negotiated_version = extract_protocol_version_from_sse(init_response) + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: negotiated_version, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, - ) - assert response.status_code == 200 - - -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): - """Test server accepts requests without protocol version header.""" - # First initialize a session to get a valid session ID - init_response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - json=INIT_REQUEST, - ) - assert init_response.status_code == 200 - session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) - - # Test request without mcp-protocol-version header (backwards compatibility) - response = requests.post( - f"{basic_server_url}/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, - stream=True, - ) - assert response.status_code == 200 # Should succeed for backwards compatibility - assert response.headers.get("Content-Type") == "text/event-stream" + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): - """Test that cases where the client crashes are handled gracefully.""" +async def test_server_validates_protocol_version_header(basic_app: Starlette) -> None: + """An invalid or unsupported protocol version header is rejected with 400; the negotiated one passes.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request with invalid protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None: + """A request without a protocol version header is accepted for backwards compatibility.""" + async with make_client(basic_app) as client: + # First initialize a session to get a valid session ID + init_response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test request without mcp-protocol-version header (backwards compatibility) + async with client.stream( + "POST", + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + ) as response: + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_app: Starlette) -> None: + """A client crashing mid-session does not prevent later clients from connecting.""" # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - raise Exception("client crash") + await session.initialize() + raise Exception("client crash") - # Run bad client a few times to trigger the crash + # Run bad client a few times to trigger the crash. The crash surfaces wrapped in exception + # groups whose exact shape is not the subject here — what matters is that the server survives. for _ in range(3): try: await bad_client() except Exception: pass - await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamable_http_client(f"{basic_server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(basic_app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - tools = await session.list_tools() - assert tools.tools + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools @pytest.mark.anyio -async def test_handle_sse_event_skips_empty_data(): - """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" +async def test_handle_sse_event_skips_empty_data() -> None: + """_handle_sse_event skips empty SSE data (keep-alive pings) without writing to the stream.""" transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") # Create a mock SSE event with empty data (keep-alive ping) @@ -1786,8 +1583,8 @@ async def test_handle_sse_event_skips_empty_data(): @pytest.mark.anyio -async def test_priming_event_not_sent_for_old_protocol_version(): - """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" +async def test_priming_event_not_sent_for_old_protocol_version() -> None: + """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1815,8 +1612,8 @@ async def test_priming_event_not_sent_for_old_protocol_version(): @pytest.mark.anyio -async def test_priming_event_not_sent_without_event_store(): - """Test that _maybe_send_priming_event returns early when no event_store is configured.""" +async def test_priming_event_not_sent_without_event_store() -> None: + """_maybe_send_priming_event returns early when no event_store is configured.""" # Create a transport WITHOUT an event store transport = StreamableHTTPServerTransport("/mcp") @@ -1835,8 +1632,8 @@ async def test_priming_event_not_sent_without_event_store(): @pytest.mark.anyio -async def test_priming_event_includes_retry_interval(): - """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" +async def test_priming_event_includes_retry_interval() -> None: + """_maybe_send_priming_event includes the retry field when retry_interval is set.""" # Create a transport with an event store AND retry_interval transport = StreamableHTTPServerTransport( "/mcp", @@ -1864,8 +1661,8 @@ async def test_priming_event_includes_retry_interval(): @pytest.mark.anyio -async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): - """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() -> None: + """close_sse_stream callbacks are only provided for protocol versions that support polling.""" # Create a transport with an event store transport = StreamableHTTPServerTransport( "/mcp", @@ -1897,83 +1694,78 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_streamable_http_client_receives_priming_event( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should receive priming event (resumption token update) on POST SSE stream.""" - _, server_url = event_server + _, app = event_app captured_resumption_tokens: list[str] = [] async def on_resumption_token_update(token: str) -> None: captured_resumption_tokens.append(token) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - # Call tool with resumption token callback via send_request - metadata = ClientMessageMetadata( - on_resumption_token_update=on_resumption_token_update, - ) - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="test_tool", arguments={}), - ) - ), - types.CallToolResult, - metadata=metadata, - ) - assert result is not None - - # Should have received priming event token BEFORE response data - # Priming event = 1 token (empty data, id only) - # Response = 1 token (actual JSON-RPC response) - # Total = 2 tokens minimum - assert len(captured_resumption_tokens) >= 2, ( - f"Server must send priming event before response. " - f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" - ) - assert captured_resumption_tokens[0] is not None + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest(params=types.CallToolRequestParams(name="test_tool", arguments={})) + ), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None @pytest.mark.anyio async def test_server_close_sse_stream_via_context( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Server tool can call ctx.close_sse_stream() to close connection.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - # Call tool that closes stream mid-operation - # This should NOT raise NotImplementedError when fully implemented - result = await session.call_tool("tool_with_stream_close", {}) + # Call tool that closes stream mid-operation + result = await session.call_tool("tool_with_stream_close", {}) - # Client should still receive complete response (via auto-reconnect) - assert result is not None - assert len(result.content) > 0 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_auto_reconnects( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" - _, server_url = event_server + _, app = event_app captured_notifications: list[str] = [] async def message_handler( @@ -1985,71 +1777,63 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch captured_notifications.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification - # 2. Closes SSE stream - # 3. Sends more notifications (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_stream_close", {}) - - # Client should have auto-reconnected and received ALL notifications - assert len(captured_notifications) >= 2, ( - "Client should auto-reconnect and receive notifications sent both before and after stream close" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_client_respects_retry_interval( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Client MUST respect retry field, waiting specified ms before reconnecting.""" - _, server_url = event_server + _, app = event_app - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, ): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + await session.initialize() - start_time = time.monotonic() - result = await session.call_tool("tool_with_stream_close", {}) - elapsed = time.monotonic() - start_time + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time - # Verify result was received - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" - # The elapsed time should include at least the retry interval - # if reconnection occurred. This test may be flaky depending on - # implementation details, but demonstrates the expected behavior. - # Note: This assertion may need adjustment based on actual implementation - assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + # The elapsed time should include at least the retry interval (500ms) before + # the client reconnected; the tool's own work only accounts for ~100ms. + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" @pytest.mark.anyio async def test_streamable_http_sse_polling_full_cycle( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """End-to-end test: server closes stream, client reconnects, receives all events.""" - _, server_url = event_server + _, app = event_app all_notifications: list[str] = [] async def message_handler( @@ -2061,43 +1845,38 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch all_notifications.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that simulates polling pattern: - # 1. Server sends priming event - # 2. Server sends "Before close" notification - # 3. Server closes stream (calls close_sse_stream) - # 4. (client reconnects automatically) - # 5. Server sends "After close" notification - # 6. Server sends final response - result = await session.call_tool("tool_with_stream_close", {}) - - # Verify all notifications received in order - assert "Before close" in all_notifications, "Should receive notification sent before stream close" - assert "After close" in all_notifications, ( - "Should receive notification sent after stream close (via auto-reconnect)" - ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Done" + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" @pytest.mark.anyio async def test_streamable_http_events_replayed_after_disconnect( - event_server: tuple[SimpleEventStore, str], + event_app: tuple[SimpleEventStore, Starlette], ) -> None: """Events sent while client is disconnected should be replayed on reconnect.""" - _, server_url = event_server + _, app = event_app notification_data: list[str] = [] async def message_handler( @@ -2109,45 +1888,43 @@ async def message_handler( if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch notification_data.append(str(message.root.params.data)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() + await session.initialize() - # Tool sends: notification1, close_stream, notification2, notification3, response - # Client should receive all notifications even though 2&3 were sent during disconnect - result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) - assert "notification1" in notification_data, "Should receive notification1 (sent before close)" - assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" - assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" - # Verify order: notification1 should come before notification2 and notification3 - idx1 = notification_data.index("notification1") - idx2 = notification_data.index("notification2") - idx3 = notification_data.index("notification3") - assert idx1 < idx2 < idx3, "Notifications should be received in order" + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "All notifications sent" + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" @pytest.mark.anyio -async def test_streamable_http_multiple_reconnections( - event_server: tuple[SimpleEventStore, str], -): - """Verify multiple close_sse_stream() calls each trigger a client reconnect. +async def test_streamable_http_multiple_reconnections() -> None: + """Every close_sse_stream() severs a live connection and triggers its own client reconnect. - Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure - client has time to reconnect before the next checkpoint. + The tool closes its SSE stream three times; before each next cycle it waits until the + client has observed the previous cycle's two new resumption tokens (the checkpoint and the + new connection's priming event). The priming event is sent only after the server has + re-registered the resumed stream, so once the client holds its token the next close is + guaranteed to sever a live connection rather than silently no-op — making the exact token + count below a consequence of causality, not timing margins. This pins reconnect-per-close + accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval. With 3 checkpoints, we expect 8 resumption tokens: - 1 priming (initial POST connection) @@ -2155,50 +1932,77 @@ async def test_streamable_http_multiple_reconnections( - 3 priming (one per reconnect after each close) - 1 response """ - _, server_url = event_server resumption_tokens: list[str] = [] + # milestones[n] fires when the client has observed n tokens. After the initial priming + # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the + # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens. + milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()} async def on_resumption_token(token: str) -> None: resumption_tokens.append(token) - - async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Use send_request with metadata to track resumption tokens - metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="tool_with_multiple_stream_closes", - # retry_interval=500ms, so sleep 600ms to ensure reconnect completes - arguments={"checkpoints": 3, "sleep_time": 0.6}, - ), - ) - ), - types.CallToolResult, - metadata=metadata, + milestone = milestones.get(len(resumption_tokens)) + if milestone is not None: + milestone.set() + + server: Server[dict[str, Any], Request] = Server("multi_reconnect_server") + + @server.call_tool() + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + assert name == "multi_close_tool" + for i, milestone in enumerate(milestones.values()): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Client and server share one event loop, so the tool can wait directly on the + # client-side callback observing the reconnect. + with anyio.fail_after(5): + await milestone.wait() + return [TextContent(type="text", text="Completed 3 checkpoints")] + + async with ( + # retry_interval is small to keep the test fast, but nonzero so each dying connection + # finishes unwinding before its replacement registers. + running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app, + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="multi_close_tool", arguments={}), + ) + ), + types.CallToolResult, + metadata=metadata, + ) - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Completed 3 checkpoints" in result.content[0].text + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio -async def test_standalone_get_stream_reconnection( - event_server: tuple[SimpleEventStore, str], -) -> None: - """ - Test that standalone GET stream automatically reconnects after server closes it. +async def test_standalone_get_stream_reconnection(event_app: tuple[SimpleEventStore, Starlette]) -> None: + """Test that standalone GET stream automatically reconnects after server closes it. Verifies: 1. Client receives notification 1 via GET stream @@ -2206,10 +2010,10 @@ async def test_standalone_get_stream_reconnection( 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection - Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + Note: Requires the event store app because close_standalone_sse_stream callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - _, server_url = event_server + _, app = event_app received_notifications: list[str] = [] async def message_handler( @@ -2221,53 +2025,46 @@ async def message_handler( if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch received_notifications.append(str(message.root.params.uri)) - async with streamable_http_client(f"{server_url}/mcp") as ( - read_stream, - write_stream, - _, + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream, _), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - ) as session: - await session.initialize() - - # Call tool that: - # 1. Sends notification_1 via GET stream - # 2. Closes standalone GET stream - # 3. Sends notification_2 (stored in event_store) - # 4. Returns response - result = await session.call_tool("tool_with_standalone_stream_close", {}) - - # Verify the tool completed - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Standalone stream close test done" - - # Verify both notifications were received - assert "http://notification_1/" in received_notifications, ( - f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" - ) - assert "http://notification_2/" in received_notifications, ( - f"Should receive notification 2 after reconnect, got: {received_notifications}" - ) + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1/" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2/" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: - """Test that streamable_http_client does not mutate the provided httpx client's headers.""" +async def test_streamable_http_client_does_not_mutate_provided_client(basic_app: Starlette) -> None: + """streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { "X-Custom-Header": "custom-value", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + async with make_client(basic_app, headers=original_headers) as custom_client: # Use the client with streamable_http_client - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=custom_client) as ( read_stream, write_stream, _, @@ -2276,35 +2073,34 @@ async def test_streamable_http_client_does_not_mutate_provided_client( result = await session.initialize() assert isinstance(result, InitializeResult) - # Verify client headers were not mutated with MCP protocol headers + # Verify client headers were not mutated with MCP protocol headers. + # These checks deliberately sit after the streamable_http_client context + # exits (a teardown-time mutation would otherwise escape them), which on + # Python 3.11 places them in the post-teardown trace-loss shadow + # (python/cpython#106749): they run and assert on every leg but go + # unmeasured on 3.11 cells, hence the lax exclusions. # If accept header exists, it should still be httpx default, not MCP's - if "accept" in custom_client.headers: # pragma: no branch + if "accept" in custom_client.headers: # pragma: lax no cover assert custom_client.headers.get("accept") == "*/*" # MCP content-type should not have been added - assert custom_client.headers.get("content-type") != "application/json" + assert custom_client.headers.get("content-type") != "application/json" # pragma: lax no cover # Verify custom headers are still present and unchanged - assert custom_client.headers.get("X-Custom-Header") == "custom-value" - assert custom_client.headers.get("Authorization") == "Bearer test-token" + assert custom_client.headers.get("X-Custom-Header") == "custom-value" # pragma: lax no cover + assert custom_client.headers.get("Authorization") == "Bearer test-token" # pragma: lax no cover @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that MCP protocol headers override httpx.AsyncClient default headers.""" +async def test_streamable_http_client_mcp_headers_override_defaults(context_app: Starlette) -> None: + """MCP protocol headers override the httpx client's default headers in actual requests.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests - async with httpx.AsyncClient(follow_redirects=True) as client: + async with make_client(context_app) as client: # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( - read_stream, - write_stream, - _, - ): + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2324,22 +2120,16 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: - """Test that both custom headers and MCP protocol headers are sent in requests.""" +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_app: Starlette) -> None: + """Custom client headers and MCP protocol headers are both sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( - read_stream, - write_stream, - _, - ): + async with make_client(context_app, headers=custom_headers) as client: + async with streamable_http_client(f"{BASE_URL}/mcp", http_client=client) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2363,12 +2153,11 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert headers_data["content-type"] == "application/json" -@pytest.mark.anyio -async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None: - """Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored.""" +def test_streamable_http_transport_deprecated_params_ignored() -> None: + """Deprecated parameters passed to StreamableHTTPTransport are accepted but ignored.""" with pytest.warns(DeprecationWarning): transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated] - url=f"{basic_server_url}/mcp", + url=f"{BASE_URL}/mcp", headers={"X-Should-Be-Ignored": "ignored"}, timeout=999, sse_read_timeout=timedelta(seconds=999), @@ -2382,10 +2171,27 @@ async def test_streamable_http_transport_deprecated_params_ignored(basic_server: @pytest.mark.anyio -async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None: - """Test that the old streamablehttp_client() function issues a deprecation warning.""" +async def test_streamablehttp_client_deprecation_warning(basic_app: Starlette) -> None: + """The old streamablehttp_client() function issues a deprecation warning.""" + + def in_process_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(basic_app), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) + with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"): - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated] + async with streamablehttp_client( # pyright: ignore[reportDeprecated] + f"{BASE_URL}/mcp", httpx_client_factory=in_process_client_factory + ) as ( read_stream, write_stream, _,