Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 162 additions & 69 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import weakref
from collections import defaultdict
from collections.abc import AsyncGenerator, Collection, Coroutine, Mapping, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -65,6 +64,7 @@
from pymongo.asynchronous.helpers import (
_RetryPolicy,
)
from pymongo.asynchronous.pool import _PoolCheckout
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
Expand Down Expand Up @@ -1785,39 +1785,8 @@ async def _get_topology(self) -> Topology:
self._opened = True
return self._topology

@contextlib.asynccontextmanager
async def _checkout(
self, server: Server, session: Optional[AsyncClientSession]
) -> AsyncGenerator[AsyncConnection, None]:
in_txn = session and session.in_transaction
async with _MongoClientErrorHandler(self, server, session) as err_handler:
# Reuse the pinned connection, if it exists.
if in_txn and session and session._pinned_connection:
err_handler.contribute_socket(session._pinned_connection)
yield session._pinned_connection
return
async with await server.checkout(handler=err_handler) as conn:
# Pin this session to the selected server or connection.
if (
in_txn
and session
and server.description.server_type
in (
SERVER_TYPE.Mongos,
SERVER_TYPE.LoadBalancer,
)
):
session._pin(server, conn)
err_handler.contribute_socket(conn)
if (
self._encrypter
and not self._encrypter._bypass_auto_encryption
and conn.max_wire_version < 8
):
raise ConfigurationError(
"Auto-encryption requires a minimum MongoDB version of 4.2"
)
yield conn
def _checkout(self, server: Server, session: Optional[AsyncClientSession]) -> _ClientCheckout:
return _ClientCheckout(self, server, session)

async def _select_server(
self,
Expand Down Expand Up @@ -1868,41 +1837,22 @@ async def _select_server(

async def _conn_for_writes(
self, session: Optional[AsyncClientSession], operation: str
) -> AbstractAsyncContextManager[AsyncConnection]:
) -> _ClientCheckout:
server = await self._select_server(writable_server_selector, session, operation)
return self._checkout(server, session)

@contextlib.asynccontextmanager
async def _conn_from_server(
def _conn_from_server(
self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession]
) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]:
) -> _ClientReadCheckout:
assert read_preference is not None, "read_preference must not be None"
# Get a connection for a server matching the read preference, and yield
# conn with the effective read preference. The Server Selection
# Spec says not to send any $readPreference to standalones and to
# always send primaryPreferred when directly connected to a repl set
# member.
# Thread safe: if the type is single it cannot change.
# NOTE: We already opened the Topology when selecting a server so there's no need
# to call _get_topology() again.
single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single
async with self._checkout(server, session) as conn:
if single:
if conn.is_repl and not (session and session.in_transaction):
# Use primary preferred to ensure any repl set member
# can handle the request.
read_preference = ReadPreference.PRIMARY_PREFERRED
elif conn.is_standalone:
# Don't send read preference to standalones.
read_preference = ReadPreference.PRIMARY
yield conn, read_preference
return _ClientReadCheckout(self, server, session, read_preference)

async def _conn_for_reads(
self,
read_preference: _ServerMode,
session: Optional[AsyncClientSession],
operation: str,
) -> AbstractAsyncContextManager[tuple[AsyncConnection, _ServerMode]]:
) -> _ClientReadCheckout:
assert read_preference is not None, "read_preference must not be None"
server = await self._select_server(read_preference, session, operation)
return self._conn_from_server(read_preference, server, session)
Expand Down Expand Up @@ -1930,8 +1880,12 @@ async def _run_operation(
)

