Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions agentex/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3804,6 +3804,15 @@ components:
type: array
title: The messages to send to the task. The order of the messages will
be the order they are added to the task.
created_at:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Optional caller-supplied base creation timestamp for the batch
description: Optional base timestamp. Each message in the batch is stamped
with base + i milliseconds to guarantee unique, monotonic ordering. If
omitted, the server stamps datetime.now(UTC) at insert time.
type: object
required:
- task_id
Expand Down Expand Up @@ -4248,6 +4257,17 @@ components:
- DONE
- type: 'null'
title: The streaming status of the message
created_at:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Optional caller-supplied creation timestamp
description: Optional timestamp for the message. Workflow callers should
pass workflow.now() (Temporal's deterministic monotonic clock) so that
two awaited messages.create calls from the same workflow are guaranteed
to have monotonic timestamps regardless of HTTP scheduling at the server.
If omitted, the server's wall clock at insert time is used.
type: object
required:
- task_id
Expand Down
46 changes: 30 additions & 16 deletions agentex/src/adapters/crud_store/adapter_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,31 @@ async def create(self, item: T) -> T:
if "_id" in data:
del data["_id"]

# Add timestamps
# Timestamps: respect caller-supplied values (allowing deterministic
# ordering across concurrent requests, e.g. workflow.now() from
# Temporal). Only fall back to server time when missing/None.
now = datetime.now(UTC)
data["created_at"] = now
data["updated_at"] = now
if data.get("created_at") is None:
data["created_at"] = now
if data.get("updated_at") is None:
data["updated_at"] = now

result = self.collection.insert_one(data)

# Update item with generated ID (as string)
# Set the .id field with the string representation of _id
if hasattr(item, "id"):
item.id = str(result.inserted_id)
# Set timestamps on the returned object
item.created_at = now
item.updated_at = now
if getattr(item, "created_at", None) is None:
item.created_at = data["created_at"]
if getattr(item, "updated_at", None) is None:
item.updated_at = data["updated_at"]
elif isinstance(item, dict):
item["id"] = str(result.inserted_id)
item["created_at"] = now
item["updated_at"] = now
if item.get("created_at") is None:
item["created_at"] = data["created_at"]
if item.get("updated_at") is None:
item["updated_at"] = data["updated_at"]

return item
except pymongo.errors.DuplicateKeyError as e:
Expand Down Expand Up @@ -258,25 +265,32 @@ async def batch_create(self, items: list[T]) -> list[T]:
if "_id" not in data or data["_id"] is None:
data.pop("_id", None)

# Add timestamps
data["created_at"] = now
data["updated_at"] = now
# Timestamps: respect caller-supplied values per-item, fall back
# to server time when missing/None. See create() for rationale.
if data.get("created_at") is None:
data["created_at"] = now
if data.get("updated_at") is None:
data["updated_at"] = now

data_list.append(data)

result = self.collection.insert_many(data_list)

# Update items with generated IDs (as strings)
for idx, inserted_id in enumerate(result.inserted_ids):
# Set the .id field with the string representation of _id
persisted = data_list[idx]
if hasattr(items[idx], "id"):
items[idx].id = str(inserted_id)
items[idx].created_at = now
items[idx].updated_at = now
if getattr(items[idx], "created_at", None) is None:
items[idx].created_at = persisted["created_at"]
if getattr(items[idx], "updated_at", None) is None:
items[idx].updated_at = persisted["updated_at"]
elif isinstance(items[idx], dict):
items[idx]["id"] = str(inserted_id)
items[idx]["created_at"] = now
items[idx]["updated_at"] = now
if items[idx].get("created_at") is None:
items[idx]["created_at"] = persisted["created_at"]
if items[idx].get("updated_at") is None:
items[idx]["updated_at"] = persisted["updated_at"]

return items
except pymongo.errors.DuplicateKeyError as e:
Expand Down
2 changes: 2 additions & 0 deletions agentex/src/api/routes/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def batch_create_messages(
task_message_entities = await message_use_case.create_batch(
task_id=request.task_id,
contents=converted_contents,
created_at=request.created_at,
)
return [
TaskMessage.model_validate(task_message_entity)
Expand Down Expand Up @@ -137,6 +138,7 @@ async def create_message(
task_id=request.task_id,
content=convert_task_message_content_to_entity(request.content.root),
streaming_status=request.streaming_status,
created_at=request.created_at,
)
return TaskMessage.model_validate(task_message_entity)

Expand Down
22 changes: 22 additions & 0 deletions agentex/src/api/schemas/task_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ class CreateTaskMessageRequest(BaseModel):
None,
title="The streaming status of the message",
)
created_at: datetime | None = Field(
None,
title="Optional caller-supplied creation timestamp",
description=(
"Optional timestamp for the message. Workflow callers should pass "
"workflow.now() (Temporal's deterministic monotonic clock) so that "
"two awaited messages.create calls from the same workflow are "
"guaranteed to have monotonic timestamps regardless of HTTP "
"scheduling at the server. If omitted, the server's wall clock at "
"insert time is used."
),
)


class UpdateTaskMessageRequest(BaseModel):
Expand All @@ -168,6 +180,16 @@ class BatchCreateTaskMessagesRequest(BaseModel):
...,
title="The messages to send to the task. The order of the messages will be the order they are added to the task.",
)
created_at: datetime | None = Field(
None,
title="Optional caller-supplied base creation timestamp for the batch",
description=(
"Optional base timestamp. Each message in the batch is stamped "
"with base + i milliseconds to guarantee unique, monotonic "
"ordering. If omitted, the server stamps datetime.now(UTC) at "
"insert time."
),
)


class BatchUpdateTaskMessagesRequest(BaseModel):
Expand Down
28 changes: 22 additions & 6 deletions agentex/src/domain/services/task_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,21 @@ async def append_message(
task_id: str,
content: TaskMessageContentEntity,
streaming_status: Literal["IN_PROGRESS", "DONE"] | None = None,
created_at: datetime | None = None,
) -> TaskMessageEntity:
"""
Append a message to the task's message list.

Args:
task_id: The task ID
message: The message to append
content: The message content
streaming_status: Optional streaming status
created_at: Optional caller-supplied timestamp. Workflow callers
should pass workflow.now() (Temporal's deterministic monotonic
clock) so two awaited messages.create calls from the same
workflow are guaranteed to have monotonic timestamps regardless
of HTTP request scheduling at the server. If omitted, the
adapter falls back to the server's wall clock at insert time.

Returns:
The created TaskMessageEntity with ID and metadata
Expand All @@ -117,6 +125,8 @@ async def append_message(
task_id=task_id,
content=content,
streaming_status=streaming_status,
created_at=created_at,
updated_at=created_at,
)

return await self.repository.create(task_message)
Expand All @@ -126,24 +136,30 @@ async def append_messages(
task_id: str,
contents: list[TaskMessageContentEntity],
streaming_status: Literal["IN_PROGRESS", "DONE"] | None = None,
created_at: datetime | None = None,
) -> list[TaskMessageEntity]:
"""
Append multiple messages to the task's message list.

Args:
task_id: The task ID
messages: The messages to append
contents: The message contents to append
streaming_status: Optional streaming status
created_at: Optional base timestamp for the batch. Each message in
the batch is stamped with base + i milliseconds to guarantee
unique, monotonic ordering. If omitted, datetime.now(UTC) is
used as the base.

Returns:
The created TaskMessageEntity objects with IDs and metadata
"""
# Add a small time increment to each message to ensure unique ordering
current_time = datetime.now(UTC)
base_time = created_at if created_at is not None else datetime.now(UTC)

task_messages = []
for i, message in enumerate(contents):
# Add i microseconds to ensure unique timestamps within the batch
adjusted_time = current_time + timedelta(microseconds=i)
# MongoDB BSON Date is millisecond-precision; stagger by ms (not µs)
# so the stored ordering is durable across re-fetches.
adjusted_time = base_time + timedelta(milliseconds=i)
task_message = TaskMessageEntity(
task_id=task_id,
content=message,
Expand Down
14 changes: 12 additions & 2 deletions agentex/src/domain/use_cases/messages_use_case.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Annotated, Any, Literal

from fastapi import Depends
Expand Down Expand Up @@ -84,13 +85,16 @@ async def create(
task_id: str,
content: TaskMessageContentEntity,
streaming_status: Literal["IN_PROGRESS", "DONE"] | None,
created_at: datetime | None = None,
) -> TaskMessageEntity:
"""
Create a new message for a task.

Args:
task_id: The task ID
content: The task message content to create
streaming_status: Optional streaming status
created_at: Optional caller-supplied timestamp (see service docstring)

Returns:
The created TaskMessageEntity with ID and metadata
Expand All @@ -99,6 +103,7 @@ async def create(
task_id=task_id,
content=content,
streaming_status=streaming_status,
created_at=created_at,
)

async def update(
Expand All @@ -125,21 +130,26 @@ async def update(
)

async def create_batch(
self, task_id: str, contents: list[TaskMessageContentEntity]
self,
task_id: str,
contents: list[TaskMessageContentEntity],
created_at: datetime | None = None,
) -> list[TaskMessageEntity]:
"""
Create multiple messages for a task.

Args:
task_id: The task ID
messages: The messages to create
contents: The messages to create
created_at: Optional base timestamp (see service docstring)

Returns:
The created TaskMessageEntity objects with IDs and metadata
"""
return await self.task_message_service.append_messages(
task_id=task_id,
contents=contents,
created_at=created_at,
)

async def update_batch(
Expand Down
87 changes: 87 additions & 0 deletions agentex/tests/unit/repositories/test_task_message_repository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Import the repository and entities we need to test

from datetime import UTC, datetime, timedelta

import pytest
from src.adapters.crud_store.exceptions import ItemDoesNotExist
from src.api.schemas.task_messages import DataContent
Expand Down Expand Up @@ -186,3 +188,88 @@ async def test_task_message_repository_list_by_task_id_pagination(
# Cleanup
for message_id in created_ids:
await repo.delete(id=message_id)


@pytest.mark.asyncio
@pytest.mark.unit
async def test_create_preserves_caller_supplied_timestamps(
task_message_repository,
):
"""The Mongo adapter must respect caller-supplied created_at/updated_at and
not clobber them with datetime.now(UTC) at insert time. This is the
server-side fix for the cross-request race that flipped message ordering
in the UI when two messages.create calls arrived within milliseconds."""
repo = task_message_repository
task_id = orm_id()

caller_time = datetime(2025, 4, 1, 8, 30, 0, tzinfo=UTC)
text_content = TextContentEntity(
type=TaskMessageContentType.TEXT,
content="caller-stamped",
author=MessageAuthor.USER,
style=MessageStyle.STATIC,
format=TextFormat.PLAIN,
)
task_message = TaskMessageEntity(
task_id=task_id,
content=text_content,
streaming_status="DONE",
created_at=caller_time,
updated_at=caller_time,
)

# pymongo strips tzinfo on read (BSON Date stores UTC but returns naive
# datetimes by default), so compare against the naive UTC equivalent.
expected_naive = caller_time.replace(tzinfo=None)

created = await repo.create(task_message)
assert created.created_at.replace(tzinfo=None) == expected_naive
assert created.updated_at.replace(tzinfo=None) == expected_naive

fetched = await repo.get(id=created.id)
assert fetched.created_at.replace(tzinfo=None) == expected_naive
assert fetched.updated_at.replace(tzinfo=None) == expected_naive

await repo.delete(id=created.id)


@pytest.mark.asyncio
@pytest.mark.unit
async def test_batch_create_preserves_per_item_timestamps(
task_message_repository,
):
"""batch_create must preserve per-item caller-supplied timestamps so that
the service-layer microsecond/millisecond stagger is durable in storage."""
repo = task_message_repository
task_id = orm_id()

base = datetime(2025, 4, 2, 12, 0, 0, tzinfo=UTC)
messages = [
TaskMessageEntity(
task_id=task_id,
content=TextContentEntity(
type=TaskMessageContentType.TEXT,
content=f"msg-{i}",
author=MessageAuthor.USER,
style=MessageStyle.STATIC,
format=TextFormat.PLAIN,
),
streaming_status="DONE",
# Insert with timestamps in *descending* order to prove that the
# adapter is not assigning a single now() to all items.
created_at=base + timedelta(milliseconds=10 - i),
updated_at=base + timedelta(milliseconds=10 - i),
)
for i in range(3)
]

created = await repo.batch_create(messages)
assert len(created) == 3
for i, c in enumerate(created):
expected = (base + timedelta(milliseconds=10 - i)).replace(tzinfo=None)
assert c.created_at.replace(tzinfo=None) == expected
assert c.updated_at.replace(tzinfo=None) == expected

# Cleanup
for c in created:
await repo.delete(id=c.id)
Loading
Loading