You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat(tools): Add an option to disallow propagating runner plugins to AgentTool runner
Merge https://github.com/google/adk-python/pull/2779 Fixes #2780 ### testing plan not available as is doesn't introduce new functionality Co-authored-by: Wei Sun (Jack) <weisun@google.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2779 from davidkl97:feature/agent-tool-plugins a602c808789f3daeed6244e352a6fb8fb6972de3 PiperOrigin-RevId: 835366974
This commit is contained in:
committed by
Copybara-Service
parent
2247a45922
commit
777dba3033
@@ -45,11 +45,22 @@ class AgentTool(BaseTool):
|
||||
Attributes:
|
||||
agent: The agent to wrap.
|
||||
skip_summarization: Whether to skip summarization of the agent output.
|
||||
include_plugins: Whether to propagate plugins from the parent runner context
|
||||
to the agent's runner. When True (default), the agent will inherit all
|
||||
plugins from its parent. Set to False to run the agent with an isolated
|
||||
plugin environment.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent,
|
||||
skip_summarization: bool = False,
|
||||
*,
|
||||
include_plugins: bool = True,
|
||||
):
|
||||
self.agent = agent
|
||||
self.skip_summarization: bool = skip_summarization
|
||||
self.include_plugins = include_plugins
|
||||
|
||||
super().__init__(name=agent.name, description=agent.description)
|
||||
|
||||
@@ -130,6 +141,11 @@ class AgentTool(BaseTool):
|
||||
invocation_context.app_name if invocation_context else None
|
||||
)
|
||||
child_app_name = parent_app_name or self.agent.name
|
||||
plugins = (
|
||||
tool_context._invocation_context.plugin_manager.plugins
|
||||
if self.include_plugins
|
||||
else None
|
||||
)
|
||||
runner = Runner(
|
||||
app_name=child_app_name,
|
||||
agent=self.agent,
|
||||
@@ -137,7 +153,7 @@ class AgentTool(BaseTool):
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
credential_service=tool_context._invocation_context.credential_service,
|
||||
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
state_dict = {
|
||||
@@ -192,7 +208,9 @@ class AgentTool(BaseTool):
|
||||
agent_tool_config.agent, config_abs_path
|
||||
)
|
||||
return cls(
|
||||
agent=agent, skip_summarization=agent_tool_config.skip_summarization
|
||||
agent=agent,
|
||||
skip_summarization=agent_tool_config.skip_summarization,
|
||||
include_plugins=agent_tool_config.include_plugins,
|
||||
)
|
||||
|
||||
|
||||
@@ -204,3 +222,6 @@ class AgentToolConfig(BaseToolConfig):
|
||||
|
||||
skip_summarization: bool = False
|
||||
"""Whether to skip summarization of the agent output."""
|
||||
|
||||
include_plugins: bool = True
|
||||
"""Whether to include plugins from parent runner context."""
|
||||
|
||||
@@ -570,3 +570,112 @@ def test_agent_tool_response_schema_with_input_schema_no_output_vertex_ai(
|
||||
# Should have string response schema for VERTEX_AI when no output_schema
|
||||
assert declaration.response is not None
|
||||
assert declaration.response.type == types.Type.STRING
|
||||
|
||||
|
||||
def test_include_plugins_default_true():
|
||||
"""Test that plugins are propagated by default (include_plugins=True)."""
|
||||
|
||||
# Create a test plugin that tracks callbacks
|
||||
class TrackingPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.before_agent_calls = 0
|
||||
|
||||
async def before_agent_callback(self, **kwargs):
|
||||
self.before_agent_calls += 1
|
||||
|
||||
tracking_plugin = TrackingPlugin(name='tracking')
|
||||
|
||||
mock_model = testing_utils.MockModel.create(
|
||||
responses=[function_call_no_schema, 'response1', 'response2']
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True
|
||||
)
|
||||
|
||||
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
|
||||
runner.run('test1')
|
||||
|
||||
# Plugin should be called for both root_agent and tool_agent
|
||||
assert tracking_plugin.before_agent_calls == 2
|
||||
|
||||
|
||||
def test_include_plugins_explicit_true():
|
||||
"""Test that plugins are propagated when include_plugins=True."""
|
||||
|
||||
class TrackingPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.before_agent_calls = 0
|
||||
|
||||
async def before_agent_callback(self, **kwargs):
|
||||
self.before_agent_calls += 1
|
||||
|
||||
tracking_plugin = TrackingPlugin(name='tracking')
|
||||
|
||||
mock_model = testing_utils.MockModel.create(
|
||||
responses=[function_call_no_schema, 'response1', 'response2']
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent, include_plugins=True)],
|
||||
)
|
||||
|
||||
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
|
||||
runner.run('test1')
|
||||
|
||||
# Plugin should be called for both root_agent and tool_agent
|
||||
assert tracking_plugin.before_agent_calls == 2
|
||||
|
||||
|
||||
def test_include_plugins_false():
|
||||
"""Test that plugins are NOT propagated when include_plugins=False."""
|
||||
|
||||
class TrackingPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.before_agent_calls = 0
|
||||
|
||||
async def before_agent_callback(self, **kwargs):
|
||||
self.before_agent_calls += 1
|
||||
|
||||
tracking_plugin = TrackingPlugin(name='tracking')
|
||||
|
||||
mock_model = testing_utils.MockModel.create(
|
||||
responses=[function_call_no_schema, 'response1', 'response2']
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent, include_plugins=False)],
|
||||
)
|
||||
|
||||
runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
|
||||
runner.run('test1')
|
||||
|
||||
# Plugin should only be called for root_agent, not tool_agent
|
||||
assert tracking_plugin.before_agent_calls == 1
|
||||
|
||||
Reference in New Issue
Block a user