Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ...tools.base_toolset import BaseToolset
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
from ...utils import model_name_utils
from .audio_cache_manager import AudioCacheManager
from .functions import build_auth_request_event

Expand Down Expand Up @@ -516,6 +517,20 @@ async def run_live(
)
llm_request.live_connect_config.session_resumption.transparent = True

if (
isinstance(llm, Gemini)
and llm._api_backend == GoogleLLMVariant.GEMINI_API
and model_name_utils.is_gemini_3_1_flash_live(llm_request.model)
and llm_request.contents
and not invocation_context.live_session_resumption_handle
):
if llm_request.live_connect_config is None:
llm_request.live_connect_config = types.LiveConnectConfig()
if llm_request.live_connect_config.history_config is None:
llm_request.live_connect_config.history_config = types.HistoryConfig(
initial_history_in_client_content=True
)

logger.info(
'Establishing live connection for agent: %s',
invocation_context.agent.name,
Expand Down
7 changes: 5 additions & 2 deletions src/google/adk/flows/llm_flows/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...utils import model_name_utils
from ...utils.output_schema_utils import can_use_output_schema_with_tools
from ._base_llm_processor import BaseLlmRequestProcessor

Expand Down Expand Up @@ -78,11 +79,13 @@ def _build_basic_request(
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
llm_request.live_connect_config.enable_affective_dialog = (
invocation_context.run_config.enable_affective_dialog
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
invocation_context.run_config.proactivity
None if is_gemini_31 else invocation_context.run_config.proactivity
)
llm_request.live_connect_config.session_resumption = (
invocation_context.run_config.session_resumption
Expand Down
70 changes: 61 additions & 9 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,24 @@ async def send_history(self, history: list[types.Content]):
]

if contents:
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
self._model_version
)
# Gemini Enterprise Agent Platform does not support history_config in the SDK.
# To initialize a live session with prior history without hitting a 1007
# protocol error (invalid role mid-session), we consolidate previous multi-turn
# interactions into a unified contextual preamble on a single user role turn.
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
collapsed_text = "Previous conversation history:\n"
for c in contents:
text_parts = "".join(p.text for p in c.parts if p.text)
collapsed_text += f'[{c.role}]: {text_parts}\n'
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]

logger.debug('Sending history to live connection: %s', contents)
await self._gemini_session.send_client_content(
turns=contents,
turn_complete=contents[-1].role == 'user',
turn_complete=True if is_gemini_31 else (contents[-1].role == 'user'),
)
else:
logger.info('no content is sent')
Expand Down Expand Up @@ -159,14 +173,21 @@ async def send_realtime(self, input: RealtimeInput):
else:
raise ValueError('Unsupported input type: %s' % type(input))

def __build_full_text_response(self, text: str):
def __build_full_text_response(
self,
text: str,
is_thought: bool = False,
grounding_metadata: types.GroundingMetadata | None = None,
):
"""Builds a full text response.

The text should not be partial and the returned LlmResponse is not
partial.

Args:
text: The text to be included in the response.
is_thought: Whether the text is a thought.
grounding_metadata: The grounding metadata to include.

Returns:
An LlmResponse containing the full text.
Expand All @@ -176,6 +197,8 @@ def __build_full_text_response(self, text: str):
role='model',
parts=[types.Part.from_text(text=text)],
),
grounding_metadata=grounding_metadata,
partial=False,
live_session_id=self._gemini_session.session_id,
)

Expand All @@ -188,6 +211,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:

text = ''
tool_call_parts = []
pending_grounding_metadata = None
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand All @@ -203,6 +227,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
if message.server_content:
content = message.server_content.model_turn
if message.server_content.grounding_metadata:
pending_grounding_metadata = (
message.server_content.grounding_metadata
)

# Standalone grounding_metadata event (when content is empty)
if (
Expand All @@ -215,6 +243,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
interrupted=message.server_content.interrupted,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)

if content and content.parts:
Expand All @@ -223,19 +254,31 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
interrupted=message.server_content.interrupted,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)
# grounding_metadata is yielded again at turn_complete,
# so avoid duplicating it here if turn_complete is true.
if not message.server_content.turn_complete:
llm_response.grounding_metadata = (
message.server_content.grounding_metadata
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
has_inline_data = any(p.inline_data for p in content.parts)
for part in content.parts:
if part.text:
current_is_thought = getattr(part, 'thought', False)
if text and current_is_thought != is_thought:
yield self.__build_full_text_response(text, is_thought)
text = ''
is_thought = False

text += part.text
is_thought = current_is_thought
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
if text and not any(p.text for p in content.parts) and not has_inline_data:
yield self.__build_full_text_response(text, is_thought)
text = ''
yield llm_response
# Note: in some cases, tool_call may arrive before
Expand Down Expand Up @@ -324,9 +367,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
self._output_transcription_text = ''
if message.server_content.turn_complete:
g_metadata_to_yield = pending_grounding_metadata
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, is_thought, g_metadata_to_yield
)
text = ''
is_thought = False
g_metadata_to_yield = None
if tool_call_parts:
logger.debug('Returning aggregated tool_call_parts')
yield LlmResponse(
Expand All @@ -338,9 +386,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
grounding_metadata=message.server_content.grounding_metadata,
grounding_metadata=message.server_content.grounding_metadata
or g_metadata_to_yield,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)
break
# in case of empty content or parts, we still surface it
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class LlmResponse(BaseModel):
Only used for streaming mode.
"""

turn_complete_reason: Optional[types.TurnCompleteReason] = None
"""The reason why the turn is complete.

Only used for streaming mode.
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""

Expand Down
Loading