diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 4d045fac..c228eafb 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -17,6 +17,8 @@ from __future__ import annotations import asyncio +import base64 +import binascii from concurrent.futures import ThreadPoolExecutor import copy import functools @@ -31,6 +33,7 @@ from typing import Optional from typing import TYPE_CHECKING import uuid +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.genai import types from ...agents.active_streaming_tool import ActiveStreamingTool @@ -991,6 +994,50 @@ def _get_tool_and_context( return (tool, tool_context) +def _try_decode_computer_use_image( + tool: BaseTool, + function_result: dict[str, object], +) -> Optional[list[types.FunctionResponsePart]]: + """Decodes the image from the function result for a computer use tool. + + Args: + tool: The tool that produced the function result. + function_result: The dictionary containing the function's result. This + dictionary may be modified in-place to remove the 'image' key if an image + is successfully decoded. + + Returns: + A list containing a `types.FunctionResponsePart` with the decoded image + data, or None if no image was found or decoding failed. + """ + + if not isinstance(tool, ComputerUseTool) or not isinstance( + function_result, dict + ): + return None + + if ( + 'image' not in function_result + or 'data' not in function_result['image'] + or 'mimetype' not in function_result['image'] + ): + return None + + try: + image_data = base64.b64decode(function_result['image']['data']) + mime_type = function_result['image']['mimetype'] + + part = types.FunctionResponsePart.from_bytes( + data=image_data, mime_type=mime_type + ) + + del function_result['image'] + return [part] + except (binascii.Error, ValueError): + logger.exception('Failed to decode image from computer use tool') + return None + + async def __call_tool_live( tool: BaseTool, args: dict[str, object], @@ -1028,8 +1075,16 @@ def __build_response_event( if not isinstance(function_result, dict): function_result = {'result': function_result} + function_response_parts = None + if isinstance(tool, ComputerUseTool): + function_response_parts = _try_decode_computer_use_image( + tool, function_result + ) + part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result + name=tool.name, + response=function_result, + parts=function_response_parts, ) part_function_response.function_response.id = tool_context.function_call_id diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 93f8c151..7aacb237 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -19,7 +19,10 @@ from typing import Callable from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types @@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses(): @pytest.mark.asyncio async def test_function_call_args_not_modified(): """Test that function_call.args is not modified when making a copy.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -455,8 +456,6 @@ async def test_function_call_args_not_modified(): @pytest.mark.asyncio async def test_function_call_args_none_handling(): """Test that function_call.args=None is handled correctly.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -504,8 +503,6 @@ async def test_function_call_args_none_handling(): @pytest.mark.asyncio async def test_function_call_args_copy_behavior(): """Test that modifying the copied args doesn't affect the original.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(test_param: str, other_param: int) -> dict: # Modify the args to test that the copy prevents affecting the original @@ -565,8 +562,6 @@ async def test_function_call_args_copy_behavior(): @pytest.mark.asyncio async def test_function_call_args_deep_copy_behavior(): """Test that deep copy behavior works correctly with nested structures.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(nested_dict: dict, list_param: list) -> dict: # Modify the nested structures to test deep copy @@ -1141,3 +1136,62 @@ async def test_mixed_function_types_execution_order(): 'yield_E', 'yield_F', ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_computer_use_tool_decoding_behavior(handle_function_calls): + """Tests that computer use tools automatically decode base64 images.""" + valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7' + + # make the tool return a dictionary with the image + async def mock_run(*args, **kwargs): + return { + 'image': {'data': valid_b64, 'mimetype': 'image/png'}, + 'url': 'https://example.com', + } + + # create a ComputerUseTool + tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768)) + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create function call + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result = await handle_function_calls( + invocation_context, + event, + tools_dict, + ) + + assert result is not None + response_part = result.content.parts[0].function_response + + # Verify original image data is removed from the dict response + assert 'image' not in response_part.response + assert 'url' in response_part.response + # Verify the image was converted to a blob + assert len(response_part.parts) == 1 + assert response_part.parts[0].inline_data is not None