feat: Add experimental agent tool simulator

PiperOrigin-RevId: 866100611
This commit is contained in:
Google Team Member
2026-02-05 13:57:11 -08:00
committed by Copybara-Service
parent 3686a3a98f
commit 6645aa07fd
14 changed files with 1146 additions and 0 deletions
@@ -0,0 +1,17 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.tools.agent_simulator.agent_simulator_factory import AgentSimulatorFactory
__all__ = ["AgentSimulator"]
@@ -0,0 +1,158 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from google.genai import types as genai_types
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic_core import ValidationError
class InjectedError(BaseModel):
"""An error to be injected into a tool call."""
injected_http_error_code: int
"""Inject http error code to the tool call. Will present as "error_code"
in the tool response dict."""
error_message: str
"""Inject error message to the tool call. Will present as
"error_message" in the tool response dict."""
class InjectionConfig(BaseModel):
"""Injection configuration for a tool."""
injection_probability: float = 1.0
"""Probability of injecting the injected_value."""
match_args: Optional[Dict[str, Any]] = None
"""Only apply injection if the request matches the match_args.
If match_args is not provided, the injection will be applied to all
requests."""
injected_latency_seconds: float = Field(default=0.0, le=120.0)
"""Inject latency to the tool call. Please note it may not be accurate if │
the interceptor is applied as after tool callback."""
random_seed: Optional[int] = None
"""The random seed to use for this injection."""
injected_error: Optional[InjectedError] = None
"""The injected error."""
injected_response: Optional[Dict[str, Any]] = None
"""The injected response."""
@model_validator(mode="after")
def check_injected_error_or_response(self) -> Self:
"""Checks that either injected_error or injected_response is set."""
if bool(self.injected_error) == bool(self.injected_response):
raise ValueError(
"Either injected_error or injected_response must be set, but not"
" both, and not neither."
)
return self
class MockStrategy(enum.Enum):
"""Mock strategy for a tool."""
MOCK_STRATEGY_UNSPECIFIED = 0
MOCK_STRATEGY_TOOL_SPEC = 1
"""Use tool specifications to mock the tool response."""
MOCK_STRATEGY_TRACING = 2
"""Use provided tracing and tool specifications to mock the tool
response based on llm response. Need to provide tracing path in
command."""
class ToolSimulationConfig(BaseModel):
"""Simulation configuration for a single tool."""
tool_name: str
"""Name of the tool to be simulated."""
injection_configs: List[InjectionConfig] = Field(default_factory=list)
"""Injection configuration for the tool. If provided, the tool will be
injected with the injected_value with the injection_probability first,
the mock_strategy will be applied if no injection config is hit."""
mock_strategy_type: MockStrategy = MockStrategy.MOCK_STRATEGY_UNSPECIFIED
"""The mock strategy to use."""
@model_validator(mode="after")
def check_mock_strategy_type(self) -> Self:
"""Checks that mock_strategy_type is not UNSPECIFIED if no injections."""
if (
not self.injection_configs
and self.mock_strategy_type == MockStrategy.MOCK_STRATEGY_UNSPECIFIED
):
raise ValueError(
"If injection_configs is empty, mock_strategy_type cannot be"
" MOCK_STRATEGY_UNSPECIFIED."
)
return self
class AgentSimulatorConfig(BaseModel):
"""Configuration for AgentSimulator."""
tool_simulation_configs: List[ToolSimulationConfig] = Field(
default_factory=list
)
"""A list of tool simulation configurations."""
simulation_model: str = Field(default="gemini-2.5-flash")
"""The model to use for internal simulator LLM calls (tool analysis, mock responses)."""
simulation_model_configuration: genai_types.GenerateContentConfig = Field(
default_factory=lambda: genai_types.GenerateContentConfig(
thinking_config=genai_types.ThinkingConfig(
include_thoughts=True,
thinking_budget=10240,
)
),
)
"""The configuration for the internal simulator LLM calls."""
tracing_path: Optional[str] = None
"""The path to the tracing file to be used for mocking. Only used if the
mock_strategy_type is MOCK_STRATEGY_TRACING."""
@field_validator("tool_simulation_configs")
@classmethod
def check_tool_simulation_configs(cls, v: List[ToolSimulationConfig]):
"""Checks that tool_simulation_configs is not empty."""
if not v:
raise ValueError("tool_simulation_configs must be provided.")
seen_tool_names = set()
for tool_sim_config in v:
if tool_sim_config.tool_name in seen_tool_names:
raise ValueError(
f"Duplicate tool_name found: {tool_sim_config.tool_name}"
)
seen_tool_names.add(tool_sim_config.tool_name)
return v
@@ -0,0 +1,132 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import concurrent.futures
import logging
import random
import time
from typing import Any
from typing import Dict
from typing import Optional
agent_simulator_logger = logging.getLogger("agent_simulator_logger")
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy as MockStrategyEnum
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
from google.adk.tools.agent_simulator.strategies import base as base_mock_strategies
from google.adk.tools.agent_simulator.strategies import tool_spec_mock_strategy
from google.adk.tools.agent_simulator.tool_connection_analyzer import ToolConnectionAnalyzer
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
from google.adk.tools.base_tool import BaseTool
def _create_mock_strategy(
mock_strategy_type: MockStrategyEnum,
llm_name: str,
llm_config: genai_types.GenerateContentConfig,
) -> base_mock_strategies.MockStrategy:
"""Creates a mock strategy based on the given type."""
if mock_strategy_type == MockStrategyEnum.MOCK_STRATEGY_TOOL_SPEC:
return tool_spec_mock_strategy.ToolSpecMockStrategy(llm_name, llm_config)
if mock_strategy_type == MockStrategyEnum.MOCK_STRATEGY_TRACING:
return base_mock_strategies.TracingMockStrategy()
raise ValueError(f"Unknown mock strategy type: {mock_strategy_type}")
class AgentSimulatorEngine:
"""Core engine to handle the simulation logic."""
def __init__(self, config: AgentSimulatorConfig):
self._config = config
self._tool_sim_configs = {
c.tool_name: c for c in config.tool_simulation_configs
}
self._is_analyzed = False
self._tool_connection_map: Optional[ToolConnectionMap] = None
self._analyzer = ToolConnectionAnalyzer(
llm_name=config.simulation_model,
llm_config=config.simulation_model_configuration,
)
self._state_store = {}
self._random_generator = random.Random()
async def simulate(
self, tool: BaseTool, args: Dict[str, Any], tool_context: Any
) -> Optional[Dict[str, Any]]:
"""Simulates a tool call."""
if tool.name not in self._tool_sim_configs:
return None
tool_sim_config = self._tool_sim_configs[tool.name]
if not self._is_analyzed and any(
c.mock_strategy_type != MockStrategyEnum.MOCK_STRATEGY_UNSPECIFIED
for c in self._config.tool_simulation_configs
):
agent = tool_context._invocation_context.agent
if isinstance(agent, LlmAgent):
tools = await agent.canonical_tools(tool_context)
self._tool_connection_map = await self._analyzer.analyze(tools)
self._is_analyzed = True
for injection_config in tool_sim_config.injection_configs:
if injection_config.match_args:
if not all(
item in args.items() for item in injection_config.match_args.items()
):
continue
if injection_config.random_seed is not None:
self._random_generator.seed(injection_config.random_seed)
if (
self._random_generator.random()
< injection_config.injection_probability
):
time.sleep(injection_config.injected_latency_seconds)
if injection_config.injected_error:
return {
"error_code": (
injection_config.injected_error.injected_http_error_code
),
"error_message": injection_config.injected_error.error_message,
}
if injection_config.injected_response:
return injection_config.injected_response
# If no injection was applied, fall back to the mock strategy.
if (
tool_sim_config.mock_strategy_type
== MockStrategyEnum.MOCK_STRATEGY_UNSPECIFIED
):
agent_simulator_logger.warning(
"Tool '%s' did not hit any injection config and has no mock strategy"
" configured. Returning no-op.",
tool.name,
)
return None
mock_strategy = _create_mock_strategy(
tool_sim_config.mock_strategy_type,
self._config.simulation_model,
self._config.simulation_model_configuration,
)
return await mock_strategy.mock(
tool, args, tool_context, self._tool_connection_map, self._state_store
)
@@ -0,0 +1,71 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Optional
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
from google.adk.tools.base_tool import BaseTool
from ...utils.feature_decorator import experimental
@experimental
class AgentSimulatorFactory:
"""Factory for creating AgentSimulator instances."""
@staticmethod
def create_callback(
config: AgentSimulatorConfig,
) -> Callable[
[BaseTool, Dict[str, Any], Any], Awaitable[Optional[Dict[str, Any]]]
]:
"""Creates a callback function for AgentSimulator.
Args:
config: The configuration for the AgentSimulator.
Returns:
A callable that can be used as a before_tool_callback or after_tool_callback.
"""
simulator_engine = AgentSimulatorEngine(config)
async def _agent_simulator_callback(
tool: BaseTool, args: Dict[str, Any], tool_context: Any
) -> Optional[Dict[str, Any]]:
return await simulator_engine.simulate(tool, args, tool_context)
return _agent_simulator_callback
@staticmethod
def create_plugin(
config: AgentSimulatorConfig,
) -> AgentSimulatorPlugin:
"""Creates an ADK Plugin for AgentSimulator.
Args:
config: The configuration for the AgentSimulator.
Returns:
An instance of AgentSimulatorPlugin that can be used as an ADK plugin.
"""
simulator_engine = AgentSimulatorEngine(config)
return AgentSimulatorPlugin(simulator_engine)
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.plugins import BasePlugin
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
class AgentSimulatorPlugin(BasePlugin):
"""ADK Plugin for AgentSimulator."""
name: str = "AgentSimulator"
def __init__(self, simulator_engine: AgentSimulatorEngine):
self._simulator_engine = simulator_engine
async def before_tool_callback(
self, tool: BaseTool, tool_args: dict[str, Any], tool_context: ToolContext
) -> Optional[Dict[str, Any]]:
"""Invokes the AgentSimulatorEngine before a tool call."""
return await self._simulator_engine.simulate(tool, tool_args, tool_context)
@@ -0,0 +1,13 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,57 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
class MockStrategy:
"""Base class for mock strategies."""
async def mock(
self,
tool: BaseTool,
args: Dict[str, Any],
tool_context: Any,
tool_connection_map: Optional[ToolConnectionMap],
state_store: Dict[str, Any],
) -> Dict[str, Any]:
"""Generates a mock response for a tool call."""
raise NotImplementedError()
class TracingMockStrategy(MockStrategy):
"""Mocks a tool response based on tracing and an LLM."""
def __init__(
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
):
self._llm_name = llm_name
self._llm_config = llm_config
async def mock(
self,
tool: BaseTool,
args: Dict[str, Any],
tool_context: Any,
tool_connection_map: Optional[ToolConnectionMap],
state_store: Dict[str, Any],
) -> Dict[str, Any]:
# TODO: Implement tracing LLM-based mocking.
return {"status": "error", "error_message": "Not implemented"}
@@ -0,0 +1,152 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import concurrent.futures
import json
import re
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.models.llm_request import LlmRequest
from google.adk.models.registry import LLMRegistry
from google.adk.tools.agent_simulator.strategies.base import MockStrategy
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
from google.adk.tools.base_tool import BaseTool
from google.adk.utils.context_utils import Aclosing
from google.genai import types as genai_types
_TOOL_SPEC_MOCK_PROMPT_TEMPLATE = """
You are a stateful tool simulator. Your task is to generate a
realistic JSON response for a tool call, maintaining consistency based
on a shared state.
Here is the map of how tools connect via stateful parameters:
{tool_connection_map_json}
Here is the current state of all stateful parameters:
{state_store_json}
You are now simulating the following tool call:
Tool Name: {tool_name}
Tool Description: {tool_description}
Tool Schema: {tool_schema_json}
Tool Arguments: {tool_arguments_json}
Your instructions:
1. Analyze the tool call. Is it a "creating" or "consuming" tool
based on the connection map?
2. If it's a "consuming" tool, check the provided arguments against
the state store. If an ID is provided that does not exist in the
state, return a realistic error (e.g., a 404 Not Found error).
Otherwise, use the data from the state to generate the response.
3. If it's a "creating" tool, generate a new, unique ID for the
stateful parameter (e.g., a random string for a ticket_id). Include
this new ID in your response. I will then update the state with it.
4. Generate a convincing, valid JSON object that mocks the tool's
response. The response must be only the JSON object, without any
additional text or formatting.
5. The response must start with '{{' and end with '}}'.
"""
class ToolSpecMockStrategy(MockStrategy):
"""Mocks a tool response based on the tool's specification."""
def __init__(
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
):
self._llm_name = llm_name
self._llm_config = llm_config
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(self._llm_name)
self._llm = llm_class(model=self._llm_name)
async def mock(
self,
tool: BaseTool,
args: Dict[str, Any],
tool_context: Any,
tool_connection_map: Optional[ToolConnectionMap],
state_store: Dict[str, Any],
) -> Dict[str, Any]:
declaration = tool._get_declaration()
if not declaration:
return {
"status": "error",
"error_message": "Could not get tool declaration.",
}
tool_connection_map_json = (
json.dumps(tool_connection_map.model_dump(exclude_none=True), indent=2)
if tool_connection_map
else "''"
)
state_store_json = json.dumps(state_store, indent=2)
tool_schema_json = json.dumps(declaration.model_dump(), indent=2)
tool_arguments_json = json.dumps(args, indent=2)
prompt = _TOOL_SPEC_MOCK_PROMPT_TEMPLATE.format(
tool_connection_map_json=tool_connection_map_json,
state_store_json=state_store_json,
tool_name=tool.name,
tool_description=tool.description,
tool_schema_json=tool_schema_json,
tool_arguments_json=tool_arguments_json,
)
request_contents = [
genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
]
request = LlmRequest(
contents=request_contents,
model=self._llm_name,
config=self._llm_config,
generation_config=genai_types.GenerateContentConfig(
response_mime_type="application/json"
),
)
response_text = ""
async with Aclosing(self._llm.generate_content_async(request)) as agen:
async for llm_response in agen:
generated_content: genai_types.Content = llm_response.content
if generated_content.parts:
for part in generated_content.parts:
if part.text:
response_text += part.text
try:
clean_json_text = re.sub(r"^```[a-zA-Z]*\n", "", response_text)
clean_json_text = re.sub(r"\n```$", "", clean_json_text)
mock_response = json.loads(clean_json_text.strip())
# After getting the response, update the state if this was a creating tool.
if tool_connection_map:
for param_info in tool_connection_map.stateful_parameters:
param_name = param_info.parameter_name
if tool.name in param_info.creating_tools:
if param_name in mock_response:
param_value = mock_response[param_name]
if param_name not in state_store:
state_store[param_name] = {}
state_store[param_name][param_value] = mock_response
return mock_response
except json.JSONDecodeError:
return {
"status": "error",
"error_message": "Failed to generate valid JSON mock response.",
"llm_output": response_text,
}
@@ -0,0 +1,121 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import concurrent.futures
import json
import logging
import re
from typing import Any
from typing import Dict
from typing import List
from google.adk.models.llm_request import LlmRequest
from google.adk.models.registry import LLMRegistry
from google.adk.tools.agent_simulator.tool_connection_map import ToolConnectionMap
from google.adk.tools.base_tool import BaseTool
from google.adk.utils.context_utils import Aclosing
from google.genai import types as genai_types
_TOOL_CONNECTION_ANALYSIS_PROMPT_TEMPLATE = """
You are a software architect analyzing a set of tools to understand
how they connect to enable stateful operations. Your task is to
identify parameters that are generated by one tool and consumed by
another.
For example, a "create_ticket" tool might output a "ticket_id", which is
then used as input for "get_ticket" or "close_ticket" tools.
Analyze the following tool schemas:
{tool_schemas_json}
Based on this analysis, generate a JSON object that describes these
stateful parameters. The JSON object should have a single key,
"stateful_parameters", which is a list. Each item in the list
should represent a stateful parameter and have the following keys:
- "parameter_name": The name of the shared parameter (e.g., "ticket_id").
- "creating_tools": A list of tools that generate this parameter.
- "consuming_tools": A list of tools that use this parameter as input.
Return only the raw JSON object.
Your response must start with '{{' and end with '}}'.
"""
class ToolConnectionAnalyzer:
"""
Uses an LLM to analyze stateful connections between tools. For example,
get_ticket will consume a ticket_id created by create_ticket, the analyzer
will create a list of such connections.
"""
def __init__(
self, llm_name: str, llm_config: genai_types.GenerateContentConfig
):
self._llm_name = llm_name
self._llm_config = llm_config
llm_registry = LLMRegistry()
llm_class = llm_registry.resolve(self._llm_name)
self._llm = llm_class(model=self._llm_name)
async def analyze(self, tools: List[BaseTool]) -> ToolConnectionMap:
"""
Analyzes a list of tools and returns a map of their connections.
"""
tool_schemas = [
tool._get_declaration().model_dump(exclude_none=True)
for tool in tools
if tool._get_declaration()
]
tool_schemas_json = json.dumps(tool_schemas, indent=2)
prompt = _TOOL_CONNECTION_ANALYSIS_PROMPT_TEMPLATE.format(
tool_schemas_json=tool_schemas_json
)
request_contents = [
genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
]
request = LlmRequest(
contents=request_contents,
model=self._llm_name,
config=self._llm_config,
generation_config=genai_types.GenerateContentConfig(
response_mime_type="application/json"
),
)
response_text = ""
async with Aclosing(self._llm.generate_content_async(request)) as agen:
async for llm_response in agen:
generated_content: genai_types.Content = llm_response.content
if not generated_content.parts:
continue
for part in generated_content.parts:
if part.text:
response_text += part.text
try:
clean_json_text = re.sub(r"^```[a-zA-Z]*\n", "", response_text)
clean_json_text = re.sub(r"\n```$", "", clean_json_text)
response_json = json.loads(clean_json_text.strip())
except json.JSONDecodeError:
logging.warning(
"Failed to parse tool connection analysis from LLM. Proceeding"
" without connection map. Error: %s\nLLM Output:\n%s",
e,
response_text,
)
return ToolConnectionMap(stateful_parameters=[])
return ToolConnectionMap.model_validate(response_json)
@@ -0,0 +1,39 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from pydantic import BaseModel
class StatefulParameter(BaseModel):
"""Represents a stateful parameter and its connections."""
parameter_name: str
"""The name of the shared parameter (e.g., "ticket_id")."""
creating_tools: List[str]
"""A list of tools that generate this parameter."""
consuming_tools: List[str]
"""A list of tools that use this parameter as input."""
class ToolConnectionMap(BaseModel):
"""Represents the map of tool connections."""
stateful_parameters: List[StatefulParameter]
"""A list of stateful parameters and their connections."""
@@ -0,0 +1,13 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,218 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law of a-specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
from google.adk.tools.agent_simulator.agent_simulator_config import InjectedError
from google.adk.tools.agent_simulator.agent_simulator_config import InjectionConfig
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
from google.adk.tools.agent_simulator.agent_simulator_engine import AgentSimulatorEngine
from google.genai import types as genai_types
import pytest
@patch(
"google.adk.tools.agent_simulator.agent_simulator_engine.ToolConnectionAnalyzer"
)
@patch(
"google.adk.tools.agent_simulator.agent_simulator_engine._create_mock_strategy"
)
@pytest.mark.asyncio
class TestAgentSimulatorEngineSimulate:
"""Test cases for the simulate method of AgentSimulatorEngine."""
async def test_simulate_no_op_for_unconfigured_tool(
self, mock_create_strategy, mock_analyzer
):
"""Test that simulate returns None for a tool not in the config."""
config = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="configured_tool",
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine = AgentSimulatorEngine(config)
mock_tool = MagicMock()
mock_tool.name = "unconfigured_tool"
result = await engine.simulate(mock_tool, {}, MagicMock())
assert result is None
async def test_injection_with_matching_args(
self, mock_create_strategy, mock_analyzer
):
"""Test that an injection is applied when match_args match."""
config = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
injection_configs=[
InjectionConfig(
match_args={"param": "value"},
injected_response={"injected": True},
)
],
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine = AgentSimulatorEngine(config)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
result = await engine.simulate(mock_tool, {"param": "value"}, MagicMock())
assert result == {"injected": True}
async def test_injection_not_applied_with_mismatched_args(
self, mock_create_strategy, mock_analyzer
):
"""Test that an injection is not applied when match_args do not match."""
mock_strategy_instance = MagicMock()
mock_strategy_instance.mock = AsyncMock(return_value={"mocked": True})
mock_create_strategy.return_value = mock_strategy_instance
config = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
injection_configs=[
InjectionConfig(
match_args={"param": "value"},
injected_response={"injected": True},
)
],
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine = AgentSimulatorEngine(config)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
result = await engine.simulate(
mock_tool, {"param": "different_value"}, MagicMock()
)
assert result == {"mocked": True}
mock_create_strategy.assert_called_once_with(
config.tool_simulation_configs[0].mock_strategy_type,
config.simulation_model,
config.simulation_model_configuration,
)
mock_strategy_instance.mock.assert_awaited_once()
async def test_no_op_when_no_injection_hit_and_unspecified_strategy(
self, mock_create_strategy, mock_analyzer, caplog
):
"""Test for no-op and warning when no injection hits and mock strategy is unspecified."""
config = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
injection_configs=[
InjectionConfig(
match_args={"param": "value"},
injected_response={"injected": True},
)
],
mock_strategy_type=MockStrategy.MOCK_STRATEGY_UNSPECIFIED,
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine = AgentSimulatorEngine(config)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
caplog.set_level(logging.WARNING, logger="agent_simulator_logger")
with caplog.at_level(logging.WARNING, logger="agent_simulator_logger"):
result = await engine.simulate(
mock_tool, {"param": "different_value"}, MagicMock()
)
assert result is None
assert (
"did not hit any injection config and has no mock strategy"
in caplog.text
)
mock_create_strategy.assert_not_called()
async def test_injection_with_random_seed_is_deterministic(
self, mock_create_strategy, mock_analyzer
):
"""Test that an injection with a random_seed is deterministic."""
# With seed=42, random.random() is > 0.5, so this will NOT be injected
# and should fall back to the mock strategy.
mock_strategy_instance = MagicMock()
mock_strategy_instance.mock = AsyncMock(return_value={"mocked": True})
mock_create_strategy.return_value = mock_strategy_instance
config_mocked = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
injection_configs=[
InjectionConfig(
injection_probability=0.5,
random_seed=42, # A fixed seed
injected_response={"injected": True},
)
],
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine_mocked = AgentSimulatorEngine(config_mocked)
mock_tool = MagicMock()
mock_tool.name = "test_tool"
result1 = await engine_mocked.simulate(mock_tool, {}, MagicMock())
assert result1 == {"mocked": True}
mock_create_strategy.assert_called_once_with(
config_mocked.tool_simulation_configs[0].mock_strategy_type,
config_mocked.simulation_model,
config_mocked.simulation_model_configuration,
)
mock_strategy_instance.mock.assert_awaited_once()
mock_create_strategy.reset_mock()
mock_strategy_instance.mock.reset_mock()
# With seed=100, random.random() is < 0.5, so this WILL be injected.
config_injected = AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
injection_configs=[
InjectionConfig(
injection_probability=0.5,
random_seed=100, # A different fixed seed
injected_response={"injected": True},
)
],
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
)
],
simulation_model="test-model",
simulation_model_configuration=genai_types.GenerateContentConfig(),
)
engine_injected = AgentSimulatorEngine(config_injected)
result2 = await engine_injected.simulate(mock_tool, {}, MagicMock())
assert result2 == {"injected": True}
mock_create_strategy.assert_not_called()
@@ -0,0 +1,69 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.tools.agent_simulator import AgentSimulatorFactory
from google.adk.tools.agent_simulator.agent_simulator_config import AgentSimulatorConfig
from google.adk.tools.agent_simulator.agent_simulator_config import MockStrategy
from google.adk.tools.agent_simulator.agent_simulator_config import ToolSimulationConfig
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
import pytest
@pytest.mark.asyncio
@patch(
"google.adk.tools.agent_simulator.agent_simulator_factory.AgentSimulatorEngine"
)
class TestAgentSimulatorFactory:
"""Test cases for the AgentSimulator factory class."""
@pytest.fixture
def mock_config(self):
"""Fixture for a basic AgentSimulatorConfig."""
return AgentSimulatorConfig(
tool_simulation_configs=[
ToolSimulationConfig(
tool_name="test_tool",
mock_strategy_type=MockStrategy.MOCK_STRATEGY_TOOL_SPEC,
)
]
)
async def test_create_callback(self, mock_engine_class, mock_config):
"""Test that create_callback returns a valid callable."""
mock_engine_instance = MagicMock()
mock_engine_instance.simulate = AsyncMock(return_value=None)
mock_engine_class.return_value = mock_engine_instance
callback = AgentSimulatorFactory.create_callback(mock_config)
assert callable(callback)
await callback(MagicMock(), {}, MagicMock())
mock_engine_class.assert_called_once_with(mock_config)
mock_engine_instance.simulate.assert_awaited_once()
@patch(
"google.adk.tools.agent_simulator.agent_simulator_factory.AgentSimulatorPlugin"
)
def test_create_plugin(
self, mock_plugin_class, mock_engine_class, mock_config
):
"""Test that create_plugin returns a valid AgentSimulatorPlugin instance."""
plugin = AgentSimulatorFactory.create_plugin(mock_config)
mock_engine_class.assert_called_once_with(mock_config)
mock_plugin_class.assert_called_once_with(mock_engine_class.return_value)
assert plugin == mock_plugin_class.return_value
@@ -0,0 +1,46 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.tools.agent_simulator.agent_simulator_plugin import AgentSimulatorPlugin
import pytest
@pytest.mark.asyncio
class TestAgentSimulatorPlugin:
"""Test cases for the AgentSimulatorPlugin."""
@pytest.fixture
def mock_simulator_engine(self):
"""Fixture for a mock AgentSimulatorEngine."""
engine = MagicMock()
engine.simulate = AsyncMock()
return engine
async def test_before_tool_callback(self, mock_simulator_engine):
"""Test that the before_tool_callback calls the engine's simulate method."""
plugin = AgentSimulatorPlugin(mock_simulator_engine)
mock_tool = MagicMock()
mock_args = {}
mock_context = MagicMock()
await plugin.before_tool_callback(mock_tool, mock_args, mock_context)
mock_simulator_engine.simulate.assert_awaited_once_with(
mock_tool, mock_args, mock_context
)