diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 979a68c3..8c680b61 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -132,6 +132,7 @@ class AgentTool(BaseTool): session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), credential_service=tool_context._invocation_context.credential_service, + plugins=list(tool_context._invocation_context.plugin_manager.plugins), ) session = await runner.session_service.create_session( app_name=self.agent.name, diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index d181f72f..1f2a026e 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from google.adk.agents.callback_context import CallbackContext from google.adk.agents.llm_agent import Agent from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.agent_tool import AgentTool from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types @@ -75,6 +80,70 @@ def test_no_schema(): ] +def test_use_plugins(): + """The agent tool can use plugins from parent runner.""" + + class ModelResponseCapturePlugin(BasePlugin): + + def __init__(self): + super().__init__('plugin') + self.model_responses = {} + + async def after_model_callback( + self, + *, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> Optional[LlmResponse]: + response_text = [] + for part in llm_response.content.parts: + if not part.text: + continue + response_text.append(part.text) + if response_text: + if callback_context.agent_name not in self.model_responses: + self.model_responses[callback_context.agent_name] = [] + self.model_responses[callback_context.agent_name].append( + ''.join(response_text) + ) + + mock_model = testing_utils.MockModel.create( + responses=[ + function_call_no_schema, + 'response1', + 'response2', + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent)], + ) + + model_response_capture = ModelResponseCapturePlugin() + runner = testing_utils.InMemoryRunner( + root_agent, plugins=[model_response_capture] + ) + + assert testing_utils.simplify_events(runner.run('test1')) == [ + ('root_agent', function_call_no_schema), + ('root_agent', function_response_no_schema), + ('root_agent', 'response2'), + ] + + # should be able to capture response from both root and tool agent. + assert model_response_capture.model_responses == { + 'tool_agent': ['response1'], + 'root_agent': ['response2'], + } + + def test_update_state(): """The agent tool can read and change parent state."""