From b28721508a41bf6bcfef52bbc042fb6193a32dfa Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 6 Jan 2026 17:55:43 -0800 Subject: [PATCH] feat: make LlmAgent.model optional with a default fallback LlmAgent now resolves model from ancestors or a system default (gemini-2.5-flash) when unset. Added LlmAgent.set_default_model() to override the default globally Co-authored-by: George Weale PiperOrigin-RevId: 853006116 --- .../agents/config_schemas/AgentConfig.json | 4 +-- src/google/adk/agents/llm_agent.py | 32 +++++++++++++++++-- src/google/adk/agents/llm_agent_config.py | 5 +-- .../llm_flows/_output_schema_processor.py | 2 +- .../adk/flows/llm_flows/base_llm_flow.py | 6 +++- src/google/adk/flows/llm_flows/basic.py | 12 ++----- .../flows/llm_flows/interactions_processor.py | 5 +-- .../unittests/agents/test_llm_agent_fields.py | 23 ++++++++++--- .../flows/llm_flows/test_basic_processor.py | 8 +++-- .../llm_flows/test_output_schema_processor.py | 4 ++- 10 files changed, 73 insertions(+), 28 deletions(-) diff --git a/src/google/adk/agents/config_schemas/AgentConfig.json b/src/google/adk/agents/config_schemas/AgentConfig.json index e2f353de..f912cefd 100644 --- a/src/google/adk/agents/config_schemas/AgentConfig.json +++ b/src/google/adk/agents/config_schemas/AgentConfig.json @@ -2461,7 +2461,7 @@ } ], "default": null, - "description": "Optional. LlmAgent.model. If not set, the model will be inherited from the ancestor.", + "description": "Optional. LlmAgent.model. Provide a model name string (e.g. \"gemini-2.0-flash\"). If not set, the model will be inherited from the ancestor or fall back to the system default (gemini-2.5-flash unless overridden via LlmAgent.set_default_model). To construct a model instance from code, use model_code.", "title": "Model" }, "instruction": { @@ -4601,4 +4601,4 @@ } ], "title": "AgentConfig" -} \ No newline at end of file +} diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5abaef58..8d85e48a 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -183,10 +183,18 @@ async def _convert_tool_union_to_tools( class LlmAgent(BaseAgent): """LLM-based Agent.""" + DEFAULT_MODEL: ClassVar[str] = 'gemini-2.5-flash' + """System default model used when no model is set on an agent.""" + + _default_model: ClassVar[Union[str, BaseLlm]] = DEFAULT_MODEL + """Current default model used when an agent has no model set.""" + model: Union[str, BaseLlm] = '' """The model to use for the agent. - When not set, the agent will inherit the model from its ancestor. + When not set, the agent will inherit the model from its ancestor. If no + ancestor provides a model, the agent uses the default model configured via + LlmAgent.set_default_model. The built-in default is gemini-2.5-flash. """ config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig @@ -503,7 +511,24 @@ class LlmAgent(BaseAgent): if isinstance(ancestor_agent, LlmAgent): return ancestor_agent.canonical_model ancestor_agent = ancestor_agent.parent_agent - raise ValueError(f'No model found for {self.name}.') + return self._resolve_default_model() + + @classmethod + def set_default_model(cls, model: Union[str, BaseLlm]) -> None: + """Overrides the default model used when an agent has no model set.""" + if not isinstance(model, (str, BaseLlm)): + raise TypeError('Default model must be a model name or BaseLlm.') + if isinstance(model, str) and not model: + raise ValueError('Default model must be a non-empty string.') + cls._default_model = model + + @classmethod + def _resolve_default_model(cls) -> BaseLlm: + """Resolves the current default model to a BaseLlm instance.""" + default_model = cls._default_model + if isinstance(default_model, BaseLlm): + return default_model + return LLMRegistry.new_llm(default_model) async def canonical_instruction( self, ctx: ReadonlyContext @@ -575,10 +600,11 @@ class LlmAgent(BaseAgent): # because the built-in tools cannot be used together with other tools. # TODO(b/448114567): Remove once the workaround is no longer needed. multiple_tools = len(self.tools) > 1 + model = self.canonical_model for tool_union in self.tools: resolved_tools.extend( await _convert_tool_union_to_tools( - tool_union, ctx, self.model, multiple_tools + tool_union, ctx, model, multiple_tools ) ) return resolved_tools diff --git a/src/google/adk/agents/llm_agent_config.py b/src/google/adk/agents/llm_agent_config.py index 59c6d588..160152df 100644 --- a/src/google/adk/agents/llm_agent_config.py +++ b/src/google/adk/agents/llm_agent_config.py @@ -56,8 +56,9 @@ class LlmAgentConfig(BaseAgentConfig): description=( 'Optional. LlmAgent.model. Provide a model name string (e.g.' ' "gemini-2.0-flash"). If not set, the model will be inherited from' - ' the ancestor. To construct a model instance from code, use' - ' model_code.' + ' the ancestor or fall back to the system default (gemini-2.5-flash' + ' unless overridden via LlmAgent.set_default_model). To construct a' + ' model instance from code, use model_code.' ), ) diff --git a/src/google/adk/flows/llm_flows/_output_schema_processor.py b/src/google/adk/flows/llm_flows/_output_schema_processor.py index 2298c044..271e350f 100644 --- a/src/google/adk/flows/llm_flows/_output_schema_processor.py +++ b/src/google/adk/flows/llm_flows/_output_schema_processor.py @@ -45,7 +45,7 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor): if ( not agent.output_schema or not agent.tools - or can_use_output_schema_with_tools(agent.model) + or can_use_output_schema_with_tools(agent.canonical_model) ): return diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 3f8dd37a..e368d0c6 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -476,7 +476,11 @@ class BaseLlmFlow(ABC): # We may need to wrap some built-in tools if there are other tools # because the built-in tools cannot be used together with other tools. # TODO(b/448114567): Remove once the workaround is no longer needed. + if not agent.tools: + return + multiple_tools = len(agent.tools) > 1 + model = agent.canonical_model for tool_union in agent.tools: tool_context = ToolContext(invocation_context) @@ -492,7 +496,7 @@ class BaseLlmFlow(ABC): tools = await _convert_tool_union_to_tools( tool_union, ReadonlyContext(invocation_context), - agent.model, + model, multiple_tools, ) for tool in tools: diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 1468a7ca..8f28e31e 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -35,15 +35,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor): async def run_async( self, invocation_context: InvocationContext, llm_request: LlmRequest ) -> AsyncGenerator[Event, None]: - from ...agents.llm_agent import LlmAgent - agent = invocation_context.agent - - llm_request.model = ( - agent.canonical_model - if isinstance(agent.canonical_model, str) - else agent.canonical_model.model - ) + model = agent.canonical_model + llm_request.model = model if isinstance(model, str) else model.model llm_request.config = ( agent.generate_content_config.model_copy(deep=True) if agent.generate_content_config @@ -54,7 +48,7 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor): # both output_schema and tools at the same time. see # _output_schema_processor.py for details if agent.output_schema: - if not agent.tools or can_use_output_schema_with_tools(agent.model): + if not agent.tools or can_use_output_schema_with_tools(model): llm_request.set_output_schema(agent.output_schema) llm_request.live_connect_config.response_modalities = ( diff --git a/src/google/adk/flows/llm_flows/interactions_processor.py b/src/google/adk/flows/llm_flows/interactions_processor.py index 461cbb99..7e8a51bd 100644 --- a/src/google/adk/flows/llm_flows/interactions_processor.py +++ b/src/google/adk/flows/llm_flows/interactions_processor.py @@ -53,9 +53,10 @@ class InteractionsRequestProcessor(BaseLlmRequestProcessor): # Only process if using Gemini with interactions API if not isinstance(agent, LlmAgent): return - if not isinstance(agent.model, Gemini): + model = agent.canonical_model + if not isinstance(model, Gemini): return - if not agent.model.use_interactions_api: + if not model.use_interactions_api: return # Extract previous interaction ID from session events previous_interaction_id = self._find_previous_interaction_id( diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 577923f7..ad70adc6 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -52,11 +52,24 @@ async def _create_readonly_context( return ReadonlyContext(invocation_context) -def test_canonical_model_empty(): - agent = LlmAgent(name='test_agent') - - with pytest.raises(ValueError): - _ = agent.canonical_model +@pytest.mark.parametrize( + ('default_model', 'expected_model_name', 'expected_model_type'), + [ + (LlmAgent.DEFAULT_MODEL, LlmAgent.DEFAULT_MODEL, Gemini), + ('gemini-2.0-flash', 'gemini-2.0-flash', Gemini), + ], +) +def test_canonical_model_default_fallback( + default_model, expected_model_name, expected_model_type +): + original_default = LlmAgent._default_model + LlmAgent.set_default_model(default_model) + try: + agent = LlmAgent(name='test_agent') + assert isinstance(agent.canonical_model, expected_model_type) + assert agent.canonical_model.model == expected_model_name + finally: + LlmAgent.set_default_model(original_default) def test_canonical_model_str(): diff --git a/tests/unittests/flows/llm_flows/test_basic_processor.py b/tests/unittests/flows/llm_flows/test_basic_processor.py index e0be7781..7bb40b92 100644 --- a/tests/unittests/flows/llm_flows/test_basic_processor.py +++ b/tests/unittests/flows/llm_flows/test_basic_processor.py @@ -110,7 +110,9 @@ class TestBasicLlmRequestProcessor: assert llm_request.config.response_mime_type != 'application/json' # Should have checked if output schema can be used with tools - can_use_output_schema_with_tools.assert_called_once_with(agent.model) + can_use_output_schema_with_tools.assert_called_once_with( + agent.canonical_model + ) @pytest.mark.asyncio async def test_sets_output_schema_when_tools_present(self, mocker): @@ -141,7 +143,9 @@ class TestBasicLlmRequestProcessor: assert llm_request.config.response_mime_type == 'application/json' # Should have checked if output schema can be used with tools - can_use_output_schema_with_tools.assert_called_once_with(agent.model) + can_use_output_schema_with_tools.assert_called_once_with( + agent.canonical_model + ) @pytest.mark.asyncio async def test_no_output_schema_no_tools(self): diff --git a/tests/unittests/flows/llm_flows/test_output_schema_processor.py b/tests/unittests/flows/llm_flows/test_output_schema_processor.py index f7ad8eb3..a870a8b5 100644 --- a/tests/unittests/flows/llm_flows/test_output_schema_processor.py +++ b/tests/unittests/flows/llm_flows/test_output_schema_processor.py @@ -191,7 +191,9 @@ async def test_output_schema_request_processor( assert not llm_request.config.system_instruction # Should have checked if output schema can be used with tools - can_use_output_schema_with_tools.assert_called_once_with(agent.model) + can_use_output_schema_with_tools.assert_called_once_with( + agent.canonical_model + ) @pytest.mark.asyncio