Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from urllib.parse import urlparse

import anyio
import httpx
Expand Down Expand Up @@ -72,13 +73,15 @@ class RequestContext:
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

def __init__(self, url: str) -> None:
def __init__(self, url: str, default_origin: str | None = None) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
default_origin: Optional Origin header value to include with requests.
"""
self.url = url
self.default_origin = default_origin
self.session_id: str | None = None
self.protocol_version: str | None = None

Expand All @@ -92,6 +95,8 @@ def _prepare_headers(self) -> dict[str, str]:
"accept": "application/json, text/event-stream",
"content-type": "application/json",
}
if self.default_origin:
headers["origin"] = self.default_origin
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
Expand Down Expand Up @@ -547,7 +552,13 @@ async def streamable_http_client(
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()

transport = StreamableHTTPTransport(url)
default_origin = None
if "origin" not in client.headers:
parsed_url = urlparse(url)
if parsed_url.scheme and parsed_url.netloc:
default_origin = f"{parsed_url.scheme}://{parsed_url.netloc}"

transport = StreamableHTTPTransport(url, default_origin=default_origin)

logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

Expand Down
15 changes: 15 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def extract_protocol_version_from_sse(response: requests.Response) -> str:
raise ValueError("Could not extract protocol version from SSE response") # pragma: no cover


def test_streamable_http_transport_includes_default_origin_header():
transport = StreamableHTTPTransport(
"https://example.com/mcp",
default_origin="https://example.com",
)

assert transport._prepare_headers()["origin"] == "https://example.com"


def test_streamable_http_transport_omits_origin_header_without_default_origin():
transport = StreamableHTTPTransport("https://example.com/mcp")

assert "origin" not in transport._prepare_headers()


# Simple in-memory event store for testing
class SimpleEventStore(EventStore):
"""Simple in-memory event store for testing."""
Expand Down
Loading