feat: add citation_metadata to LlmResponse

PiperOrigin-RevId: 811997009
This commit is contained in:
Google Team Member
2025-09-26 16:30:21 -07:00
committed by Copybara-Service
parent 7b707cebea
commit 3f28e30c6d
2 changed files with 86 additions and 0 deletions
+8
View File
@@ -126,6 +126,12 @@ class LlmResponse(BaseModel):
This field is automatically populated when context caching is enabled.
"""
citation_metadata: Optional[types.CitationMetadata] = None
"""Citation metadata for the response.
This field is automatically populated when citation is enabled.
"""
@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,
@@ -148,6 +154,7 @@ class LlmResponse(BaseModel):
grounding_metadata=candidate.grounding_metadata,
usage_metadata=usage_metadata,
finish_reason=candidate.finish_reason,
citation_metadata=candidate.citation_metadata,
avg_logprobs=candidate.avg_logprobs,
logprobs_result=candidate.logprobs_result,
)
@@ -155,6 +162,7 @@ class LlmResponse(BaseModel):
return LlmResponse(
error_code=candidate.finish_reason,
error_message=candidate.finish_message,
citation_metadata=candidate.citation_metadata,
usage_metadata=usage_metadata,
finish_reason=candidate.finish_reason,
avg_logprobs=candidate.avg_logprobs,
@@ -239,3 +239,81 @@ def test_llm_response_create_with_partial_logprobs_result():
assert len(response.logprobs_result.top_candidates) == 0
assert response.logprobs_result.chosen_candidates[0].token == 'Hello'
assert response.logprobs_result.chosen_candidates[1].token == ' world'
def test_llm_response_create_with_citation_metadata():
"""Test LlmResponse.create() extracts citation_metadata from candidate."""
citation_metadata = types.CitationMetadata(
citations=[
types.Citation(
start_index=0,
end_index=10,
uri='https://example.com',
)
]
)
generate_content_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text='Response text')]),
finish_reason=types.FinishReason.STOP,
citation_metadata=citation_metadata,
)
]
)
response = LlmResponse.create(generate_content_response)
assert response.citation_metadata == citation_metadata
assert response.content.parts[0].text == 'Response text'
def test_llm_response_create_without_citation_metadata():
"""Test LlmResponse.create() handles missing citation_metadata gracefully."""
generate_content_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(parts=[types.Part(text='Response text')]),
finish_reason=types.FinishReason.STOP,
citation_metadata=None,
)
]
)
response = LlmResponse.create(generate_content_response)
assert response.citation_metadata is None
assert response.content.parts[0].text == 'Response text'
def test_llm_response_create_error_case_with_citation_metadata():
"""Test LlmResponse.create() includes citation_metadata in error cases."""
citation_metadata = types.CitationMetadata(
citations=[
types.Citation(
start_index=0,
end_index=10,
uri='https://example.com',
)
]
)
generate_content_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=None, # No content - blocked case
finish_reason=types.FinishReason.RECITATION,
finish_message='Response blocked due to recitation triggered',
citation_metadata=citation_metadata,
)
]
)
response = LlmResponse.create(generate_content_response)
assert response.citation_metadata == citation_metadata
assert response.error_code == types.FinishReason.RECITATION
assert (
response.error_message == 'Response blocked due to recitation triggered'
)