You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: make LlmAgent.model optional with a default fallback
LlmAgent now resolves model from ancestors or a system default (gemini-2.5-flash) when unset. Added LlmAgent.set_default_model() to override the default globally Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 853006116
This commit is contained in:
committed by
Copybara-Service
parent
742c9265a2
commit
b28721508a
@@ -2461,7 +2461,7 @@
|
||||
}
|
||||
],
|
||||
"default": null,
|
||||
"description": "Optional. LlmAgent.model. If not set, the model will be inherited from the ancestor.",
|
||||
"description": "Optional. LlmAgent.model. Provide a model name string (e.g. \"gemini-2.0-flash\"). If not set, the model will be inherited from the ancestor or fall back to the system default (gemini-2.5-flash unless overridden via LlmAgent.set_default_model). To construct a model instance from code, use model_code.",
|
||||
"title": "Model"
|
||||
},
|
||||
"instruction": {
|
||||
@@ -4601,4 +4601,4 @@
|
||||
}
|
||||
],
|
||||
"title": "AgentConfig"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,10 +183,18 @@ async def _convert_tool_union_to_tools(
|
||||
class LlmAgent(BaseAgent):
|
||||
"""LLM-based Agent."""
|
||||
|
||||
DEFAULT_MODEL: ClassVar[str] = 'gemini-2.5-flash'
|
||||
"""System default model used when no model is set on an agent."""
|
||||
|
||||
_default_model: ClassVar[Union[str, BaseLlm]] = DEFAULT_MODEL
|
||||
"""Current default model used when an agent has no model set."""
|
||||
|
||||
model: Union[str, BaseLlm] = ''
|
||||
"""The model to use for the agent.
|
||||
|
||||
When not set, the agent will inherit the model from its ancestor.
|
||||
When not set, the agent will inherit the model from its ancestor. If no
|
||||
ancestor provides a model, the agent uses the default model configured via
|
||||
LlmAgent.set_default_model. The built-in default is gemini-2.5-flash.
|
||||
"""
|
||||
|
||||
config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig
|
||||
@@ -503,7 +511,24 @@ class LlmAgent(BaseAgent):
|
||||
if isinstance(ancestor_agent, LlmAgent):
|
||||
return ancestor_agent.canonical_model
|
||||
ancestor_agent = ancestor_agent.parent_agent
|
||||
raise ValueError(f'No model found for {self.name}.')
|
||||
return self._resolve_default_model()
|
||||
|
||||
@classmethod
|
||||
def set_default_model(cls, model: Union[str, BaseLlm]) -> None:
|
||||
"""Overrides the default model used when an agent has no model set."""
|
||||
if not isinstance(model, (str, BaseLlm)):
|
||||
raise TypeError('Default model must be a model name or BaseLlm.')
|
||||
if isinstance(model, str) and not model:
|
||||
raise ValueError('Default model must be a non-empty string.')
|
||||
cls._default_model = model
|
||||
|
||||
@classmethod
|
||||
def _resolve_default_model(cls) -> BaseLlm:
|
||||
"""Resolves the current default model to a BaseLlm instance."""
|
||||
default_model = cls._default_model
|
||||
if isinstance(default_model, BaseLlm):
|
||||
return default_model
|
||||
return LLMRegistry.new_llm(default_model)
|
||||
|
||||
async def canonical_instruction(
|
||||
self, ctx: ReadonlyContext
|
||||
@@ -575,10 +600,11 @@ class LlmAgent(BaseAgent):
|
||||
# because the built-in tools cannot be used together with other tools.
|
||||
# TODO(b/448114567): Remove once the workaround is no longer needed.
|
||||
multiple_tools = len(self.tools) > 1
|
||||
model = self.canonical_model
|
||||
for tool_union in self.tools:
|
||||
resolved_tools.extend(
|
||||
await _convert_tool_union_to_tools(
|
||||
tool_union, ctx, self.model, multiple_tools
|
||||
tool_union, ctx, model, multiple_tools
|
||||
)
|
||||
)
|
||||
return resolved_tools
|
||||
|
||||
@@ -56,8 +56,9 @@ class LlmAgentConfig(BaseAgentConfig):
|
||||
description=(
|
||||
'Optional. LlmAgent.model. Provide a model name string (e.g.'
|
||||
' "gemini-2.0-flash"). If not set, the model will be inherited from'
|
||||
' the ancestor. To construct a model instance from code, use'
|
||||
' model_code.'
|
||||
' the ancestor or fall back to the system default (gemini-2.5-flash'
|
||||
' unless overridden via LlmAgent.set_default_model). To construct a'
|
||||
' model instance from code, use model_code.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
|
||||
if (
|
||||
not agent.output_schema
|
||||
or not agent.tools
|
||||
or can_use_output_schema_with_tools(agent.model)
|
||||
or can_use_output_schema_with_tools(agent.canonical_model)
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
@@ -476,7 +476,11 @@ class BaseLlmFlow(ABC):
|
||||
# We may need to wrap some built-in tools if there are other tools
|
||||
# because the built-in tools cannot be used together with other tools.
|
||||
# TODO(b/448114567): Remove once the workaround is no longer needed.
|
||||
if not agent.tools:
|
||||
return
|
||||
|
||||
multiple_tools = len(agent.tools) > 1
|
||||
model = agent.canonical_model
|
||||
for tool_union in agent.tools:
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
@@ -492,7 +496,7 @@ class BaseLlmFlow(ABC):
|
||||
tools = await _convert_tool_union_to_tools(
|
||||
tool_union,
|
||||
ReadonlyContext(invocation_context),
|
||||
agent.model,
|
||||
model,
|
||||
multiple_tools,
|
||||
)
|
||||
for tool in tools:
|
||||
|
||||
@@ -35,15 +35,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
|
||||
llm_request.model = (
|
||||
agent.canonical_model
|
||||
if isinstance(agent.canonical_model, str)
|
||||
else agent.canonical_model.model
|
||||
)
|
||||
model = agent.canonical_model
|
||||
llm_request.model = model if isinstance(model, str) else model.model
|
||||
llm_request.config = (
|
||||
agent.generate_content_config.model_copy(deep=True)
|
||||
if agent.generate_content_config
|
||||
@@ -54,7 +48,7 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
# both output_schema and tools at the same time. see
|
||||
# _output_schema_processor.py for details
|
||||
if agent.output_schema:
|
||||
if not agent.tools or can_use_output_schema_with_tools(agent.model):
|
||||
if not agent.tools or can_use_output_schema_with_tools(model):
|
||||
llm_request.set_output_schema(agent.output_schema)
|
||||
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
|
||||
@@ -53,9 +53,10 @@ class InteractionsRequestProcessor(BaseLlmRequestProcessor):
|
||||
# Only process if using Gemini with interactions API
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
if not isinstance(agent.model, Gemini):
|
||||
model = agent.canonical_model
|
||||
if not isinstance(model, Gemini):
|
||||
return
|
||||
if not agent.model.use_interactions_api:
|
||||
if not model.use_interactions_api:
|
||||
return
|
||||
# Extract previous interaction ID from session events
|
||||
previous_interaction_id = self._find_previous_interaction_id(
|
||||
|
||||
@@ -52,11 +52,24 @@ async def _create_readonly_context(
|
||||
return ReadonlyContext(invocation_context)
|
||||
|
||||
|
||||
def test_canonical_model_empty():
|
||||
agent = LlmAgent(name='test_agent')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = agent.canonical_model
|
||||
@pytest.mark.parametrize(
|
||||
('default_model', 'expected_model_name', 'expected_model_type'),
|
||||
[
|
||||
(LlmAgent.DEFAULT_MODEL, LlmAgent.DEFAULT_MODEL, Gemini),
|
||||
('gemini-2.0-flash', 'gemini-2.0-flash', Gemini),
|
||||
],
|
||||
)
|
||||
def test_canonical_model_default_fallback(
|
||||
default_model, expected_model_name, expected_model_type
|
||||
):
|
||||
original_default = LlmAgent._default_model
|
||||
LlmAgent.set_default_model(default_model)
|
||||
try:
|
||||
agent = LlmAgent(name='test_agent')
|
||||
assert isinstance(agent.canonical_model, expected_model_type)
|
||||
assert agent.canonical_model.model == expected_model_name
|
||||
finally:
|
||||
LlmAgent.set_default_model(original_default)
|
||||
|
||||
|
||||
def test_canonical_model_str():
|
||||
|
||||
@@ -110,7 +110,9 @@ class TestBasicLlmRequestProcessor:
|
||||
assert llm_request.config.response_mime_type != 'application/json'
|
||||
|
||||
# Should have checked if output schema can be used with tools
|
||||
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
|
||||
can_use_output_schema_with_tools.assert_called_once_with(
|
||||
agent.canonical_model
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_output_schema_when_tools_present(self, mocker):
|
||||
@@ -141,7 +143,9 @@ class TestBasicLlmRequestProcessor:
|
||||
assert llm_request.config.response_mime_type == 'application/json'
|
||||
|
||||
# Should have checked if output schema can be used with tools
|
||||
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
|
||||
can_use_output_schema_with_tools.assert_called_once_with(
|
||||
agent.canonical_model
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_output_schema_no_tools(self):
|
||||
|
||||
@@ -191,7 +191,9 @@ async def test_output_schema_request_processor(
|
||||
assert not llm_request.config.system_instruction
|
||||
|
||||
# Should have checked if output schema can be used with tools
|
||||
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
|
||||
can_use_output_schema_with_tools.assert_called_once_with(
|
||||
agent.canonical_model
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user