diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index c1f7cef91f..4441deb726 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -550,12 +550,9 @@ def _convert_google_genai_response_to_chatmessage(response: types.GenerateConten usage["thoughts_token_count"] = usage_metadata.thoughts_token_count # Add cached content token count if available (implicit or explicit context caching) - if ( - usage_metadata - and hasattr(usage_metadata, "cached_content_token_count") - and usage_metadata.cached_content_token_count - ): - usage["cached_content_token_count"] = usage_metadata.cached_content_token_count + cached_content_token_count = getattr(usage_metadata, "cached_content_token_count", None) if usage_metadata else None + if cached_content_token_count is not None: + usage["cached_content_token_count"] = cached_content_token_count usage.update(_convert_usage_metadata_to_serializable(usage_metadata)) @@ -625,6 +622,11 @@ def _convert_google_chunk_to_streaming_chunk( if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: usage["thoughts_token_count"] = usage_metadata.thoughts_token_count + # Add cached content token count if available (context caching) + cached_content_token_count = getattr(usage_metadata, "cached_content_token_count", None) if usage_metadata else None + if cached_content_token_count is not None: + usage["cached_content_token_count"] = cached_content_token_count + if candidate.content and candidate.content.parts: tc_index = -1 for part_index, part in enumerate(candidate.content.parts): @@ -717,6 +719,7 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> reasoning_text_parts: list[str] = [] thought_signatures: list[dict[str, Any]] = [] thoughts_token_count = None + cached_content_token_count = None for chunk in chunks: # Extract reasoning from the StreamingChunk.reasoning field @@ -731,11 +734,13 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> # We'll keep the last set of signatures as they represent the complete state thought_signatures = signature_deltas - # Extract thinking token usage (from the last chunk that has it) + # Extract token usage metadata (from the last chunk that has it) if chunk.meta and "usage" in chunk.meta: chunk_usage = chunk.meta["usage"] if "thoughts_token_count" in chunk_usage: thoughts_token_count = chunk_usage["thoughts_token_count"] + if "cached_content_token_count" in chunk_usage: + cached_content_token_count = chunk_usage["cached_content_token_count"] # Add thinking token count to usage if present if thoughts_token_count is not None and "usage" in message.meta: @@ -743,6 +748,12 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> message.meta["usage"] = {} message.meta["usage"]["thoughts_token_count"] = thoughts_token_count + # Add cached content token count to usage if present + if cached_content_token_count is not None and "usage" in message.meta: + if message.meta["usage"] is None: + message.meta["usage"] = {} + message.meta["usage"]["cached_content_token_count"] = cached_content_token_count + # Add thought signatures to meta if present (for multi-turn context preservation) if thought_signatures: message.meta["thought_signatures"] = thought_signatures diff --git a/integrations/google_genai/tests/test_chat_generator_utils.py b/integrations/google_genai/tests/test_chat_generator_utils.py index 52a77f80f5..6bf12b06ec 100644 --- a/integrations/google_genai/tests/test_chat_generator_utils.py +++ b/integrations/google_genai/tests/test_chat_generator_utils.py @@ -702,6 +702,76 @@ def test_aggregate_streaming_chunks_with_thought_signatures_and_thinking_tokens( assert "thought_signatures" in result.meta assert result.meta["thought_signatures"][0]["signature"] == "sig_xyz" + def test_convert_google_chunk_to_streaming_chunk_with_cached_tokens(self, monkeypatch): + """cached_content_token_count from usage_metadata is included in the streaming chunk's usage.""" + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component_info = ComponentInfo.from_component(GoogleGenAIChatGenerator()) + + mock_usage = Mock() + mock_usage.prompt_token_count = 1000 + mock_usage.candidates_token_count = 10 + mock_usage.total_token_count = 1010 + mock_usage.thoughts_token_count = None + mock_usage.cached_content_token_count = 800 + + mock_part = Mock() + mock_part.text = "The answer is 4." + mock_part.function_call = None + mock_part.thought = False + mock_part.thought_signature = None + mock_content = Mock() + mock_content.parts = [mock_part] + mock_candidate = Mock() + mock_candidate.content = mock_content + mock_candidate.finish_reason = "STOP" + + mock_chunk = Mock() + mock_chunk.candidates = [mock_candidate] + mock_chunk.usage_metadata = mock_usage + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, + index=0, + component_info=component_info, + model="gemini-2.5-flash", + ) + + assert chunk.meta["usage"]["prompt_tokens"] == 1000 + assert chunk.meta["usage"]["completion_tokens"] == 10 + assert chunk.meta["usage"]["total_tokens"] == 1010 + assert chunk.meta["usage"]["cached_content_token_count"] == 800 + + def test_aggregate_streaming_chunks_with_cached_tokens(self, monkeypatch): + """cached_content_token_count from the final chunk is propagated to the aggregated message.""" + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component_info = ComponentInfo.from_component(GoogleGenAIChatGenerator()) + + chunk1 = StreamingChunk( + content="Hello", + component_info=component_info, + index=0, + meta={"usage": {"prompt_tokens": 1000, "completion_tokens": 5, "total_tokens": 1005}}, + ) + final_chunk = StreamingChunk( + content=" world", + component_info=component_info, + index=1, + meta={ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 10, + "total_tokens": 1010, + "cached_content_token_count": 800, + }, + "model": "gemini-2.5-flash", + }, + ) + + result = _aggregate_streaming_chunks_with_reasoning([chunk1, final_chunk]) + + assert result.text == "Hello world" + assert result.meta["usage"]["cached_content_token_count"] == 800 + class TestConvertMessageToGoogleGenAI: def test_convert_message_to_google_genai_format_complex(self):