fix: Decode image data from ComputerUse tool response into image blobs

PiperOrigin-RevId: 875292001
This commit is contained in:
Google Team Member
2026-02-25 12:45:52 -08:00
committed by Copybara-Service
parent 35366f4e2a
commit d7cfd8fe4d
2 changed files with 118 additions and 9 deletions
+56 -1
View File
@@ -17,6 +17,8 @@
from __future__ import annotations
import asyncio
import base64
import binascii
from concurrent.futures import ThreadPoolExecutor
import copy
import functools
@@ -31,6 +33,7 @@ from typing import Optional
from typing import TYPE_CHECKING
import uuid
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
from google.genai import types
from ...agents.active_streaming_tool import ActiveStreamingTool
@@ -991,6 +994,50 @@ def _get_tool_and_context(
return (tool, tool_context)
def _try_decode_computer_use_image(
tool: BaseTool,
function_result: dict[str, object],
) -> Optional[list[types.FunctionResponsePart]]:
"""Decodes the image from the function result for a computer use tool.
Args:
tool: The tool that produced the function result.
function_result: The dictionary containing the function's result. This
dictionary may be modified in-place to remove the 'image' key if an image
is successfully decoded.
Returns:
A list containing a `types.FunctionResponsePart` with the decoded image
data, or None if no image was found or decoding failed.
"""
if not isinstance(tool, ComputerUseTool) or not isinstance(
function_result, dict
):
return None
if (
'image' not in function_result
or 'data' not in function_result['image']
or 'mimetype' not in function_result['image']
):
return None
try:
image_data = base64.b64decode(function_result['image']['data'])
mime_type = function_result['image']['mimetype']
part = types.FunctionResponsePart.from_bytes(
data=image_data, mime_type=mime_type
)
del function_result['image']
return [part]
except (binascii.Error, ValueError):
logger.exception('Failed to decode image from computer use tool')
return None
async def __call_tool_live(
tool: BaseTool,
args: dict[str, object],
@@ -1028,8 +1075,16 @@ def __build_response_event(
if not isinstance(function_result, dict):
function_result = {'result': function_result}
function_response_parts = None
if isinstance(tool, ComputerUseTool):
function_response_parts = _try_decode_computer_use_image(
tool, function_result
)
part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result
name=tool.name,
response=function_result,
parts=function_response_parts,
)
part_function_response.function_response.id = tool_context.function_call_id
@@ -19,7 +19,10 @@ from typing import Callable
from google.adk.agents.llm_agent import Agent
from google.adk.events.event import Event
from google.adk.flows.llm_flows.functions import find_matching_function_call
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live
from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.genai import types
@@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses():
@pytest.mark.asyncio
async def test_function_call_args_not_modified():
"""Test that function_call.args is not modified when making a copy."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live
def simple_fn(**kwargs) -> dict:
return {'result': 'test'}
@@ -455,8 +456,6 @@ async def test_function_call_args_not_modified():
@pytest.mark.asyncio
async def test_function_call_args_none_handling():
"""Test that function_call.args=None is handled correctly."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live
def simple_fn(**kwargs) -> dict:
return {'result': 'test'}
@@ -504,8 +503,6 @@ async def test_function_call_args_none_handling():
@pytest.mark.asyncio
async def test_function_call_args_copy_behavior():
"""Test that modifying the copied args doesn't affect the original."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live
def simple_fn(test_param: str, other_param: int) -> dict:
# Modify the args to test that the copy prevents affecting the original
@@ -565,8 +562,6 @@ async def test_function_call_args_copy_behavior():
@pytest.mark.asyncio
async def test_function_call_args_deep_copy_behavior():
"""Test that deep copy behavior works correctly with nested structures."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live
def simple_fn(nested_dict: dict, list_param: list) -> dict:
# Modify the nested structures to test deep copy
@@ -1141,3 +1136,62 @@ async def test_mixed_function_types_execution_order():
'yield_E',
'yield_F',
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
'handle_function_calls',
[
(handle_function_calls_async),
(handle_function_calls_live),
],
)
async def test_computer_use_tool_decoding_behavior(handle_function_calls):
"""Tests that computer use tools automatically decode base64 images."""
valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7'
# make the tool return a dictionary with the image
async def mock_run(*args, **kwargs):
return {
'image': {'data': valid_b64, 'mimetype': 'image/png'},
'url': 'https://example.com',
}
# create a ComputerUseTool
tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768))
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='test_agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)
# Create function call
function_call = types.FunctionCall(name=tool.name, args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
result = await handle_function_calls(
invocation_context,
event,
tools_dict,
)
assert result is not None
response_part = result.content.parts[0].function_response
# Verify original image data is removed from the dict response
assert 'image' not in response_part.response
assert 'url' in response_part.response
# Verify the image was converted to a blob
assert len(response_part.parts) == 1
assert response_part.parts[0].inline_data is not None