diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 93d82317..8934b131 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -88,7 +88,7 @@ class Gemini(BaseLlm): Yields: LlmResponse: The model response. """ - self._preprocess_request(llm_request) + await self._preprocess_request(llm_request) self._maybe_append_user_content(llm_request) logger.info( 'Sending out request, model: %s, backend: %s, stream: %s', @@ -269,7 +269,22 @@ class Gemini(BaseLlm): ) as live_session: yield GeminiLlmConnection(live_session) - def _preprocess_request(self, llm_request: LlmRequest) -> None: + async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: + """Adapt the google computer use predefined functions to the adk computer use toolset.""" + + from ..tools.computer_use.computer_use_toolset import ComputerUseToolset + + async def convert_wait_to_wait_5_seconds(wait_func): + async def wait_5_seconds(): + return await wait_func(5) + + return wait_5_seconds + + await ComputerUseToolset.adapt_computer_use_tool( + 'wait', convert_wait_to_wait_5_seconds, llm_request + ) + + async def _preprocess_request(self, llm_request: LlmRequest) -> None: if self._api_backend == GoogleLLMVariant.GEMINI_API: # Using API key from Google AI Studio to call model doesn't support labels. @@ -284,6 +299,18 @@ class Gemini(BaseLlm): _remove_display_name_if_present(part.inline_data) _remove_display_name_if_present(part.file_data) + # Initialize config if needed + if llm_request.config and llm_request.config.tools: + # Check if computer use is configured + for tool in llm_request.config.tools: + if ( + isinstance(tool, (types.Tool, types.ToolDict)) + and hasattr(tool, 'computer_use') + and tool.computer_use + ): + llm_request.config.system_instruction = None + await self._adapt_computer_use_tool(llm_request) + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, diff --git a/src/google/adk/tools/computer_use/__init__.py b/src/google/adk/tools/computer_use/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/src/google/adk/tools/computer_use/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/tools/computer_use/base_computer.py b/src/google/adk/tools/computer_use/base_computer.py new file mode 100644 index 00000000..9e4edc82 --- /dev/null +++ b/src/google/adk/tools/computer_use/base_computer.py @@ -0,0 +1,265 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from enum import Enum +from typing import Literal +from typing import Optional + +import pydantic + +from ...utils.feature_decorator import experimental + + +@experimental +class ComputerEnvironment(str, Enum): + """Case insensitive enum for computer environments.""" + + ENVIRONMENT_UNSPECIFIED = "ENVIRONMENT_UNSPECIFIED" + """Defaults to browser.""" + ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER" + """Operates in a web browser.""" + + +@experimental +class ComputerState(pydantic.BaseModel): + """Represents the current state of the computer environment. + + Attributes: + screenshot: The screenshot in PNG format as bytes. + url: The current URL of the webpage being displayed. + """ + + screenshot: bytes = pydantic.Field( + default=None, description="Screenshot in PNG format" + ) + url: Optional[str] = pydantic.Field( + default=None, description="Current webpage URL" + ) + + +@experimental +class BaseComputer(abc.ABC): + """async defines an interface for computer environments. + + This abstract base class async defines the standard interface for controlling + computer environments, including web browsers and other interactive systems. + """ + + @abc.abstractmethod + async def screen_size(self) -> tuple[int, int]: + """Returns the screen size of the environment. + + Returns: + A tuple of (width, height) in pixels. + """ + + @abc.abstractmethod + async def open_web_browser(self) -> ComputerState: + """Opens the web browser. + + Returns: + The current state after opening the browser. + """ + + @abc.abstractmethod + async def click_at(self, x: int, y: int) -> ComputerState: + """Clicks at a specific x, y coordinate on the webpage. + + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to click at. + y: The y-coordinate to click at. + + Returns: + The current state after clicking. + """ + + @abc.abstractmethod + async def hover_at(self, x: int, y: int) -> ComputerState: + """Hovers at a specific x, y coordinate on the webpage. + + May be used to explore sub-menus that appear on hover. + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to hover at. + y: The y-coordinate to hover at. + + Returns: + The current state after hovering. + """ + + @abc.abstractmethod + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> ComputerState: + """Types text at a specific x, y coordinate. + + The system automatically presses ENTER after typing. To disable this, set `press_enter` to False. + The system automatically clears any existing content before typing the specified `text`. To disable this, set `clear_before_typing` to False. + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to type at. + y: The y-coordinate to type at. + text: The text to type. + press_enter: Whether to press ENTER after typing. + clear_before_typing: Whether to clear existing content before typing. + + Returns: + The current state after typing. + """ + + @abc.abstractmethod + async def scroll_document( + self, direction: Literal["up", "down", "left", "right"] + ) -> ComputerState: + """Scrolls the entire webpage "up", "down", "left" or "right" based on direction. + + Args: + direction: The direction to scroll. + + Returns: + The current state after scrolling. + """ + + @abc.abstractmethod + async def scroll_at( + self, + x: int, + y: int, + direction: Literal["up", "down", "left", "right"], + magnitude: int, + ) -> ComputerState: + """Scrolls up, down, right, or left at a x, y coordinate by magnitude. + + The 'x' and 'y' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to scroll at. + y: The y-coordinate to scroll at. + direction: The direction to scroll. + magnitude: The amount to scroll. + + Returns: + The current state after scrolling. + """ + + @abc.abstractmethod + async def wait(self, seconds: int) -> ComputerState: + """Waits for n seconds to allow unfinished webpage processes to complete. + + Args: + seconds: The number of seconds to wait. + + Returns: + The current state after waiting. + """ + + @abc.abstractmethod + async def go_back(self) -> ComputerState: + """Navigates back to the previous webpage in the browser history. + + Returns: + The current state after navigating back. + """ + + @abc.abstractmethod + async def go_forward(self) -> ComputerState: + """Navigates forward to the next webpage in the browser history. + + Returns: + The current state after navigating forward. + """ + + @abc.abstractmethod + async def search(self) -> ComputerState: + """Directly jumps to a search engine home page. + + Used when you need to start with a search. For example, this is used when + the current website doesn't have the information needed or because a new + task is being started. + + Returns: + The current state after navigating to search. + """ + + @abc.abstractmethod + async def navigate(self, url: str) -> ComputerState: + """Navigates directly to a specified URL. + + Args: + url: The URL to navigate to. + + Returns: + The current state after navigation. + """ + + @abc.abstractmethod + async def key_combination(self, keys: list[str]) -> ComputerState: + """Presses keyboard keys and combinations, such as "control+c" or "enter". + + Args: + keys: List of keys to press in combination. + + Returns: + The current state after key press. + """ + + @abc.abstractmethod + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> ComputerState: + """Drag and drop an element from a x, y coordinate to a destination destination_y, destination_x coordinate. + + The 'x', 'y', 'destination_y' and 'destination_x' values are absolute values, scaled to the height and width of the screen. + + Args: + x: The x-coordinate to start dragging from. + y: The y-coordinate to start dragging from. + destination_x: The x-coordinate to drop at. + destination_y: The y-coordinate to drop at. + + Returns: + The current state after drag and drop. + """ + + @abc.abstractmethod + async def current_state(self) -> ComputerState: + """Returns the current state of the current webpage. + + Returns: + The current environment state. + """ + + async def initialize(self) -> None: + """Initialize the computer.""" + pass + + async def close(self) -> None: + """Cleanup resource of the computer.""" + pass + + @abc.abstractmethod + async def environment(self) -> ComputerEnvironment: + """Returns the environment of the computer.""" diff --git a/src/google/adk/tools/computer_use/computer_use_tool.py b/src/google/adk/tools/computer_use/computer_use_tool.py new file mode 100644 index 00000000..367c10e2 --- /dev/null +++ b/src/google/adk/tools/computer_use/computer_use_tool.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 +import logging +from typing import Any +from typing import Callable + +from google.genai import types +from typing_extensions import override + +from ...models.llm_request import LlmRequest +from ...utils.feature_decorator import experimental +from ..function_tool import FunctionTool +from ..tool_context import ToolContext +from .base_computer import ComputerState + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class ComputerUseTool(FunctionTool): + """A tool that wraps computer control functions for use with LLMs. + + This tool automatically normalizes coordinates from a virtual coordinate space + (by default 1000x1000) to the actual screen size. This allows LLMs to work + with a consistent coordinate system regardless of the actual screen dimensions, + making their output more predictable and easier to handle. + """ + + def __init__( + self, + *, + func: Callable[..., Any], + screen_size: tuple[int, int], + virtual_screen_size: tuple[int, int] = (1000, 1000), + ): + """Initialize the ComputerUseTool. + + Args: + func: The computer control function to wrap. + screen_size: The actual screen size as (width, height) in pixels. + This represents the real dimensions of the target screen/display. + virtual_screen_size: The virtual coordinate space dimensions as (width, height) + that the LLM uses to specify coordinates. Coordinates from the LLM are + automatically normalized from this virtual space to the actual screen_size. + Default is (1000, 1000), meaning the LLM thinks it's working with a + 1000x1000 pixel screen regardless of the actual screen dimensions. + + Raises: + ValueError: If screen_size or virtual_screen_size is not a valid tuple + of positive integers. + """ + super().__init__(func=func) + self._screen_size = screen_size + self._coordinate_space = virtual_screen_size + + # Validate screen size + if not isinstance(screen_size, tuple) or len(screen_size) != 2: + raise ValueError("screen_size must be a tuple of (width, height)") + if screen_size[0] <= 0 or screen_size[1] <= 0: + raise ValueError("screen_size dimensions must be positive") + + # Validate virtual screen size + if ( + not isinstance(virtual_screen_size, tuple) + or len(virtual_screen_size) != 2 + ): + raise ValueError("virtual_screen_size must be a tuple of (width, height)") + if virtual_screen_size[0] <= 0 or virtual_screen_size[1] <= 0: + raise ValueError("virtual_screen_size dimensions must be positive") + + def _normalize_x(self, x: int) -> int: + """Normalize x coordinate from virtual screen space to actual screen width.""" + if not isinstance(x, (int, float)): + raise ValueError(f"x coordinate must be numeric, got {type(x)}") + + normalized = int(x / self._coordinate_space[0] * self._screen_size[0]) + # Clamp to screen bounds + return max(0, min(normalized, self._screen_size[0] - 1)) + + def _normalize_y(self, y: int) -> int: + """Normalize y coordinate from virtual screen space to actual screen height.""" + if not isinstance(y, (int, float)): + raise ValueError(f"y coordinate must be numeric, got {type(y)}") + + normalized = int(y / self._coordinate_space[1] * self._screen_size[1]) + # Clamp to screen bounds + return max(0, min(normalized, self._screen_size[1] - 1)) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Run the computer control function with normalized coordinates.""" + + try: + # Normalize coordinates if present + if "x" in args: + original_x = args["x"] + args["x"] = self._normalize_x(args["x"]) + logger.debug("Normalized x: %s -> %s", original_x, args["x"]) + + if "y" in args: + original_y = args["y"] + args["y"] = self._normalize_y(args["y"]) + logger.debug("Normalized y: %s -> %s", original_y, args["y"]) + + # Handle destination coordinates for drag and drop + if "destination_x" in args: + original_dest_x = args["destination_x"] + args["destination_x"] = self._normalize_x(args["destination_x"]) + logger.debug( + "Normalized destination_x: %s -> %s", + original_dest_x, + args["destination_x"], + ) + + if "destination_y" in args: + original_dest_y = args["destination_y"] + args["destination_y"] = self._normalize_y(args["destination_y"]) + logger.debug( + "Normalized destination_y: %s -> %s", + original_dest_y, + args["destination_y"], + ) + + # Execute the actual computer control function + result = await super().run_async(args=args, tool_context=tool_context) + + # Process the result if it's an EnvironmentState + if isinstance(result, ComputerState): + return { + "image": { + "mimetype": "image/png", + "data": base64.b64encode(result.screenshot).decode("utf-8"), + }, + "url": result.url, + } + + return result + + except Exception as e: + logger.error("Error in ComputerUseTool.run_async: %s", e) + raise + + @override + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + """ComputerUseToolset will add this tool to the LLM request and add computer + use configuration to the LLM request.""" + pass diff --git a/src/google/adk/tools/computer_use/computer_use_toolset.py b/src/google/adk/tools/computer_use/computer_use_toolset.py new file mode 100644 index 00000000..8834b5a4 --- /dev/null +++ b/src/google/adk/tools/computer_use/computer_use_toolset.py @@ -0,0 +1,220 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Callable +from typing import Optional +from typing import Union + +from google.genai import types +from typing_extensions import override + +from ...agents.readonly_context import ReadonlyContext +from ...models.llm_request import LlmRequest +from ...utils.feature_decorator import experimental +from ..base_toolset import BaseToolset +from ..tool_context import ToolContext +from .base_computer import BaseComputer +from .computer_use_tool import ComputerUseTool + +# Methods that should be excluded when creating tools from BaseComputer methods +EXCLUDED_METHODS = {"screen_size", "environment", "close"} + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class ComputerUseToolset(BaseToolset): + + def __init__( + self, + *, + computer: BaseComputer, + ): + super().__init__() + self._computer = computer + self._initialized = False + self._tools = None + + async def _ensure_initialized(self) -> None: + if not self._initialized: + await self._computer.initialize() + self._initialized = True + + @staticmethod + async def adapt_computer_use_tool( + method_name: str, + adapter_func: Union[ + Callable[[Callable[..., Any]], Callable[..., Any]], + Callable[[Callable[..., Any]], Any], + ], + llm_request: LlmRequest, + ) -> None: + """Adapt a computer use tool by replacing it with a modified version. + + Args: + method_name: The name of the method (of BaseComputer class) to adapt (e.g. 'wait'). + adapter_func: A function that accepts existing computer use async function and returns a new computer use async function. + Can be either sync or async function. The name of the returned function will be used as the new tool name. + llm_request: The LLM request containing the tools dictionary. + """ + # Validate that the method is a valid BaseComputer method + if method_name in EXCLUDED_METHODS: + logger.warning( + "Method %s is not a valid BaseComputer method", method_name + ) + return + + # Check if it's a method defined in BaseComputer class + attr = getattr(BaseComputer, method_name, None) + if attr is None or not callable(attr): + logger.warning( + "Method %s is not a valid BaseComputer method", method_name + ) + return + + if method_name not in llm_request.tools_dict: + logger.warning("Method %s not found in tools_dict", method_name) + return + + original_tool = llm_request.tools_dict[method_name] + + # Create the adapted function using the adapter + # Handle both sync and async adapter functions + if asyncio.iscoroutinefunction(adapter_func): + # If adapter_func is async, await it to get the adapted function + adapted_func = await adapter_func(original_tool.func) + else: + # If adapter_func is sync, call it directly + adapted_func = adapter_func(original_tool.func) + + # Get the name from the adapted function + new_method_name = adapted_func.__name__ + + # Create a new ComputerUseTool with the adapted function + adapted_tool = ComputerUseTool( + func=adapted_func, + screen_size=original_tool._screen_size, + virtual_screen_size=original_tool._coordinate_space, + ) + + # Add the adapted tool and remove the original + llm_request.tools_dict[new_method_name] = adapted_tool + del llm_request.tools_dict[method_name] + + logger.debug( + "Adapted tool %s to %s with adapter function", + method_name, + new_method_name, + ) + + @override + async def get_tools( + self, + readonly_context: Optional[ReadonlyContext] = None, + ) -> list[ComputerUseTool]: + if self._tools: + return self._tools + await self._ensure_initialized() + # Get screen size for tool configuration + screen_size = await self._computer.screen_size() + + # Get all methods defined in Computer abstract base class, excluding specified methods + computer_methods = [] + + # Get all methods defined in the Computer ABC interface + for method_name in dir(BaseComputer): + # Skip private methods (starting with underscore) + if method_name.startswith("_"): + continue + + # Skip excluded methods + if method_name in EXCLUDED_METHODS: + continue + + # Check if it's a method defined in Computer class + attr = getattr(BaseComputer, method_name, None) + if attr is not None and callable(attr): + # Get the corresponding method from the concrete instance + instance_method = getattr(self._computer, method_name) + computer_methods.append(instance_method) + + # Create ComputerUseTool instances for each method + + self._tools = [ + ComputerUseTool( + func=method, + screen_size=screen_size, + ) + for method in computer_methods + ] + return self._tools + + @override + async def close(self) -> None: + await self._computer.close() + + @override + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + """Add its tools to the LLM request and add computer + use configuration to the LLM request.""" + try: + + # Add this tool to the tools dictionary + if not self._tools: + await self.get_tools() + + for tool in self._tools: + llm_request.tools_dict[tool.name] = tool + + # Initialize config if needed + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + # Check if computer use is already configured + for tool in llm_request.config.tools: + if ( + isinstance(tool, (types.Tool, types.ToolDict)) + and hasattr(tool, "computer_use") + and tool.computer_use + ): + logger.debug("Computer use already configured in LLM request") + return + + # Add computer use tool configuration + computer_environment = await self._computer.environment() + environment = getattr( + types.Environment, + computer_environment.name, + types.Environment.ENVIRONMENT_BROWSER, + ) + llm_request.config.tools.append( + types.Tool( + computer_use=types.ToolComputerUse(environment=environment) + ) + ) + logger.debug( + "Added computer use tool with environment: %s", + environment, + ) + + except Exception as e: + logger.error("Error in ComputerUseToolset.process_llm_request: %s", e) + raise diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index f757b14d..8cde21fe 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -16,6 +16,7 @@ import os import sys from typing import Optional from unittest import mock +from unittest.mock import AsyncMock from google.adk import version as adk_version from google.adk.models.gemini_llm_connection import GeminiLlmConnection @@ -28,7 +29,6 @@ from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai import version as genai_version from google.genai.types import Content -from google.genai.types import FinishReason from google.genai.types import Part import pytest @@ -66,6 +66,26 @@ def llm_request(): ) +@pytest.fixture +def llm_request_with_computer_use(): + return LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction="You are a helpful assistant", + tools=[ + types.Tool( + computer_use=types.ToolComputerUse( + environment=types.Environment.ENVIRONMENT_BROWSER + ) + ) + ], + ), + ) + + @pytest.fixture def mock_os_environ(): initial_env = os.environ.copy() @@ -614,7 +634,8 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request): ), ], ) -def test_preprocess_request_handles_backend_specific_fields( +@pytest.mark.asyncio +async def test_preprocess_request_handles_backend_specific_fields( gemini_llm: Gemini, api_backend: GoogleLLMVariant, expected_file_display_name: Optional[str], @@ -662,7 +683,7 @@ def test_preprocess_request_handles_backend_specific_fields( mock_backend.return_value = api_backend # Act: Run the preprocessing method - gemini_llm._preprocess_request(llm_request_with_files) + await gemini_llm._preprocess_request(llm_request_with_files) # Assert: Check if the fields were correctly processed file_part = llm_request_with_files.contents[0].parts[0] @@ -1535,3 +1556,165 @@ async def test_generate_content_async_stream_two_separate_text_aggregations(): function_call_responses[0].content.parts[0].function_call.name == "divide" ) + + +@pytest.mark.asyncio +async def test_computer_use_removes_system_instruction(): + """Test that system instruction is set to None when computer use is configured.""" + llm = Gemini() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + tools=[ + types.Tool( + computer_use=types.ToolComputerUse( + environment=types.Environment.ENVIRONMENT_BROWSER + ) + ) + ], + ), + ) + + await llm._preprocess_request(llm_request) + + # System instruction should be set to None when computer use is configured + assert llm_request.config.system_instruction is None + + +@pytest.mark.asyncio +async def test_computer_use_preserves_system_instruction_when_no_computer_use(): + """Test that system instruction is preserved when computer use is not configured.""" + llm = Gemini() + + original_instruction = "You are a helpful assistant" + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + system_instruction=original_instruction, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration(name="test", description="test") + ] + ) + ], + ), + ) + + await llm._preprocess_request(llm_request) + + # System instruction should be preserved when no computer use + assert llm_request.config.system_instruction == original_instruction + + +@pytest.mark.asyncio +async def test_computer_use_with_no_config(): + """Test that preprocessing works when config is None.""" + llm = Gemini() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=None, + ) + + # Should not raise an exception + await llm._preprocess_request(llm_request) + + +@pytest.mark.asyncio +async def test_computer_use_with_no_tools(): + """Test that preprocessing works when config.tools is None.""" + llm = Gemini() + + original_instruction = "You are a helpful assistant" + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + system_instruction=original_instruction, + tools=None, + ), + ) + + await llm._preprocess_request(llm_request) + + # System instruction should be preserved when no tools + assert llm_request.config.system_instruction == original_instruction + + +@pytest.mark.asyncio +async def test_adapt_computer_use_tool_wait(): + """Test that _adapt_computer_use_tool correctly adapts wait to wait_5_seconds.""" + from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool + + llm = Gemini() + + # Create a mock wait tool + mock_wait_func = AsyncMock() + mock_wait_func.return_value = "mock_result" + + original_wait_tool = ComputerUseTool( + func=mock_wait_func, + screen_size=(1920, 1080), + virtual_screen_size=(1000, 1000), + ) + + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + # Add wait to tools_dict + llm_request.tools_dict["wait"] = original_wait_tool + + # Call the adaptation method (now async) + await llm._adapt_computer_use_tool(llm_request) + + # Verify wait was removed and wait_5_seconds was added + assert "wait" not in llm_request.tools_dict + assert "wait_5_seconds" in llm_request.tools_dict + + # Verify the new tool has correct properties + wait_5_seconds_tool = llm_request.tools_dict["wait_5_seconds"] + assert isinstance(wait_5_seconds_tool, ComputerUseTool) + assert wait_5_seconds_tool._screen_size == (1920, 1080) + assert wait_5_seconds_tool._coordinate_space == (1000, 1000) + + # Verify calling the new tool calls the original with 5 seconds + result = await wait_5_seconds_tool.func() + assert result == "mock_result" + mock_wait_func.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_adapt_computer_use_tool_no_wait(): + """Test that _adapt_computer_use_tool does nothing when wait is not present.""" + llm = Gemini() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + # Don't add any tools + original_tools_dict = llm_request.tools_dict.copy() + + # Call the adaptation method (now async) + await llm._adapt_computer_use_tool(llm_request) + + # Verify tools_dict is unchanged + assert llm_request.tools_dict == original_tools_dict + assert "wait_5_seconds" not in llm_request.tools_dict diff --git a/tests/unittests/tools/computer_use/__init__.py b/tests/unittests/tools/computer_use/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/tools/computer_use/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/computer_use/test_base_computer.py b/tests/unittests/tools/computer_use/test_base_computer.py new file mode 100644 index 00000000..8a2bcfa4 --- /dev/null +++ b/tests/unittests/tools/computer_use/test_base_computer.py @@ -0,0 +1,341 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for base_computer module.""" + +from typing import Literal + +from google.adk.tools.computer_use.base_computer import BaseComputer +from google.adk.tools.computer_use.base_computer import ComputerEnvironment +from google.adk.tools.computer_use.base_computer import ComputerState +import pytest + + +class TestComputerEnvironment: + """Test cases for ComputerEnvironment enum.""" + + def test_valid_environments(self): + """Test valid environment values.""" + assert ( + ComputerEnvironment.ENVIRONMENT_UNSPECIFIED == "ENVIRONMENT_UNSPECIFIED" + ) + assert ComputerEnvironment.ENVIRONMENT_BROWSER == "ENVIRONMENT_BROWSER" + + def test_invalid_environment_raises(self): + """Test that invalid environment values raise ValueError.""" + + with pytest.raises(ValueError): + ComputerEnvironment("INVALID_ENVIRONMENT") + + def test_string_representation(self): + """Test string representation of enum values.""" + assert ( + ComputerEnvironment.ENVIRONMENT_BROWSER.value == "ENVIRONMENT_BROWSER" + ) + assert ( + ComputerEnvironment.ENVIRONMENT_UNSPECIFIED.value + == "ENVIRONMENT_UNSPECIFIED" + ) + + +class TestComputerState: + """Test cases for ComputerState Pydantic model.""" + + def test_default_initialization(self): + """Test ComputerState with default values.""" + state = ComputerState() + assert state.screenshot is None + assert state.url is None + + def test_initialization_with_screenshot(self): + """Test ComputerState with screenshot data.""" + screenshot_data = b"fake_png_data" + state = ComputerState(screenshot=screenshot_data) + assert state.screenshot == screenshot_data + assert state.url is None + + def test_initialization_with_url(self): + """Test ComputerState with URL.""" + url = "https://example.com" + state = ComputerState(url=url) + assert state.screenshot is None + assert state.url == url + + def test_initialization_with_all_fields(self): + """Test ComputerState with all fields provided.""" + screenshot_data = b"fake_png_data" + url = "https://example.com" + state = ComputerState(screenshot=screenshot_data, url=url) + assert state.screenshot == screenshot_data + assert state.url == url + + def test_field_validation(self): + """Test field validation for ComputerState.""" + # Test that bytes are accepted for screenshot + state = ComputerState(screenshot=b"test_data") + assert state.screenshot == b"test_data" + + # Test that string is accepted for URL + state = ComputerState(url="https://test.com") + assert state.url == "https://test.com" + + def test_model_serialization(self): + """Test that ComputerState can be serialized.""" + state = ComputerState(screenshot=b"test", url="https://example.com") + # Should not raise an exception + model_dict = state.model_dump() + assert "screenshot" in model_dict + assert "url" in model_dict + + +class MockComputer(BaseComputer): + """Mock implementation of BaseComputer for testing.""" + + def __init__(self): + self.initialized = False + self.closed = False + + async def screen_size(self) -> tuple[int, int]: + return (1920, 1080) + + async def open_web_browser(self) -> ComputerState: + return ComputerState(url="https://example.com") + + async def click_at(self, x: int, y: int) -> ComputerState: + return ComputerState(url="https://example.com") + + async def hover_at(self, x: int, y: int) -> ComputerState: + return ComputerState(url="https://example.com") + + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> ComputerState: + return ComputerState(url="https://example.com") + + async def scroll_document( + self, direction: Literal["up", "down", "left", "right"] + ) -> ComputerState: + return ComputerState(url="https://example.com") + + async def scroll_at( + self, + x: int, + y: int, + direction: Literal["up", "down", "left", "right"], + magnitude: int, + ) -> ComputerState: + return ComputerState(url="https://example.com") + + async def wait(self, seconds: int) -> ComputerState: + return ComputerState(url="https://example.com") + + async def go_back(self) -> ComputerState: + return ComputerState(url="https://example.com") + + async def go_forward(self) -> ComputerState: + return ComputerState(url="https://example.com") + + async def search(self) -> ComputerState: + return ComputerState(url="https://search.example.com") + + async def navigate(self, url: str) -> ComputerState: + return ComputerState(url=url) + + async def key_combination(self, keys: list[str]) -> ComputerState: + return ComputerState(url="https://example.com") + + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> ComputerState: + return ComputerState(url="https://example.com") + + async def current_state(self) -> ComputerState: + return ComputerState( + url="https://example.com", screenshot=b"screenshot_data" + ) + + async def initialize(self) -> None: + self.initialized = True + + async def close(self) -> None: + self.closed = True + + async def environment(self) -> ComputerEnvironment: + return ComputerEnvironment.ENVIRONMENT_BROWSER + + +class TestBaseComputer: + """Test cases for BaseComputer abstract base class.""" + + @pytest.fixture + def mock_computer(self) -> MockComputer: + """Fixture providing a mock computer implementation.""" + return MockComputer() + + def test_cannot_instantiate_abstract_class(self): + """Test that BaseComputer cannot be instantiated directly.""" + import pytest + + with pytest.raises(TypeError): + BaseComputer() # Should raise TypeError because it's abstract + + @pytest.mark.asyncio + async def test_screen_size(self, mock_computer): + """Test screen_size method.""" + size = await mock_computer.screen_size() + assert size == (1920, 1080) + assert isinstance(size, tuple) + assert len(size) == 2 + + @pytest.mark.asyncio + async def test_open_web_browser(self, mock_computer): + """Test open_web_browser method.""" + state = await mock_computer.open_web_browser() + assert isinstance(state, ComputerState) + assert state.url == "https://example.com" + + @pytest.mark.asyncio + async def test_click_at(self, mock_computer): + """Test click_at method.""" + state = await mock_computer.click_at(100, 200) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_hover_at(self, mock_computer): + """Test hover_at method.""" + state = await mock_computer.hover_at(150, 250) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_type_text_at(self, mock_computer): + """Test type_text_at method with different parameters.""" + # Test with default parameters + state = await mock_computer.type_text_at(100, 200, "Hello World") + assert isinstance(state, ComputerState) + + # Test with custom parameters + state = await mock_computer.type_text_at( + 100, 200, "Hello", press_enter=False, clear_before_typing=False + ) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_scroll_document(self, mock_computer): + """Test scroll_document method with different directions.""" + directions = ["up", "down", "left", "right"] + for direction in directions: + state = await mock_computer.scroll_document(direction) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_scroll_at(self, mock_computer): + """Test scroll_at method.""" + state = await mock_computer.scroll_at(100, 200, "down", 5) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_wait(self, mock_computer): + """Test wait method.""" + state = await mock_computer.wait(5) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_go_back(self, mock_computer): + """Test go_back method.""" + state = await mock_computer.go_back() + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_go_forward(self, mock_computer): + """Test go_forward method.""" + state = await mock_computer.go_forward() + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_search(self, mock_computer): + """Test search method.""" + state = await mock_computer.search() + assert isinstance(state, ComputerState) + assert state.url == "https://search.example.com" + + @pytest.mark.asyncio + async def test_navigate(self, mock_computer): + """Test navigate method.""" + test_url = "https://test.example.com" + state = await mock_computer.navigate(test_url) + assert isinstance(state, ComputerState) + assert state.url == test_url + + @pytest.mark.asyncio + async def test_key_combination(self, mock_computer): + """Test key_combination method.""" + state = await mock_computer.key_combination(["ctrl", "c"]) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_drag_and_drop(self, mock_computer): + """Test drag_and_drop method.""" + state = await mock_computer.drag_and_drop(100, 200, 300, 400) + assert isinstance(state, ComputerState) + + @pytest.mark.asyncio + async def test_current_state(self, mock_computer): + """Test current_state method.""" + state = await mock_computer.current_state() + assert isinstance(state, ComputerState) + assert state.url == "https://example.com" + assert state.screenshot == b"screenshot_data" + + @pytest.mark.asyncio + async def test_initialize(self, mock_computer): + """Test initialize method.""" + assert not mock_computer.initialized + await mock_computer.initialize() + assert mock_computer.initialized + + @pytest.mark.asyncio + async def test_close(self, mock_computer): + """Test close method.""" + assert not mock_computer.closed + await mock_computer.close() + assert mock_computer.closed + + @pytest.mark.asyncio + async def test_environment(self, mock_computer): + """Test environment method.""" + env = await mock_computer.environment() + assert env == ComputerEnvironment.ENVIRONMENT_BROWSER + assert isinstance(env, ComputerEnvironment) + + @pytest.mark.asyncio + async def test_lifecycle_methods(self, mock_computer): + """Test the lifecycle of a computer instance.""" + # Initially not initialized or closed + assert not mock_computer.initialized + assert not mock_computer.closed + + # Initialize + await mock_computer.initialize() + assert mock_computer.initialized + assert not mock_computer.closed + + # Close + await mock_computer.close() + assert mock_computer.initialized + assert mock_computer.closed diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py new file mode 100644 index 00000000..4dbdfbb5 --- /dev/null +++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py @@ -0,0 +1,500 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import inspect + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.computer_use.base_computer import ComputerState +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class TestComputerUseTool: + """Test cases for ComputerUseTool class.""" + + @pytest.fixture + async def tool_context(self): + """Fixture providing a tool context.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + agent = SequentialAgent(name="test_agent") + invocation_context = InvocationContext( + invocation_id="invocation_id", + agent=agent, + session=session, + session_service=session_service, + ) + return ToolContext(invocation_context=invocation_context) + + @pytest.fixture + def mock_computer_function(self): + """Fixture providing a mock computer function.""" + # Create a real async function instead of AsyncMock for Python 3.9 compatibility + calls = [] + + async def mock_func(*args, **kwargs): + calls.append((args, kwargs)) + # Return a default ComputerState - this will be overridden in individual tests + return ComputerState(screenshot=b"default", url="https://default.com") + + # Add attributes that tests expect + mock_func.__name__ = "test_function" + mock_func.__doc__ = "Test function documentation" + mock_func.calls = calls + + # Add assertion methods for compatibility with Mock + def assert_called_once_with(*args, **kwargs): + assert len(calls) == 1, f"Expected 1 call, got {len(calls)}" + assert calls[0] == ( + args, + kwargs, + ), f"Expected {(args, kwargs)}, got {calls[0]}" + + def assert_called_once(): + assert len(calls) == 1, f"Expected 1 call, got {len(calls)}" + + mock_func.assert_called_once_with = assert_called_once_with + mock_func.assert_called_once = assert_called_once + + return mock_func + + def test_init(self, mock_computer_function): + """Test ComputerUseTool initialization.""" + screen_size = (1920, 1080) + tool = ComputerUseTool(func=mock_computer_function, screen_size=screen_size) + + assert tool._screen_size == screen_size + assert tool.func == mock_computer_function + + def test_init_with_invalid_screen_size(self, mock_computer_function): + """Test ComputerUseTool initialization with invalid screen size.""" + with pytest.raises(ValueError, match="screen_size must be a tuple"): + ComputerUseTool(func=mock_computer_function, screen_size=[1920, 1080]) + + with pytest.raises(ValueError, match="screen_size must be a tuple"): + ComputerUseTool(func=mock_computer_function, screen_size=(1920,)) + + with pytest.raises( + ValueError, match="screen_size dimensions must be positive" + ): + ComputerUseTool(func=mock_computer_function, screen_size=(0, 1080)) + + with pytest.raises( + ValueError, match="screen_size dimensions must be positive" + ): + ComputerUseTool(func=mock_computer_function, screen_size=(1920, -1)) + + def test_init_with_invalid_virtual_screen_size(self, mock_computer_function): + """Test ComputerUseTool initialization with invalid virtual_screen_size.""" + with pytest.raises(ValueError, match="virtual_screen_size must be a tuple"): + ComputerUseTool( + func=mock_computer_function, + screen_size=(1920, 1080), + virtual_screen_size=[1000, 1000], + ) + + with pytest.raises(ValueError, match="virtual_screen_size must be a tuple"): + ComputerUseTool( + func=mock_computer_function, + screen_size=(1920, 1080), + virtual_screen_size=(1000,), + ) + + with pytest.raises( + ValueError, match="virtual_screen_size dimensions must be positive" + ): + ComputerUseTool( + func=mock_computer_function, + screen_size=(1920, 1080), + virtual_screen_size=(0, 1000), + ) + + with pytest.raises( + ValueError, match="virtual_screen_size dimensions must be positive" + ): + ComputerUseTool( + func=mock_computer_function, + screen_size=(1920, 1080), + virtual_screen_size=(1000, -1), + ) + + def test_init_with_custom_virtual_screen_size(self, mock_computer_function): + """Test ComputerUseTool initialization with custom virtual_screen_size.""" + screen_size = (1920, 1080) + virtual_screen_size = (2000, 2000) + tool = ComputerUseTool( + func=mock_computer_function, + screen_size=screen_size, + virtual_screen_size=virtual_screen_size, + ) + + assert tool._screen_size == screen_size + assert tool._coordinate_space == virtual_screen_size + assert tool.func == mock_computer_function + + def test_normalize_x(self, mock_computer_function): + """Test x coordinate normalization with default virtual screen size (1000x1000).""" + tool = ComputerUseTool( + func=mock_computer_function, screen_size=(1920, 1080) + ) + + # Test normal cases + assert tool._normalize_x(0) == 0 + assert tool._normalize_x(500) == 960 # 500/1000 * 1920 + assert tool._normalize_x(1000) == 1919 # Clamped to screen bounds + + # Test edge cases + assert tool._normalize_x(-100) == 0 # Clamped to 0 + assert tool._normalize_x(1500) == 1919 # Clamped to max + + def test_normalize_y(self, mock_computer_function): + """Test y coordinate normalization with default virtual screen size (1000x1000).""" + tool = ComputerUseTool( + func=mock_computer_function, screen_size=(1920, 1080) + ) + + # Test normal cases + assert tool._normalize_y(0) == 0 + assert tool._normalize_y(500) == 540 # 500/1000 * 1080 + assert tool._normalize_y(1000) == 1079 # Clamped to screen bounds + + # Test edge cases + assert tool._normalize_y(-100) == 0 # Clamped to 0 + assert tool._normalize_y(1500) == 1079 # Clamped to max + + def test_normalize_with_custom_virtual_screen_size( + self, mock_computer_function + ): + """Test coordinate normalization with custom virtual screen size.""" + tool = ComputerUseTool( + func=mock_computer_function, + screen_size=(1920, 1080), + virtual_screen_size=(2000, 2000), + ) + + # Test x coordinate normalization with 2000x2000 virtual space + assert tool._normalize_x(0) == 0 + assert tool._normalize_x(1000) == 960 # 1000/2000 * 1920 + assert tool._normalize_x(2000) == 1919 # Clamped to screen bounds + + # Test y coordinate normalization with 2000x2000 virtual space + assert tool._normalize_y(0) == 0 + assert tool._normalize_y(1000) == 540 # 1000/2000 * 1080 + assert tool._normalize_y(2000) == 1079 # Clamped to screen bounds + + # Test edge cases + assert tool._normalize_x(-100) == 0 # Clamped to 0 + assert tool._normalize_x(3000) == 1919 # Clamped to max + assert tool._normalize_y(-100) == 0 # Clamped to 0 + assert tool._normalize_y(3000) == 1079 # Clamped to max + + def test_normalize_with_invalid_coordinates(self, mock_computer_function): + """Test coordinate normalization with invalid inputs.""" + tool = ComputerUseTool( + func=mock_computer_function, screen_size=(1920, 1080) + ) + + with pytest.raises(ValueError, match="x coordinate must be numeric"): + tool._normalize_x("invalid") + + with pytest.raises(ValueError, match="y coordinate must be numeric"): + tool._normalize_y("invalid") + + @pytest.mark.asyncio + async def test_run_async_with_coordinates( + self, mock_computer_function, tool_context + ): + """Test run_async with coordinate normalization.""" + + # Set up a proper signature for the mock function + def dummy_func(x: int, y: int): + pass + + mock_computer_function.__name__ = "dummy_func" + mock_computer_function.__signature__ = inspect.signature(dummy_func) + + # Create a specific mock function for this test that returns the expected state + calls = [] + mock_state = ComputerState( + screenshot=b"test_screenshot", url="https://example.com" + ) + + async def specific_mock_func(x: int, y: int): + calls.append((x, y)) + return mock_state + + specific_mock_func.__name__ = "dummy_func" + specific_mock_func.__signature__ = inspect.signature(dummy_func) + specific_mock_func.calls = calls + + def assert_called_once_with(x, y): + assert len(calls) == 1, f"Expected 1 call, got {len(calls)}" + assert calls[0] == (x, y), f"Expected ({x}, {y}), got {calls[0]}" + + specific_mock_func.assert_called_once_with = assert_called_once_with + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + args = {"x": 500, "y": 300} + result = await tool.run_async(args=args, tool_context=tool_context) + + # Check that coordinates were normalized + specific_mock_func.assert_called_once_with(x=960, y=324) + + # Check return format for ComputerState + expected_result = { + "image": { + "mimetype": "image/png", + "data": base64.b64encode(b"test_screenshot").decode("utf-8"), + }, + "url": "https://example.com", + } + assert result == expected_result + + @pytest.mark.asyncio + async def test_run_async_with_drag_and_drop_coordinates( + self, mock_computer_function, tool_context + ): + """Test run_async with drag and drop coordinate normalization.""" + + # Set up a proper signature for the mock function + def dummy_func(x: int, y: int, destination_x: int, destination_y: int): + pass + + # Create a specific mock function for this test + calls = [] + mock_state = ComputerState( + screenshot=b"test_screenshot", url="https://example.com" + ) + + async def specific_mock_func( + x: int, y: int, destination_x: int, destination_y: int + ): + calls.append((x, y, destination_x, destination_y)) + return mock_state + + specific_mock_func.__name__ = "dummy_func" + specific_mock_func.__signature__ = inspect.signature(dummy_func) + specific_mock_func.calls = calls + + def assert_called_once_with(x, y, destination_x, destination_y): + assert len(calls) == 1, f"Expected 1 call, got {len(calls)}" + assert calls[0] == (x, y, destination_x, destination_y), ( + f"Expected ({x}, {y}, {destination_x}, {destination_y}), got" + f" {calls[0]}" + ) + + specific_mock_func.assert_called_once_with = assert_called_once_with + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + args = {"x": 100, "y": 200, "destination_x": 800, "destination_y": 600} + result = await tool.run_async(args=args, tool_context=tool_context) + + # Check that all coordinates were normalized + specific_mock_func.assert_called_once_with( + x=192, # 100/1000 * 1920 + y=216, # 200/1000 * 1080 + destination_x=1536, # 800/1000 * 1920 + destination_y=648, # 600/1000 * 1080 + ) + + @pytest.mark.asyncio + async def test_run_async_with_non_computer_state_result( + self, mock_computer_function, tool_context + ): + """Test run_async when function returns non-ComputerState result.""" + # Create a specific mock function that returns non-ComputerState + calls = [] + + async def specific_mock_func(*args, **kwargs): + calls.append((args, kwargs)) + return {"status": "success"} + + specific_mock_func.__name__ = "test_function" + specific_mock_func.calls = calls + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + args = {"text": "hello"} + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return the result as-is + assert result == {"status": "success"} + + @pytest.mark.asyncio + async def test_run_async_without_coordinates( + self, mock_computer_function, tool_context + ): + """Test run_async with no coordinate parameters.""" + + # Set up a proper signature for the mock function + def dummy_func(direction: str): + pass + + # Create a specific mock function for this test + calls = [] + mock_state = ComputerState( + screenshot=b"test_screenshot", url="https://example.com" + ) + + async def specific_mock_func(direction: str): + calls.append((direction,)) + return mock_state + + specific_mock_func.__name__ = "dummy_func" + specific_mock_func.__signature__ = inspect.signature(dummy_func) + specific_mock_func.calls = calls + + def assert_called_once_with(direction): + assert len(calls) == 1, f"Expected 1 call, got {len(calls)}" + assert calls[0] == ( + direction, + ), f"Expected ({direction},), got {calls[0]}" + + specific_mock_func.assert_called_once_with = assert_called_once_with + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + args = {"direction": "down"} + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should call function with original args + specific_mock_func.assert_called_once_with(direction="down") + + @pytest.mark.asyncio + async def test_run_async_with_error( + self, mock_computer_function, tool_context + ): + """Test run_async when underlying function raises an error.""" + # Create a specific mock function that raises an error + calls = [] + + async def specific_mock_func(*args, **kwargs): + calls.append((args, kwargs)) + raise ValueError("Test error") + + specific_mock_func.__name__ = "test_function" + specific_mock_func.calls = calls + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + args = {"x": 500, "y": 300} + + with pytest.raises(ValueError, match="Test error"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_process_llm_request( + self, mock_computer_function, tool_context + ): + """Test process_llm_request method.""" + tool = ComputerUseTool( + func=mock_computer_function, screen_size=(1920, 1080) + ) + llm_request = LlmRequest() + + # Should not raise any exceptions and should do nothing + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + # Verify llm_request is unchanged (process_llm_request is now a no-op) + assert llm_request.tools_dict == {} + + def test_inheritance(self, mock_computer_function): + """Test that ComputerUseTool inherits from FunctionTool.""" + from google.adk.tools.function_tool import FunctionTool + + tool = ComputerUseTool( + func=mock_computer_function, screen_size=(1920, 1080) + ) + assert isinstance(tool, FunctionTool) + + def test_custom_screen_size(self, mock_computer_function): + """Test ComputerUseTool with custom screen size and default virtual screen size.""" + custom_size = (2560, 1440) + tool = ComputerUseTool(func=mock_computer_function, screen_size=custom_size) + + # Test normalization with custom screen size and default 1000x1000 virtual space + assert tool._normalize_x(500) == 1280 # 500/1000 * 2560 + assert tool._normalize_y(500) == 720 # 500/1000 * 1440 + + def test_custom_screen_size_with_custom_virtual_screen_size( + self, mock_computer_function + ): + """Test ComputerUseTool with both custom screen size and custom virtual screen size.""" + screen_size = (2560, 1440) + virtual_screen_size = (800, 600) + tool = ComputerUseTool( + func=mock_computer_function, + screen_size=screen_size, + virtual_screen_size=virtual_screen_size, + ) + + # Test normalization: 400/800 * 2560 = 1280, 300/600 * 1440 = 720 + assert tool._normalize_x(400) == 1280 # 400/800 * 2560 + assert tool._normalize_y(300) == 720 # 300/600 * 1440 + + # Test bounds + assert ( + tool._normalize_x(800) == 2559 + ) # 800/800 * 2560, clamped to screen bounds + assert ( + tool._normalize_y(600) == 1439 + ) # 600/600 * 1440, clamped to screen bounds + + @pytest.mark.asyncio + async def test_coordinate_logging( + self, mock_computer_function, tool_context, caplog + ): + """Test that coordinate normalization is logged.""" + import logging + + # Set up a proper signature for the mock function + def dummy_func(x: int, y: int): + pass + + # Create a specific mock function for this test + calls = [] + mock_state = ComputerState( + screenshot=b"test_screenshot", url="https://example.com" + ) + + async def specific_mock_func(x: int, y: int): + calls.append((x, y)) + return mock_state + + specific_mock_func.__name__ = "dummy_func" + specific_mock_func.__signature__ = inspect.signature(dummy_func) + specific_mock_func.calls = calls + + tool = ComputerUseTool(func=specific_mock_func, screen_size=(1920, 1080)) + + # Set the specific logger used by ComputerUseTool to DEBUG level + logger_name = "google_adk.google.adk.tools.computer_use.computer_use_tool" + with caplog.at_level(logging.DEBUG, logger=logger_name): + args = {"x": 500, "y": 300} + await tool.run_async(args=args, tool_context=tool_context) + + # Check that normalization was logged + assert "Normalized x: 500 -> 960" in caplog.text + assert "Normalized y: 300 -> 324" in caplog.text diff --git a/tests/unittests/tools/computer_use/test_computer_use_toolset.py b/tests/unittests/tools/computer_use/test_computer_use_toolset.py new file mode 100644 index 00000000..803dddd0 --- /dev/null +++ b/tests/unittests/tools/computer_use/test_computer_use_toolset.py @@ -0,0 +1,558 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +from google.adk.models.llm_request import LlmRequest +# Use the actual ComputerEnvironment enum from the code +from google.adk.tools.computer_use.base_computer import BaseComputer +from google.adk.tools.computer_use.base_computer import ComputerEnvironment +from google.adk.tools.computer_use.base_computer import ComputerState +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool +from google.adk.tools.computer_use.computer_use_toolset import ComputerUseToolset +from google.genai import types +import pytest + + +class MockComputer(BaseComputer): + """Mock Computer implementation for testing.""" + + def __init__(self): + self.initialize_called = False + self.close_called = False + self._screen_size = (1920, 1080) + self._environment = ComputerEnvironment.ENVIRONMENT_BROWSER + + async def initialize(self): + self.initialize_called = True + + async def close(self): + self.close_called = True + + async def screen_size(self) -> tuple[int, int]: + return self._screen_size + + async def environment(self) -> ComputerEnvironment: + return self._environment + + # Implement all abstract methods to make this a concrete class + async def open_web_browser(self) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def click_at(self, x: int, y: int) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def hover_at(self, x: int, y: int) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def type_text_at( + self, + x: int, + y: int, + text: str, + press_enter: bool = True, + clear_before_typing: bool = True, + ) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def scroll_document(self, direction: str) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def scroll_at( + self, x: int, y: int, direction: str, magnitude: int + ) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def wait(self, seconds: int) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def go_back(self) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def go_forward(self) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def search(self) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def navigate(self, url: str) -> ComputerState: + return ComputerState(screenshot=b"test", url=url) + + async def key_combination(self, keys: list[str]) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def drag_and_drop( + self, x: int, y: int, destination_x: int, destination_y: int + ) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + async def current_state(self) -> ComputerState: + return ComputerState(screenshot=b"test", url="https://example.com") + + +class TestComputerUseToolset: + """Test cases for ComputerUseToolset class.""" + + @pytest.fixture + def mock_computer(self): + """Fixture providing a mock computer.""" + return MockComputer() + + @pytest.fixture + def toolset(self, mock_computer): + """Fixture providing a ComputerUseToolset instance.""" + return ComputerUseToolset(computer=mock_computer) + + def test_init(self, mock_computer): + """Test ComputerUseToolset initialization.""" + toolset = ComputerUseToolset(computer=mock_computer) + + assert toolset._computer == mock_computer + assert toolset._initialized is False + + @pytest.mark.asyncio + async def test_ensure_initialized(self, toolset, mock_computer): + """Test that _ensure_initialized calls computer.initialize().""" + assert not mock_computer.initialize_called + assert not toolset._initialized + + await toolset._ensure_initialized() + + assert mock_computer.initialize_called + assert toolset._initialized + + @pytest.mark.asyncio + async def test_ensure_initialized_only_once(self, toolset, mock_computer): + """Test that _ensure_initialized only calls initialize once.""" + await toolset._ensure_initialized() + + # Reset the flag to test it's not called again + mock_computer.initialize_called = False + + await toolset._ensure_initialized() + + # Should not be called again + assert not mock_computer.initialize_called + assert toolset._initialized + + @pytest.mark.asyncio + async def test_get_tools(self, toolset, mock_computer): + """Test that get_tools returns ComputerUseTool instances.""" + tools = await toolset.get_tools() + + # Should initialize the computer + assert mock_computer.initialize_called + + # Should return a list of ComputerUseTool instances + assert isinstance(tools, list) + assert len(tools) > 0 + assert all(isinstance(tool, ComputerUseTool) for tool in tools) + + # Each tool should have the correct configuration + for tool in tools: + assert tool._screen_size == (1920, 1080) + # Should use default virtual screen size + assert tool._coordinate_space == (1000, 1000) + + @pytest.mark.asyncio + async def test_get_tools_excludes_utility_methods(self, toolset): + """Test that get_tools excludes utility methods like screen_size, environment, close.""" + tools = await toolset.get_tools() + + # Get tool function names + tool_names = [tool.func.__name__ for tool in tools] + + # Should exclude utility methods + excluded_methods = {"screen_size", "environment", "close"} + for method in excluded_methods: + assert method not in tool_names + + # initialize might be included since it's a concrete method, not just abstract + # This is acceptable behavior + + # Should include action methods + expected_methods = { + "open_web_browser", + "click_at", + "hover_at", + "type_text_at", + "scroll_document", + "scroll_at", + "wait", + "go_back", + "go_forward", + "search", + "navigate", + "key_combination", + "drag_and_drop", + "current_state", + } + for method in expected_methods: + assert method in tool_names + + @pytest.mark.asyncio + async def test_get_tools_with_readonly_context(self, toolset): + """Test get_tools with readonly_context parameter.""" + from google.adk.agents.readonly_context import ReadonlyContext + + readonly_context = MagicMock(spec=ReadonlyContext) + + tools = await toolset.get_tools(readonly_context=readonly_context) + + # Should still return tools (readonly_context doesn't affect behavior currently) + assert isinstance(tools, list) + assert len(tools) > 0 + + @pytest.mark.asyncio + async def test_close(self, toolset, mock_computer): + """Test that close calls computer.close().""" + await toolset.close() + + assert mock_computer.close_called + + @pytest.mark.asyncio + async def test_get_tools_creates_tools_with_correct_methods( + self, toolset, mock_computer + ): + """Test that get_tools creates tools with the correct underlying methods.""" + tools = await toolset.get_tools() + + # Find the click_at tool + click_tool = None + for tool in tools: + if tool.func.__name__ == "click_at": + click_tool = tool + break + + assert click_tool is not None + + # The tool's function should be bound to the mock computer instance + assert click_tool.func.__self__ == mock_computer + + @pytest.mark.asyncio + async def test_get_tools_handles_custom_screen_size(self, mock_computer): + """Test get_tools with custom screen size.""" + mock_computer._screen_size = (2560, 1440) + + toolset = ComputerUseToolset(computer=mock_computer) + tools = await toolset.get_tools() + + # All tools should have the custom screen size + for tool in tools: + assert tool._screen_size == (2560, 1440) + + @pytest.mark.asyncio + async def test_get_tools_handles_custom_environment(self, mock_computer): + """Test get_tools with custom environment.""" + mock_computer._environment = ComputerEnvironment.ENVIRONMENT_UNSPECIFIED + + toolset = ComputerUseToolset(computer=mock_computer) + tools = await toolset.get_tools() + + # Should still return tools regardless of environment + assert isinstance(tools, list) + assert len(tools) > 0 + + @pytest.mark.asyncio + async def test_multiple_get_tools_calls_return_cached_instances( + self, toolset + ): + """Test that multiple get_tools calls return the same cached instances.""" + tools1 = await toolset.get_tools() + tools2 = await toolset.get_tools() + + # Should return the same list instance + assert tools1 is tools2 + + def test_inheritance(self, toolset): + """Test that ComputerUseToolset inherits from BaseToolset.""" + from google.adk.tools.base_toolset import BaseToolset + + assert isinstance(toolset, BaseToolset) + + @pytest.mark.asyncio + async def test_get_tools_method_filtering(self, toolset): + """Test that get_tools properly filters methods from BaseComputer.""" + tools = await toolset.get_tools() + + # Get all method names from the tools + tool_method_names = [tool.func.__name__ for tool in tools] + + # Should not include private methods (starting with _) + for name in tool_method_names: + assert not name.startswith("_") + + # Should not include excluded methods + excluded_methods = {"screen_size", "environment", "close"} + for excluded in excluded_methods: + assert excluded not in tool_method_names + + @pytest.mark.asyncio + async def test_computer_method_binding(self, toolset, mock_computer): + """Test that tools are properly bound to the computer instance.""" + tools = await toolset.get_tools() + + # All tools should be bound to the mock computer + for tool in tools: + assert tool.func.__self__ == mock_computer + + @pytest.mark.asyncio + async def test_toolset_handles_computer_initialization_failure( + self, mock_computer + ): + """Test that toolset handles computer initialization failure gracefully.""" + + # Make initialize raise an exception + async def failing_initialize(): + raise Exception("Initialization failed") + + mock_computer.initialize = failing_initialize + + toolset = ComputerUseToolset(computer=mock_computer) + + # Should raise the exception when trying to get tools + with pytest.raises(Exception, match="Initialization failed"): + await toolset.get_tools() + + @pytest.mark.asyncio + async def test_process_llm_request(self, toolset, mock_computer): + """Test that process_llm_request adds tools and computer use configuration.""" + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + await toolset.process_llm_request( + tool_context=MagicMock(), llm_request=llm_request + ) + + # Should add tools to the request + assert len(llm_request.tools_dict) > 0 + + # Should add computer use configuration + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) > 0 + + # Should have computer use tool + computer_use_tools = [ + tool + for tool in llm_request.config.tools + if hasattr(tool, "computer_use") and tool.computer_use + ] + assert len(computer_use_tools) == 1 + + # Should have correct environment + computer_use_tool = computer_use_tools[0] + assert ( + computer_use_tool.computer_use.environment + == types.Environment.ENVIRONMENT_BROWSER + ) + + @pytest.mark.asyncio + async def test_process_llm_request_with_existing_computer_use( + self, toolset, mock_computer + ): + """Test that process_llm_request doesn't add duplicate computer use configuration.""" + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig( + tools=[ + types.Tool( + computer_use=types.ToolComputerUse( + environment=types.Environment.ENVIRONMENT_BROWSER + ) + ) + ] + ), + ) + + original_tools_count = len(llm_request.config.tools) + + await toolset.process_llm_request( + tool_context=MagicMock(), llm_request=llm_request + ) + + # Should not add duplicate computer use configuration + assert len(llm_request.config.tools) == original_tools_count + + # Should still add the actual tools + assert len(llm_request.tools_dict) > 0 + + @pytest.mark.asyncio + async def test_process_llm_request_error_handling(self, mock_computer): + """Test that process_llm_request handles errors gracefully.""" + + # Make environment raise an exception + async def failing_environment(): + raise Exception("Environment failed") + + mock_computer.environment = failing_environment + + toolset = ComputerUseToolset(computer=mock_computer) + + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + # Should raise the exception + with pytest.raises(Exception, match="Environment failed"): + await toolset.process_llm_request( + tool_context=MagicMock(), llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_adapt_computer_use_tool_sync_adapter(self): + """Test adapt_computer_use_tool with sync adapter function.""" + # Create a mock tool + mock_func = AsyncMock() + original_tool = ComputerUseTool( + func=mock_func, + screen_size=(1920, 1080), + virtual_screen_size=(1000, 1000), + ) + + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + llm_request.tools_dict["wait"] = original_tool + + # Create a sync adapter function + def sync_adapter(original_func): + async def adapted_func(): + return await original_func(5) + + return adapted_func + + # Call the adaptation method + await ComputerUseToolset.adapt_computer_use_tool( + "wait", sync_adapter, llm_request + ) + + # Verify the original tool was replaced + assert "wait" not in llm_request.tools_dict + assert "adapted_func" in llm_request.tools_dict + + # Verify the new tool has correct properties + adapted_tool = llm_request.tools_dict["adapted_func"] + assert isinstance(adapted_tool, ComputerUseTool) + assert adapted_tool._screen_size == (1920, 1080) + assert adapted_tool._coordinate_space == (1000, 1000) + + @pytest.mark.asyncio + async def test_adapt_computer_use_tool_async_adapter(self): + """Test adapt_computer_use_tool with async adapter function.""" + # Create a mock tool + mock_func = AsyncMock() + original_tool = ComputerUseTool( + func=mock_func, + screen_size=(1920, 1080), + virtual_screen_size=(1000, 1000), + ) + + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + llm_request.tools_dict["wait"] = original_tool + + # Create an async adapter function + async def async_adapter(original_func): + async def adapted_func(): + return await original_func(5) + + return adapted_func + + # Call the adaptation method + await ComputerUseToolset.adapt_computer_use_tool( + "wait", async_adapter, llm_request + ) + + # Verify the original tool was replaced + assert "wait" not in llm_request.tools_dict + assert "adapted_func" in llm_request.tools_dict + + # Verify the new tool has correct properties + adapted_tool = llm_request.tools_dict["adapted_func"] + assert isinstance(adapted_tool, ComputerUseTool) + assert adapted_tool._screen_size == (1920, 1080) + assert adapted_tool._coordinate_space == (1000, 1000) + + @pytest.mark.asyncio + async def test_adapt_computer_use_tool_invalid_method(self): + """Test adapt_computer_use_tool with invalid method name.""" + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + def adapter(original_func): + async def adapted_func(): + return await original_func() + + return adapted_func + + # Should not raise an exception, just log a warning + await ComputerUseToolset.adapt_computer_use_tool( + "invalid_method", adapter, llm_request + ) + + # Should not add any tools + assert len(llm_request.tools_dict) == 0 + + @pytest.mark.asyncio + async def test_adapt_computer_use_tool_excluded_method(self): + """Test adapt_computer_use_tool with excluded method name.""" + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + def adapter(original_func): + async def adapted_func(): + return await original_func() + + return adapted_func + + # Should not raise an exception, just log a warning + await ComputerUseToolset.adapt_computer_use_tool( + "screen_size", adapter, llm_request + ) + + # Should not add any tools + assert len(llm_request.tools_dict) == 0 + + @pytest.mark.asyncio + async def test_adapt_computer_use_tool_method_not_in_tools_dict(self): + """Test adapt_computer_use_tool when method is not in tools_dict.""" + llm_request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(), + ) + + def adapter(original_func): + async def adapted_func(): + return await original_func() + + return adapted_func + + # Should not raise an exception, just log a warning + await ComputerUseToolset.adapt_computer_use_tool( + "wait", adapter, llm_request + ) + + # Should not add any tools + assert len(llm_request.tools_dict) == 0