Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import functools
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, MutableMapping
from collections.abc import Mapping, Sequence
from typing import Any, Callable, MutableMapping, cast

from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
Expand All @@ -27,6 +28,33 @@
logger = logging.getLogger(__name__)


def _extract_result(operation: Operation) -> str | None:
if operation.step_details and operation.step_details.result is not None:
return operation.step_details.result
if operation.callback_details and operation.callback_details.result is not None:
return operation.callback_details.result
if (
operation.chained_invoke_details
and operation.chained_invoke_details.result is not None
):
return operation.chained_invoke_details.result
if operation.context_details and operation.context_details.result is not None:
return operation.context_details.result
return None


def _extract_error(operation: Operation) -> ErrorObject | None:
if operation.step_details and operation.step_details.error:
return operation.step_details.error
if operation.callback_details and operation.callback_details.error:
return operation.callback_details.error
if operation.chained_invoke_details and operation.chained_invoke_details.error:
return operation.chained_invoke_details.error
if operation.context_details and operation.context_details.error:
return operation.context_details.error
return None


@dataclass(frozen=True)
class OperationInfo:
operation_id: str
Expand All @@ -37,6 +65,33 @@ class OperationInfo:
start_time: datetime.datetime | None
is_replayed: bool
status: OperationStatus
end_time: datetime.datetime | None = field(default=None, kw_only=True)
result: str | None = field(default=None, kw_only=True)
error: ErrorObject | None = field(default=None, kw_only=True)
attempt: int | None = field(default=None, kw_only=True)

@staticmethod
def from_operation(
operation: Operation,
*,
is_replayed: bool = False,
) -> OperationInfo:
return OperationInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
end_time=operation.end_timestamp,
result=_extract_result(operation),
error=_extract_error(operation),
attempt=(
operation.step_details.attempt if operation.step_details else None
),
is_replayed=is_replayed,
status=operation.status,
)


@dataclass(frozen=True)
Expand All @@ -46,8 +101,15 @@ class OperationStartInfo(OperationInfo):

@dataclass(frozen=True)
class OperationEndInfo(OperationInfo):
end_time: datetime.datetime | None
error: ErrorObject | None
pass


@dataclass(frozen=True)
class OperationChangeInfo:
request_id: str | None
execution_arn: str | None
updated_operations: dict[str, OperationInfo]
operations: dict[str, OperationInfo]


class UserFunctionOutcome(Enum):
Expand All @@ -66,20 +128,14 @@ class UserFunctionStartInfo(OperationInfo):
is_replay_children: bool = (
False # True if user function is called to replay children (MAP/PARALLEL)
)
attempt: int | None = (
None # None for user function called more than once in CONTEXT
)


@dataclass(frozen=True)
class UserFunctionEndInfo(OperationInfo):
is_replay_children: (
bool # True if user function is called to replay children (MAP/PARALLEL)
)
attempt: int | None # None for user function called more than once in CONTEXT
outcome: UserFunctionOutcome
end_time: datetime.datetime | None
error: ErrorObject | None

@classmethod
def from_start_info(
Expand Down Expand Up @@ -178,6 +234,16 @@ def on_operation_end(self, info: OperationEndInfo) -> None:
"""
pass

def on_operation_change(self, info: OperationChangeInfo) -> None:
"""
Called when checkpointed operations change after a checkpoint response is merged.
This is called NOT within the thread that runs operation.

Args:
info: Updated operations and the full operation map for the invocation.
"""
pass

def on_user_function_start(self, info: UserFunctionStartInfo) -> None:
"""Called when an operation starts to execute user provided function. This is called within the thread that runs user provided function.

Expand Down Expand Up @@ -229,6 +295,8 @@ def _dispatch_plugin(plugin: DurableInstrumentationPlugin, info) -> None:
plugin.on_operation_start(info)
case OperationEndInfo():
plugin.on_operation_end(info)
case OperationChangeInfo():
plugin.on_operation_change(info)
case UserFunctionStartInfo():
plugin.on_user_function_start(info)
case UserFunctionEndInfo():
Expand Down Expand Up @@ -362,52 +430,104 @@ def on_operation_replay(self, operation: Operation) -> None:
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
end_time=operation.end_timestamp,
result=_extract_result(operation),
status=operation.status,
error=self._extract_error(operation),
attempt=(
operation.step_details.attempt
if operation.step_details
else None
),
is_replayed=True,
),
sync=True,
)

