You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add experimental agent tool simulator
PiperOrigin-RevId: 866100611
This commit is contained in:
committed by
Copybara-Service
parent
3686a3a98f
commit
6645aa07fd
@@ -0,0 +1,17 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.tools.agent_simulator.agent_simulator_factory import AgentSimulatorFactory
|
||||
|
||||
__all__ = ["AgentSimulator"]
|
||||
@@ -0,0 +1,158 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
|
||||
class InjectedError(BaseModel):
|
||||
"""An error to be injected into a tool call."""
|
||||
|
||||
injected_http_error_code: int
|
||||
"""Inject http error code to the tool call. Will present as "error_code"
|
||||
in the tool response dict."""
|
||||
|
||||
error_message: str
|
||||
"""Inject error message to the tool call. Will present as
|
||||
"error_message" in the tool response dict."""
|
||||
|
||||
|
||||
class InjectionConfig(BaseModel):
|
||||
"""Injection configuration for a tool."""
|
||||
|
||||
injection_probability: float = 1.0
|
||||
"""Probability of injecting the injected_value."""
|
||||
|
||||
match_args: Optional[Dict[str, Any]] = None
|
||||
"""Only apply injection if the request matches the match_args.
|
||||
If match_args is not provided, the injection will be applied to all
|
||||
requests."""
|
||||
|
||||
injected_latency_seconds: float = Field(default=0.0, le=120.0)
|
||||
"""Inject latency to the tool call. Please note it may not be accurate if │
|
||||
the interceptor is applied as after tool callback."""
|
||||
|
||||
random_seed: Optional[int] = None
|
||||
"""The random seed to use for this injection."""
|
||||
|
||||
injected_error: Optional[InjectedError] = None
|
||||
"""The injected error."""
|
||||
|
||||
injected_response: Optional[Dict[str, Any]] = None
|
||||
"""The injected response."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_injected_error_or_response(self) -> Self:
|
||||
"""Checks that either injected_error or injected_response is set."""
|
||||
if bool(self.injected_error) == bool(self.injected_response):
|
||||
raise ValueError(
|
||||
"Either injected_error or injected_response must be set, but not"
|
||||
" both, and not neither."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class MockStrategy(enum.Enum):
|
||||
"""Mock strategy for a tool."""
|
||||
|
||||
MOCK_STRATEGY_UNSPECIFIED = 0
|
||||
|
||||
MOCK_STRATEGY_TOOL_SPEC = 1
|
||||
"""Use tool specifications to mock the tool response."""
|
||||
|
||||
MOCK_STRATEGY_TRACING = 2
|
||||
"""Use provided tracing and tool specifications to mock the tool
|
||||
response based on llm response. Need to provide tracing path in
|
||||
command."""
|
||||
|
||||
|
||||
class ToolSimulationConfig(BaseModel):
|
||||
"""Simulation configuration for a single tool."""
|
||||
|
||||
tool_name: str
|
||||
"""Name of the tool to be simulated."""
|
||||
|
||||
injection_configs: List[InjectionConfig] = Field(default_factory=list)
|
||||
"""Injection configuration for the tool. If provided, the tool will be
|
||||
injected with the injected_value with the injection_probability first,
|
||||
the mock_strategy will be applied if no injection config is hit."""
|
||||
|
||||
mock_strategy_type: MockStrategy = MockStrategy.MOCK_STRATEGY_UNSPECIFIED
|
||||
"""The mock strategy to use."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mock_strategy_type(self) -> Self:
|
||||
"""Checks that mock_strategy_type is not UNSPECIFIED if no injections."""
|
||||
if (
|
||||
not self.injection_configs
|
||||
and self.mock_strategy_type == MockStrategy.MOCK_STRATEGY_UNSPECIFIED
|
||||
):
|
||||
raise ValueError(
|
||||
"If injection_configs is empty, mock_strategy_type cannot be"
|
||||
" MOCK_STRATEGY_UNSPECIFIED."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentSimulatorConfig(BaseModel):
|
||||
"""Configuration for AgentSimulator."""
|
||||
|
||||
tool_simulation_configs: List[ToolSimulationConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
"""A list of tool simulation configurations."""
|
||||
|
||||
simulation_model: str = Field(default="gemini-2.5-flash")
|
||||
"""The model to use for internal simulator LLM calls (tool analysis, mock responses)."""
|
||||
|
||||
simulation_model_configuration: genai_types.GenerateContentConfig = Field(
|
||||
default_factory=lambda: genai_types.GenerateContentConfig(
|
||||
thinking_config=genai_types.ThinkingConfig(
|
||||
include_thoughts=True,
|
||||
thinking_budget=10240,
|
||||
)
|
||||
),
|
||||
)
|
||||
"""The configuration for the internal simulator LLM calls."""
|
||||
|
||||
tracing_path: Optional[str] = None
|
||||
"""The path to the tracing file to be used for mocking. Only used if the
|
||||
mock_strategy_type is MOCK_STRATEGY_TRACING."""
|
||||
|
||||
@field_validator("tool_simulation_configs")
|
||||
@classmethod
|
||||
def check_tool_simulation_configs(cls, v: List[ToolSimulationConfig]):
|
||||
"""Checks that tool_simulation_configs is not empty."""
|
||||
if not v:
|
||||
raise ValueError("tool_simulation_configs must be provided.")
|
||||
seen_tool_names = set()
|
||||
for tool_sim_config in v:
|
||||
if tool_sim_config.tool_name in seen_tool_names:
|
||||
raise ValueError(
|
||||
f"Duplicate tool_name found: {tool_sim_config.tool_name}"
|
||||
)
|
||||
seen_tool_names.add(tool_sim_config.tool_name)
|
||||
return v
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
agent_simulator_logger = logging.getLogger("agent_simulator_logger")
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy as MockStrategyEnum
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
|
||||
from google.adk.tools.agent_simulator.strategies import base as base_mock_strategies
|
||||
from google.adk.tools.agent_simulator.strategies import tool_spec_mock_strategy
|
||||
from google.adk.tools.agent_simulator.tool_connection_analyzer import ToolConnectionAnalyzer
|
||||
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
def _create_mock_strategy(
|
||||
mock_strategy_type: MockStrategyEnum,
|
||||
llm_name: str,
|
||||
llm_config: genai_types.GenerateContentConfig,
|
||||
) -> base_mock_strategies.MockStrategy:
|
||||
"""Creates a mock strategy based on the given type."""
|
||||
if mock_strategy_type == MockStrategyEnum.MOCK_STRATEGY_TOOL_SPEC:
|
||||
return tool_spec_mock_strategy.ToolSpecMockStrategy(llm_name, llm_config)
|
||||
if mock_strategy_type == MockStrategyEnum.MOCK_STRATEGY_TRACING:
|
||||
return base_mock_strategies.TracingMockStrategy()
|
||||
raise ValueError(f"Unknown mock strategy type: {mock_strategy_type}")
|
||||
|
||||
|
||||
class AgentSimulatorEngine:
|
||||
"""Core engine to handle the simulation logic."""
|
||||
|
||||
def __init__(self, config: AgentSimulatorConfig):
|
||||
self._config = config
|
||||
self._tool_sim_configs = {
|
||||
c.tool_name: c for c in config.tool_simulation_configs
|
||||
}
|
||||
self._is_analyzed = False
|
||||
self._tool_connection_map: Optional[ToolConnectionMap] = None
|
||||
self._analyzer = ToolConnectionAnalyzer(
|
||||
llm_name=config.simulation_model,
|
||||
llm_config=config.simulation_model_configuration,
|
||||
)
|
||||
self._state_store = {}
|
||||
self._random_generator = random.Random()
|
||||
|
||||
async def simulate(
|
||||
self, tool: BaseTool, args: Dict[str, Any], tool_context: Any
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Simulates a tool call."""
|
||||
if tool.name not in self._tool_sim_configs:
|
||||
return None
|
||||
|
||||
tool_sim_config = self._tool_sim_configs[tool.name]
|
||||
|
||||
if not self._is_analyzed and any(
|
||||
c.mock_strategy_type != MockStrategyEnum.MOCK_STRATEGY_UNSPECIFIED
|
||||
for c in self._config.tool_simulation_configs
|
||||
):
|
||||
agent = tool_context._invocation_context.agent
|
||||
if isinstance(agent, LlmAgent):
|
||||
tools = await agent.canonical_tools(tool_context)
|
||||
self._tool_connection_map = await self._analyzer.analyze(tools)
|
||||
self._is_analyzed = True
|
||||
|
||||
for injection_config in tool_sim_config.injection_configs:
|
||||
if injection_config.match_args:
|
||||
if not all(
|
||||
item in args.items() for item in injection_config.match_args.items()
|
||||
):
|
||||
continue
|
||||
|
||||
if injection_config.random_seed is not None:
|
||||
self._random_generator.seed(injection_config.random_seed)
|
||||
|
||||
if (
|
||||
self._random_generator.random()
|
||||
< injection_config.injection_probability
|
||||
):
|
||||
time.sleep(injection_config.injected_latency_seconds)
|
||||
if injection_config.injected_error:
|
||||
return {
|
||||
"error_code": (
|
||||
injection_config.injected_error.injected_http_error_code
|
||||
),
|
||||
"error_message": injection_config.injected_error.error_message,
|
||||
}
|
||||
if injection_config.injected_response:
|
||||
return injection_config.injected_response
|
||||
|
||||
# If no injection was applied, fall back to the mock strategy.
|
||||
if (
|
||||
tool_sim_config.mock_strategy_type
|
||||
== MockStrategyEnum.MOCK_STRATEGY_UNSPECIFIED
|
||||
):
|
||||
agent_simulator_logger.warning(
|
||||
"Tool '%s' did not hit any injection config and has no mock strategy"
|
||||
" configured. Returning no-op.",
|
||||
tool.name,
|
||||
)
|
||||
return None
|
||||
|
||||
mock_strategy = _create_mock_strategy(
|
||||
tool_sim_config.mock_strategy_type,
|
||||
self._config.simulation_model,
|
||||
self._config.simulation_model_configuration,
|
||||
)
|
||||
return await mock_strategy.mock(
|
||||
tool, args, tool_context, self._tool_connection_map, self._state_store
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
|
||||
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
|
||||
from ...utils.feature_decorator import experimental
|
||||
|
||||
|
||||
@experimental
|
||||
class AgentSimulatorFactory:
|
||||
"""Factory for creating AgentSimulator instances."""
|
||||
|
||||
@staticmethod
|
||||
def create_callback(
|
||||
config: AgentSimulatorConfig,
|
||||
) -> Callable[
|
||||
[BaseTool, Dict[str, Any], Any], Awaitable[Optional[Dict[str, Any]]]
|
||||
]:
|
||||
"""Creates a callback function for AgentSimulator.
|
||||
|
||||
Args:
|
||||
config: The configuration for the AgentSimulator.
|
||||
|
||||
Returns:
|
||||
A callable that can be used as a before_tool_callback or after_tool_callback.
|
||||
"""
|
||||
simulator_engine = AgentSimulatorEngine(config)
|
||||
|
||||
async def _agent_simulator_callback(
|
||||
tool: BaseTool, args: Dict[str, Any], tool_context: Any
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await simulator_engine.simulate(tool, args, tool_context)
|
||||
|
||||
return _agent_simulator_callback
|
||||
|
||||
@staticmethod
|
||||
def create_plugin(
|
||||
config: AgentSimulatorConfig,
|
||||
) -> AgentSimulatorPlugin:
|
||||
"""Creates an ADK Plugin for AgentSimulator.
|
||||
|
||||
Args:
|
||||
config: The configuration for the AgentSimulator.
|
||||
|
||||
Returns:
|
||||
An instance of AgentSimulatorPlugin that can be used as an ADK plugin.
|
||||
"""
|
||||
simulator_engine = AgentSimulatorEngine(config)
|
||||
return AgentSimulatorPlugin(simulator_engine)
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.plugins import BasePlugin
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
|
||||
|
||||
class AgentSimulatorPlugin(BasePlugin):
|
||||
"""ADK Plugin for AgentSimulator."""
|
||||
|
||||
name: str = "AgentSimulator"
|
||||
|
||||
def __init__(self, simulator_engine: AgentSimulatorEngine):
|
||||
self._simulator_engine = simulator_engine
|
||||
|
||||
async def before_tool_callback(
|
||||
self, tool: BaseTool, tool_args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Invokes the AgentSimulatorEngine before a tool call."""
|
||||
return await self._simulator_engine.simulate(tool, tool_args, tool_context)
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
|
||||
|
||||
|
||||
class MockStrategy:
|
||||
"""Base class for mock strategies."""
|
||||
|
||||
async def mock(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: Any,
|
||||
tool_connection_map: Optional[ToolConnectionMap],
|
||||
state_store: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Generates a mock response for a tool call."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TracingMockStrategy(MockStrategy):
|
||||
"""Mocks a tool response based on tracing and an LLM."""
|
||||
|
||||
def __init__(
|
||||
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
|
||||
):
|
||||
self._llm_name = llm_name
|
||||
self._llm_config = llm_config
|
||||
|
||||
async def mock(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: Any,
|
||||
tool_connection_map: Optional[ToolConnectionMap],
|
||||
state_store: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
# TODO: Implement tracing LLM-based mocking.
|
||||
return {"status": "error", "error_message": "Not implemented"}
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.registry import LLMRegistry
|
||||
from google.adk.tools.agent_simulator.strategies.base import MockStrategy
|
||||
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.utils.context_utils import Aclosing
|
||||
from google.genai import types as genai_types
|
||||
|
||||
_TOOL_SPEC_MOCK_PROMPT_TEMPLATE = """
|
||||
You are a stateful tool simulator. Your task is to generate a
|
||||
realistic JSON response for a tool call, maintaining consistency based
|
||||
on a shared state.
|
||||
|
||||
Here is the map of how tools connect via stateful parameters:
|
||||
{tool_connection_map_json}
|
||||
|
||||
Here is the current state of all stateful parameters:
|
||||
{state_store_json}
|
||||
|
||||
You are now simulating the following tool call:
|
||||
Tool Name: {tool_name}
|
||||
Tool Description: {tool_description}
|
||||
Tool Schema: {tool_schema_json}
|
||||
Tool Arguments: {tool_arguments_json}
|
||||
|
||||
Your instructions:
|
||||
1. Analyze the tool call. Is it a "creating" or "consuming" tool
|
||||
based on the connection map?
|
||||
2. If it's a "consuming" tool, check the provided arguments against
|
||||
the state store. If an ID is provided that does not exist in the
|
||||
state, return a realistic error (e.g., a 404 Not Found error).
|
||||
Otherwise, use the data from the state to generate the response.
|
||||
3. If it's a "creating" tool, generate a new, unique ID for the
|
||||
stateful parameter (e.g., a random string for a ticket_id). Include
|
||||
this new ID in your response. I will then update the state with it.
|
||||
4. Generate a convincing, valid JSON object that mocks the tool's
|
||||
response. The response must be only the JSON object, without any
|
||||
additional text or formatting.
|
||||
5. The response must start with '{{' and end with '}}'.
|
||||
"""
|
||||
|
||||
|
||||
class ToolSpecMockStrategy(MockStrategy):
|
||||
"""Mocks a tool response based on the tool's specification."""
|
||||
|
||||
def __init__(
|
||||
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
|
||||
):
|
||||
self._llm_name = llm_name
|
||||
self._llm_config = llm_config
|
||||
llm_registry = LLMRegistry()
|
||||
llm_class = llm_registry.resolve(self._llm_name)
|
||||
self._llm = llm_class(model=self._llm_name)
|
||||
|
||||
async def mock(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: Any,
|
||||
tool_connection_map: Optional[ToolConnectionMap],
|
||||
state_store: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
declaration = tool._get_declaration()
|
||||
if not declaration:
|
||||
return {
|
||||
"status": "error",
|
||||
"error_message": "Could not get tool declaration.",
|
||||
}
|
||||
|
||||
tool_connection_map_json = (
|
||||
json.dumps(tool_connection_map.model_dump(exclude_none=True), indent=2)
|
||||
if tool_connection_map
|
||||
else "''"
|
||||
)
|
||||
state_store_json = json.dumps(state_store, indent=2)
|
||||
tool_schema_json = json.dumps(declaration.model_dump(), indent=2)
|
||||
tool_arguments_json = json.dumps(args, indent=2)
|
||||
|
||||
prompt = _TOOL_SPEC_MOCK_PROMPT_TEMPLATE.format(
|
||||
tool_connection_map_json=tool_connection_map_json,
|
||||
state_store_json=state_store_json,
|
||||
tool_name=tool.name,
|
||||
tool_description=tool.description,
|
||||
tool_schema_json=tool_schema_json,
|
||||
tool_arguments_json=tool_arguments_json,
|
||||
)
|
||||
|
||||
request_contents = [
|
||||
genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
|
||||
]
|
||||
request = LlmRequest(
|
||||
contents=request_contents,
|
||||
model=self._llm_name,
|
||||
config=self._llm_config,
|
||||
generation_config=genai_types.GenerateContentConfig(
|
||||
response_mime_type="application/json"
|
||||
),
|
||||
)
|
||||
response_text = ""
|
||||
async with Aclosing(self._llm.generate_content_async(request)) as agen:
|
||||
async for llm_response in agen:
|
||||
generated_content: genai_types.Content = llm_response.content
|
||||
if generated_content.parts:
|
||||
for part in generated_content.parts:
|
||||
if part.text:
|
||||
response_text += part.text
|
||||
|
||||
try:
|
||||
clean_json_text = re.sub(r"^```[a-zA-Z]*\n", "", response_text)
|
||||
clean_json_text = re.sub(r"\n```$", "", clean_json_text)
|
||||
mock_response = json.loads(clean_json_text.strip())
|
||||
# After getting the response, update the state if this was a creating tool.
|
||||
if tool_connection_map:
|
||||
for param_info in tool_connection_map.stateful_parameters:
|
||||
param_name = param_info.parameter_name
|
||||
if tool.name in param_info.creating_tools:
|
||||
if param_name in mock_response:
|
||||
param_value = mock_response[param_name]
|
||||
if param_name not in state_store:
|
||||
state_store[param_name] = {}
|
||||
state_store[param_name][param_value] = mock_response
|
||||
return mock_response
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"status": "error",
|
||||
"error_message": "Failed to generate valid JSON mock response.",
|
||||
"llm_output": response_text,
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.registry import LLMRegistry
|
||||
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.utils.context_utils import Aclosing
|
||||
from google.genai import types as genai_types
|
||||
|
||||
_TOOL_CONNECTION_ANALYSIS_PROMPT_TEMPLATE = """
|
||||
You are a software architect analyzing a set of tools to understand
|
||||
how they connect to enable stateful operations. Your task is to
|
||||
identify parameters that are generated by one tool and consumed by
|
||||
another.
|
||||
|
||||
For example, a "create_ticket" tool might output a "ticket_id", which is
|
||||
then used as input for "get_ticket" or "close_ticket" tools.
|
||||
|
||||
Analyze the following tool schemas:
|
||||
{tool_schemas_json}
|
||||
|
||||
Based on this analysis, generate a JSON object that describes these
|
||||
stateful parameters. The JSON object should have a single key,
|
||||
"stateful_parameters", which is a list. Each item in the list
|
||||
should represent a stateful parameter and have the following keys:
|
||||
- "parameter_name": The name of the shared parameter (e.g., "ticket_id").
|
||||
- "creating_tools": A list of tools that generate this parameter.
|
||||
- "consuming_tools": A list of tools that use this parameter as input.
|
||||
|
||||
Return only the raw JSON object.
|
||||
Your response must start with '{{' and end with '}}'.
|
||||
"""
|
||||
|
||||
|
||||
class ToolConnectionAnalyzer:
|
||||
"""
|
||||
Uses an LLM to analyze stateful connections between tools. For example,
|
||||
get_ticket will consume a ticket_id created by create_ticket, the analyzer
|
||||
will create a list of such connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
|
||||
):
|
||||
self._llm_name = llm_name
|
||||
self._llm_config = llm_config
|
||||
llm_registry = LLMRegistry()
|
||||
llm_class = llm_registry.resolve(self._llm_name)
|
||||
self._llm = llm_class(model=self._llm_name)
|
||||
|
||||
async def analyze(self, tools: List[BaseTool]) -> ToolConnectionMap:
|
||||
"""
|
||||
Analyzes a list of tools and returns a map of their connections.
|
||||
"""
|
||||
tool_schemas = [
|
||||
tool._get_declaration().model_dump(exclude_none=True)
|
||||
for tool in tools
|
||||
if tool._get_declaration()
|
||||
]
|
||||
tool_schemas_json = json.dumps(tool_schemas, indent=2)
|
||||
prompt = _TOOL_CONNECTION_ANALYSIS_PROMPT_TEMPLATE.format(
|
||||
tool_schemas_json=tool_schemas_json
|
||||
)
|
||||
|
||||
request_contents = [
|
||||
genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
|
||||
]
|
||||
request = LlmRequest(
|
||||
contents=request_contents,
|
||||
model=self._llm_name,
|
||||
config=self._llm_config,
|
||||
generation_config=genai_types.GenerateContentConfig(
|
||||
response_mime_type="application/json"
|
||||
),
|
||||
)
|
||||
response_text = ""
|
||||
async with Aclosing(self._llm.generate_content_async(request)) as agen:
|
||||
async for llm_response in agen:
|
||||
generated_content: genai_types.Content = llm_response.content
|
||||
if not generated_content.parts:
|
||||
continue
|
||||
for part in generated_content.parts:
|
||||
if part.text:
|
||||
response_text += part.text
|
||||
|
||||
try:
|
||||
clean_json_text = re.sub(r"^```[a-zA-Z]*\n", "", response_text)
|
||||
clean_json_text = re.sub(r"\n```$", "", clean_json_text)
|
||||
response_json = json.loads(clean_json_text.strip())
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(
|
||||
"Failed to parse tool connection analysis from LLM. Proceeding"
|
||||
" without connection map. Error: %s\nLLM Output:\n%s",
|
||||
e,
|
||||
response_text,
|
||||
)
|
||||
return ToolConnectionMap(stateful_parameters=[])
|
||||
return ToolConnectionMap.model_validate(response_json)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StatefulParameter(BaseModel):
|
||||
"""Represents a stateful parameter and its connections."""
|
||||
|
||||
parameter_name: str
|
||||
"""The name of the shared parameter (e.g., "ticket_id")."""
|
||||
|
||||
creating_tools: List[str]
|
||||
"""A list of tools that generate this parameter."""
|
||||
|
||||
consuming_tools: List[str]
|
||||
"""A list of tools that use this parameter as input."""
|
||||
|
||||
|
||||
class ToolConnectionMap(BaseModel):
|
||||
"""Represents the map of tool connections."""
|
||||
|
||||
stateful_parameters: List[StatefulParameter]
|
||||
"""A list of stateful parameters and their connections."""
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law of a-specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import InjectedError
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import InjectionConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
|
||||
from google.genai import types as genai_types
|
||||
import pytest
|
||||
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.agent_simulator.agent_simulator_engine.ToolConnectionAnalyzer"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.agent_simulator.agent_simulator_engine._create_mock_strategy"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentSimulatorEngineSimulate:
|
||||
"""Test cases for the simulate method of AgentSimulatorEngine."""
|
||||
|
||||
async def test_simulate_no_op_for_unconfigured_tool(
|
||||
self, mock_create_strategy, mock_analyzer
|
||||
):
|
||||
"""Test that simulate returns None for a tool not in the config."""
|
||||
config = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="configured_tool",
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine = AgentSimulatorEngine(config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "unconfigured_tool"
|
||||
result = await engine.simulate(mock_tool, {}, MagicMock())
|
||||
assert result is None
|
||||
|
||||
async def test_injection_with_matching_args(
|
||||
self, mock_create_strategy, mock_analyzer
|
||||
):
|
||||
"""Test that an injection is applied when match_args match."""
|
||||
config = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
injection_configs=[
|
||||
InjectionConfig(
|
||||
match_args={"param": "value"},
|
||||
injected_response={"injected": True},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine = AgentSimulatorEngine(config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
result = await engine.simulate(mock_tool, {"param": "value"}, MagicMock())
|
||||
assert result == {"injected": True}
|
||||
|
||||
async def test_injection_not_applied_with_mismatched_args(
|
||||
self, mock_create_strategy, mock_analyzer
|
||||
):
|
||||
"""Test that an injection is not applied when match_args do not match."""
|
||||
mock_strategy_instance = MagicMock()
|
||||
mock_strategy_instance.mock = AsyncMock(return_value={"mocked": True})
|
||||
mock_create_strategy.return_value = mock_strategy_instance
|
||||
config = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
injection_configs=[
|
||||
InjectionConfig(
|
||||
match_args={"param": "value"},
|
||||
injected_response={"injected": True},
|
||||
)
|
||||
],
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine = AgentSimulatorEngine(config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
result = await engine.simulate(
|
||||
mock_tool, {"param": "different_value"}, MagicMock()
|
||||
)
|
||||
assert result == {"mocked": True}
|
||||
mock_create_strategy.assert_called_once_with(
|
||||
config.tool_simulation_configs[0].mock_strategy_type,
|
||||
config.simulation_model,
|
||||
config.simulation_model_configuration,
|
||||
)
|
||||
mock_strategy_instance.mock.assert_awaited_once()
|
||||
|
||||
async def test_no_op_when_no_injection_hit_and_unspecified_strategy(
|
||||
self, mock_create_strategy, mock_analyzer, caplog
|
||||
):
|
||||
"""Test for no-op and warning when no injection hits and mock strategy is unspecified."""
|
||||
config = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
injection_configs=[
|
||||
InjectionConfig(
|
||||
match_args={"param": "value"},
|
||||
injected_response={"injected": True},
|
||||
)
|
||||
],
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_UNSPECIFIED,
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine = AgentSimulatorEngine(config)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="agent_simulator_logger")
|
||||
with caplog.at_level(logging.WARNING, logger="agent_simulator_logger"):
|
||||
result = await engine.simulate(
|
||||
mock_tool, {"param": "different_value"}, MagicMock()
|
||||
)
|
||||
assert result is None
|
||||
assert (
|
||||
"did not hit any injection config and has no mock strategy"
|
||||
in caplog.text
|
||||
)
|
||||
mock_create_strategy.assert_not_called()
|
||||
|
||||
async def test_injection_with_random_seed_is_deterministic(
|
||||
self, mock_create_strategy, mock_analyzer
|
||||
):
|
||||
"""Test that an injection with a random_seed is deterministic."""
|
||||
# With seed=42, random.random() is > 0.5, so this will NOT be injected
|
||||
# and should fall back to the mock strategy.
|
||||
mock_strategy_instance = MagicMock()
|
||||
mock_strategy_instance.mock = AsyncMock(return_value={"mocked": True})
|
||||
mock_create_strategy.return_value = mock_strategy_instance
|
||||
config_mocked = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
injection_configs=[
|
||||
InjectionConfig(
|
||||
injection_probability=0.5,
|
||||
random_seed=42, # A fixed seed
|
||||
injected_response={"injected": True},
|
||||
)
|
||||
],
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine_mocked = AgentSimulatorEngine(config_mocked)
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "test_tool"
|
||||
|
||||
result1 = await engine_mocked.simulate(mock_tool, {}, MagicMock())
|
||||
assert result1 == {"mocked": True}
|
||||
mock_create_strategy.assert_called_once_with(
|
||||
config_mocked.tool_simulation_configs[0].mock_strategy_type,
|
||||
config_mocked.simulation_model,
|
||||
config_mocked.simulation_model_configuration,
|
||||
)
|
||||
mock_strategy_instance.mock.assert_awaited_once()
|
||||
|
||||
mock_create_strategy.reset_mock()
|
||||
mock_strategy_instance.mock.reset_mock()
|
||||
|
||||
# With seed=100, random.random() is < 0.5, so this WILL be injected.
|
||||
config_injected = AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
injection_configs=[
|
||||
InjectionConfig(
|
||||
injection_probability=0.5,
|
||||
random_seed=100, # A different fixed seed
|
||||
injected_response={"injected": True},
|
||||
)
|
||||
],
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
|
||||
)
|
||||
],
|
||||
simulation_model="test-model",
|
||||
simulation_model_configuration=genai_types.GenerateContentConfig(),
|
||||
)
|
||||
engine_injected = AgentSimulatorEngine(config_injected)
|
||||
result2 = await engine_injected.simulate(mock_tool, {}, MagicMock())
|
||||
assert result2 == {"injected": True}
|
||||
mock_create_strategy.assert_not_called()
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools.agent_simulator import AgentSimulatorFactory
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy
|
||||
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
|
||||
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"google.adk.tools.agent_simulator.agent_simulator_factory.AgentSimulatorEngine"
|
||||
)
|
||||
class TestAgentSimulatorFactory:
|
||||
"""Test cases for the AgentSimulator factory class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Fixture for a basic AgentSimulatorConfig."""
|
||||
return AgentSimulatorConfig(
|
||||
tool_simulation_configs=[
|
||||
ToolSimulationConfig(
|
||||
tool_name="test_tool",
|
||||
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def test_create_callback(self, mock_engine_class, mock_config):
|
||||
"""Test that create_callback returns a valid callable."""
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_engine_instance.simulate = AsyncMock(return_value=None)
|
||||
mock_engine_class.return_value = mock_engine_instance
|
||||
|
||||
callback = AgentSimulatorFactory.create_callback(mock_config)
|
||||
assert callable(callback)
|
||||
await callback(MagicMock(), {}, MagicMock())
|
||||
|
||||
mock_engine_class.assert_called_once_with(mock_config)
|
||||
mock_engine_instance.simulate.assert_awaited_once()
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.agent_simulator.agent_simulator_factory.AgentSimulatorPlugin"
|
||||
)
|
||||
def test_create_plugin(
|
||||
self, mock_plugin_class, mock_engine_class, mock_config
|
||||
):
|
||||
"""Test that create_plugin returns a valid AgentSimulatorPlugin instance."""
|
||||
plugin = AgentSimulatorFactory.create_plugin(mock_config)
|
||||
mock_engine_class.assert_called_once_with(mock_config)
|
||||
mock_plugin_class.assert_called_once_with(mock_engine_class.return_value)
|
||||
assert plugin == mock_plugin_class.return_value
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentSimulatorPlugin:
|
||||
"""Test cases for the AgentSimulatorPlugin."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simulator_engine(self):
|
||||
"""Fixture for a mock AgentSimulatorEngine."""
|
||||
engine = MagicMock()
|
||||
engine.simulate = AsyncMock()
|
||||
return engine
|
||||
|
||||
async def test_before_tool_callback(self, mock_simulator_engine):
|
||||
"""Test that the before_tool_callback calls the engine's simulate method."""
|
||||
plugin = AgentSimulatorPlugin(mock_simulator_engine)
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_args = {}
|
||||
mock_context = MagicMock()
|
||||
|
||||
await plugin.before_tool_callback(mock_tool, mock_args, mock_context)
|
||||
|
||||
mock_simulator_engine.simulate.assert_awaited_once_with(
|
||||
mock_tool, mock_args, mock_context
|
||||
)
|
||||
Reference in New Issue
Block a user