diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..4930775d73 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -245,7 +245,7 @@ def _create_session_message( async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -421,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: no cover + if not has_json: response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -669,7 +669,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): return # Handle resumability: check for Last-Event-ID header @@ -1017,7 +1017,7 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) else: diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e469..8d18e34b3f 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,8 +4,6 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager @@ -19,7 +17,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,9 +39,8 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn +def make_unicode_server_app() -> Starlette: # pragma: no cover + """Create the Unicode test server.""" # Need to recreate the server setup in this process async def handle_list_tools( @@ -137,43 +134,14 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: lifespan=lifespan, ) - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] + return app @pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) +def running_unicode_server() -> Generator[str, None, None]: + """Start a Unicode test server without preselecting a port.""" + with run_uvicorn_in_thread(make_unicode_server_app()) as url: + yield url @pytest.mark.anyio diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index e95dc51b31..bbb32f5332 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,15 +1,14 @@ """Tests for SSE server request validation.""" import logging -import multiprocessing import re -import socket +from collections.abc import Generator +from contextlib import contextmanager import anyio import httpx import pytest import sse_starlette.sse -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -24,7 +23,7 @@ from mcp.shared._stream_protocols import WriteStream from mcp.shared.message import SessionMessage from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" @@ -39,19 +38,7 @@ def reset_sse_starlette_exit_event() -> None: app_status.should_exit_event = None -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) @@ -59,8 +46,8 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" +def make_server_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: # pragma: no cover + """Create the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) @@ -79,94 +66,74 @@ async def handle_sse(request: Request): Mount("/messages/", app=sse_transport.handle_post_message), ] - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + return Starlette(routes=routes) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +@contextmanager +def run_server_with_settings( + security_settings: TransportSecuritySettings | None = None, +) -> Generator[str, None, None]: + """Run the SSE server without preselecting a port.""" + with run_uvicorn_in_thread(make_server_app(security_settings)) as url: + yield url @pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): +async def test_sse_security_default_settings(): """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) - - try: + with run_server_with_settings() as server_url: headers = {"Host": "evil.com", "Origin": "http://evil.com"} async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: assert response.status_code == 200 - finally: - process.terminate() - process.join() @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: # Test with invalid host header headers = {"Host": "evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{server_url}/sse", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(): """Test SSE with invalid Origin header.""" # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: # Test with invalid origin header headers = {"Origin": "http://evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{server_url}/sse", headers=headers) assert response.status_code == 403 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(): """Test POST endpoint with invalid Content-Type header.""" # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{server_url}/messages/?session_id={fake_session_id}", headers={"Content-Type": "text/plain"}, content="test", ) @@ -174,55 +141,41 @@ async def test_sse_security_post_invalid_content_type(server_port: int): assert response.text == "Invalid Content-Type header" # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) + response = await client.post(f"{server_url}/messages/?session_id={fake_session_id}", content="test") assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: + with run_server_with_settings(settings) as server_url: # Test with invalid host header - should still work headers = {"Host": "evil.com"} async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_with_settings(settings) as server_url: # Test with custom allowed host headers = {"Host": "custom.host"} async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: # Should connect successfully with custom host assert response.status_code == 200 @@ -230,33 +183,27 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): headers = {"Host": "evil.com"} async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{server_url}/sse", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_with_settings(settings) as server_url: # Test with various port numbers for test_port in [8080, 3000, 9999]: headers = {"Host": f"localhost:{test_port}"} async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 @@ -264,25 +211,19 @@ async def test_sse_security_wildcard_ports(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: # Should connect successfully with any port assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(): """Test POST endpoint with valid Content-Type headers.""" # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: async with httpx.AsyncClient() as client: # Test with various valid content types valid_content_types = [ @@ -296,7 +237,7 @@ async def test_sse_security_post_valid_content_type(server_port: int): # Use a valid UUID format (even though session won't exist) fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{server_url}/messages/?session_id={fake_session_id}", headers={"Content-Type": content_type}, json={"test": "data"}, ) @@ -305,10 +246,6 @@ async def test_sse_security_post_valid_content_type(server_port: int): assert response.status_code == 404 assert response.text == "Could not find session" - finally: - process.terminate() - process.join() - def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..3db71f1035 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,13 +1,10 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator, Generator +from contextlib import asynccontextmanager, contextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount from starlette.types import Receive, Scope, Send @@ -16,24 +13,12 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_streamable_http_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) @@ -41,8 +26,8 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" +def make_server_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: # pragma: no cover + """Create the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() # Create session manager with security settings @@ -67,30 +52,27 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: Mount("/", app=handle_streamable_http), ] - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + return Starlette(routes=routes, lifespan=lifespan) -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process +@contextmanager +def run_server_with_settings( + security_settings: TransportSecuritySettings | None = None, +) -> Generator[str, None, None]: + """Run the StreamableHTTP server without preselecting a port.""" + with run_uvicorn_in_thread(make_server_app(security_settings)) as url: + yield url @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): +async def test_streamable_http_security_default_settings(): """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: + with run_server_with_settings() as server_url: # Test with valid localhost headers async with httpx.AsyncClient(timeout=5.0) as client: # POST request to initialize session response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers={ "Accept": "application/json, text/event-stream", @@ -100,18 +82,12 @@ async def test_streamable_http_security_default_settings(server_port: int): assert response.status_code == 200 assert "mcp-session-id" in response.headers - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): +async def test_streamable_http_security_invalid_host_header(): """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: # Test with invalid host header headers = { "Host": "evil.com", @@ -121,25 +97,19 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): +async def test_streamable_http_security_invalid_origin_header(): """Test StreamableHTTP with invalid Origin header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: # Test with invalid origin header headers = { "Origin": "http://evil.com", @@ -149,28 +119,22 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) assert response.status_code == 403 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): +async def test_streamable_http_security_invalid_content_type(): """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: + with run_server_with_settings() as server_url: async with httpx.AsyncClient(timeout=5.0) as client: # Test POST with invalid content type response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", headers={ "Content-Type": "text/plain", "Accept": "application/json, text/event-stream", @@ -182,25 +146,19 @@ async def test_streamable_http_security_invalid_content_type(server_port: int): # Test POST with missing content type response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", headers={"Accept": "application/json, text/event-stream"}, content="test", ) assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): +async def test_streamable_http_security_disabled(): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: + with run_server_with_settings(settings) as server_url: # Test with invalid host header - should still work headers = { "Host": "evil.com", @@ -210,29 +168,23 @@ async def test_streamable_http_security_disabled(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): +async def test_streamable_http_security_custom_allowed_hosts(): """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_with_settings(settings) as server_url: # Test with custom allowed host headers = { "Host": "custom.host", @@ -242,24 +194,19 @@ async def test_streamable_http_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( - f"http://127.0.0.1:{server_port}/", + f"{server_url}/", json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, headers=headers, ) # Should connect successfully with custom host assert response.status_code == 200 - finally: - process.terminate() - process.join() @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): +async def test_streamable_http_security_get_request(): """Test StreamableHTTP GET request with security.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: + with run_server_with_settings(security_settings) as server_url: # Test GET request with invalid host header headers = { "Host": "evil.com", @@ -267,7 +214,7 @@ async def test_streamable_http_security_get_request(server_port: int): } async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + response = await client.get(f"{server_url}/", headers=headers) assert response.status_code == 421 assert response.text == "Invalid Host header" @@ -280,12 +227,8 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: # GET requests need a session ID in StreamableHTTP # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + response = await client.get(f"{server_url}/", headers=headers) # This should pass security but fail on session validation assert response.status_code == 400 body = response.json() assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707b..48675fd8be 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,6 +1,4 @@ import json -import multiprocessing -import socket from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -9,7 +7,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -41,23 +38,11 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_SSE" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - async def _handle_read_resource( # pragma: no cover ctx: ServerRequestContext, params: ReadResourceRequestParams ) -> ReadResourceResult: @@ -93,7 +78,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) -def _create_server() -> Server: # pragma: no cover +def _create_server() -> Server: return Server( SERVER_NAME, on_read_resource=_handle_read_resource, @@ -127,31 +112,15 @@ async def handle_sse(request: Request) -> Response: return app -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() +def server() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as url: + yield url - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +@pytest.fixture +def server_url(server: str) -> str: + return server @pytest.fixture() @@ -297,37 +266,21 @@ async def test_sse_client_timeout( # pragma: no cover pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: # pragma: no cover +def make_mounted_server_app() -> Starlette: app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() + return main_app @pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +def mounted_server() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_mounted_server_app()) as url: + yield url @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: +async def test_sse_client_basic_connection_mounted_app(mounted_server: str) -> None: + async with sse_client(mounted_server + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -381,7 +334,7 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover +def make_context_server_app() -> Starlette: # pragma: no cover """Run a server that captures request context""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( @@ -406,33 +359,18 @@ async def handle_sse(request: Request) -> Response: ] ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() + return app @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: +def context_server() -> Generator[str, None, None]: """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") + with run_uvicorn_in_thread(make_context_server_app()) as url: + yield url @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(context_server: str) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -441,7 +379,7 @@ async def test_request_context_propagation(context_server: None, server_url: str "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( + async with sse_client(context_server + "/sse", headers=custom_headers) as ( read_stream, write_stream, ): @@ -465,7 +403,7 @@ async def test_request_context_propagation(context_server: None, server_url: str @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(context_server: str) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] @@ -473,7 +411,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( + async with sse_client(context_server + "/sse", headers=headers) as ( read_stream, write_stream, ): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb61..256e2b3278 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,10 +6,8 @@ from __future__ import annotations as _annotations import json -import multiprocessing import socket import time -import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -21,7 +19,6 @@ import httpx import pytest import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -66,7 +63,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread # Test constants SERVER_NAME = "test_streamable_http_server" @@ -145,7 +142,7 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() @@ -383,7 +380,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -432,74 +429,16 @@ def create_app( return app -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. - """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", +@pytest.fixture +def basic_server() -> Generator[str, None, None]: + """Start a basic server.""" + with run_uvicorn_in_thread( + create_app(), limit_concurrency=10, timeout_keep_alive=5, access_log=False, - ) - - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests -@pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) + ) as url: + yield url @pytest.fixture @@ -509,65 +448,50 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) + with run_uvicorn_in_thread( + create_app(event_store=event_store, retry_interval=500), + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) as url: + yield event_store, url @pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: +def json_response_server() -> Generator[str, None, None]: """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) + with run_uvicorn_in_thread( + create_app(is_json_response_enabled=True), + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) as url: + yield url @pytest.fixture -def basic_server_url(basic_server_port: int) -> str: +def basic_server_url(basic_server: str) -> str: """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" + return basic_server @pytest.fixture -def json_server_url(json_server_port: int) -> str: +def json_server_url(json_response_server: str) -> str: """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" + return json_response_server + + +def test_basic_server_url_uses_reserved_port_regression_2704(basic_server_url: str) -> None: + """The server URL must come from the socket uvicorn is already serving.""" + parsed = urlparse(basic_server_url) + assert parsed.hostname is not None + assert parsed.port is not None + + with socket.socket() as sock: + with pytest.raises(OSError): + sock.bind((parsed.hostname, parsed.port)) # Basic request validation tests @@ -1517,8 +1441,7 @@ async def _handle_context_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover +def make_context_aware_app() -> Starlette: # pragma: no cover """Run the context-aware test server.""" server = Server( "ContextAwareServer", @@ -1540,36 +1463,18 @@ def run_context_aware_server(port: int): # pragma: no cover lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() + return app @pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: +def context_aware_server() -> Generator[str, None, None]: """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") + with run_uvicorn_in_thread(make_context_aware_app()) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1578,7 +1483,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1602,7 +1507,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1615,7 +1520,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1640,9 +1545,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_aware_server: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -2255,7 +2160,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str + context_aware_server: str, ) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header @@ -2265,7 +2170,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2286,7 +2194,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str + context_aware_server: str, ) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { @@ -2296,7 +2204,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 810c72820b..42aca1a9dd 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -2,9 +2,8 @@ import socket import threading -import time from collections.abc import Generator -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import Any import uvicorn @@ -56,30 +55,5 @@ def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None finally: server.should_exit = True thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) - - -def wait_for_server(port: int, timeout: float = 20.0) -> None: - """Wait for server to be ready to accept connections. - - Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. - - Args: - port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) - - Raises: - TimeoutError: If server doesn't start within the timeout period - """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(0.1) - s.connect(("127.0.0.1", port)) - # Server is ready - return - except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly - time.sleep(0.01) - raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover + with suppress(OSError): + sock.close()