diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index b42581df..d6ea34df 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -11,6 +11,7 @@ from .._exceptions import ConnectError, ConnectTimeout from .._models import Origin, Request, Response from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import AsyncLock from .._trace import Trace from .http11 import AsyncHTTP11Connection @@ -148,8 +149,9 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: kwargs = { "ssl_context": ssl_context, - "server_hostname": sni_hostname - or self._origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + sni_hostname or self._origin.host.decode("ascii") + ), "timeout": timeout, } async with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d9206..798ee01d 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -17,6 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import AsyncLock from .._trace import Trace from .connection import AsyncHTTPConnection @@ -309,7 +310,9 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + self._remote_origin.host.decode("ascii") + ), "timeout": timeout, } async with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index b363f55a..6c0cc2e7 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -10,6 +10,7 @@ from .._exceptions import ConnectionNotAvailable, ProxyError from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import AsyncLock from .._trace import Trace from .connection_pool import AsyncConnectionPool @@ -258,8 +259,9 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": sni_hostname - or self._remote_origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + sni_hostname or self._remote_origin.host.decode("ascii") + ), "timeout": timeout, } async with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index a140095e..f148eff1 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -14,7 +14,7 @@ WriteTimeout, map_exceptions, ) -from .._utils import is_socket_readable +from .._utils import is_socket_readable, normalize_tls_server_hostname from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -70,7 +70,7 @@ async def start_tls( ssl_stream = await anyio.streams.tls.TLSStream.wrap( self._stream, ssl_context=ssl_context, - hostname=server_hostname, + hostname=normalize_tls_server_hostname(server_hostname), standard_compatible=False, server_side=False, ) diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 4018a09c..61966168 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -16,7 +16,7 @@ WriteTimeout, map_exceptions, ) -from .._utils import is_socket_readable +from .._utils import is_socket_readable, normalize_tls_server_hostname from .base import SOCKET_OPTION, NetworkBackend, NetworkStream @@ -45,7 +45,7 @@ def __init__( self.ssl_obj = ssl_context.wrap_bio( incoming=self._incoming, outgoing=self._outgoing, - server_hostname=server_hostname, + server_hostname=normalize_tls_server_hostname(server_hostname), ) self._sock.settimeout(timeout) @@ -163,7 +163,10 @@ def start_tls( else: self._sock.settimeout(timeout) sock = ssl_context.wrap_socket( - self._sock, server_hostname=server_hostname + self._sock, + server_hostname=normalize_tls_server_hostname( + server_hostname + ), ) except Exception as exc: # pragma: nocover self.close() diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index 6f53f5f2..3c013475 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -15,6 +15,7 @@ WriteTimeout, map_exceptions, ) +from .._utils import normalize_tls_server_hostname from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -65,7 +66,7 @@ async def start_tls( ssl_stream = trio.SSLStream( self._stream, ssl_context=ssl_context, - server_hostname=server_hostname, + server_hostname=normalize_tls_server_hostname(server_hostname), https_compatible=True, server_side=False, ) diff --git a/httpcore/_ssl.py b/httpcore/_ssl.py index c99c5a67..d4bc3990 100644 --- a/httpcore/_ssl.py +++ b/httpcore/_ssl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ssl import certifi @@ -7,3 +9,10 @@ def default_ssl_context() -> ssl.SSLContext: context = ssl.create_default_context() context.load_verify_locations(certifi.where()) return context + + +def normalize_server_hostname(server_hostname: str | None) -> str | None: + if server_hostname is None: + return None + + return server_hostname.rstrip(".") diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 363f8be8..797ec54a 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -11,6 +11,7 @@ from .._exceptions import ConnectError, ConnectTimeout from .._models import Origin, Request, Response from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import Lock from .._trace import Trace from .http11 import HTTP11Connection @@ -148,8 +149,9 @@ def _connect(self, request: Request) -> NetworkStream: kwargs = { "ssl_context": ssl_context, - "server_hostname": sni_hostname - or self._origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + sni_hostname or self._origin.host.decode("ascii") + ), "timeout": timeout, } with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7..31d505b4 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -17,6 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import Lock from .._trace import Trace from .connection import HTTPConnection @@ -309,7 +310,9 @@ def handle_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + self._remote_origin.host.decode("ascii") + ), "timeout": timeout, } with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 0ca96ddf..c3b3e464 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -10,6 +10,7 @@ from .._exceptions import ConnectionNotAvailable, ProxyError from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url from .._ssl import default_ssl_context +from .._ssl import normalize_server_hostname from .._synchronization import Lock from .._trace import Trace from .connection_pool import ConnectionPool @@ -258,8 +259,9 @@ def handle_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, - "server_hostname": sni_hostname - or self._remote_origin.host.decode("ascii"), + "server_hostname": normalize_server_hostname( + sni_hostname or self._remote_origin.host.decode("ascii") + ), "timeout": timeout, } with Trace("start_tls", logger, request, kwargs) as trace: diff --git a/httpcore/_utils.py b/httpcore/_utils.py index c44ff93c..035ea93c 100644 --- a/httpcore/_utils.py +++ b/httpcore/_utils.py @@ -35,3 +35,9 @@ def is_socket_readable(sock: socket.socket | None) -> bool: p = select.poll() p.register(sock_fd, select.POLLIN) return bool(p.poll(0)) + + +def normalize_tls_server_hostname(server_hostname: str | None) -> str | None: + if server_hostname is None: + return None + return server_hostname.rstrip(".") diff --git a/tests/_async/test_connection.py b/tests/_async/test_connection.py index b6ee0c7e..fb7d4837 100644 --- a/tests/_async/test_connection.py +++ b/tests/_async/test_connection.py @@ -235,6 +235,49 @@ async def test_request_to_incorrect_origin(): await conn.request("GET", "https://other.com/") +@pytest.mark.anyio +async def test_connection_strips_trailing_dot_from_sni_hostname(): + class CaptureHostnameBackend(AsyncMockBackend): + def __init__(self, buffer: typing.List[bytes]) -> None: + super().__init__(buffer) + self.server_hostname: typing.Optional[str] = None + + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncMockStream: + backend = self + + class CaptureHostnameStream(AsyncMockStream): + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + backend.server_hostname = server_hostname + return await super().start_tls( + ssl_context, server_hostname, timeout + ) + + return CaptureHostnameStream(list(self._buffer), http2=self._http2) + + origin = Origin(b"https", b"example.com.", 443) + network_backend = CaptureHostnameBackend( + [b"HTTP/1.1 200 OK\r\n", b"Content-Length: 0\r\n", b"\r\n"] + ) + + async with AsyncHTTPConnection(origin=origin, network_backend=network_backend) as conn: + response = await conn.request("GET", "https://example.com./") + assert response.status == 200 + + assert network_backend.server_hostname == "example.com" + + class NeedsRetryBackend(AsyncMockBackend): def __init__( self, diff --git a/tests/_async/test_tls_hostname.py b/tests/_async/test_tls_hostname.py new file mode 100644 index 00000000..9d051a96 --- /dev/null +++ b/tests/_async/test_tls_hostname.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import ssl + +import anyio +import pytest +import trio + +from httpcore._backends.anyio import AnyIOStream +from httpcore._backends.trio import TrioStream + + +class DummyAsyncStream: + async def aclose(self) -> None: + pass + + +@pytest.mark.anyio +async def test_anyio_stream_start_tls_strips_trailing_dot_from_server_hostname( + monkeypatch: pytest.MonkeyPatch, +) -> None: + recorded: dict[str, str | None] = {} + + class DummyTLSStream: + async def aclose(self) -> None: + pass + + def extra(self, attribute: object, default: object = None) -> object: + return default + + async def wrap( + stream: DummyAsyncStream, + *, + ssl_context: ssl.SSLContext, + hostname: str | None, + standard_compatible: bool, + server_side: bool, + ) -> DummyTLSStream: + recorded["server_hostname"] = hostname + return DummyTLSStream() + + monkeypatch.setattr(anyio.streams.tls.TLSStream, "wrap", wrap) + + stream = AnyIOStream(DummyAsyncStream()) + wrapped = await stream.start_tls( + ssl.create_default_context(), + server_hostname="localhost.", + timeout=1, + ) + + assert recorded["server_hostname"] == "localhost" + assert isinstance(wrapped, AnyIOStream) + + +@pytest.mark.trio +async def test_trio_stream_start_tls_strips_trailing_dot_from_server_hostname( + monkeypatch: pytest.MonkeyPatch, +) -> None: + recorded: dict[str, str | None] = {} + + class DummySSLStream: + def __init__( + self, + stream: DummyAsyncStream, + *, + ssl_context: ssl.SSLContext, + server_hostname: str | None, + https_compatible: bool, + server_side: bool, + ) -> None: + self.transport_stream = stream + recorded["server_hostname"] = server_hostname + + async def do_handshake(self) -> None: + return None + + async def aclose(self) -> None: + pass + + monkeypatch.setattr(trio, "SSLStream", DummySSLStream) + + stream = TrioStream(DummyAsyncStream()) + wrapped = await stream.start_tls( + ssl.create_default_context(), + server_hostname="localhost.", + timeout=1, + ) + + assert recorded["server_hostname"] == "localhost" + assert isinstance(wrapped, TrioStream) diff --git a/tests/_sync/test_connection.py b/tests/_sync/test_connection.py index 37c82e02..fdb557c2 100644 --- a/tests/_sync/test_connection.py +++ b/tests/_sync/test_connection.py @@ -235,6 +235,46 @@ def test_request_to_incorrect_origin(): conn.request("GET", "https://other.com/") +def test_connection_strips_trailing_dot_from_sni_hostname(): + class CaptureHostnameBackend(MockBackend): + def __init__(self, buffer: typing.List[bytes]) -> None: + super().__init__(buffer) + self.server_hostname: typing.Optional[str] = None + + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> MockStream: + backend = self + + class CaptureHostnameStream(MockStream): + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> NetworkStream: + backend.server_hostname = server_hostname + return super().start_tls(ssl_context, server_hostname, timeout) + + return CaptureHostnameStream(list(self._buffer), http2=self._http2) + + origin = Origin(b"https", b"example.com.", 443) + network_backend = CaptureHostnameBackend( + [b"HTTP/1.1 200 OK\r\n", b"Content-Length: 0\r\n", b"\r\n"] + ) + + with HTTPConnection(origin=origin, network_backend=network_backend) as conn: + response = conn.request("GET", "https://example.com./") + assert response.status == 200 + + assert network_backend.server_hostname == "example.com" + + class NeedsRetryBackend(MockBackend): def __init__( self, diff --git a/tests/_sync/test_tls_hostname.py b/tests/_sync/test_tls_hostname.py new file mode 100644 index 00000000..de3988f6 --- /dev/null +++ b/tests/_sync/test_tls_hostname.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import ssl + +from httpcore._backends.sync import SyncStream, TLSinTLSStream + + +class DummySocket: + def settimeout(self, timeout: float | None) -> None: + self.timeout = timeout + + def sendall(self, data: bytes) -> None: + self.data = data + + def close(self) -> None: + pass + + +def test_sync_stream_start_tls_strips_trailing_dot_from_server_hostname() -> None: + recorded: dict[str, str | None] = {} + + class DummyContext: + def wrap_socket(self, sock: DummySocket, server_hostname: str | None = None) -> DummySocket: + recorded["server_hostname"] = server_hostname + return sock + + stream = SyncStream(DummySocket()) + + wrapped = stream.start_tls( + DummyContext(), + server_hostname="localhost.", + timeout=1, + ) + + assert recorded["server_hostname"] == "localhost" + assert isinstance(wrapped, SyncStream) + + +def test_tls_in_tls_stream_strips_trailing_dot_from_server_hostname() -> None: + recorded: dict[str, str | None] = {} + + class DummySSLObject: + def do_handshake(self) -> None: + return None + + class DummyContext: + def wrap_bio( + self, + incoming: ssl.MemoryBIO, + outgoing: ssl.MemoryBIO, + server_hostname: str | None = None, + ) -> DummySSLObject: + recorded["server_hostname"] = server_hostname + return DummySSLObject() + + TLSinTLSStream( + DummySocket(), + DummyContext(), + server_hostname="localhost.", + timeout=1, + ) + + assert recorded["server_hostname"] == "localhost"