diff --git a/contributing/samples/hello_world_app/agent.py b/contributing/samples/hello_world_app/agent.py index 61295d13..0232f263 100755 --- a/contributing/samples/hello_world_app/agent.py +++ b/contributing/samples/hello_world_app/agent.py @@ -21,6 +21,8 @@ from google.adk.apps import App from google.adk.models.llm_request import LlmRequest from google.adk.plugins.base_plugin import BasePlugin from google.adk.plugins.context_filter_plugin import ContextFilterPlugin +from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin +from google.adk.tools import load_artifacts from google.adk.tools.tool_context import ToolContext from google.genai import types @@ -97,6 +99,7 @@ root_agent = Agent( tools=[ roll_die, check_prime, + load_artifacts, ], # planner=BuiltInPlanner( # thinking_config=types.ThinkingConfig( @@ -145,5 +148,6 @@ app = App( plugins=[ CountInvocationPlugin(), ContextFilterPlugin(num_invocations_to_keep=3), + SaveFilesAsArtifactsPlugin(), ], ) diff --git a/contributing/samples/hello_world_app/main.py b/contributing/samples/hello_world_app/main.py index b9e30355..f9a2ac78 100755 --- a/contributing/samples/hello_world_app/main.py +++ b/contributing/samples/hello_world_app/main.py @@ -65,7 +65,7 @@ async def main(): user_id=user_id_1, session_id=session.id, new_message=content, - run_config=RunConfig(save_input_blobs_as_artifacts=True), + run_config=RunConfig(save_input_blobs_as_artifacts=False), ): if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index b65cde90..9fe82fab 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -48,8 +48,15 @@ class RunConfig(BaseModel): response_modalities: Optional[list[str]] = None """The output modalities. If not set, it's default to AUDIO.""" - save_input_blobs_as_artifacts: bool = False - """Whether or not to save the input blobs as artifacts.""" + save_input_blobs_as_artifacts: bool = Field( + default=False, + deprecated=True, + description=( + 'Whether or not to save the input blobs as artifacts. DEPRECATED: Use' + ' SaveFilesAsArtifactsPlugin instead for better control and' + ' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin.' + ), + ) support_cfc: bool = False """ diff --git a/src/google/adk/plugins/save_files_as_artifacts_plugin.py b/src/google/adk/plugins/save_files_as_artifacts_plugin.py new file mode 100644 index 00000000..1dd908ef --- /dev/null +++ b/src/google/adk/plugins/save_files_as_artifacts_plugin.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional + +from google.genai import types + +from ..agents.invocation_context import InvocationContext +from .base_plugin import BasePlugin + +logger = logging.getLogger('google_adk.' + __name__) + + +class SaveFilesAsArtifactsPlugin(BasePlugin): + """A plugin that saves files embedded in user messages as artifacts. + + This is useful to allow users to upload files in the chat experience and have + those files available to the agent. + + We use Blob.display_name to determine + the file name. Artifacts with the same name will be overwritten. A placeholder + with the artifact name will be put in place of the embedded file in the user + message so the model knows where to find the file. You may want to add + load_artifacts tool to the agent, or load the artifacts in your own tool to + use the files. + """ + + def __init__(self, name: str = 'save_files_as_artifacts_plugin'): + """Initialize the save files as artifacts plugin. + + Args: + name: The name of the plugin instance. + """ + super().__init__(name) + + async def on_user_message_callback( + self, + *, + invocation_context: InvocationContext, + user_message: types.Content, + ) -> Optional[types.Content]: + """Process user message and save any attached files as artifacts.""" + if not invocation_context.artifact_service: + logger.warning( + 'Artifact service is not set. SaveFilesAsArtifactsPlugin' + ' will not be enabled.' + ) + return user_message + + if not user_message.parts: + return user_message + + for i, part in enumerate(user_message.parts): + if part.inline_data is None: + continue + + try: + # Use display_name if available, otherwise generate a filename + file_name = part.inline_data.display_name + if not file_name: + file_name = f'artifact_{invocation_context.invocation_id}_{i}' + logger.info( + f'No display_name found, using generated filename: {file_name}' + ) + + await invocation_context.artifact_service.save_artifact( + app_name=invocation_context.app_name, + user_id=invocation_context.user_id, + session_id=invocation_context.session.id, + filename=file_name, + artifact=part, + ) + + # Replace the inline data with a placeholder text + user_message.parts[i] = types.Part( + text=f'[Uploaded Artifact: "{file_name}"]' + ) + logger.info(f'Successfully saved artifact: {file_name}') + + except Exception as e: + logger.error(f'Failed to save artifact for part {i}: {e}') + # Keep the original part if saving fails + continue + + return user_message diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 8fd34eee..e9be5a5c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -419,6 +419,15 @@ class Runner: raise ValueError('No parts in the new_message.') if self.artifact_service and save_input_blobs_as_artifacts: + # Issue deprecation warning + warnings.warn( + "The 'save_input_blobs_as_artifacts' parameter is deprecated. Use" + ' SaveFilesAsArtifactsPlugin instead for better control and' + ' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for' + ' migration guidance.', + DeprecationWarning, + stacklevel=3, + ) # The runner directly saves the artifacts (if applicable) in the # user message and replaces the artifact data with a file name # placeholder. diff --git a/tests/unittests/plugins/test_save_files_as_artifacts.py b/tests/unittests/plugins/test_save_files_as_artifacts.py new file mode 100644 index 00000000..fe229cc0 --- /dev/null +++ b/tests/unittests/plugins/test_save_files_as_artifacts.py @@ -0,0 +1,297 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin +from google.genai import types +import pytest + + +class TestSaveFilesAsArtifactsPlugin: + """Test suite for SaveFilesAsArtifactsPlugin.""" + + def setup_method(self): + """Set up test fixtures.""" + self.plugin = SaveFilesAsArtifactsPlugin() + + # Mock invocation context + self.mock_context = Mock(spec=InvocationContext) + self.mock_context.artifact_service = AsyncMock() + self.mock_context.app_name = "test_app" + self.mock_context.user_id = "test_user" + self.mock_context.invocation_id = "test_invocation_123" + self.mock_context.session = Mock() + self.mock_context.session.id = "test_session" + + @pytest.mark.asyncio + async def test_save_files_with_display_name(self): + """Test saving files when inline_data has display_name.""" + # Create a message with inline data + inline_data = types.Blob( + display_name="test_document.pdf", + data=b"test data", + mime_type="application/pdf", + ) + + original_part = types.Part(inline_data=inline_data) + user_message = types.Content(parts=[original_part]) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Verify artifact was saved with correct filename + self.mock_context.artifact_service.save_artifact.assert_called_once_with( + app_name="test_app", + user_id="test_user", + session_id="test_session", + filename="test_document.pdf", + artifact=original_part, + ) + + # Verify message was modified with placeholder + assert result.parts[0].text == '[Uploaded Artifact: "test_document.pdf"]' + + @pytest.mark.asyncio + async def test_save_files_without_display_name(self): + """Test saving files when inline_data has no display_name.""" + # Create inline data without display_name + inline_data = types.Blob( + display_name=None, data=b"test data", mime_type="application/pdf" + ) + + original_part = types.Part(inline_data=inline_data) + user_message = types.Content(parts=[original_part]) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Verify artifact was saved with generated filename + expected_filename = "artifact_test_invocation_123_0" + self.mock_context.artifact_service.save_artifact.assert_called_once_with( + app_name="test_app", + user_id="test_user", + session_id="test_session", + filename=expected_filename, + artifact=original_part, + ) + + # Verify message was modified with generated filename + assert result.parts[0].text == f'[Uploaded Artifact: "{expected_filename}"]' + + @pytest.mark.asyncio + async def test_multiple_files_in_message(self): + """Test handling multiple files in a single message.""" + # Create message with multiple inline data parts + inline_data1 = types.Blob( + display_name="file1.txt", data=b"file1 content", mime_type="text/plain" + ) + + inline_data2 = types.Blob( + display_name="file2.jpg", data=b"file2 content", mime_type="image/jpeg" + ) + + user_message = types.Content( + parts=[ + types.Part(inline_data=inline_data1), + types.Part(text="Some text between files"), + types.Part(inline_data=inline_data2), + ] + ) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Verify both artifacts were saved + assert self.mock_context.artifact_service.save_artifact.call_count == 2 + + # Check first file + first_call = ( + self.mock_context.artifact_service.save_artifact.call_args_list[0] + ) + assert first_call[1]["filename"] == "file1.txt" + + # Check second file + second_call = ( + self.mock_context.artifact_service.save_artifact.call_args_list[1] + ) + assert second_call[1]["filename"] == "file2.jpg" + + # Verify message parts were modified correctly + assert result.parts[0].text == '[Uploaded Artifact: "file1.txt"]' + assert result.parts[1].text == "Some text between files" # Unchanged + assert result.parts[2].text == '[Uploaded Artifact: "file2.jpg"]' + + @pytest.mark.asyncio + async def test_no_artifact_service(self): + """Test behavior when artifact service is not available.""" + # Set artifact service to None + self.mock_context.artifact_service = None + + inline_data = types.Blob( + display_name="test.pdf", data=b"test data", mime_type="application/pdf" + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should return original message unchanged + assert result == user_message + assert result.parts[0].inline_data == inline_data + + @pytest.mark.asyncio + async def test_no_parts_in_message(self): + """Test behavior when message has no parts.""" + user_message = types.Content(parts=[]) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should return original message unchanged + assert result == user_message + assert result.parts == [] + + # Should not try to save any artifacts + self.mock_context.artifact_service.save_artifact.assert_not_called() + + @pytest.mark.asyncio + async def test_parts_without_inline_data(self): + """Test behavior with parts that don't have inline_data.""" + user_message = types.Content( + parts=[types.Part(text="Hello world"), types.Part(text="No files here")] + ) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should return original message unchanged + assert result == user_message + assert result.parts[0].text == "Hello world" + assert result.parts[1].text == "No files here" + + # Should not try to save any artifacts + self.mock_context.artifact_service.save_artifact.assert_not_called() + + @pytest.mark.asyncio + async def test_save_artifact_failure(self): + """Test behavior when saving artifact fails.""" + # Mock save_artifact to raise an exception + self.mock_context.artifact_service.save_artifact.side_effect = Exception( + "Storage error" + ) + + inline_data = types.Blob( + display_name="test.pdf", data=b"test data", mime_type="application/pdf" + ) + + original_part = types.Part(inline_data=inline_data) + user_message = types.Content(parts=[original_part]) + + # Execute the plugin - should not raise exception + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Should preserve original part when saving fails + assert result.parts[0] == original_part + assert result.parts[0].inline_data == inline_data + + @pytest.mark.asyncio + async def test_mixed_success_and_failure(self): + """Test behavior when some files save successfully and others fail.""" + # Mock save_artifact to succeed on first call, fail on second + save_calls = 0 + + def mock_save_artifact(*_args, **_kwargs): + nonlocal save_calls + save_calls += 1 + if save_calls == 2: + raise Exception("Storage error on second file") + return AsyncMock() + + self.mock_context.artifact_service.save_artifact.side_effect = ( + mock_save_artifact + ) + + inline_data1 = types.Blob( + display_name="success.pdf", + data=b"success data", + mime_type="application/pdf", + ) + + inline_data2 = types.Blob( + display_name="failure.pdf", + data=b"failure data", + mime_type="application/pdf", + ) + + original_part2 = types.Part(inline_data=inline_data2) + user_message = types.Content( + parts=[types.Part(inline_data=inline_data1), original_part2] + ) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # First file should be replaced with placeholder + assert result.parts[0].text == '[Uploaded Artifact: "success.pdf"]' + + # Second file should remain unchanged due to failure + assert result.parts[1] == original_part2 + assert result.parts[1].inline_data == inline_data2 + + @pytest.mark.asyncio + async def test_placeholder_text_format(self): + """Test that placeholder text is formatted correctly.""" + inline_data = types.Blob( + display_name="test file with spaces.docx", + data=b"document data", + mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) + + user_message = types.Content(parts=[types.Part(inline_data=inline_data)]) + + # Execute the plugin + result = await self.plugin.on_user_message_callback( + invocation_context=self.mock_context, user_message=user_message + ) + + # Verify exact format of placeholder text + expected_text = '[Uploaded Artifact: "test file with spaces.docx"]' + assert result.parts[0].text == expected_text + + def test_plugin_name_default(self): + """Test that plugin has correct default name.""" + plugin = SaveFilesAsArtifactsPlugin() + assert plugin.name == "save_files_as_artifacts_plugin"