You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Expose log probs of candidates in LlmResponse
fixes https://github.com/google/adk-python/issues/2764 PiperOrigin-RevId: 807516910
This commit is contained in:
committed by
Copybara-Service
parent
1ce043a278
commit
f7bd3c111c
@@ -42,6 +42,8 @@ class LlmResponse(BaseModel):
|
||||
custom_metadata: The custom metadata of the LlmResponse.
|
||||
input_transcription: Audio transcription of user input.
|
||||
output_transcription: Audio transcription of model output.
|
||||
avg_logprobs: Average log probability of the generated tokens.
|
||||
logprobs_result: Detailed log probabilities for chosen and top candidate tokens.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -109,6 +111,12 @@ class LlmResponse(BaseModel):
|
||||
output_transcription: Optional[types.Transcription] = None
|
||||
"""Audio transcription of model output."""
|
||||
|
||||
avg_logprobs: Optional[float] = None
|
||||
"""Average log probability of the generated tokens."""
|
||||
|
||||
logprobs_result: Optional[types.LogprobsResult] = None
|
||||
"""Detailed log probabilities for chosen and top candidate tokens."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
@@ -131,6 +139,8 @@ class LlmResponse(BaseModel):
|
||||
grounding_metadata=candidate.grounding_metadata,
|
||||
usage_metadata=usage_metadata,
|
||||
finish_reason=candidate.finish_reason,
|
||||
avg_logprobs=candidate.avg_logprobs,
|
||||
logprobs_result=candidate.logprobs_result,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
@@ -138,6 +148,8 @@ class LlmResponse(BaseModel):
|
||||
error_message=candidate.finish_message,
|
||||
usage_metadata=usage_metadata,
|
||||
finish_reason=candidate.finish_reason,
|
||||
avg_logprobs=candidate.avg_logprobs,
|
||||
logprobs_result=candidate.logprobs_result,
|
||||
)
|
||||
else:
|
||||
if generate_content_response.prompt_feedback:
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for LlmResponse, including log probabilities feature."""
|
||||
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def test_llm_response_create_with_logprobs():
|
||||
"""Test LlmResponse.create() extracts logprobs from candidate."""
|
||||
avg_logprobs = -0.75
|
||||
logprobs_result = types.LogprobsResult(
|
||||
chosen_candidates=[], top_candidates=[]
|
||||
)
|
||||
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text='Response text')]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
avg_logprobs=avg_logprobs,
|
||||
logprobs_result=logprobs_result,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
assert response.avg_logprobs == avg_logprobs
|
||||
assert response.logprobs_result == logprobs_result
|
||||
assert response.content.parts[0].text == 'Response text'
|
||||
assert response.finish_reason == types.FinishReason.STOP
|
||||
|
||||
|
||||
def test_llm_response_create_without_logprobs():
|
||||
"""Test LlmResponse.create() handles missing logprobs gracefully."""
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text='Response text')]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
avg_logprobs=None,
|
||||
logprobs_result=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
assert response.avg_logprobs is None
|
||||
assert response.logprobs_result is None
|
||||
assert response.content.parts[0].text == 'Response text'
|
||||
|
||||
|
||||
def test_llm_response_create_error_case_with_logprobs():
|
||||
"""Test LlmResponse.create() includes logprobs in error cases."""
|
||||
avg_logprobs = -2.1
|
||||
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=None, # No content - error case
|
||||
finish_reason=types.FinishReason.SAFETY,
|
||||
finish_message='Safety filter triggered',
|
||||
avg_logprobs=avg_logprobs,
|
||||
logprobs_result=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
assert response.avg_logprobs == avg_logprobs
|
||||
assert response.logprobs_result is None
|
||||
assert response.error_code == types.FinishReason.SAFETY
|
||||
assert response.error_message == 'Safety filter triggered'
|
||||
|
||||
|
||||
def test_llm_response_create_no_candidates():
|
||||
"""Test LlmResponse.create() with no candidates."""
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[],
|
||||
prompt_feedback=types.GenerateContentResponsePromptFeedback(
|
||||
block_reason=types.BlockedReason.SAFETY,
|
||||
block_reason_message='Prompt blocked for safety',
|
||||
),
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
# No candidates means no logprobs
|
||||
assert response.avg_logprobs is None
|
||||
assert response.logprobs_result is None
|
||||
assert response.error_code == types.BlockedReason.SAFETY
|
||||
assert response.error_message == 'Prompt blocked for safety'
|
||||
|
||||
|
||||
def test_llm_response_create_with_concrete_logprobs_result():
|
||||
"""Test LlmResponse.create() with detailed logprobs_result containing actual token data."""
|
||||
# Create realistic logprobs data
|
||||
chosen_candidates = [
|
||||
types.LogprobsResultCandidate(
|
||||
token='The', log_probability=-0.1, token_id=123
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token=' capital', log_probability=-0.5, token_id=456
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token=' of', log_probability=-0.2, token_id=789
|
||||
),
|
||||
]
|
||||
|
||||
top_candidates = [
|
||||
types.LogprobsResultTopCandidates(
|
||||
candidates=[
|
||||
types.LogprobsResultCandidate(
|
||||
token='The', log_probability=-0.1, token_id=123
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token='A', log_probability=-2.3, token_id=124
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token='This', log_probability=-3.1, token_id=125
|
||||
),
|
||||
]
|
||||
),
|
||||
types.LogprobsResultTopCandidates(
|
||||
candidates=[
|
||||
types.LogprobsResultCandidate(
|
||||
token=' capital', log_probability=-0.5, token_id=456
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token=' city', log_probability=-1.2, token_id=457
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token=' main', log_probability=-2.8, token_id=458
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
avg_logprobs = -0.27 # Average of -0.1, -0.5, -0.2
|
||||
logprobs_result = types.LogprobsResult(
|
||||
chosen_candidates=chosen_candidates, top_candidates=top_candidates
|
||||
)
|
||||
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text='The capital of France is Paris.')]
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
avg_logprobs=avg_logprobs,
|
||||
logprobs_result=logprobs_result,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
assert response.avg_logprobs == avg_logprobs
|
||||
assert response.logprobs_result is not None
|
||||
|
||||
# Test chosen candidates
|
||||
assert len(response.logprobs_result.chosen_candidates) == 3
|
||||
assert response.logprobs_result.chosen_candidates[0].token == 'The'
|
||||
assert response.logprobs_result.chosen_candidates[0].log_probability == -0.1
|
||||
assert response.logprobs_result.chosen_candidates[0].token_id == 123
|
||||
assert response.logprobs_result.chosen_candidates[1].token == ' capital'
|
||||
assert response.logprobs_result.chosen_candidates[1].log_probability == -0.5
|
||||
assert response.logprobs_result.chosen_candidates[1].token_id == 456
|
||||
|
||||
# Test top candidates
|
||||
assert len(response.logprobs_result.top_candidates) == 2
|
||||
assert (
|
||||
len(response.logprobs_result.top_candidates[0].candidates) == 3
|
||||
) # 3 alternatives for first token
|
||||
assert response.logprobs_result.top_candidates[0].candidates[0].token == 'The'
|
||||
assert (
|
||||
response.logprobs_result.top_candidates[0].candidates[0].token_id == 123
|
||||
)
|
||||
assert response.logprobs_result.top_candidates[0].candidates[1].token == 'A'
|
||||
assert (
|
||||
response.logprobs_result.top_candidates[0].candidates[1].token_id == 124
|
||||
)
|
||||
assert (
|
||||
response.logprobs_result.top_candidates[0].candidates[2].token == 'This'
|
||||
)
|
||||
assert (
|
||||
response.logprobs_result.top_candidates[0].candidates[2].token_id == 125
|
||||
)
|
||||
|
||||
|
||||
def test_llm_response_create_with_partial_logprobs_result():
|
||||
"""Test LlmResponse.create() with logprobs_result having only chosen_candidates."""
|
||||
chosen_candidates = [
|
||||
types.LogprobsResultCandidate(
|
||||
token='Hello', log_probability=-0.05, token_id=111
|
||||
),
|
||||
types.LogprobsResultCandidate(
|
||||
token=' world', log_probability=-0.8, token_id=222
|
||||
),
|
||||
]
|
||||
|
||||
logprobs_result = types.LogprobsResult(
|
||||
chosen_candidates=chosen_candidates,
|
||||
top_candidates=[], # Empty top candidates
|
||||
)
|
||||
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(parts=[types.Part(text='Hello world')]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
avg_logprobs=-0.425, # Average of -0.05 and -0.8
|
||||
logprobs_result=logprobs_result,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
response = LlmResponse.create(generate_content_response)
|
||||
|
||||
assert response.avg_logprobs == -0.425
|
||||
assert response.logprobs_result is not None
|
||||
assert len(response.logprobs_result.chosen_candidates) == 2
|
||||
assert len(response.logprobs_result.top_candidates) == 0
|
||||
assert response.logprobs_result.chosen_candidates[0].token == 'Hello'
|
||||
assert response.logprobs_result.chosen_candidates[1].token == ' world'
|
||||
Reference in New Issue
Block a user