You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Decode image data from ComputerUse tool response into image blobs
PiperOrigin-RevId: 875292001
This commit is contained in:
committed by
Copybara-Service
parent
35366f4e2a
commit
d7cfd8fe4d
@@ -17,6 +17,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import functools
|
||||
@@ -31,6 +33,7 @@ from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
import uuid
|
||||
|
||||
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
|
||||
from google.genai import types
|
||||
|
||||
from ...agents.active_streaming_tool import ActiveStreamingTool
|
||||
@@ -991,6 +994,50 @@ def _get_tool_and_context(
|
||||
return (tool, tool_context)
|
||||
|
||||
|
||||
def _try_decode_computer_use_image(
|
||||
tool: BaseTool,
|
||||
function_result: dict[str, object],
|
||||
) -> Optional[list[types.FunctionResponsePart]]:
|
||||
"""Decodes the image from the function result for a computer use tool.
|
||||
|
||||
Args:
|
||||
tool: The tool that produced the function result.
|
||||
function_result: The dictionary containing the function's result. This
|
||||
dictionary may be modified in-place to remove the 'image' key if an image
|
||||
is successfully decoded.
|
||||
|
||||
Returns:
|
||||
A list containing a `types.FunctionResponsePart` with the decoded image
|
||||
data, or None if no image was found or decoding failed.
|
||||
"""
|
||||
|
||||
if not isinstance(tool, ComputerUseTool) or not isinstance(
|
||||
function_result, dict
|
||||
):
|
||||
return None
|
||||
|
||||
if (
|
||||
'image' not in function_result
|
||||
or 'data' not in function_result['image']
|
||||
or 'mimetype' not in function_result['image']
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
image_data = base64.b64decode(function_result['image']['data'])
|
||||
mime_type = function_result['image']['mimetype']
|
||||
|
||||
part = types.FunctionResponsePart.from_bytes(
|
||||
data=image_data, mime_type=mime_type
|
||||
)
|
||||
|
||||
del function_result['image']
|
||||
return [part]
|
||||
except (binascii.Error, ValueError):
|
||||
logger.exception('Failed to decode image from computer use tool')
|
||||
return None
|
||||
|
||||
|
||||
async def __call_tool_live(
|
||||
tool: BaseTool,
|
||||
args: dict[str, object],
|
||||
@@ -1028,8 +1075,16 @@ def __build_response_event(
|
||||
if not isinstance(function_result, dict):
|
||||
function_result = {'result': function_result}
|
||||
|
||||
function_response_parts = None
|
||||
if isinstance(tool, ComputerUseTool):
|
||||
function_response_parts = _try_decode_computer_use_image(
|
||||
tool, function_result
|
||||
)
|
||||
|
||||
part_function_response = types.Part.from_function_response(
|
||||
name=tool.name, response=function_result
|
||||
name=tool.name,
|
||||
response=function_result,
|
||||
parts=function_response_parts,
|
||||
)
|
||||
part_function_response.function_response.id = tool_context.function_call_id
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ from typing import Callable
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.functions import find_matching_function_call
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_live
|
||||
from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events
|
||||
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
@@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses():
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_args_not_modified():
|
||||
"""Test that function_call.args is not modified when making a copy."""
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_live
|
||||
|
||||
def simple_fn(**kwargs) -> dict:
|
||||
return {'result': 'test'}
|
||||
@@ -455,8 +456,6 @@ async def test_function_call_args_not_modified():
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_args_none_handling():
|
||||
"""Test that function_call.args=None is handled correctly."""
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_live
|
||||
|
||||
def simple_fn(**kwargs) -> dict:
|
||||
return {'result': 'test'}
|
||||
@@ -504,8 +503,6 @@ async def test_function_call_args_none_handling():
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_args_copy_behavior():
|
||||
"""Test that modifying the copied args doesn't affect the original."""
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_live
|
||||
|
||||
def simple_fn(test_param: str, other_param: int) -> dict:
|
||||
# Modify the args to test that the copy prevents affecting the original
|
||||
@@ -565,8 +562,6 @@ async def test_function_call_args_copy_behavior():
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_call_args_deep_copy_behavior():
|
||||
"""Test that deep copy behavior works correctly with nested structures."""
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||
from google.adk.flows.llm_flows.functions import handle_function_calls_live
|
||||
|
||||
def simple_fn(nested_dict: dict, list_param: list) -> dict:
|
||||
# Modify the nested structures to test deep copy
|
||||
@@ -1141,3 +1136,62 @@ async def test_mixed_function_types_execution_order():
|
||||
'yield_E',
|
||||
'yield_F',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'handle_function_calls',
|
||||
[
|
||||
(handle_function_calls_async),
|
||||
(handle_function_calls_live),
|
||||
],
|
||||
)
|
||||
async def test_computer_use_tool_decoding_behavior(handle_function_calls):
|
||||
"""Tests that computer use tools automatically decode base64 images."""
|
||||
valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7'
|
||||
|
||||
# make the tool return a dictionary with the image
|
||||
async def mock_run(*args, **kwargs):
|
||||
return {
|
||||
'image': {'data': valid_b64, 'mimetype': 'image/png'},
|
||||
'url': 'https://example.com',
|
||||
}
|
||||
|
||||
# create a ComputerUseTool
|
||||
tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768))
|
||||
|
||||
model = testing_utils.MockModel.create(responses=[])
|
||||
agent = Agent(
|
||||
name='test_agent',
|
||||
model=model,
|
||||
tools=[tool],
|
||||
)
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent, user_content=''
|
||||
)
|
||||
|
||||
# Create function call
|
||||
function_call = types.FunctionCall(name=tool.name, args={})
|
||||
content = types.Content(parts=[types.Part(function_call=function_call)])
|
||||
event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
content=content,
|
||||
)
|
||||
tools_dict = {tool.name: tool}
|
||||
|
||||
result = await handle_function_calls(
|
||||
invocation_context,
|
||||
event,
|
||||
tools_dict,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
response_part = result.content.parts[0].function_response
|
||||
|
||||
# Verify original image data is removed from the dict response
|
||||
assert 'image' not in response_part.response
|
||||
assert 'url' in response_part.response
|
||||
# Verify the image was converted to a blob
|
||||
assert len(response_part.parts) == 1
|
||||
assert response_part.parts[0].inline_data is not None
|
||||
|
||||
Reference in New Issue
Block a user