feat: make LlmAgent.model optional with a default fallback

LlmAgent now resolves model from ancestors or a system default (gemini-2.5-flash) when unset. Added LlmAgent.set_default_model() to override the default globally

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 853006116
This commit is contained in:
George Weale
2026-01-06 17:55:43 -08:00
committed by Copybara-Service
parent 742c9265a2
commit b28721508a
10 changed files with 73 additions and 28 deletions
@@ -2461,7 +2461,7 @@
}
],
"default": null,
"description": "Optional. LlmAgent.model. If not set, the model will be inherited from the ancestor.",
"description": "Optional. LlmAgent.model. Provide a model name string (e.g. \"gemini-2.0-flash\"). If not set, the model will be inherited from the ancestor or fall back to the system default (gemini-2.5-flash unless overridden via LlmAgent.set_default_model). To construct a model instance from code, use model_code.",
"title": "Model"
},
"instruction": {
@@ -4601,4 +4601,4 @@
}
],
"title": "AgentConfig"
}
}
+29 -3
View File
@@ -183,10 +183,18 @@ async def _convert_tool_union_to_tools(
class LlmAgent(BaseAgent):
"""LLM-based Agent."""
DEFAULT_MODEL: ClassVar[str] = 'gemini-2.5-flash'
"""System default model used when no model is set on an agent."""
_default_model: ClassVar[Union[str, BaseLlm]] = DEFAULT_MODEL
"""Current default model used when an agent has no model set."""
model: Union[str, BaseLlm] = ''
"""The model to use for the agent.
When not set, the agent will inherit the model from its ancestor.
When not set, the agent will inherit the model from its ancestor. If no
ancestor provides a model, the agent uses the default model configured via
LlmAgent.set_default_model. The built-in default is gemini-2.5-flash.
"""
config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig
@@ -503,7 +511,24 @@ class LlmAgent(BaseAgent):
if isinstance(ancestor_agent, LlmAgent):
return ancestor_agent.canonical_model
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
return self._resolve_default_model()
@classmethod
def set_default_model(cls, model: Union[str, BaseLlm]) -> None:
"""Overrides the default model used when an agent has no model set."""
if not isinstance(model, (str, BaseLlm)):
raise TypeError('Default model must be a model name or BaseLlm.')
if isinstance(model, str) and not model:
raise ValueError('Default model must be a non-empty string.')
cls._default_model = model
@classmethod
def _resolve_default_model(cls) -> BaseLlm:
"""Resolves the current default model to a BaseLlm instance."""
default_model = cls._default_model
if isinstance(default_model, BaseLlm):
return default_model
return LLMRegistry.new_llm(default_model)
async def canonical_instruction(
self, ctx: ReadonlyContext
@@ -575,10 +600,11 @@ class LlmAgent(BaseAgent):
# because the built-in tools cannot be used together with other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
multiple_tools = len(self.tools) > 1
model = self.canonical_model
for tool_union in self.tools:
resolved_tools.extend(
await _convert_tool_union_to_tools(
tool_union, ctx, self.model, multiple_tools
tool_union, ctx, model, multiple_tools
)
)
return resolved_tools
+3 -2
View File
@@ -56,8 +56,9 @@ class LlmAgentConfig(BaseAgentConfig):
description=(
'Optional. LlmAgent.model. Provide a model name string (e.g.'
' "gemini-2.0-flash"). If not set, the model will be inherited from'
' the ancestor. To construct a model instance from code, use'
' model_code.'
' the ancestor or fall back to the system default (gemini-2.5-flash'
' unless overridden via LlmAgent.set_default_model). To construct a'
' model instance from code, use model_code.'
),
)
@@ -45,7 +45,7 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
if (
not agent.output_schema
or not agent.tools
or can_use_output_schema_with_tools(agent.model)
or can_use_output_schema_with_tools(agent.canonical_model)
):
return
@@ -476,7 +476,11 @@ class BaseLlmFlow(ABC):
# We may need to wrap some built-in tools if there are other tools
# because the built-in tools cannot be used together with other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if not agent.tools:
return
multiple_tools = len(agent.tools) > 1
model = agent.canonical_model
for tool_union in agent.tools:
tool_context = ToolContext(invocation_context)
@@ -492,7 +496,7 @@ class BaseLlmFlow(ABC):
tools = await _convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
agent.model,
model,
multiple_tools,
)
for tool in tools:
+3 -9
View File
@@ -35,15 +35,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
llm_request.model = (
agent.canonical_model
if isinstance(agent.canonical_model, str)
else agent.canonical_model.model
)
model = agent.canonical_model
llm_request.model = model if isinstance(model, str) else model.model
llm_request.config = (
agent.generate_content_config.model_copy(deep=True)
if agent.generate_content_config
@@ -54,7 +48,7 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
# both output_schema and tools at the same time. see
# _output_schema_processor.py for details
if agent.output_schema:
if not agent.tools or can_use_output_schema_with_tools(agent.model):
if not agent.tools or can_use_output_schema_with_tools(model):
llm_request.set_output_schema(agent.output_schema)
llm_request.live_connect_config.response_modalities = (
@@ -53,9 +53,10 @@ class InteractionsRequestProcessor(BaseLlmRequestProcessor):
# Only process if using Gemini with interactions API
if not isinstance(agent, LlmAgent):
return
if not isinstance(agent.model, Gemini):
model = agent.canonical_model
if not isinstance(model, Gemini):
return
if not agent.model.use_interactions_api:
if not model.use_interactions_api:
return
# Extract previous interaction ID from session events
previous_interaction_id = self._find_previous_interaction_id(
@@ -52,11 +52,24 @@ async def _create_readonly_context(
return ReadonlyContext(invocation_context)
def test_canonical_model_empty():
agent = LlmAgent(name='test_agent')
with pytest.raises(ValueError):
_ = agent.canonical_model
@pytest.mark.parametrize(
('default_model', 'expected_model_name', 'expected_model_type'),
[
(LlmAgent.DEFAULT_MODEL, LlmAgent.DEFAULT_MODEL, Gemini),
('gemini-2.0-flash', 'gemini-2.0-flash', Gemini),
],
)
def test_canonical_model_default_fallback(
default_model, expected_model_name, expected_model_type
):
original_default = LlmAgent._default_model
LlmAgent.set_default_model(default_model)
try:
agent = LlmAgent(name='test_agent')
assert isinstance(agent.canonical_model, expected_model_type)
assert agent.canonical_model.model == expected_model_name
finally:
LlmAgent.set_default_model(original_default)
def test_canonical_model_str():
@@ -110,7 +110,9 @@ class TestBasicLlmRequestProcessor:
assert llm_request.config.response_mime_type != 'application/json'
# Should have checked if output schema can be used with tools
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
can_use_output_schema_with_tools.assert_called_once_with(
agent.canonical_model
)
@pytest.mark.asyncio
async def test_sets_output_schema_when_tools_present(self, mocker):
@@ -141,7 +143,9 @@ class TestBasicLlmRequestProcessor:
assert llm_request.config.response_mime_type == 'application/json'
# Should have checked if output schema can be used with tools
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
can_use_output_schema_with_tools.assert_called_once_with(
agent.canonical_model
)
@pytest.mark.asyncio
async def test_no_output_schema_no_tools(self):
@@ -191,7 +191,9 @@ async def test_output_schema_request_processor(
assert not llm_request.config.system_instruction
# Should have checked if output schema can be used with tools
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
can_use_output_schema_with_tools.assert_called_once_with(
agent.canonical_model
)
@pytest.mark.asyncio