diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index fb9a3a5163..c9c6ac6500 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -159,7 +159,12 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response(self, text: str): + def __build_full_text_response( + self, + text: str, + is_thought: bool = False, + grounding_metadata: types.GroundingMetadata | None = None, + ): """Builds a full text response. The text should not be partial and the returned LlmResponse is not @@ -167,6 +172,8 @@ def __build_full_text_response(self, text: str): Args: text: The text to be included in the response. + is_thought: Whether the text is a thought. + grounding_metadata: The grounding metadata to include. Returns: An LlmResponse containing the full text. @@ -176,6 +183,8 @@ def __build_full_text_response(self, text: str): role='model', parts=[types.Part.from_text(text=text)], ), + grounding_metadata=grounding_metadata, + partial=False, live_session_id=self._gemini_session.session_id, ) @@ -188,6 +197,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: text = '' tool_call_parts = [] + pending_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -203,6 +213,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn + if message.server_content.grounding_metadata: + pending_grounding_metadata = ( + message.server_content.grounding_metadata + ) # Standalone grounding_metadata event (when content is empty) if ( @@ -215,6 +229,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) if content and content.parts: @@ -223,6 +240,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) # grounding_metadata is yielded again at turn_complete, # so avoid duplicating it here if turn_complete is true. @@ -324,9 +344,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._output_transcription_text = '' if message.server_content.turn_complete: + g_metadata_to_yield = pending_grounding_metadata if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, is_thought, g_metadata_to_yield + ) text = '' + is_thought = False + g_metadata_to_yield = None if tool_call_parts: logger.debug('Returning aggregated tool_call_parts') yield LlmResponse( @@ -338,9 +363,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, - grounding_metadata=message.server_content.grounding_metadata, + grounding_metadata=message.server_content.grounding_metadata + or g_metadata_to_yield, model_version=self._model_version, live_session_id=live_session_id, + turn_complete_reason=getattr( + message.server_content, 'turn_complete_reason', None + ), ) break # in case of empty content or parts, we still surface it diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index c921f197c3..333034565f 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -81,6 +81,12 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ + turn_complete_reason: Optional[types.TurnCompleteReason] = None + """The reason why the turn is complete. + + Only used for streaming mode. + """ + finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 58aace30ed..6d28c7a0df 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1262,3 +1262,200 @@ async def mock_receive_generator(): content_response = next((r for r in responses if r.content), None) assert content_response is not None assert content_response.content == mock_content + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_pending( + gemini_connection, mock_gemini_session +): + """Test that grounding metadata in partial chunks is pending and yielded on full text.""" + grounding_metadata = types.GroundingMetadata( + web_search_queries=['stock price of google'], + ) + + def make_msg(text=None, g_meta=None, tc=False): + msg = mock.Mock( + usage_metadata=None, + tool_call=None, + session_resumption_update=None, + go_away=None, + ) + msg.server_content = mock.Mock( + interrupted=False, + input_transcription=None, + output_transcription=None, + generation_complete=False, + turn_complete=tc, + grounding_metadata=g_meta, + model_turn=types.Content( + role='model', parts=[types.Part.from_text(text=text)] + ) + if text + else None, + ) + return msg + + msg1 = make_msg(text='hello', g_meta=grounding_metadata) + msg2 = make_msg(text=' world') + msg3 = make_msg(tc=True) + + async def gen(): + yield msg1 + yield msg2 + yield msg3 + + mock_gemini_session.receive = mock.Mock(return_value=gen()) + + responses = [resp async for resp in gemini_connection.receive()] + + # Expected responses: + # 1. Msg 1 partial (hello) with grounding_metadata + # 2. Msg 2 partial ( world) without grounding_metadata + # 3. Full text response (hello world) with PENDING grounding_metadata + # 4. Turn complete response without grounding_metadata (already cleared) + assert len(responses) == 4 + + assert responses[0].content.parts[0].text == 'hello' + assert responses[0].partial is True + assert responses[0].grounding_metadata == grounding_metadata + + assert responses[1].content.parts[0].text == ' world' + assert responses[1].partial is True + assert responses[1].grounding_metadata is None + + assert responses[2].content.parts[0].text == 'hello world' + assert responses[2].partial is False + assert responses[2].grounding_metadata == grounding_metadata + + assert responses[3].turn_complete is True + assert responses[3].grounding_metadata is None + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = True + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].turn_complete is True + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_standalone_grounding( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata.""" + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = None + mock_server_content.grounding_metadata = mock.create_autospec( + types.GroundingMetadata, instance=True + ) + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].grounding_metadata is not None + assert responses[0].turn_complete is None + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + ) + + +@pytest.mark.asyncio +async def test_receive_populates_turn_complete_reason_with_content( + gemini_connection, mock_gemini_session +): + """Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts.""" + mock_content = types.Content( + role='model', + parts=[types.Part.from_text(text='hello')], + ) + mock_server_content = mock.create_autospec( + types.LiveServerContent, instance=True + ) + mock_server_content.model_turn = mock_content + mock_server_content.grounding_metadata = None + mock_server_content.turn_complete = False + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.generation_complete = False + mock_server_content.turn_complete_reason = ( + types.TurnCompleteReason.RESPONSE_REJECTED + ) + + mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + mock_message.go_away = None + + async def mock_receive_generator(): + yield mock_message + + mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].content == mock_content + assert ( + responses[0].turn_complete_reason + == types.TurnCompleteReason.RESPONSE_REJECTED + )