Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions temporalio/contrib/google_adk_agents/_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from datetime import timedelta

from google.adk.models import BaseLlm, LLMRegistry
Expand Down Expand Up @@ -40,20 +40,37 @@ class TemporalModel(BaseLlm):
"""A Temporal-based LLM model that executes model invocations as activities."""

def __init__(
self, model_name: str, activity_config: ActivityConfig | None = None
self,
model_name: str,
activity_config: ActivityConfig | None = None,
*,
summary_fn: Callable[[LlmRequest], str | None] | None = None,
) -> None:
"""Initialize the TemporalModel.

Args:
model_name: The name of the model to use.
activity_config: Configuration options for the activity execution.
summary_fn: Optional callable that receives the LlmRequest and
returns a summary string (or None) for the activity. Must be
deterministic as it is called during workflow execution. If
the callable raises, the exception will propagate and fail
the workflow task.

Raises:
ValueError: If both ``ActivityConfig["summary"]`` and ``summary_fn`` are set.
"""
super().__init__(model=model_name)
self._model_name = model_name
self._summary_fn = summary_fn
self._activity_config = ActivityConfig(
start_to_close_timeout=timedelta(seconds=60)
)
if activity_config:
if activity_config is not None:
if summary_fn is not None and activity_config.get("summary") is not None:
raise ValueError(
"Cannot specify both ActivityConfig 'summary' and 'summary_fn'"
)
self._activity_config.update(activity_config)

async def generate_content_async(
Expand All @@ -76,10 +93,20 @@ async def generate_content_async(
yield response
return

config = self._activity_config.copy()
if self._summary_fn is not None:
summary = self._summary_fn(llm_request)
if summary is not None:
config["summary"] = summary
elif "summary" not in config:
if llm_request.config and llm_request.config.labels:
agent_name = llm_request.config.labels.get("adk_agent_name")
if agent_name:
config["summary"] = agent_name
responses = await workflow.execute_activity(
invoke_model,
args=[llm_request],
**self._activity_config,
**config,
)
for response in responses:
yield response
97 changes: 97 additions & 0 deletions tests/contrib/google_adk_agents/test_google_adk_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,103 @@ async def test_unsetting_timeout():
assert model._activity_config.get("start_to_close_timeout", None) is None


class SummaryFnModel(TestModel):
"""Returns a single text response for summary_fn testing."""

def responses(self) -> list[LlmResponse]:
return [
LlmResponse(content=Content(role="model", parts=[Part(text="response")])),
]

@classmethod
def supported_models(cls) -> list[str]:
return ["summary_fn_model"]


@workflow.defn
class SummaryTestWorkflow:
@workflow.run
async def run(self, model_name: str) -> None:
modes = [
("dynamic", lambda req: f"Invoking {req.model}"),
("none", lambda req: None),
("empty", lambda req: ""),
("label_fallback", None),
]
for mode_name, summary_fn in modes:
agent = Agent(
name=f"summary_test_{mode_name}",
model=TemporalModel(model_name, summary_fn=summary_fn),
)
runner = InMemoryRunner(agent=agent, app_name=f"summary_{mode_name}")
session = await runner.session_service.create_session(
app_name=f"summary_{mode_name}", user_id="test"
)
async with Aclosing(
runner.run_async(
user_id="test",
session_id=session.id,
new_message=types.Content(
role="user", parts=[types.Part(text="hi")]
),
)
) as agen:
async for _ in agen:
pass


@pytest.mark.asyncio
async def test_summary_fn_variants(client: Client):
"""Test summary_fn with dynamic, None, empty string, and label fallback."""
new_config = client.config()
new_config["plugins"] = [GoogleAdkPlugin()]
client = Client(**new_config)
LLMRegistry.register(SummaryFnModel)

async with Worker(
client,
task_queue="adk-summary-test",
workflows=[SummaryTestWorkflow],
max_cached_workflows=0,
):
handle = await client.start_workflow(
SummaryTestWorkflow.run,
"summary_fn_model",
id=f"summary-test-{uuid.uuid4()}",
task_queue="adk-summary-test",
execution_timeout=timedelta(seconds=60),
)
await handle.result()

summaries = []
async for e in handle.fetch_history_events():
if e.HasField("activity_task_scheduled_event_attributes"):
attrs = e.activity_task_scheduled_event_attributes
if attrs.activity_type.name == "invoke_model":
summaries.append(e.user_metadata.summary.data)

assert len(summaries) == 4
assert summaries[0] == b'"Invoking summary_fn_model"' # dynamic
assert summaries[1] == b"" # none
assert summaries[2] == b"" # empty
assert (
summaries[3] == b'"summary_test_label_fallback"'
) # label fallback agent name


def test_summary_and_summary_fn_raises():
"""Cannot specify both summary and summary_fn."""
with pytest.raises(
ValueError,
match="Cannot specify both ActivityConfig 'summary' and 'summary_fn'",
):
TemporalModel(
"m",
activity_config=ActivityConfig(summary="static"),
summary_fn=lambda req: "dynamic",
)


@pytest.mark.asyncio
async def test_agent_outside_workflow():
"""Test that an agent using TemporalModel and activity_tool works outside a Temporal workflow."""
Expand Down
Loading