def on_operation_update(self, operation: Operation | None):
"""Execute any registered plugins for a given operation when it receives an update
def on_operation_update(
self,
operation_or_operations: Operation | Sequence[Operation] | None,
operations: Mapping[str, Operation] | None = None,
previous_operations: Mapping[str, Operation] | None = None,
):
"""Execute any registered plugins for operation updates.

Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be
checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED).

Note: the operation may not be up-to-date if the checkpoint is called asynchronously.

Args:
operation: the operation is just checkpointed
operation_or_operations: operation or operations that were just checkpointed.
operations: full operation map after the update, when available.
previous_operations: operation map before the update, when available.
"""
if operation and self._is_terminal_status(operation.status):
self.execute_plugins(
OperationEndInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
end_time=operation.end_timestamp,
status=operation.status,
error=self._extract_error(operation),
is_replayed=False,
),
sync=True,
)
if operation_or_operations is None:
return

updated_operations: list[Operation] = (
cast(list[Operation], list(operation_or_operations))
if isinstance(operation_or_operations, list | tuple)
else [cast(Operation, operation_or_operations)]
)
for operation in updated_operations:
if self._is_terminal_status(operation.status):
self.execute_plugins(
OperationEndInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_time=operation.start_timestamp,
end_time=operation.end_timestamp,
result=_extract_result(operation),
status=operation.status,
error=self._extract_error(operation),
attempt=(
operation.step_details.attempt
if operation.step_details
else None
),
is_replayed=False,
),
sync=True,
)

if (
operations is None
or previous_operations is None
or self._invocation_status is None
):
return

changed_operations = [
operation
for operation in updated_operations
if previous_operations.get(operation.operation_id) is None
or previous_operations[operation.operation_id].status != operation.status
]
if not changed_operations:
return

self.execute_plugins(
OperationChangeInfo(
request_id=self._invocation_status.request_id,
execution_arn=self._invocation_status.execution_arn,
updated_operations={
operation.operation_id: OperationInfo.from_operation(operation)
for operation in changed_operations
},
operations={
operation_id: OperationInfo.from_operation(operation)
for operation_id, operation in operations.items()
},
),
sync=True,
)

@staticmethod
def _extract_error(operation: Operation):
if operation.step_details and operation.step_details.error:
return operation.step_details.error
if operation.callback_details and operation.callback_details.error:
return operation.callback_details.error
if operation.chained_invoke_details and operation.chained_invoke_details.error:
return operation.chained_invoke_details.error
if operation.context_details and operation.context_details.error:
return operation.context_details.error
return None
return _extract_error(operation)

@staticmethod
def _is_terminal_status(status):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,18 +744,22 @@ def checkpoint_batches_forever(self) -> None:
# Update local token for next iteration
current_checkpoint_token = output.checkpoint_token

previous_operations = self.operations

# Fetch new operations from the API before unblocking sync waiters
updated_operations = self.fetch_paginated_operations(
output.new_execution_state.operations,
output.checkpoint_token,
output.new_execution_state.next_marker,
)

for update in updates:
self._plugin_executor.on_operation_action(update)

for operation in updated_operations:
self._plugin_executor.on_operation_update(operation)
self._plugin_executor.on_operation_update(
updated_operations,
self.operations,
previous_operations,
)

# Signal completion for any synchronous operations
for queued_op in batch:
Expand Down
Loading