async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
async with _ClientCheckout.for_existing_conn(
self,
server,
operation.session, # type: ignore[arg-type]
operation.conn_mgr.conn,
):
return await server.run_operation(
operation.conn_mgr.conn,
operation,
Expand Down Expand Up @@ -2672,10 +2626,17 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong
exc_to_check._add_error_label("RetryableWriteError")


class _MongoClientErrorHandler:
"""Handle errors raised when executing an operation."""
class _ClientCheckout:
"""Context manager for checking out a connection from the pool.

Handles pool checkout, SDAM error handling, and session pinning in a
single class-based CM to eliminate generator overhead on the hot path.
"""

__slots__ = (
"_existing_conn",
"_pool_checkout",
"_server",
"client",
"completed_handshake",
"handled",
Expand Down Expand Up @@ -2709,9 +2670,12 @@ def __init__(
self.completed_handshake = False
self.service_id: Optional[ObjectId] = None
self.handled = False
self._existing_conn: Optional[AsyncConnection] = None
self._pool_checkout: Optional[_PoolCheckout] = None
self._server = server

def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None:
"""Provide socket information to the error handler."""
"""Record connection metadata needed for SDAM error handling."""
self.max_wire_version = conn.max_wire_version
self.sock_generation = conn.generation
self.service_id = conn.service_id
Expand Down Expand Up @@ -2746,21 +2710,148 @@ async def handle(
assert self.client._topology is not None
await self.client._topology.handle_error(self.server_address, err_ctx)

async def __aenter__(self) -> _MongoClientErrorHandler:
return self
async def __aenter__(self) -> AsyncConnection:
if self._existing_conn is not None:
return self._existing_conn
server = self._server
session = self.session
in_txn = session and session.in_transaction
# Reuse the pinned connection, if it exists.
if in_txn and session and session._pinned_connection:
self.contribute_socket(session._pinned_connection)
return session._pinned_connection
pool_checkout = server.pool.checkout(self)
try:
conn = await pool_checkout.__aenter__()
except BaseException as exc:
# __aenter__ raised — pool already cleaned up internally.
# Run SDAM error handling so the topology learns about the failure.
await self.handle(type(exc), exc)
raise
self._pool_checkout = pool_checkout
try:
# Pin this session to the selected server or connection.
if (
in_txn
and session
and server.description.server_type
in (
SERVER_TYPE.Mongos,
SERVER_TYPE.LoadBalancer,
)
):
session._pin(server, conn)
self.contribute_socket(conn)
if (
self.client._encrypter
and not self.client._encrypter._bypass_auto_encryption
and conn.max_wire_version < 8
):
raise ConfigurationError(
"Auto-encryption requires a minimum MongoDB version of 4.2"
)
except BaseException as exc:
try:
await self.handle(type(exc), exc)
finally:
await pool_checkout.__aexit__(type(exc), exc, exc.__traceback__)
self._pool_checkout = None
raise
return conn

async def __aexit__(
self,
exc_type: Optional[type[Exception]],
exc_val: Optional[Exception],
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
return await self.handle(exc_type, exc_val)
# Perform SDAM error handling while the connection is still checked out.
try:
await self.handle(exc_type, exc_val)
finally:
if self._pool_checkout is not None:
await self._pool_checkout.__aexit__(exc_type, exc_val, exc_tb)

@classmethod
def for_existing_conn(
cls,
client: AsyncMongoClient, # type: ignore[type-arg]
server: Server,
session: Optional[AsyncClientSession],
conn: AsyncConnection,
) -> _ClientCheckout:
"""Return a _ClientCheckout for an already-checked-out connection.

Used when SDAM error handling is needed around an existing connection
without performing a new pool checkout (e.g. re-running a getMore).
"""
checkout = cls(client, server, session)
checkout.contribute_socket(conn)
checkout._existing_conn = conn
return checkout


class _ClientReadCheckout(_ClientCheckout):
"""Context manager for read operations.

Extends _ClientCheckout to apply the single-topology read preference
adjustment and return the effective read preference alongside the connection.
"""

__slots__ = ("_effective_read_pref",)

def __init__(
self,
client: AsyncMongoClient, # type: ignore[type-arg]
server: Server,
session: Optional[AsyncClientSession],
read_preference: _ServerMode,
) -> None:
super().__init__(client, server, session)
self._effective_read_pref: _ServerMode = read_preference

async def __aenter__(self) -> tuple[AsyncConnection, _ServerMode]: # type: ignore[override]
conn = await super().__aenter__()
# The Server Selection Spec says not to send any $readPreference to
# standalones and to always send primaryPreferred when directly
# connected to a replica set member.
# Thread safe: topology type cannot change once set to Single.
single = self.client._topology.description.topology_type == TOPOLOGY_TYPE.Single
if single:
if conn.is_repl and not (self.session and self.session.in_transaction):
self._effective_read_pref = ReadPreference.PRIMARY_PREFERRED
elif conn.is_standalone:
self._effective_read_pref = ReadPreference.PRIMARY
return conn, self._effective_read_pref


class _ClientConnectionRetryable(Generic[T]):
"""Responsible for executing retryable connections on read or write operations"""

__slots__ = (
"_address",
"_always_retryable",
"_attempt_number",
"_bulk",
"_client",
"_deprioritized_servers",
"_func",
"_is_aggregate_write",
"_is_read",
"_is_run_command",
"_last_error",
"_max_retries",
"_operation",
"_operation_id",
"_read_pref",
"_retry_policy",
"_retryable",
"_retrying",
"_server",
"_server_selector",
"_session",
)

def __init__(
self,
mongo_client: AsyncMongoClient, # type: ignore[type-arg]
Expand Down Expand Up @@ -2793,7 +2884,7 @@ def __init__(
)
self._address = address
self._server: Server = None # type: ignore
self._deprioritized_servers: list[Server] = []
self._deprioritized_servers: Optional[list[Server]] = None
self._operation = operation
self._operation_id = operation_id
self._attempt_number = 0
Expand Down Expand Up @@ -2933,6 +3024,8 @@ async def run(self) -> T:
self._client.topology_description.topology_type_name == "Sharded"
or (overloaded and self._client.options.enable_overload_retargeting)
):
if self._deprioritized_servers is None:
self._deprioritized_servers = []
self._deprioritized_servers.append(self._server)

self._always_retryable = always_retryable
Expand Down
Loading
Loading