diff --git a/agentex/openapi.yaml b/agentex/openapi.yaml index aed30cfb..ac50e362 100644 --- a/agentex/openapi.yaml +++ b/agentex/openapi.yaml @@ -3804,6 +3804,15 @@ components: type: array title: The messages to send to the task. The order of the messages will be the order they are added to the task. + created_at: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Optional caller-supplied base creation timestamp for the batch + description: Optional base timestamp. Each message in the batch is stamped + with base + i milliseconds to guarantee unique, monotonic ordering. If + omitted, the server stamps datetime.now(UTC) at insert time. type: object required: - task_id @@ -4248,6 +4257,17 @@ components: - DONE - type: 'null' title: The streaming status of the message + created_at: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Optional caller-supplied creation timestamp + description: Optional timestamp for the message. Workflow callers should + pass workflow.now() (Temporal's deterministic monotonic clock) so that + two awaited messages.create calls from the same workflow are guaranteed + to have monotonic timestamps regardless of HTTP scheduling at the server. + If omitted, the server's wall clock at insert time is used. type: object required: - task_id diff --git a/agentex/src/adapters/crud_store/adapter_mongodb.py b/agentex/src/adapters/crud_store/adapter_mongodb.py index fe75078e..aab286fe 100644 --- a/agentex/src/adapters/crud_store/adapter_mongodb.py +++ b/agentex/src/adapters/crud_store/adapter_mongodb.py @@ -207,10 +207,14 @@ async def create(self, item: T) -> T: if "_id" in data: del data["_id"] - # Add timestamps + # Timestamps: respect caller-supplied values (allowing deterministic + # ordering across concurrent requests, e.g. workflow.now() from + # Temporal). Only fall back to server time when missing/None. now = datetime.now(UTC) - data["created_at"] = now - data["updated_at"] = now + if data.get("created_at") is None: + data["created_at"] = now + if data.get("updated_at") is None: + data["updated_at"] = now result = self.collection.insert_one(data) @@ -218,13 +222,16 @@ async def create(self, item: T) -> T: # Set the .id field with the string representation of _id if hasattr(item, "id"): item.id = str(result.inserted_id) - # Set timestamps on the returned object - item.created_at = now - item.updated_at = now + if getattr(item, "created_at", None) is None: + item.created_at = data["created_at"] + if getattr(item, "updated_at", None) is None: + item.updated_at = data["updated_at"] elif isinstance(item, dict): item["id"] = str(result.inserted_id) - item["created_at"] = now - item["updated_at"] = now + if item.get("created_at") is None: + item["created_at"] = data["created_at"] + if item.get("updated_at") is None: + item["updated_at"] = data["updated_at"] return item except pymongo.errors.DuplicateKeyError as e: @@ -258,9 +265,12 @@ async def batch_create(self, items: list[T]) -> list[T]: if "_id" not in data or data["_id"] is None: data.pop("_id", None) - # Add timestamps - data["created_at"] = now - data["updated_at"] = now + # Timestamps: respect caller-supplied values per-item, fall back + # to server time when missing/None. See create() for rationale. + if data.get("created_at") is None: + data["created_at"] = now + if data.get("updated_at") is None: + data["updated_at"] = now data_list.append(data) @@ -268,15 +278,19 @@ async def batch_create(self, items: list[T]) -> list[T]: # Update items with generated IDs (as strings) for idx, inserted_id in enumerate(result.inserted_ids): - # Set the .id field with the string representation of _id + persisted = data_list[idx] if hasattr(items[idx], "id"): items[idx].id = str(inserted_id) - items[idx].created_at = now - items[idx].updated_at = now + if getattr(items[idx], "created_at", None) is None: + items[idx].created_at = persisted["created_at"] + if getattr(items[idx], "updated_at", None) is None: + items[idx].updated_at = persisted["updated_at"] elif isinstance(items[idx], dict): items[idx]["id"] = str(inserted_id) - items[idx]["created_at"] = now - items[idx]["updated_at"] = now + if items[idx].get("created_at") is None: + items[idx]["created_at"] = persisted["created_at"] + if items[idx].get("updated_at") is None: + items[idx]["updated_at"] = persisted["updated_at"] return items except pymongo.errors.DuplicateKeyError as e: diff --git a/agentex/src/api/routes/messages.py b/agentex/src/api/routes/messages.py index 545d27de..2f842758 100644 --- a/agentex/src/api/routes/messages.py +++ b/agentex/src/api/routes/messages.py @@ -94,6 +94,7 @@ async def batch_create_messages( task_message_entities = await message_use_case.create_batch( task_id=request.task_id, contents=converted_contents, + created_at=request.created_at, ) return [ TaskMessage.model_validate(task_message_entity) @@ -137,6 +138,7 @@ async def create_message( task_id=request.task_id, content=convert_task_message_content_to_entity(request.content.root), streaming_status=request.streaming_status, + created_at=request.created_at, ) return TaskMessage.model_validate(task_message_entity) diff --git a/agentex/src/api/schemas/task_messages.py b/agentex/src/api/schemas/task_messages.py index f05b6e39..3d4db964 100644 --- a/agentex/src/api/schemas/task_messages.py +++ b/agentex/src/api/schemas/task_messages.py @@ -142,6 +142,18 @@ class CreateTaskMessageRequest(BaseModel): None, title="The streaming status of the message", ) + created_at: datetime | None = Field( + None, + title="Optional caller-supplied creation timestamp", + description=( + "Optional timestamp for the message. Workflow callers should pass " + "workflow.now() (Temporal's deterministic monotonic clock) so that " + "two awaited messages.create calls from the same workflow are " + "guaranteed to have monotonic timestamps regardless of HTTP " + "scheduling at the server. If omitted, the server's wall clock at " + "insert time is used." + ), + ) class UpdateTaskMessageRequest(BaseModel): @@ -168,6 +180,16 @@ class BatchCreateTaskMessagesRequest(BaseModel): ..., title="The messages to send to the task. The order of the messages will be the order they are added to the task.", ) + created_at: datetime | None = Field( + None, + title="Optional caller-supplied base creation timestamp for the batch", + description=( + "Optional base timestamp. Each message in the batch is stamped " + "with base + i milliseconds to guarantee unique, monotonic " + "ordering. If omitted, the server stamps datetime.now(UTC) at " + "insert time." + ), + ) class BatchUpdateTaskMessagesRequest(BaseModel): diff --git a/agentex/src/domain/services/task_message_service.py b/agentex/src/domain/services/task_message_service.py index 4cd23b34..c6ea3c37 100644 --- a/agentex/src/domain/services/task_message_service.py +++ b/agentex/src/domain/services/task_message_service.py @@ -102,13 +102,21 @@ async def append_message( task_id: str, content: TaskMessageContentEntity, streaming_status: Literal["IN_PROGRESS", "DONE"] | None = None, + created_at: datetime | None = None, ) -> TaskMessageEntity: """ Append a message to the task's message list. Args: task_id: The task ID - message: The message to append + content: The message content + streaming_status: Optional streaming status + created_at: Optional caller-supplied timestamp. Workflow callers + should pass workflow.now() (Temporal's deterministic monotonic + clock) so two awaited messages.create calls from the same + workflow are guaranteed to have monotonic timestamps regardless + of HTTP request scheduling at the server. If omitted, the + adapter falls back to the server's wall clock at insert time. Returns: The created TaskMessageEntity with ID and metadata @@ -117,6 +125,8 @@ async def append_message( task_id=task_id, content=content, streaming_status=streaming_status, + created_at=created_at, + updated_at=created_at, ) return await self.repository.create(task_message) @@ -126,24 +136,30 @@ async def append_messages( task_id: str, contents: list[TaskMessageContentEntity], streaming_status: Literal["IN_PROGRESS", "DONE"] | None = None, + created_at: datetime | None = None, ) -> list[TaskMessageEntity]: """ Append multiple messages to the task's message list. Args: task_id: The task ID - messages: The messages to append + contents: The message contents to append + streaming_status: Optional streaming status + created_at: Optional base timestamp for the batch. Each message in + the batch is stamped with base + i milliseconds to guarantee + unique, monotonic ordering. If omitted, datetime.now(UTC) is + used as the base. Returns: The created TaskMessageEntity objects with IDs and metadata """ - # Add a small time increment to each message to ensure unique ordering - current_time = datetime.now(UTC) + base_time = created_at if created_at is not None else datetime.now(UTC) task_messages = [] for i, message in enumerate(contents): - # Add i microseconds to ensure unique timestamps within the batch - adjusted_time = current_time + timedelta(microseconds=i) + # MongoDB BSON Date is millisecond-precision; stagger by ms (not µs) + # so the stored ordering is durable across re-fetches. + adjusted_time = base_time + timedelta(milliseconds=i) task_message = TaskMessageEntity( task_id=task_id, content=message, diff --git a/agentex/src/domain/use_cases/messages_use_case.py b/agentex/src/domain/use_cases/messages_use_case.py index 4f4612dc..7b66835c 100644 --- a/agentex/src/domain/use_cases/messages_use_case.py +++ b/agentex/src/domain/use_cases/messages_use_case.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Annotated, Any, Literal from fastapi import Depends @@ -84,6 +85,7 @@ async def create( task_id: str, content: TaskMessageContentEntity, streaming_status: Literal["IN_PROGRESS", "DONE"] | None, + created_at: datetime | None = None, ) -> TaskMessageEntity: """ Create a new message for a task. @@ -91,6 +93,8 @@ async def create( Args: task_id: The task ID content: The task message content to create + streaming_status: Optional streaming status + created_at: Optional caller-supplied timestamp (see service docstring) Returns: The created TaskMessageEntity with ID and metadata @@ -99,6 +103,7 @@ async def create( task_id=task_id, content=content, streaming_status=streaming_status, + created_at=created_at, ) async def update( @@ -125,14 +130,18 @@ async def update( ) async def create_batch( - self, task_id: str, contents: list[TaskMessageContentEntity] + self, + task_id: str, + contents: list[TaskMessageContentEntity], + created_at: datetime | None = None, ) -> list[TaskMessageEntity]: """ Create multiple messages for a task. Args: task_id: The task ID - messages: The messages to create + contents: The messages to create + created_at: Optional base timestamp (see service docstring) Returns: The created TaskMessageEntity objects with IDs and metadata @@ -140,6 +149,7 @@ async def create_batch( return await self.task_message_service.append_messages( task_id=task_id, contents=contents, + created_at=created_at, ) async def update_batch( diff --git a/agentex/tests/unit/repositories/test_task_message_repository.py b/agentex/tests/unit/repositories/test_task_message_repository.py index 6f9d4578..c239da00 100644 --- a/agentex/tests/unit/repositories/test_task_message_repository.py +++ b/agentex/tests/unit/repositories/test_task_message_repository.py @@ -1,5 +1,7 @@ # Import the repository and entities we need to test +from datetime import UTC, datetime, timedelta + import pytest from src.adapters.crud_store.exceptions import ItemDoesNotExist from src.api.schemas.task_messages import DataContent @@ -186,3 +188,88 @@ async def test_task_message_repository_list_by_task_id_pagination( # Cleanup for message_id in created_ids: await repo.delete(id=message_id) + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_create_preserves_caller_supplied_timestamps( + task_message_repository, +): + """The Mongo adapter must respect caller-supplied created_at/updated_at and + not clobber them with datetime.now(UTC) at insert time. This is the + server-side fix for the cross-request race that flipped message ordering + in the UI when two messages.create calls arrived within milliseconds.""" + repo = task_message_repository + task_id = orm_id() + + caller_time = datetime(2025, 4, 1, 8, 30, 0, tzinfo=UTC) + text_content = TextContentEntity( + type=TaskMessageContentType.TEXT, + content="caller-stamped", + author=MessageAuthor.USER, + style=MessageStyle.STATIC, + format=TextFormat.PLAIN, + ) + task_message = TaskMessageEntity( + task_id=task_id, + content=text_content, + streaming_status="DONE", + created_at=caller_time, + updated_at=caller_time, + ) + + # pymongo strips tzinfo on read (BSON Date stores UTC but returns naive + # datetimes by default), so compare against the naive UTC equivalent. + expected_naive = caller_time.replace(tzinfo=None) + + created = await repo.create(task_message) + assert created.created_at.replace(tzinfo=None) == expected_naive + assert created.updated_at.replace(tzinfo=None) == expected_naive + + fetched = await repo.get(id=created.id) + assert fetched.created_at.replace(tzinfo=None) == expected_naive + assert fetched.updated_at.replace(tzinfo=None) == expected_naive + + await repo.delete(id=created.id) + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_batch_create_preserves_per_item_timestamps( + task_message_repository, +): + """batch_create must preserve per-item caller-supplied timestamps so that + the service-layer microsecond/millisecond stagger is durable in storage.""" + repo = task_message_repository + task_id = orm_id() + + base = datetime(2025, 4, 2, 12, 0, 0, tzinfo=UTC) + messages = [ + TaskMessageEntity( + task_id=task_id, + content=TextContentEntity( + type=TaskMessageContentType.TEXT, + content=f"msg-{i}", + author=MessageAuthor.USER, + style=MessageStyle.STATIC, + format=TextFormat.PLAIN, + ), + streaming_status="DONE", + # Insert with timestamps in *descending* order to prove that the + # adapter is not assigning a single now() to all items. + created_at=base + timedelta(milliseconds=10 - i), + updated_at=base + timedelta(milliseconds=10 - i), + ) + for i in range(3) + ] + + created = await repo.batch_create(messages) + assert len(created) == 3 + for i, c in enumerate(created): + expected = (base + timedelta(milliseconds=10 - i)).replace(tzinfo=None) + assert c.created_at.replace(tzinfo=None) == expected + assert c.updated_at.replace(tzinfo=None) == expected + + # Cleanup + for c in created: + await repo.delete(id=c.id) diff --git a/agentex/tests/unit/services/test_task_message_service.py b/agentex/tests/unit/services/test_task_message_service.py index 50bf29eb..ac8916a7 100644 --- a/agentex/tests/unit/services/test_task_message_service.py +++ b/agentex/tests/unit/services/test_task_message_service.py @@ -217,17 +217,98 @@ async def test_append_messages_batch_success( # Then assert len(result) == 3 - # Check that all messages were created correctly + # Check that all messages were created correctly. The service stamps + # base + i milliseconds, and BSON Date is millisecond-precision, so the + # persisted timestamps must be strictly increasing across the batch. for i, message in enumerate(result): assert message.id is not None assert message.task_id == sample_task_id assert message.content.content == contents[i].content assert message.created_at is not None - # Check that timestamps are incremented (or at least not decreasing) if i > 0: - # Later messages should have timestamps >= previous (microsecond precision may be identical) - assert message.created_at >= result[i - 1].created_at + assert message.created_at > result[i - 1].created_at + + async def test_append_message_preserves_caller_created_at( + self, task_message_service, sample_task_id, sample_message_content + ): + """Caller-supplied created_at must round-trip through the adapter + (regression test for the Mongo race where the adapter overwrote + caller timestamps with datetime.now(UTC) at insert time).""" + # Given: an explicit, well-in-the-past timestamp the caller wants to set. + caller_time = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC) + + # When + result = await task_message_service.append_message( + task_id=sample_task_id, + content=sample_message_content, + created_at=caller_time, + ) + + # Then: the persisted value must equal what the caller supplied. The + # input is at second precision so BSON ms truncation is a no-op; tzinfo + # gets stripped on read by pymongo, so compare naively. + expected = caller_time.replace(tzinfo=None) + assert result.created_at.replace(tzinfo=None) == expected + assert result.updated_at.replace(tzinfo=None) == expected + + # And the value survives a fresh fetch from the store. + fetched = await task_message_service.get_message(result.id) + assert fetched.created_at.replace(tzinfo=None) == expected + + async def test_append_messages_uses_caller_base_time( + self, task_message_service, sample_task_id + ): + """Batch append must use the caller-supplied base timestamp and stagger + each message by 1 ms on top of it.""" + # Given + base = datetime(2025, 6, 15, 9, 0, 0, tzinfo=UTC) + contents = [ + TextContent(content=f"Message {i}", author=MessageAuthor.USER) + for i in range(3) + ] + + # When + result = await task_message_service.append_messages( + task_id=sample_task_id, contents=contents, created_at=base + ) + + # Then (compare naively — pymongo strips tzinfo on read) + assert len(result) == 3 + for i, message in enumerate(result): + expected = (base + timedelta(milliseconds=i)).replace(tzinfo=None) + assert message.created_at.replace(tzinfo=None) == expected + assert message.updated_at.replace(tzinfo=None) == expected + + async def test_concurrent_appends_with_caller_timestamps_preserve_order( + self, task_message_service, sample_task_id + ): + """Simulate the OneEdge race: two messages.create calls fired + back-to-back, with the *second one* reaching the adapter first. + Because the caller supplies created_at, the persisted ordering must + still reflect caller intent rather than insert order.""" + first_time = datetime(2025, 3, 1, 10, 0, 0, tzinfo=UTC) + second_time = first_time + timedelta(milliseconds=1) + + # Insert in the *reverse* of caller-intended order to prove that the + # adapter no longer uses now() at insert time. + second = await task_message_service.append_message( + task_id=sample_task_id, + content=TextContent(content="agent", author=MessageAuthor.AGENT), + created_at=second_time, + ) + first = await task_message_service.append_message( + task_id=sample_task_id, + content=TextContent(content="user", author=MessageAuthor.USER), + created_at=first_time, + ) + + # pymongo strips tzinfo on read; normalize for comparison. + assert first.created_at.replace(tzinfo=None) == first_time.replace(tzinfo=None) + assert second.created_at.replace(tzinfo=None) == second_time.replace( + tzinfo=None + ) + assert first.created_at < second.created_at async def test_update_message_success( self, task_message_service, sample_task_id, sample_message_content