feat: Add ComputerUseToolset

PiperOrigin-RevId: 785646071
This commit is contained in:
Xiang (Sean) Zhou
2025-07-21 18:07:31 -07:00
committed by Copybara-Service
parent b2c2f1bd33
commit 083dcb4465
10 changed files with 2291 additions and 5 deletions
+29 -2
View File
@@ -88,7 +88,7 @@ class Gemini(BaseLlm):
Yields:
LlmResponse: The model response.
"""
self._preprocess_request(llm_request)
await self._preprocess_request(llm_request)
self._maybe_append_user_content(llm_request)
logger.info(
'Sending out request, model: %s, backend: %s, stream: %s',
@@ -269,7 +269,22 @@ class Gemini(BaseLlm):
) as live_session:
yield GeminiLlmConnection(live_session)
def _preprocess_request(self, llm_request: LlmRequest) -> None:
async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
"""Adapt the google computer use predefined functions to the adk computer use toolset."""
from ..tools.computer_use.computer_use_toolset import ComputerUseToolset
async def convert_wait_to_wait_5_seconds(wait_func):
async def wait_5_seconds():
return await wait_func(5)
return wait_5_seconds
await ComputerUseToolset.adapt_computer_use_tool(
'wait', convert_wait_to_wait_5_seconds, llm_request
)
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
if self._api_backend == GoogleLLMVariant.GEMINI_API:
# Using API key from Google AI Studio to call model doesn't support labels.
@@ -284,6 +299,18 @@ class Gemini(BaseLlm):
_remove_display_name_if_present(part.inline_data)
_remove_display_name_if_present(part.file_data)
# Initialize config if needed
if llm_request.config and llm_request.config.tools:
# Check if computer use is configured
for tool in llm_request.config.tools:
if (
isinstance(tool, (types.Tool, types.ToolDict))
and hasattr(tool, 'computer_use')
and tool.computer_use
):
llm_request.config.system_instruction = None
await self._adapt_computer_use_tool(llm_request)
def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,265 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from enum import Enum
from typing import Literal
from typing import Optional
import pydantic
from ...utils.feature_decorator import experimental
@experimental
class ComputerEnvironment(str, Enum):
"""Case insensitive enum for computer environments."""
ENVIRONMENT_UNSPECIFIED = "ENVIRONMENT_UNSPECIFIED"
"""Defaults to browser."""
ENVIRONMENT_BROWSER = "ENVIRONMENT_BROWSER"
"""Operates in a web browser."""
@experimental
class ComputerState(pydantic.BaseModel):
"""Represents the current state of the computer environment.
Attributes:
screenshot: The screenshot in PNG format as bytes.
url: The current URL of the webpage being displayed.
"""
screenshot: bytes = pydantic.Field(
default=None, description="Screenshot in PNG format"
)
url: Optional[str] = pydantic.Field(
default=None, description="Current webpage URL"
)
@experimental
class BaseComputer(abc.ABC):
"""async defines an interface for computer environments.
This abstract base class async defines the standard interface for controlling
computer environments, including web browsers and other interactive systems.
"""
@abc.abstractmethod
async def screen_size(self) -> tuple[int, int]:
"""Returns the screen size of the environment.
Returns:
A tuple of (width, height) in pixels.
"""
@abc.abstractmethod
async def open_web_browser(self) -> ComputerState:
"""Opens the web browser.
Returns:
The current state after opening the browser.
"""
@abc.abstractmethod
async def click_at(self, x: int, y: int) -> ComputerState:
"""Clicks at a specific x, y coordinate on the webpage.
The 'x' and 'y' values are absolute values, scaled to the height and width of the screen.
Args:
x: The x-coordinate to click at.
y: The y-coordinate to click at.
Returns:
The current state after clicking.
"""
@abc.abstractmethod
async def hover_at(self, x: int, y: int) -> ComputerState:
"""Hovers at a specific x, y coordinate on the webpage.
May be used to explore sub-menus that appear on hover.
The 'x' and 'y' values are absolute values, scaled to the height and width of the screen.
Args:
x: The x-coordinate to hover at.
y: The y-coordinate to hover at.
Returns:
The current state after hovering.
"""
@abc.abstractmethod
async def type_text_at(
self,
x: int,
y: int,
text: str,
press_enter: bool = True,
clear_before_typing: bool = True,
) -> ComputerState:
"""Types text at a specific x, y coordinate.
The system automatically presses ENTER after typing. To disable this, set `press_enter` to False.
The system automatically clears any existing content before typing the specified `text`. To disable this, set `clear_before_typing` to False.
The 'x' and 'y' values are absolute values, scaled to the height and width of the screen.
Args:
x: The x-coordinate to type at.
y: The y-coordinate to type at.
text: The text to type.
press_enter: Whether to press ENTER after typing.
clear_before_typing: Whether to clear existing content before typing.
Returns:
The current state after typing.
"""
@abc.abstractmethod
async def scroll_document(
self, direction: Literal["up", "down", "left", "right"]
) -> ComputerState:
"""Scrolls the entire webpage "up", "down", "left" or "right" based on direction.
Args:
direction: The direction to scroll.
Returns:
The current state after scrolling.
"""
@abc.abstractmethod
async def scroll_at(
self,
x: int,
y: int,
direction: Literal["up", "down", "left", "right"],
magnitude: int,
) -> ComputerState:
"""Scrolls up, down, right, or left at a x, y coordinate by magnitude.
The 'x' and 'y' values are absolute values, scaled to the height and width of the screen.
Args:
x: The x-coordinate to scroll at.
y: The y-coordinate to scroll at.
direction: The direction to scroll.
magnitude: The amount to scroll.
Returns:
The current state after scrolling.
"""
@abc.abstractmethod
async def wait(self, seconds: int) -> ComputerState:
"""Waits for n seconds to allow unfinished webpage processes to complete.
Args:
seconds: The number of seconds to wait.
Returns:
The current state after waiting.
"""
@abc.abstractmethod
async def go_back(self) -> ComputerState:
"""Navigates back to the previous webpage in the browser history.
Returns:
The current state after navigating back.
"""
@abc.abstractmethod
async def go_forward(self) -> ComputerState:
"""Navigates forward to the next webpage in the browser history.
Returns:
The current state after navigating forward.
"""
@abc.abstractmethod
async def search(self) -> ComputerState:
"""Directly jumps to a search engine home page.
Used when you need to start with a search. For example, this is used when
the current website doesn't have the information needed or because a new
task is being started.
Returns:
The current state after navigating to search.
"""
@abc.abstractmethod
async def navigate(self, url: str) -> ComputerState:
"""Navigates directly to a specified URL.
Args:
url: The URL to navigate to.
Returns:
The current state after navigation.
"""
@abc.abstractmethod
async def key_combination(self, keys: list[str]) -> ComputerState:
"""Presses keyboard keys and combinations, such as "control+c" or "enter".
Args:
keys: List of keys to press in combination.
Returns:
The current state after key press.
"""
@abc.abstractmethod
async def drag_and_drop(
self, x: int, y: int, destination_x: int, destination_y: int
) -> ComputerState:
"""Drag and drop an element from a x, y coordinate to a destination destination_y, destination_x coordinate.
The 'x', 'y', 'destination_y' and 'destination_x' values are absolute values, scaled to the height and width of the screen.
Args:
x: The x-coordinate to start dragging from.
y: The y-coordinate to start dragging from.
destination_x: The x-coordinate to drop at.
destination_y: The y-coordinate to drop at.
Returns:
The current state after drag and drop.
"""
@abc.abstractmethod
async def current_state(self) -> ComputerState:
"""Returns the current state of the current webpage.
Returns:
The current environment state.
"""
async def initialize(self) -> None:
"""Initialize the computer."""
pass
async def close(self) -> None:
"""Cleanup resource of the computer."""
pass
@abc.abstractmethod
async def environment(self) -> ComputerEnvironment:
"""Returns the environment of the computer."""
@@ -0,0 +1,166 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import logging
from typing import Any
from typing import Callable
from google.genai import types
from typing_extensions import override
from ...models.llm_request import LlmRequest
from ...utils.feature_decorator import experimental
from ..function_tool import FunctionTool
from ..tool_context import ToolContext
from .base_computer import ComputerState
logger = logging.getLogger("google_adk." + __name__)
@experimental
class ComputerUseTool(FunctionTool):
"""A tool that wraps computer control functions for use with LLMs.
This tool automatically normalizes coordinates from a virtual coordinate space
(by default 1000x1000) to the actual screen size. This allows LLMs to work
with a consistent coordinate system regardless of the actual screen dimensions,
making their output more predictable and easier to handle.
"""
def __init__(
self,
*,
func: Callable[..., Any],
screen_size: tuple[int, int],
virtual_screen_size: tuple[int, int] = (1000, 1000),
):
"""Initialize the ComputerUseTool.
Args:
func: The computer control function to wrap.
screen_size: The actual screen size as (width, height) in pixels.
This represents the real dimensions of the target screen/display.
virtual_screen_size: The virtual coordinate space dimensions as (width, height)
that the LLM uses to specify coordinates. Coordinates from the LLM are
automatically normalized from this virtual space to the actual screen_size.
Default is (1000, 1000), meaning the LLM thinks it's working with a
1000x1000 pixel screen regardless of the actual screen dimensions.
Raises:
ValueError: If screen_size or virtual_screen_size is not a valid tuple
of positive integers.
"""
super().__init__(func=func)
self._screen_size = screen_size
self._coordinate_space = virtual_screen_size
# Validate screen size
if not isinstance(screen_size, tuple) or len(screen_size) != 2:
raise ValueError("screen_size must be a tuple of (width, height)")
if screen_size[0] <= 0 or screen_size[1] <= 0:
raise ValueError("screen_size dimensions must be positive")
# Validate virtual screen size
if (
not isinstance(virtual_screen_size, tuple)
or len(virtual_screen_size) != 2
):
raise ValueError("virtual_screen_size must be a tuple of (width, height)")
if virtual_screen_size[0] <= 0 or virtual_screen_size[1] <= 0:
raise ValueError("virtual_screen_size dimensions must be positive")
def _normalize_x(self, x: int) -> int:
"""Normalize x coordinate from virtual screen space to actual screen width."""
if not isinstance(x, (int, float)):
raise ValueError(f"x coordinate must be numeric, got {type(x)}")
normalized = int(x / self._coordinate_space[0] * self._screen_size[0])
# Clamp to screen bounds
return max(0, min(normalized, self._screen_size[0] - 1))
def _normalize_y(self, y: int) -> int:
"""Normalize y coordinate from virtual screen space to actual screen height."""
if not isinstance(y, (int, float)):
raise ValueError(f"y coordinate must be numeric, got {type(y)}")
normalized = int(y / self._coordinate_space[1] * self._screen_size[1])
# Clamp to screen bounds
return max(0, min(normalized, self._screen_size[1] - 1))
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
"""Run the computer control function with normalized coordinates."""
try:
# Normalize coordinates if present
if "x" in args:
original_x = args["x"]
args["x"] = self._normalize_x(args["x"])
logger.debug("Normalized x: %s -> %s", original_x, args["x"])
if "y" in args:
original_y = args["y"]
args["y"] = self._normalize_y(args["y"])
logger.debug("Normalized y: %s -> %s", original_y, args["y"])
# Handle destination coordinates for drag and drop
if "destination_x" in args:
original_dest_x = args["destination_x"]
args["destination_x"] = self._normalize_x(args["destination_x"])
logger.debug(
"Normalized destination_x: %s -> %s",
original_dest_x,
args["destination_x"],
)
if "destination_y" in args:
original_dest_y = args["destination_y"]
args["destination_y"] = self._normalize_y(args["destination_y"])
logger.debug(
"Normalized destination_y: %s -> %s",
original_dest_y,
args["destination_y"],
)
# Execute the actual computer control function
result = await super().run_async(args=args, tool_context=tool_context)
# Process the result if it's an EnvironmentState
if isinstance(result, ComputerState):
return {
"image": {
"mimetype": "image/png",
"data": base64.b64encode(result.screenshot).decode("utf-8"),
},
"url": result.url,
}
return result
except Exception as e:
logger.error("Error in ComputerUseTool.run_async: %s", e)
raise
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
"""ComputerUseToolset will add this tool to the LLM request and add computer
use configuration to the LLM request."""
pass
@@ -0,0 +1,220 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union
from google.genai import types
from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from ...models.llm_request import LlmRequest
from ...utils.feature_decorator import experimental
from ..base_toolset import BaseToolset
from ..tool_context import ToolContext
from .base_computer import BaseComputer
from .computer_use_tool import ComputerUseTool
# Methods that should be excluded when creating tools from BaseComputer methods
EXCLUDED_METHODS = {"screen_size", "environment", "close"}
logger = logging.getLogger("google_adk." + __name__)
@experimental
class ComputerUseToolset(BaseToolset):
def __init__(
self,
*,
computer: BaseComputer,
):
super().__init__()
self._computer = computer
self._initialized = False
self._tools = None
async def _ensure_initialized(self) -> None:
if not self._initialized:
await self._computer.initialize()
self._initialized = True
@staticmethod
async def adapt_computer_use_tool(
method_name: str,
adapter_func: Union[
Callable[[Callable[..., Any]], Callable[..., Any]],
Callable[[Callable[..., Any]], Any],
],
llm_request: LlmRequest,
) -> None:
"""Adapt a computer use tool by replacing it with a modified version.
Args:
method_name: The name of the method (of BaseComputer class) to adapt (e.g. 'wait').
adapter_func: A function that accepts existing computer use async function and returns a new computer use async function.
Can be either sync or async function. The name of the returned function will be used as the new tool name.
llm_request: The LLM request containing the tools dictionary.
"""
# Validate that the method is a valid BaseComputer method
if method_name in EXCLUDED_METHODS:
logger.warning(
"Method %s is not a valid BaseComputer method", method_name
)
return
# Check if it's a method defined in BaseComputer class
attr = getattr(BaseComputer, method_name, None)
if attr is None or not callable(attr):
logger.warning(
"Method %s is not a valid BaseComputer method", method_name
)
return
if method_name not in llm_request.tools_dict:
logger.warning("Method %s not found in tools_dict", method_name)
return
original_tool = llm_request.tools_dict[method_name]
# Create the adapted function using the adapter
# Handle both sync and async adapter functions
if asyncio.iscoroutinefunction(adapter_func):
# If adapter_func is async, await it to get the adapted function
adapted_func = await adapter_func(original_tool.func)
else:
# If adapter_func is sync, call it directly
adapted_func = adapter_func(original_tool.func)
# Get the name from the adapted function
new_method_name = adapted_func.__name__
# Create a new ComputerUseTool with the adapted function
adapted_tool = ComputerUseTool(
func=adapted_func,
screen_size=original_tool._screen_size,
virtual_screen_size=original_tool._coordinate_space,
)
# Add the adapted tool and remove the original
llm_request.tools_dict[new_method_name] = adapted_tool
del llm_request.tools_dict[method_name]
logger.debug(
"Adapted tool %s to %s with adapter function",
method_name,
new_method_name,
)
@override
async def get_tools(
self,
readonly_context: Optional[ReadonlyContext] = None,
) -> list[ComputerUseTool]:
if self._tools:
return self._tools
await self._ensure_initialized()
# Get screen size for tool configuration
screen_size = await self._computer.screen_size()
# Get all methods defined in Computer abstract base class, excluding specified methods
computer_methods = []
# Get all methods defined in the Computer ABC interface
for method_name in dir(BaseComputer):
# Skip private methods (starting with underscore)
if method_name.startswith("_"):
continue
# Skip excluded methods
if method_name in EXCLUDED_METHODS:
continue
# Check if it's a method defined in Computer class
attr = getattr(BaseComputer, method_name, None)
if attr is not None and callable(attr):
# Get the corresponding method from the concrete instance
instance_method = getattr(self._computer, method_name)
computer_methods.append(instance_method)
# Create ComputerUseTool instances for each method
self._tools = [
ComputerUseTool(
func=method,
screen_size=screen_size,
)
for method in computer_methods
]
return self._tools
@override
async def close(self) -> None:
await self._computer.close()
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
"""Add its tools to the LLM request and add computer
use configuration to the LLM request."""
try:
# Add this tool to the tools dictionary
if not self._tools:
await self.get_tools()
for tool in self._tools:
llm_request.tools_dict[tool.name] = tool
# Initialize config if needed
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
# Check if computer use is already configured
for tool in llm_request.config.tools:
if (
isinstance(tool, (types.Tool, types.ToolDict))
and hasattr(tool, "computer_use")
and tool.computer_use
):
logger.debug("Computer use already configured in LLM request")
return
# Add computer use tool configuration
computer_environment = await self._computer.environment()
environment = getattr(
types.Environment,
computer_environment.name,
types.Environment.ENVIRONMENT_BROWSER,
)
llm_request.config.tools.append(
types.Tool(
computer_use=types.ToolComputerUse(environment=environment)
)
)
logger.debug(
"Added computer use tool with environment: %s",
environment,
)
except Exception as e:
logger.error("Error in ComputerUseToolset.process_llm_request: %s", e)
raise
+186 -3
View File
@@ -16,6 +16,7 @@ import os
import sys
from typing import Optional
from unittest import mock
from unittest.mock import AsyncMock
from google.adk import version as adk_version
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
@@ -28,7 +29,6 @@ from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai import version as genai_version
from google.genai.types import Content
from google.genai.types import FinishReason
from google.genai.types import Part
import pytest
@@ -66,6 +66,26 @@ def llm_request():
)
@pytest.fixture
def llm_request_with_computer_use():
return LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction="You are a helpful assistant",
tools=[
types.Tool(
computer_use=types.ToolComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER
)
)
],
),
)
@pytest.fixture
def mock_os_environ():
initial_env = os.environ.copy()
@@ -614,7 +634,8 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request):
),
],
)
def test_preprocess_request_handles_backend_specific_fields(
@pytest.mark.asyncio
async def test_preprocess_request_handles_backend_specific_fields(
gemini_llm: Gemini,
api_backend: GoogleLLMVariant,
expected_file_display_name: Optional[str],
@@ -662,7 +683,7 @@ def test_preprocess_request_handles_backend_specific_fields(
mock_backend.return_value = api_backend
# Act: Run the preprocessing method
gemini_llm._preprocess_request(llm_request_with_files)
await gemini_llm._preprocess_request(llm_request_with_files)
# Assert: Check if the fields were correctly processed
file_part = llm_request_with_files.contents[0].parts[0]
@@ -1535,3 +1556,165 @@ async def test_generate_content_async_stream_two_separate_text_aggregations():
function_call_responses[0].content.parts[0].function_call.name
== "divide"
)
@pytest.mark.asyncio
async def test_computer_use_removes_system_instruction():
"""Test that system instruction is set to None when computer use is configured."""
llm = Gemini()
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="Hello")])
],
config=types.GenerateContentConfig(
system_instruction="You are a helpful assistant",
tools=[
types.Tool(
computer_use=types.ToolComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER
)
)
],
),
)
await llm._preprocess_request(llm_request)
# System instruction should be set to None when computer use is configured
assert llm_request.config.system_instruction is None
@pytest.mark.asyncio
async def test_computer_use_preserves_system_instruction_when_no_computer_use():
"""Test that system instruction is preserved when computer use is not configured."""
llm = Gemini()
original_instruction = "You are a helpful assistant"
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="Hello")])
],
config=types.GenerateContentConfig(
system_instruction=original_instruction,
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(name="test", description="test")
]
)
],
),
)
await llm._preprocess_request(llm_request)
# System instruction should be preserved when no computer use
assert llm_request.config.system_instruction == original_instruction
@pytest.mark.asyncio
async def test_computer_use_with_no_config():
"""Test that preprocessing works when config is None."""
llm = Gemini()
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="Hello")])
],
config=None,
)
# Should not raise an exception
await llm._preprocess_request(llm_request)
@pytest.mark.asyncio
async def test_computer_use_with_no_tools():
"""Test that preprocessing works when config.tools is None."""
llm = Gemini()
original_instruction = "You are a helpful assistant"
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="Hello")])
],
config=types.GenerateContentConfig(
system_instruction=original_instruction,
tools=None,
),
)
await llm._preprocess_request(llm_request)
# System instruction should be preserved when no tools
assert llm_request.config.system_instruction == original_instruction
@pytest.mark.asyncio
async def test_adapt_computer_use_tool_wait():
"""Test that _adapt_computer_use_tool correctly adapts wait to wait_5_seconds."""
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
llm = Gemini()
# Create a mock wait tool
mock_wait_func = AsyncMock()
mock_wait_func.return_value = "mock_result"
original_wait_tool = ComputerUseTool(
func=mock_wait_func,
screen_size=(1920, 1080),
virtual_screen_size=(1000, 1000),
)
llm_request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(),
)
# Add wait to tools_dict
llm_request.tools_dict["wait"] = original_wait_tool
# Call the adaptation method (now async)
await llm._adapt_computer_use_tool(llm_request)
# Verify wait was removed and wait_5_seconds was added
assert "wait" not in llm_request.tools_dict
assert "wait_5_seconds" in llm_request.tools_dict
# Verify the new tool has correct properties
wait_5_seconds_tool = llm_request.tools_dict["wait_5_seconds"]
assert isinstance(wait_5_seconds_tool, ComputerUseTool)
assert wait_5_seconds_tool._screen_size == (1920, 1080)
assert wait_5_seconds_tool._coordinate_space == (1000, 1000)
# Verify calling the new tool calls the original with 5 seconds
result = await wait_5_seconds_tool.func()
assert result == "mock_result"
mock_wait_func.assert_awaited_once_with(5)
@pytest.mark.asyncio
async def test_adapt_computer_use_tool_no_wait():
"""Test that _adapt_computer_use_tool does nothing when wait is not present."""
llm = Gemini()
llm_request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(),
)
# Don't add any tools
original_tools_dict = llm_request.tools_dict.copy()
# Call the adaptation method (now async)
await llm._adapt_computer_use_tool(llm_request)
# Verify tools_dict is unchanged
assert llm_request.tools_dict == original_tools_dict
assert "wait_5_seconds" not in llm_request.tools_dict
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,341 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for base_computer module."""
from typing import Literal
from google.adk.tools.computer_use.base_computer import BaseComputer
from google.adk.tools.computer_use.base_computer import ComputerEnvironment
from google.adk.tools.computer_use.base_computer import ComputerState
import pytest
class TestComputerEnvironment:
"""Test cases for ComputerEnvironment enum."""
def test_valid_environments(self):
"""Test valid environment values."""
assert (
ComputerEnvironment.ENVIRONMENT_UNSPECIFIED == "ENVIRONMENT_UNSPECIFIED"
)
assert ComputerEnvironment.ENVIRONMENT_BROWSER == "ENVIRONMENT_BROWSER"
def test_invalid_environment_raises(self):
"""Test that invalid environment values raise ValueError."""
with pytest.raises(ValueError):
ComputerEnvironment("INVALID_ENVIRONMENT")
def test_string_representation(self):
"""Test string representation of enum values."""
assert (
ComputerEnvironment.ENVIRONMENT_BROWSER.value == "ENVIRONMENT_BROWSER"
)
assert (
ComputerEnvironment.ENVIRONMENT_UNSPECIFIED.value
== "ENVIRONMENT_UNSPECIFIED"
)
class TestComputerState:
"""Test cases for ComputerState Pydantic model."""
def test_default_initialization(self):
"""Test ComputerState with default values."""
state = ComputerState()
assert state.screenshot is None
assert state.url is None
def test_initialization_with_screenshot(self):
"""Test ComputerState with screenshot data."""
screenshot_data = b"fake_png_data"
state = ComputerState(screenshot=screenshot_data)
assert state.screenshot == screenshot_data
assert state.url is None
def test_initialization_with_url(self):
"""Test ComputerState with URL."""
url = "https://example.com"
state = ComputerState(url=url)
assert state.screenshot is None
assert state.url == url
def test_initialization_with_all_fields(self):
"""Test ComputerState with all fields provided."""
screenshot_data = b"fake_png_data"
url = "https://example.com"
state = ComputerState(screenshot=screenshot_data, url=url)
assert state.screenshot == screenshot_data
assert state.url == url
def test_field_validation(self):
"""Test field validation for ComputerState."""
# Test that bytes are accepted for screenshot
state = ComputerState(screenshot=b"test_data")
assert state.screenshot == b"test_data"
# Test that string is accepted for URL
state = ComputerState(url="https://test.com")
assert state.url == "https://test.com"
def test_model_serialization(self):
"""Test that ComputerState can be serialized."""
state = ComputerState(screenshot=b"test", url="https://example.com")
# Should not raise an exception
model_dict = state.model_dump()
assert "screenshot" in model_dict
assert "url" in model_dict
class MockComputer(BaseComputer):
"""Mock implementation of BaseComputer for testing."""
def __init__(self):
self.initialized = False
self.closed = False
async def screen_size(self) -> tuple[int, int]:
return (1920, 1080)
async def open_web_browser(self) -> ComputerState:
return ComputerState(url="https://example.com")
async def click_at(self, x: int, y: int) -> ComputerState:
return ComputerState(url="https://example.com")
async def hover_at(self, x: int, y: int) -> ComputerState:
return ComputerState(url="https://example.com")
async def type_text_at(
self,
x: int,
y: int,
text: str,
press_enter: bool = True,
clear_before_typing: bool = True,
) -> ComputerState:
return ComputerState(url="https://example.com")
async def scroll_document(
self, direction: Literal["up", "down", "left", "right"]
) -> ComputerState:
return ComputerState(url="https://example.com")
async def scroll_at(
self,
x: int,
y: int,
direction: Literal["up", "down", "left", "right"],
magnitude: int,
) -> ComputerState:
return ComputerState(url="https://example.com")
async def wait(self, seconds: int) -> ComputerState:
return ComputerState(url="https://example.com")
async def go_back(self) -> ComputerState:
return ComputerState(url="https://example.com")
async def go_forward(self) -> ComputerState:
return ComputerState(url="https://example.com")
async def search(self) -> ComputerState:
return ComputerState(url="https://search.example.com")
async def navigate(self, url: str) -> ComputerState:
return ComputerState(url=url)
async def key_combination(self, keys: list[str]) -> ComputerState:
return ComputerState(url="https://example.com")
async def drag_and_drop(
self, x: int, y: int, destination_x: int, destination_y: int
) -> ComputerState:
return ComputerState(url="https://example.com")
async def current_state(self) -> ComputerState:
return ComputerState(
url="https://example.com", screenshot=b"screenshot_data"
)
async def initialize(self) -> None:
self.initialized = True
async def close(self) -> None:
self.closed = True
async def environment(self) -> ComputerEnvironment:
return ComputerEnvironment.ENVIRONMENT_BROWSER
class TestBaseComputer:
"""Test cases for BaseComputer abstract base class."""
@pytest.fixture
def mock_computer(self) -> MockComputer:
"""Fixture providing a mock computer implementation."""
return MockComputer()
def test_cannot_instantiate_abstract_class(self):
"""Test that BaseComputer cannot be instantiated directly."""
import pytest
with pytest.raises(TypeError):
BaseComputer() # Should raise TypeError because it's abstract
@pytest.mark.asyncio
async def test_screen_size(self, mock_computer):
"""Test screen_size method."""
size = await mock_computer.screen_size()
assert size == (1920, 1080)
assert isinstance(size, tuple)
assert len(size) == 2
@pytest.mark.asyncio
async def test_open_web_browser(self, mock_computer):
"""Test open_web_browser method."""
state = await mock_computer.open_web_browser()
assert isinstance(state, ComputerState)
assert state.url == "https://example.com"
@pytest.mark.asyncio
async def test_click_at(self, mock_computer):
"""Test click_at method."""
state = await mock_computer.click_at(100, 200)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_hover_at(self, mock_computer):
"""Test hover_at method."""
state = await mock_computer.hover_at(150, 250)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_type_text_at(self, mock_computer):
"""Test type_text_at method with different parameters."""
# Test with default parameters
state = await mock_computer.type_text_at(100, 200, "Hello World")
assert isinstance(state, ComputerState)
# Test with custom parameters
state = await mock_computer.type_text_at(
100, 200, "Hello", press_enter=False, clear_before_typing=False
)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_scroll_document(self, mock_computer):
"""Test scroll_document method with different directions."""
directions = ["up", "down", "left", "right"]
for direction in directions:
state = await mock_computer.scroll_document(direction)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_scroll_at(self, mock_computer):
"""Test scroll_at method."""
state = await mock_computer.scroll_at(100, 200, "down", 5)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_wait(self, mock_computer):
"""Test wait method."""
state = await mock_computer.wait(5)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_go_back(self, mock_computer):
"""Test go_back method."""
state = await mock_computer.go_back()
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_go_forward(self, mock_computer):
"""Test go_forward method."""
state = await mock_computer.go_forward()
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_search(self, mock_computer):
"""Test search method."""
state = await mock_computer.search()
assert isinstance(state, ComputerState)
assert state.url == "https://search.example.com"
@pytest.mark.asyncio
async def test_navigate(self, mock_computer):
"""Test navigate method."""
test_url = "https://test.example.com"
state = await mock_computer.navigate(test_url)
assert isinstance(state, ComputerState)
assert state.url == test_url
@pytest.mark.asyncio
async def test_key_combination(self, mock_computer):
"""Test key_combination method."""
state = await mock_computer.key_combination(["ctrl", "c"])
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_drag_and_drop(self, mock_computer):
"""Test drag_and_drop method."""
state = await mock_computer.drag_and_drop(100, 200, 300, 400)
assert isinstance(state, ComputerState)
@pytest.mark.asyncio
async def test_current_state(self, mock_computer):
"""Test current_state method."""
state = await mock_computer.current_state()
assert isinstance(state, ComputerState)
assert state.url == "https://example.com"
assert state.screenshot == b"screenshot_data"
@pytest.mark.asyncio
async def test_initialize(self, mock_computer):
"""Test initialize method."""
assert not mock_computer.initialized
await mock_computer.initialize()
assert mock_computer.initialized
@pytest.mark.asyncio
async def test_close(self, mock_computer):
"""Test close method."""
assert not mock_computer.closed
await mock_computer.close()
assert mock_computer.closed
@pytest.mark.asyncio
async def test_environment(self, mock_computer):
"""Test environment method."""
env = await mock_computer.environment()
assert env == ComputerEnvironment.ENVIRONMENT_BROWSER
assert isinstance(env, ComputerEnvironment)
@pytest.mark.asyncio
async def test_lifecycle_methods(self, mock_computer):
"""Test the lifecycle of a computer instance."""
# Initially not initialized or closed
assert not mock_computer.initialized
assert not mock_computer.closed
# Initialize
await mock_computer.initialize()
assert mock_computer.initialized
assert not mock_computer.closed
# Close
await mock_computer.close()
assert mock_computer.initialized
assert mock_computer.closed
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff