diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index d8d24fac..56eb6318 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -126,6 +126,12 @@ class LlmResponse(BaseModel): This field is automatically populated when context caching is enabled. """ + citation_metadata: Optional[types.CitationMetadata] = None + """Citation metadata for the response. + + This field is automatically populated when citation is enabled. + """ + @staticmethod def create( generate_content_response: types.GenerateContentResponse, @@ -148,6 +154,7 @@ class LlmResponse(BaseModel): grounding_metadata=candidate.grounding_metadata, usage_metadata=usage_metadata, finish_reason=candidate.finish_reason, + citation_metadata=candidate.citation_metadata, avg_logprobs=candidate.avg_logprobs, logprobs_result=candidate.logprobs_result, ) @@ -155,6 +162,7 @@ class LlmResponse(BaseModel): return LlmResponse( error_code=candidate.finish_reason, error_message=candidate.finish_message, + citation_metadata=candidate.citation_metadata, usage_metadata=usage_metadata, finish_reason=candidate.finish_reason, avg_logprobs=candidate.avg_logprobs, diff --git a/tests/unittests/models/test_llm_response.py b/tests/unittests/models/test_llm_response.py index f53de5b6..8ab5e6aa 100644 --- a/tests/unittests/models/test_llm_response.py +++ b/tests/unittests/models/test_llm_response.py @@ -239,3 +239,81 @@ def test_llm_response_create_with_partial_logprobs_result(): assert len(response.logprobs_result.top_candidates) == 0 assert response.logprobs_result.chosen_candidates[0].token == 'Hello' assert response.logprobs_result.chosen_candidates[1].token == ' world' + + +def test_llm_response_create_with_citation_metadata(): + """Test LlmResponse.create() extracts citation_metadata from candidate.""" + citation_metadata = types.CitationMetadata( + citations=[ + types.Citation( + start_index=0, + end_index=10, + uri='https://example.com', + ) + ] + ) + + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text='Response text')]), + finish_reason=types.FinishReason.STOP, + citation_metadata=citation_metadata, + ) + ] + ) + + response = LlmResponse.create(generate_content_response) + + assert response.citation_metadata == citation_metadata + assert response.content.parts[0].text == 'Response text' + + +def test_llm_response_create_without_citation_metadata(): + """Test LlmResponse.create() handles missing citation_metadata gracefully.""" + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(parts=[types.Part(text='Response text')]), + finish_reason=types.FinishReason.STOP, + citation_metadata=None, + ) + ] + ) + + response = LlmResponse.create(generate_content_response) + + assert response.citation_metadata is None + assert response.content.parts[0].text == 'Response text' + + +def test_llm_response_create_error_case_with_citation_metadata(): + """Test LlmResponse.create() includes citation_metadata in error cases.""" + citation_metadata = types.CitationMetadata( + citations=[ + types.Citation( + start_index=0, + end_index=10, + uri='https://example.com', + ) + ] + ) + + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=None, # No content - blocked case + finish_reason=types.FinishReason.RECITATION, + finish_message='Response blocked due to recitation triggered', + citation_metadata=citation_metadata, + ) + ] + ) + + response = LlmResponse.create(generate_content_response) + + assert response.citation_metadata == citation_metadata + assert response.error_code == types.FinishReason.RECITATION + assert ( + response.error_message == 'Response blocked due to recitation triggered' + )