fix: Adds plugin to save artifacts for issue #2176

PiperOrigin-RevId: 810522939
This commit is contained in:
George Weale
2025-09-23 11:45:57 -07:00
committed by Copybara-Service
parent c944a12e31
commit 657369cffe
6 changed files with 419 additions and 3 deletions
@@ -21,6 +21,8 @@ from google.adk.apps import App
from google.adk.models.llm_request import LlmRequest
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.plugins.context_filter_plugin import ContextFilterPlugin
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
from google.adk.tools import load_artifacts
from google.adk.tools.tool_context import ToolContext
from google.genai import types
@@ -97,6 +99,7 @@ root_agent = Agent(
tools=[
roll_die,
check_prime,
load_artifacts,
],
# planner=BuiltInPlanner(
# thinking_config=types.ThinkingConfig(
@@ -145,5 +148,6 @@ app = App(
plugins=[
CountInvocationPlugin(),
ContextFilterPlugin(num_invocations_to_keep=3),
SaveFilesAsArtifactsPlugin(),
],
)
+1 -1
View File
@@ -65,7 +65,7 @@ async def main():
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
run_config=RunConfig(save_input_blobs_as_artifacts=False),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
+9 -2
View File
@@ -48,8 +48,15 @@ class RunConfig(BaseModel):
response_modalities: Optional[list[str]] = None
"""The output modalities. If not set, it's default to AUDIO."""
save_input_blobs_as_artifacts: bool = False
"""Whether or not to save the input blobs as artifacts."""
save_input_blobs_as_artifacts: bool = Field(
default=False,
deprecated=True,
description=(
'Whether or not to save the input blobs as artifacts. DEPRECATED: Use'
' SaveFilesAsArtifactsPlugin instead for better control and'
' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin.'
),
)
support_cfc: bool = False
"""
@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Optional
from google.genai import types
from ..agents.invocation_context import InvocationContext
from .base_plugin import BasePlugin
logger = logging.getLogger('google_adk.' + __name__)
class SaveFilesAsArtifactsPlugin(BasePlugin):
"""A plugin that saves files embedded in user messages as artifacts.
This is useful to allow users to upload files in the chat experience and have
those files available to the agent.
We use Blob.display_name to determine
the file name. Artifacts with the same name will be overwritten. A placeholder
with the artifact name will be put in place of the embedded file in the user
message so the model knows where to find the file. You may want to add
load_artifacts tool to the agent, or load the artifacts in your own tool to
use the files.
"""
def __init__(self, name: str = 'save_files_as_artifacts_plugin'):
"""Initialize the save files as artifacts plugin.
Args:
name: The name of the plugin instance.
"""
super().__init__(name)
async def on_user_message_callback(
self,
*,
invocation_context: InvocationContext,
user_message: types.Content,
) -> Optional[types.Content]:
"""Process user message and save any attached files as artifacts."""
if not invocation_context.artifact_service:
logger.warning(
'Artifact service is not set. SaveFilesAsArtifactsPlugin'
' will not be enabled.'
)
return user_message
if not user_message.parts:
return user_message
for i, part in enumerate(user_message.parts):
if part.inline_data is None:
continue
try:
# Use display_name if available, otherwise generate a filename
file_name = part.inline_data.display_name
if not file_name:
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
logger.info(
f'No display_name found, using generated filename: {file_name}'
)
await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
user_id=invocation_context.user_id,
session_id=invocation_context.session.id,
filename=file_name,
artifact=part,
)
# Replace the inline data with a placeholder text
user_message.parts[i] = types.Part(
text=f'[Uploaded Artifact: "{file_name}"]'
)
logger.info(f'Successfully saved artifact: {file_name}')
except Exception as e:
logger.error(f'Failed to save artifact for part {i}: {e}')
# Keep the original part if saving fails
continue
return user_message
+9
View File
@@ -419,6 +419,15 @@ class Runner:
raise ValueError('No parts in the new_message.')
if self.artifact_service and save_input_blobs_as_artifacts:
# Issue deprecation warning
warnings.warn(
"The 'save_input_blobs_as_artifacts' parameter is deprecated. Use"
' SaveFilesAsArtifactsPlugin instead for better control and'
' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for'
' migration guidance.',
DeprecationWarning,
stacklevel=3,
)
# The runner directly saves the artifacts (if applicable) in the
# user message and replaces the artifact data with a file name
# placeholder.
@@ -0,0 +1,297 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from unittest.mock import AsyncMock
from unittest.mock import Mock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
from google.genai import types
import pytest
class TestSaveFilesAsArtifactsPlugin:
"""Test suite for SaveFilesAsArtifactsPlugin."""
def setup_method(self):
"""Set up test fixtures."""
self.plugin = SaveFilesAsArtifactsPlugin()
# Mock invocation context
self.mock_context = Mock(spec=InvocationContext)
self.mock_context.artifact_service = AsyncMock()
self.mock_context.app_name = "test_app"
self.mock_context.user_id = "test_user"
self.mock_context.invocation_id = "test_invocation_123"
self.mock_context.session = Mock()
self.mock_context.session.id = "test_session"
@pytest.mark.asyncio
async def test_save_files_with_display_name(self):
"""Test saving files when inline_data has display_name."""
# Create a message with inline data
inline_data = types.Blob(
display_name="test_document.pdf",
data=b"test data",
mime_type="application/pdf",
)
original_part = types.Part(inline_data=inline_data)
user_message = types.Content(parts=[original_part])
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Verify artifact was saved with correct filename
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
app_name="test_app",
user_id="test_user",
session_id="test_session",
filename="test_document.pdf",
artifact=original_part,
)
# Verify message was modified with placeholder
assert result.parts[0].text == '[Uploaded Artifact: "test_document.pdf"]'
@pytest.mark.asyncio
async def test_save_files_without_display_name(self):
"""Test saving files when inline_data has no display_name."""
# Create inline data without display_name
inline_data = types.Blob(
display_name=None, data=b"test data", mime_type="application/pdf"
)
original_part = types.Part(inline_data=inline_data)
user_message = types.Content(parts=[original_part])
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Verify artifact was saved with generated filename
expected_filename = "artifact_test_invocation_123_0"
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
app_name="test_app",
user_id="test_user",
session_id="test_session",
filename=expected_filename,
artifact=original_part,
)
# Verify message was modified with generated filename
assert result.parts[0].text == f'[Uploaded Artifact: "{expected_filename}"]'
@pytest.mark.asyncio
async def test_multiple_files_in_message(self):
"""Test handling multiple files in a single message."""
# Create message with multiple inline data parts
inline_data1 = types.Blob(
display_name="file1.txt", data=b"file1 content", mime_type="text/plain"
)
inline_data2 = types.Blob(
display_name="file2.jpg", data=b"file2 content", mime_type="image/jpeg"
)
user_message = types.Content(
parts=[
types.Part(inline_data=inline_data1),
types.Part(text="Some text between files"),
types.Part(inline_data=inline_data2),
]
)
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Verify both artifacts were saved
assert self.mock_context.artifact_service.save_artifact.call_count == 2
# Check first file
first_call = (
self.mock_context.artifact_service.save_artifact.call_args_list[0]
)
assert first_call[1]["filename"] == "file1.txt"
# Check second file
second_call = (
self.mock_context.artifact_service.save_artifact.call_args_list[1]
)
assert second_call[1]["filename"] == "file2.jpg"
# Verify message parts were modified correctly
assert result.parts[0].text == '[Uploaded Artifact: "file1.txt"]'
assert result.parts[1].text == "Some text between files" # Unchanged
assert result.parts[2].text == '[Uploaded Artifact: "file2.jpg"]'
@pytest.mark.asyncio
async def test_no_artifact_service(self):
"""Test behavior when artifact service is not available."""
# Set artifact service to None
self.mock_context.artifact_service = None
inline_data = types.Blob(
display_name="test.pdf", data=b"test data", mime_type="application/pdf"
)
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Should return original message unchanged
assert result == user_message
assert result.parts[0].inline_data == inline_data
@pytest.mark.asyncio
async def test_no_parts_in_message(self):
"""Test behavior when message has no parts."""
user_message = types.Content(parts=[])
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Should return original message unchanged
assert result == user_message
assert result.parts == []
# Should not try to save any artifacts
self.mock_context.artifact_service.save_artifact.assert_not_called()
@pytest.mark.asyncio
async def test_parts_without_inline_data(self):
"""Test behavior with parts that don't have inline_data."""
user_message = types.Content(
parts=[types.Part(text="Hello world"), types.Part(text="No files here")]
)
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Should return original message unchanged
assert result == user_message
assert result.parts[0].text == "Hello world"
assert result.parts[1].text == "No files here"
# Should not try to save any artifacts
self.mock_context.artifact_service.save_artifact.assert_not_called()
@pytest.mark.asyncio
async def test_save_artifact_failure(self):
"""Test behavior when saving artifact fails."""
# Mock save_artifact to raise an exception
self.mock_context.artifact_service.save_artifact.side_effect = Exception(
"Storage error"
)
inline_data = types.Blob(
display_name="test.pdf", data=b"test data", mime_type="application/pdf"
)
original_part = types.Part(inline_data=inline_data)
user_message = types.Content(parts=[original_part])
# Execute the plugin - should not raise exception
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Should preserve original part when saving fails
assert result.parts[0] == original_part
assert result.parts[0].inline_data == inline_data
@pytest.mark.asyncio
async def test_mixed_success_and_failure(self):
"""Test behavior when some files save successfully and others fail."""
# Mock save_artifact to succeed on first call, fail on second
save_calls = 0
def mock_save_artifact(*_args, **_kwargs):
nonlocal save_calls
save_calls += 1
if save_calls == 2:
raise Exception("Storage error on second file")
return AsyncMock()
self.mock_context.artifact_service.save_artifact.side_effect = (
mock_save_artifact
)
inline_data1 = types.Blob(
display_name="success.pdf",
data=b"success data",
mime_type="application/pdf",
)
inline_data2 = types.Blob(
display_name="failure.pdf",
data=b"failure data",
mime_type="application/pdf",
)
original_part2 = types.Part(inline_data=inline_data2)
user_message = types.Content(
parts=[types.Part(inline_data=inline_data1), original_part2]
)
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# First file should be replaced with placeholder
assert result.parts[0].text == '[Uploaded Artifact: "success.pdf"]'
# Second file should remain unchanged due to failure
assert result.parts[1] == original_part2
assert result.parts[1].inline_data == inline_data2
@pytest.mark.asyncio
async def test_placeholder_text_format(self):
"""Test that placeholder text is formatted correctly."""
inline_data = types.Blob(
display_name="test file with spaces.docx",
data=b"document data",
mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
)
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
# Execute the plugin
result = await self.plugin.on_user_message_callback(
invocation_context=self.mock_context, user_message=user_message
)
# Verify exact format of placeholder text
expected_text = '[Uploaded Artifact: "test file with spaces.docx"]'
assert result.parts[0].text == expected_text
def test_plugin_name_default(self):
"""Test that plugin has correct default name."""
plugin = SaveFilesAsArtifactsPlugin()
assert plugin.name == "save_files_as_artifacts_plugin"