Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import AsyncLock
from .._trace import Trace
from .http11 import AsyncHTTP11Connection
Expand Down Expand Up @@ -148,8 +149,9 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
sni_hostname or self._origin.host.decode("ascii")
),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
5 changes: 4 additions & 1 deletion httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
enforce_url,
)
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import AsyncLock
from .._trace import Trace
from .connection import AsyncHTTPConnection
Expand Down Expand Up @@ -309,7 +310,9 @@ async def handle_async_request(self, request: Request) -> Response:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
self._remote_origin.host.decode("ascii")
),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
6 changes: 4 additions & 2 deletions httpcore/_async/socks_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .._exceptions import ConnectionNotAvailable, ProxyError
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import AsyncLock
from .._trace import Trace
from .connection_pool import AsyncConnectionPool
Expand Down Expand Up @@ -258,8 +259,9 @@ async def handle_async_request(self, request: Request) -> Response:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._remote_origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
sni_hostname or self._remote_origin.host.decode("ascii")
),
"timeout": timeout,
}
async with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
4 changes: 2 additions & 2 deletions httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .._utils import is_socket_readable, normalize_tls_server_hostname
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream


Expand Down Expand Up @@ -70,7 +70,7 @@ async def start_tls(
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=server_hostname,
hostname=normalize_tls_server_hostname(server_hostname),
standard_compatible=False,
server_side=False,
)
Expand Down
9 changes: 6 additions & 3 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .._utils import is_socket_readable, normalize_tls_server_hostname
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream


Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
self.ssl_obj = ssl_context.wrap_bio(
incoming=self._incoming,
outgoing=self._outgoing,
server_hostname=server_hostname,
server_hostname=normalize_tls_server_hostname(server_hostname),
)

self._sock.settimeout(timeout)
Expand Down Expand Up @@ -163,7 +163,10 @@ def start_tls(
else:
self._sock.settimeout(timeout)
sock = ssl_context.wrap_socket(
self._sock, server_hostname=server_hostname
self._sock,
server_hostname=normalize_tls_server_hostname(
server_hostname
),
)
except Exception as exc: # pragma: nocover
self.close()
Expand Down
3 changes: 2 additions & 1 deletion httpcore/_backends/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
WriteTimeout,
map_exceptions,
)
from .._utils import normalize_tls_server_hostname
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream


Expand Down Expand Up @@ -65,7 +66,7 @@ async def start_tls(
ssl_stream = trio.SSLStream(
self._stream,
ssl_context=ssl_context,
server_hostname=server_hostname,
server_hostname=normalize_tls_server_hostname(server_hostname),
https_compatible=True,
server_side=False,
)
Expand Down
9 changes: 9 additions & 0 deletions httpcore/_ssl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import ssl

import certifi
Expand All @@ -7,3 +9,10 @@ def default_ssl_context() -> ssl.SSLContext:
context = ssl.create_default_context()
context.load_verify_locations(certifi.where())
return context


def normalize_server_hostname(server_hostname: str | None) -> str | None:
if server_hostname is None:
return None

return server_hostname.rstrip(".")
6 changes: 4 additions & 2 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .._exceptions import ConnectError, ConnectTimeout
from .._models import Origin, Request, Response
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import Lock
from .._trace import Trace
from .http11 import HTTP11Connection
Expand Down Expand Up @@ -148,8 +149,9 @@ def _connect(self, request: Request) -> NetworkStream:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
sni_hostname or self._origin.host.decode("ascii")
),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
5 changes: 4 additions & 1 deletion httpcore/_sync/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
enforce_url,
)
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import Lock
from .._trace import Trace
from .connection import HTTPConnection
Expand Down Expand Up @@ -309,7 +310,9 @@ def handle_request(self, request: Request) -> Response:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
self._remote_origin.host.decode("ascii")
),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
6 changes: 4 additions & 2 deletions httpcore/_sync/socks_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .._exceptions import ConnectionNotAvailable, ProxyError
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
from .._ssl import default_ssl_context
from .._ssl import normalize_server_hostname
from .._synchronization import Lock
from .._trace import Trace
from .connection_pool import ConnectionPool
Expand Down Expand Up @@ -258,8 +259,9 @@ def handle_request(self, request: Request) -> Response:

kwargs = {
"ssl_context": ssl_context,
"server_hostname": sni_hostname
or self._remote_origin.host.decode("ascii"),
"server_hostname": normalize_server_hostname(
sni_hostname or self._remote_origin.host.decode("ascii")
),
"timeout": timeout,
}
with Trace("start_tls", logger, request, kwargs) as trace:
Expand Down
6 changes: 6 additions & 0 deletions httpcore/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@ def is_socket_readable(sock: socket.socket | None) -> bool:
p = select.poll()
p.register(sock_fd, select.POLLIN)
return bool(p.poll(0))


