You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Adds plugin to save artifacts for issue #2176
PiperOrigin-RevId: 810522939
This commit is contained in:
committed by
Copybara-Service
parent
c944a12e31
commit
657369cffe
@@ -21,6 +21,8 @@ from google.adk.apps import App
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.plugins.context_filter_plugin import ContextFilterPlugin
|
||||
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
|
||||
from google.adk.tools import load_artifacts
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
@@ -97,6 +99,7 @@ root_agent = Agent(
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
load_artifacts,
|
||||
],
|
||||
# planner=BuiltInPlanner(
|
||||
# thinking_config=types.ThinkingConfig(
|
||||
@@ -145,5 +148,6 @@ app = App(
|
||||
plugins=[
|
||||
CountInvocationPlugin(),
|
||||
ContextFilterPlugin(num_invocations_to_keep=3),
|
||||
SaveFilesAsArtifactsPlugin(),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ async def main():
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
run_config=RunConfig(save_input_blobs_as_artifacts=True),
|
||||
run_config=RunConfig(save_input_blobs_as_artifacts=False),
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
|
||||
@@ -48,8 +48,15 @@ class RunConfig(BaseModel):
|
||||
response_modalities: Optional[list[str]] = None
|
||||
"""The output modalities. If not set, it's default to AUDIO."""
|
||||
|
||||
save_input_blobs_as_artifacts: bool = False
|
||||
"""Whether or not to save the input blobs as artifacts."""
|
||||
save_input_blobs_as_artifacts: bool = Field(
|
||||
default=False,
|
||||
deprecated=True,
|
||||
description=(
|
||||
'Whether or not to save the input blobs as artifacts. DEPRECATED: Use'
|
||||
' SaveFilesAsArtifactsPlugin instead for better control and'
|
||||
' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin.'
|
||||
),
|
||||
)
|
||||
|
||||
support_cfc: bool = False
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .base_plugin import BasePlugin
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class SaveFilesAsArtifactsPlugin(BasePlugin):
|
||||
"""A plugin that saves files embedded in user messages as artifacts.
|
||||
|
||||
This is useful to allow users to upload files in the chat experience and have
|
||||
those files available to the agent.
|
||||
|
||||
We use Blob.display_name to determine
|
||||
the file name. Artifacts with the same name will be overwritten. A placeholder
|
||||
with the artifact name will be put in place of the embedded file in the user
|
||||
message so the model knows where to find the file. You may want to add
|
||||
load_artifacts tool to the agent, or load the artifacts in your own tool to
|
||||
use the files.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = 'save_files_as_artifacts_plugin'):
|
||||
"""Initialize the save files as artifacts plugin.
|
||||
|
||||
Args:
|
||||
name: The name of the plugin instance.
|
||||
"""
|
||||
super().__init__(name)
|
||||
|
||||
async def on_user_message_callback(
|
||||
self,
|
||||
*,
|
||||
invocation_context: InvocationContext,
|
||||
user_message: types.Content,
|
||||
) -> Optional[types.Content]:
|
||||
"""Process user message and save any attached files as artifacts."""
|
||||
if not invocation_context.artifact_service:
|
||||
logger.warning(
|
||||
'Artifact service is not set. SaveFilesAsArtifactsPlugin'
|
||||
' will not be enabled.'
|
||||
)
|
||||
return user_message
|
||||
|
||||
if not user_message.parts:
|
||||
return user_message
|
||||
|
||||
for i, part in enumerate(user_message.parts):
|
||||
if part.inline_data is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use display_name if available, otherwise generate a filename
|
||||
file_name = part.inline_data.display_name
|
||||
if not file_name:
|
||||
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
|
||||
logger.info(
|
||||
f'No display_name found, using generated filename: {file_name}'
|
||||
)
|
||||
|
||||
await invocation_context.artifact_service.save_artifact(
|
||||
app_name=invocation_context.app_name,
|
||||
user_id=invocation_context.user_id,
|
||||
session_id=invocation_context.session.id,
|
||||
filename=file_name,
|
||||
artifact=part,
|
||||
)
|
||||
|
||||
# Replace the inline data with a placeholder text
|
||||
user_message.parts[i] = types.Part(
|
||||
text=f'[Uploaded Artifact: "{file_name}"]'
|
||||
)
|
||||
logger.info(f'Successfully saved artifact: {file_name}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to save artifact for part {i}: {e}')
|
||||
# Keep the original part if saving fails
|
||||
continue
|
||||
|
||||
return user_message
|
||||
@@ -419,6 +419,15 @@ class Runner:
|
||||
raise ValueError('No parts in the new_message.')
|
||||
|
||||
if self.artifact_service and save_input_blobs_as_artifacts:
|
||||
# Issue deprecation warning
|
||||
warnings.warn(
|
||||
"The 'save_input_blobs_as_artifacts' parameter is deprecated. Use"
|
||||
' SaveFilesAsArtifactsPlugin instead for better control and'
|
||||
' flexibility. See google.adk.plugins.SaveFilesAsArtifactsPlugin for'
|
||||
' migration guidance.',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
# The runner directly saves the artifacts (if applicable) in the
|
||||
# user message and replaces the artifact data with a file name
|
||||
# placeholder.
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSaveFilesAsArtifactsPlugin:
|
||||
"""Test suite for SaveFilesAsArtifactsPlugin."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.plugin = SaveFilesAsArtifactsPlugin()
|
||||
|
||||
# Mock invocation context
|
||||
self.mock_context = Mock(spec=InvocationContext)
|
||||
self.mock_context.artifact_service = AsyncMock()
|
||||
self.mock_context.app_name = "test_app"
|
||||
self.mock_context.user_id = "test_user"
|
||||
self.mock_context.invocation_id = "test_invocation_123"
|
||||
self.mock_context.session = Mock()
|
||||
self.mock_context.session.id = "test_session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_files_with_display_name(self):
|
||||
"""Test saving files when inline_data has display_name."""
|
||||
# Create a message with inline data
|
||||
inline_data = types.Blob(
|
||||
display_name="test_document.pdf",
|
||||
data=b"test data",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
original_part = types.Part(inline_data=inline_data)
|
||||
user_message = types.Content(parts=[original_part])
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Verify artifact was saved with correct filename
|
||||
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
session_id="test_session",
|
||||
filename="test_document.pdf",
|
||||
artifact=original_part,
|
||||
)
|
||||
|
||||
# Verify message was modified with placeholder
|
||||
assert result.parts[0].text == '[Uploaded Artifact: "test_document.pdf"]'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_files_without_display_name(self):
|
||||
"""Test saving files when inline_data has no display_name."""
|
||||
# Create inline data without display_name
|
||||
inline_data = types.Blob(
|
||||
display_name=None, data=b"test data", mime_type="application/pdf"
|
||||
)
|
||||
|
||||
original_part = types.Part(inline_data=inline_data)
|
||||
user_message = types.Content(parts=[original_part])
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Verify artifact was saved with generated filename
|
||||
expected_filename = "artifact_test_invocation_123_0"
|
||||
self.mock_context.artifact_service.save_artifact.assert_called_once_with(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
session_id="test_session",
|
||||
filename=expected_filename,
|
||||
artifact=original_part,
|
||||
)
|
||||
|
||||
# Verify message was modified with generated filename
|
||||
assert result.parts[0].text == f'[Uploaded Artifact: "{expected_filename}"]'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_files_in_message(self):
|
||||
"""Test handling multiple files in a single message."""
|
||||
# Create message with multiple inline data parts
|
||||
inline_data1 = types.Blob(
|
||||
display_name="file1.txt", data=b"file1 content", mime_type="text/plain"
|
||||
)
|
||||
|
||||
inline_data2 = types.Blob(
|
||||
display_name="file2.jpg", data=b"file2 content", mime_type="image/jpeg"
|
||||
)
|
||||
|
||||
user_message = types.Content(
|
||||
parts=[
|
||||
types.Part(inline_data=inline_data1),
|
||||
types.Part(text="Some text between files"),
|
||||
types.Part(inline_data=inline_data2),
|
||||
]
|
||||
)
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Verify both artifacts were saved
|
||||
assert self.mock_context.artifact_service.save_artifact.call_count == 2
|
||||
|
||||
# Check first file
|
||||
first_call = (
|
||||
self.mock_context.artifact_service.save_artifact.call_args_list[0]
|
||||
)
|
||||
assert first_call[1]["filename"] == "file1.txt"
|
||||
|
||||
# Check second file
|
||||
second_call = (
|
||||
self.mock_context.artifact_service.save_artifact.call_args_list[1]
|
||||
)
|
||||
assert second_call[1]["filename"] == "file2.jpg"
|
||||
|
||||
# Verify message parts were modified correctly
|
||||
assert result.parts[0].text == '[Uploaded Artifact: "file1.txt"]'
|
||||
assert result.parts[1].text == "Some text between files" # Unchanged
|
||||
assert result.parts[2].text == '[Uploaded Artifact: "file2.jpg"]'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_artifact_service(self):
|
||||
"""Test behavior when artifact service is not available."""
|
||||
# Set artifact service to None
|
||||
self.mock_context.artifact_service = None
|
||||
|
||||
inline_data = types.Blob(
|
||||
display_name="test.pdf", data=b"test data", mime_type="application/pdf"
|
||||
)
|
||||
|
||||
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Should return original message unchanged
|
||||
assert result == user_message
|
||||
assert result.parts[0].inline_data == inline_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_parts_in_message(self):
|
||||
"""Test behavior when message has no parts."""
|
||||
user_message = types.Content(parts=[])
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Should return original message unchanged
|
||||
assert result == user_message
|
||||
assert result.parts == []
|
||||
|
||||
# Should not try to save any artifacts
|
||||
self.mock_context.artifact_service.save_artifact.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parts_without_inline_data(self):
|
||||
"""Test behavior with parts that don't have inline_data."""
|
||||
user_message = types.Content(
|
||||
parts=[types.Part(text="Hello world"), types.Part(text="No files here")]
|
||||
)
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Should return original message unchanged
|
||||
assert result == user_message
|
||||
assert result.parts[0].text == "Hello world"
|
||||
assert result.parts[1].text == "No files here"
|
||||
|
||||
# Should not try to save any artifacts
|
||||
self.mock_context.artifact_service.save_artifact.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_artifact_failure(self):
|
||||
"""Test behavior when saving artifact fails."""
|
||||
# Mock save_artifact to raise an exception
|
||||
self.mock_context.artifact_service.save_artifact.side_effect = Exception(
|
||||
"Storage error"
|
||||
)
|
||||
|
||||
inline_data = types.Blob(
|
||||
display_name="test.pdf", data=b"test data", mime_type="application/pdf"
|
||||
)
|
||||
|
||||
original_part = types.Part(inline_data=inline_data)
|
||||
user_message = types.Content(parts=[original_part])
|
||||
|
||||
# Execute the plugin - should not raise exception
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Should preserve original part when saving fails
|
||||
assert result.parts[0] == original_part
|
||||
assert result.parts[0].inline_data == inline_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_success_and_failure(self):
|
||||
"""Test behavior when some files save successfully and others fail."""
|
||||
# Mock save_artifact to succeed on first call, fail on second
|
||||
save_calls = 0
|
||||
|
||||
def mock_save_artifact(*_args, **_kwargs):
|
||||
nonlocal save_calls
|
||||
save_calls += 1
|
||||
if save_calls == 2:
|
||||
raise Exception("Storage error on second file")
|
||||
return AsyncMock()
|
||||
|
||||
self.mock_context.artifact_service.save_artifact.side_effect = (
|
||||
mock_save_artifact
|
||||
)
|
||||
|
||||
inline_data1 = types.Blob(
|
||||
display_name="success.pdf",
|
||||
data=b"success data",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
inline_data2 = types.Blob(
|
||||
display_name="failure.pdf",
|
||||
data=b"failure data",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
original_part2 = types.Part(inline_data=inline_data2)
|
||||
user_message = types.Content(
|
||||
parts=[types.Part(inline_data=inline_data1), original_part2]
|
||||
)
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# First file should be replaced with placeholder
|
||||
assert result.parts[0].text == '[Uploaded Artifact: "success.pdf"]'
|
||||
|
||||
# Second file should remain unchanged due to failure
|
||||
assert result.parts[1] == original_part2
|
||||
assert result.parts[1].inline_data == inline_data2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_placeholder_text_format(self):
|
||||
"""Test that placeholder text is formatted correctly."""
|
||||
inline_data = types.Blob(
|
||||
display_name="test file with spaces.docx",
|
||||
data=b"document data",
|
||||
mime_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
)
|
||||
|
||||
user_message = types.Content(parts=[types.Part(inline_data=inline_data)])
|
||||
|
||||
# Execute the plugin
|
||||
result = await self.plugin.on_user_message_callback(
|
||||
invocation_context=self.mock_context, user_message=user_message
|
||||
)
|
||||
|
||||
# Verify exact format of placeholder text
|
||||
expected_text = '[Uploaded Artifact: "test file with spaces.docx"]'
|
||||
assert result.parts[0].text == expected_text
|
||||
|
||||
def test_plugin_name_default(self):
|
||||
"""Test that plugin has correct default name."""
|
||||
plugin = SaveFilesAsArtifactsPlugin()
|
||||
assert plugin.name == "save_files_as_artifacts_plugin"
|
||||
Reference in New Issue
Block a user