def normalize_tls_server_hostname(server_hostname: str | None) -> str | None:
if server_hostname is None:
return None
return server_hostname.rstrip(".")
43 changes: 43 additions & 0 deletions tests/_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,49 @@ async def test_request_to_incorrect_origin():
await conn.request("GET", "https://other.com/")


@pytest.mark.anyio
async def test_connection_strips_trailing_dot_from_sni_hostname():
class CaptureHostnameBackend(AsyncMockBackend):
def __init__(self, buffer: typing.List[bytes]) -> None:
super().__init__(buffer)
self.server_hostname: typing.Optional[str] = None

async def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> AsyncMockStream:
backend = self

class CaptureHostnameStream(AsyncMockStream):
async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> AsyncNetworkStream:
backend.server_hostname = server_hostname
return await super().start_tls(
ssl_context, server_hostname, timeout
)

return CaptureHostnameStream(list(self._buffer), http2=self._http2)

origin = Origin(b"https", b"example.com.", 443)
network_backend = CaptureHostnameBackend(
[b"HTTP/1.1 200 OK\r\n", b"Content-Length: 0\r\n", b"\r\n"]
)

async with AsyncHTTPConnection(origin=origin, network_backend=network_backend) as conn:
response = await conn.request("GET", "https://example.com./")
assert response.status == 200

assert network_backend.server_hostname == "example.com"


class NeedsRetryBackend(AsyncMockBackend):
def __init__(
self,
Expand Down
90 changes: 90 additions & 0 deletions tests/_async/test_tls_hostname.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

import ssl

import anyio
import pytest
import trio

from httpcore._backends.anyio import AnyIOStream
from httpcore._backends.trio import TrioStream


class DummyAsyncStream:
async def aclose(self) -> None:
pass


@pytest.mark.anyio
async def test_anyio_stream_start_tls_strips_trailing_dot_from_server_hostname(
monkeypatch: pytest.MonkeyPatch,
) -> None:
recorded: dict[str, str | None] = {}

class DummyTLSStream:
async def aclose(self) -> None:
pass

def extra(self, attribute: object, default: object = None) -> object:
return default

async def wrap(
stream: DummyAsyncStream,
*,
ssl_context: ssl.SSLContext,
hostname: str | None,
standard_compatible: bool,
server_side: bool,
) -> DummyTLSStream:
recorded["server_hostname"] = hostname
return DummyTLSStream()

monkeypatch.setattr(anyio.streams.tls.TLSStream, "wrap", wrap)

stream = AnyIOStream(DummyAsyncStream())
wrapped = await stream.start_tls(
ssl.create_default_context(),
server_hostname="localhost.",
timeout=1,
)

assert recorded["server_hostname"] == "localhost"
assert isinstance(wrapped, AnyIOStream)


@pytest.mark.trio
async def test_trio_stream_start_tls_strips_trailing_dot_from_server_hostname(
monkeypatch: pytest.MonkeyPatch,
) -> None:
recorded: dict[str, str | None] = {}

class DummySSLStream:
def __init__(
self,
stream: DummyAsyncStream,
*,
ssl_context: ssl.SSLContext,
server_hostname: str | None,
https_compatible: bool,
server_side: bool,
) -> None:
self.transport_stream = stream
recorded["server_hostname"] = server_hostname

async def do_handshake(self) -> None:
return None

async def aclose(self) -> None:
pass

monkeypatch.setattr(trio, "SSLStream", DummySSLStream)

stream = TrioStream(DummyAsyncStream())
wrapped = await stream.start_tls(
ssl.create_default_context(),
server_hostname="localhost.",
timeout=1,
)

assert recorded["server_hostname"] == "localhost"
assert isinstance(wrapped, TrioStream)
40 changes: 40 additions & 0 deletions tests/_sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,46 @@ def test_request_to_incorrect_origin():
conn.request("GET", "https://other.com/")


def test_connection_strips_trailing_dot_from_sni_hostname():
class CaptureHostnameBackend(MockBackend):
def __init__(self, buffer: typing.List[bytes]) -> None:
super().__init__(buffer)
self.server_hostname: typing.Optional[str] = None

def connect_tcp(
self,
host: str,
port: int,
timeout: typing.Optional[float] = None,
local_address: typing.Optional[str] = None,
socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
) -> MockStream:
backend = self

class CaptureHostnameStream(MockStream):
def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
) -> NetworkStream:
backend.server_hostname = server_hostname
return super().start_tls(ssl_context, server_hostname, timeout)

return CaptureHostnameStream(list(self._buffer), http2=self._http2)

origin = Origin(b"https", b"example.com.", 443)
network_backend = CaptureHostnameBackend(
[b"HTTP/1.1 200 OK\r\n", b"Content-Length: 0\r\n", b"\r\n"]
)

with HTTPConnection(origin=origin, network_backend=network_backend) as conn:
response = conn.request("GET", "https://example.com./")
assert response.status == 200

assert network_backend.server_hostname == "example.com"


class NeedsRetryBackend(MockBackend):
def __init__(
self,
Expand Down
Loading
Loading