You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support streaming function call arguments in progressive SSE streaming feature
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 837172244
This commit is contained in:
committed by
Copybara-Service
parent
73e5687b9a
commit
786aaed335
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import Agent
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def concat_number_and_string(num: int, s: str) -> str:
|
||||
"""Concatenate a number and a string.
|
||||
|
||||
Args:
|
||||
num: The number to concatenate.
|
||||
s: The string to concatenate.
|
||||
|
||||
Returns:
|
||||
The concatenated string.
|
||||
"""
|
||||
return str(num) + ': ' + s
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-3-pro-preview',
|
||||
name='hello_world_stream_fc_args',
|
||||
description='Demo agent showcasing streaming function call arguments.',
|
||||
instruction="""
|
||||
You are a helpful assistant.
|
||||
You can use the `concat_number_and_string` tool to concatenate a number and a string.
|
||||
You should always call the concat_number_and_string tool to concatenate a number and a string.
|
||||
You should never concatenate on your own.
|
||||
""",
|
||||
tools=[
|
||||
concat_number_and_string,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True,
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
stream_function_call_arguments=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
+1
-1
@@ -41,7 +41,7 @@ dependencies = [
|
||||
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
|
||||
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
|
||||
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
|
||||
"google-genai>=1.45.0, <2.0.0", # Google GenAI SDK
|
||||
"google-genai>=1.51.0, <2.0.0", # Google GenAI SDK
|
||||
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
|
||||
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
|
||||
"mcp>=1.10.0, <2.0.0", # For MCP Toolset
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
@@ -43,6 +44,12 @@ class StreamingResponseAggregator:
|
||||
self._current_text_is_thought: Optional[bool] = None
|
||||
self._finish_reason: Optional[types.FinishReason] = None
|
||||
|
||||
# For streaming function call arguments
|
||||
self._current_fc_name: Optional[str] = None
|
||||
self._current_fc_args: dict[str, Any] = {}
|
||||
self._current_fc_id: Optional[str] = None
|
||||
self._current_thought_signature: Optional[str] = None
|
||||
|
||||
def _flush_text_buffer_to_sequence(self):
|
||||
"""Flush current text buffer to parts sequence.
|
||||
|
||||
@@ -61,6 +68,171 @@ class StreamingResponseAggregator:
|
||||
self._current_text_buffer = ''
|
||||
self._current_text_is_thought = None
|
||||
|
||||
def _get_value_from_partial_arg(
|
||||
self, partial_arg: types.PartialArg, json_path: str
|
||||
):
|
||||
"""Extract value from a partial argument.
|
||||
|
||||
Args:
|
||||
partial_arg: The partial argument object
|
||||
json_path: JSONPath for this argument
|
||||
|
||||
Returns:
|
||||
Tuple of (value, has_value) where has_value indicates if a value exists
|
||||
"""
|
||||
value = None
|
||||
has_value = False
|
||||
|
||||
if partial_arg.string_value is not None:
|
||||
# For streaming strings, append chunks to existing value
|
||||
string_chunk = partial_arg.string_value
|
||||
has_value = True
|
||||
|
||||
# Get current value for this path (if any)
|
||||
path_without_prefix = (
|
||||
json_path[2:] if json_path.startswith('$.') else json_path
|
||||
)
|
||||
path_parts = path_without_prefix.split('.')
|
||||
|
||||
# Try to get existing value
|
||||
existing_value = self._current_fc_args
|
||||
for part in path_parts:
|
||||
if isinstance(existing_value, dict) and part in existing_value:
|
||||
existing_value = existing_value[part]
|
||||
else:
|
||||
existing_value = None
|
||||
break
|
||||
|
||||
# Append to existing string or set new value
|
||||
if isinstance(existing_value, str):
|
||||
value = existing_value + string_chunk
|
||||
else:
|
||||
value = string_chunk
|
||||
|
||||
elif partial_arg.number_value is not None:
|
||||
value = partial_arg.number_value
|
||||
has_value = True
|
||||
elif partial_arg.bool_value is not None:
|
||||
value = partial_arg.bool_value
|
||||
has_value = True
|
||||
elif partial_arg.null_value is not None:
|
||||
value = None
|
||||
has_value = True
|
||||
|
||||
return value, has_value
|
||||
|
||||
def _set_value_by_json_path(self, json_path: str, value: Any):
|
||||
"""Set a value in _current_fc_args using JSONPath notation.
|
||||
|
||||
Args:
|
||||
json_path: JSONPath string like "$.location" or "$.location.latitude"
|
||||
value: The value to set
|
||||
"""
|
||||
# Remove leading "$." from jsonPath
|
||||
if json_path.startswith('$.'):
|
||||
path = json_path[2:]
|
||||
else:
|
||||
path = json_path
|
||||
|
||||
# Split path into components
|
||||
path_parts = path.split('.')
|
||||
|
||||
# Navigate to the correct location and set the value
|
||||
current = self._current_fc_args
|
||||
for part in path_parts[:-1]:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
# Set the final value
|
||||
current[path_parts[-1]] = value
|
||||
|
||||
def _flush_function_call_to_sequence(self):
|
||||
"""Flush current function call to parts sequence.
|
||||
|
||||
This creates a complete FunctionCall part from accumulated partial args.
|
||||
"""
|
||||
if self._current_fc_name:
|
||||
# Create function call part with accumulated args
|
||||
fc_part = types.Part.from_function_call(
|
||||
name=self._current_fc_name,
|
||||
args=self._current_fc_args.copy(),
|
||||
)
|
||||
|
||||
# Set the ID if provided (directly on the function_call object)
|
||||
if self._current_fc_id and fc_part.function_call:
|
||||
fc_part.function_call.id = self._current_fc_id
|
||||
|
||||
# Set thought_signature if provided (on the Part, not FunctionCall)
|
||||
if self._current_thought_signature:
|
||||
fc_part.thought_signature = self._current_thought_signature
|
||||
|
||||
self._parts_sequence.append(fc_part)
|
||||
|
||||
# Reset FC state
|
||||
self._current_fc_name = None
|
||||
self._current_fc_args = {}
|
||||
self._current_fc_id = None
|
||||
self._current_thought_signature = None
|
||||
|
||||
def _process_streaming_function_call(self, fc: types.FunctionCall):
|
||||
"""Process a streaming function call with partialArgs.
|
||||
|
||||
Args:
|
||||
fc: The function call object with partial_args
|
||||
"""
|
||||
# Save function name if present (first chunk)
|
||||
if fc.name:
|
||||
self._current_fc_name = fc.name
|
||||
if fc.id:
|
||||
self._current_fc_id = fc.id
|
||||
|
||||
# Process each partial argument
|
||||
for partial_arg in getattr(fc, 'partial_args', []):
|
||||
json_path = partial_arg.json_path
|
||||
if not json_path:
|
||||
continue
|
||||
|
||||
# Extract value from partial arg
|
||||
value, has_value = self._get_value_from_partial_arg(
|
||||
partial_arg, json_path
|
||||
)
|
||||
|
||||
# Set the value using JSONPath (only if a value was provided)
|
||||
if has_value:
|
||||
self._set_value_by_json_path(json_path, value)
|
||||
|
||||
# Check if function call is complete
|
||||
fc_will_continue = getattr(fc, 'will_continue', False)
|
||||
if not fc_will_continue:
|
||||
# Function call complete, flush it
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._flush_function_call_to_sequence()
|
||||
|
||||
def _process_function_call_part(self, part: types.Part):
|
||||
"""Process a function call part (streaming or non-streaming).
|
||||
|
||||
Args:
|
||||
part: The part containing a function call
|
||||
"""
|
||||
fc = part.function_call
|
||||
|
||||
# Check if this is a streaming FC (has partialArgs)
|
||||
if hasattr(fc, 'partial_args') and fc.partial_args:
|
||||
# Streaming function call arguments
|
||||
|
||||
# Save thought_signature from the part (first chunk should have it)
|
||||
if part.thought_signature and not self._current_thought_signature:
|
||||
self._current_thought_signature = part.thought_signature
|
||||
self._process_streaming_function_call(fc)
|
||||
else:
|
||||
# Non-streaming function call (standard format with args)
|
||||
# Skip empty function calls (used as streaming end markers)
|
||||
if fc.name:
|
||||
# Flush any buffered text first, then add the FC part
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._parts_sequence.append(part)
|
||||
|
||||
async def process_response(
|
||||
self, response: types.GenerateContentResponse
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
@@ -101,8 +273,12 @@ class StreamingResponseAggregator:
|
||||
if not self._current_text_buffer:
|
||||
self._current_text_is_thought = part.thought
|
||||
self._current_text_buffer += part.text
|
||||
elif part.function_call:
|
||||
# Process function call (handles both streaming Args and
|
||||
# non-streaming Args)
|
||||
self._process_function_call_part(part)
|
||||
else:
|
||||
# Non-text part (function_call, bytes, etc.)
|
||||
# Other non-text parts (bytes, etc.)
|
||||
# Flush any buffered text first, then add the non-text part
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._parts_sequence.append(part)
|
||||
@@ -155,8 +331,9 @@ class StreamingResponseAggregator:
|
||||
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
|
||||
# Always generate final aggregated response in progressive mode
|
||||
if self._response and self._response.candidates:
|
||||
# Flush any remaining text buffer to complete the sequence
|
||||
# Flush any remaining buffers to complete the sequence
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._flush_function_call_to_sequence()
|
||||
|
||||
# Use the parts sequence which preserves original ordering
|
||||
final_parts = self._parts_sequence
|
||||
|
||||
@@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering():
|
||||
# Part 4: Second function call (from chunk8)
|
||||
assert parts[4].function_call.name == "get_weather"
|
||||
assert parts[4].function_call.args["location"] == "New York"
|
||||
|
||||
|
||||
def test_progressive_sse_streaming_function_call_arguments():
|
||||
"""Test streaming function call arguments feature.
|
||||
|
||||
This test simulates the streamFunctionCallArguments feature where a function
|
||||
call's arguments are streamed incrementally across multiple chunks:
|
||||
|
||||
Chunk 1: FC name + partial location argument ("New ")
|
||||
Chunk 2: Continue location argument ("York") -> concatenated to "New York"
|
||||
Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete
|
||||
|
||||
Expected result: FunctionCall(name="get_weather",
|
||||
args={"location": "New York", "unit":
|
||||
"celsius"},
|
||||
id="fc_001")
|
||||
"""
|
||||
|
||||
aggregator = StreamingResponseAggregator()
|
||||
|
||||
# Chunk 1: FC name + partial location argument
|
||||
chunk1_fc = types.FunctionCall(
|
||||
name="get_weather",
|
||||
id="fc_001",
|
||||
partial_args=[
|
||||
types.PartialArg(json_path="$.location", string_value="New ")
|
||||
],
|
||||
will_continue=True,
|
||||
)
|
||||
chunk1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=chunk1_fc)]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Chunk 2: Continue streaming location argument
|
||||
chunk2_fc = types.FunctionCall(
|
||||
partial_args=[
|
||||
types.PartialArg(json_path="$.location", string_value="York")
|
||||
],
|
||||
will_continue=True,
|
||||
)
|
||||
chunk2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=chunk2_fc)]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Chunk 3: Add unit argument, FC complete
|
||||
chunk3_fc = types.FunctionCall(
|
||||
partial_args=[
|
||||
types.PartialArg(json_path="$.unit", string_value="celsius")
|
||||
],
|
||||
will_continue=False, # FC complete
|
||||
)
|
||||
chunk3 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=chunk3_fc)]
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Process all chunks through aggregator
|
||||
processed_chunks = []
|
||||
for chunk in [chunk1, chunk2, chunk3]:
|
||||
|
||||
async def process():
|
||||
results = []
|
||||
async for response in aggregator.process_response(chunk):
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
import asyncio
|
||||
|
||||
chunk_results = asyncio.run(process())
|
||||
processed_chunks.extend(chunk_results)
|
||||
|
||||
# Get final aggregated response
|
||||
final_response = aggregator.close()
|
||||
|
||||
# Verify final aggregated response has complete FC
|
||||
assert final_response is not None
|
||||
assert len(final_response.content.parts) == 1
|
||||
|
||||
fc_part = final_response.content.parts[0]
|
||||
assert fc_part.function_call is not None
|
||||
assert fc_part.function_call.name == "get_weather"
|
||||
assert fc_part.function_call.id == "fc_001"
|
||||
|
||||
# Verify arguments were correctly assembled from streaming chunks
|
||||
args = fc_part.function_call.args
|
||||
assert args["location"] == "New York" # "New " + "York" concatenated
|
||||
assert args["unit"] == "celsius"
|
||||
|
||||
|
||||
def test_progressive_sse_preserves_thought_signature():
|
||||
"""Test that thought_signature is preserved when streaming FC arguments.
|
||||
|
||||
This test verifies that when a streaming function call has a thought_signature
|
||||
in the Part, it is correctly preserved in the final aggregated FunctionCall.
|
||||
"""
|
||||
|
||||
aggregator = StreamingResponseAggregator()
|
||||
|
||||
# Create a thought signature (simulating what Gemini returns)
|
||||
# thought_signature is bytes (base64 encoded)
|
||||
test_thought_signature = b"test_signature_abc123"
|
||||
|
||||
# Chunk with streaming FC args and thought_signature
|
||||
chunk_fc = types.FunctionCall(
|
||||
name="add_5_numbers",
|
||||
id="fc_003",
|
||||
partial_args=[
|
||||
types.PartialArg(json_path="$.num1", number_value=10),
|
||||
types.PartialArg(json_path="$.num2", number_value=20),
|
||||
],
|
||||
will_continue=False,
|
||||
)
|
||||
|
||||
# Create Part with both function_call AND thought_signature
|
||||
chunk_part = types.Part(
|
||||
function_call=chunk_fc, thought_signature=test_thought_signature
|
||||
)
|
||||
|
||||
chunk = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(role="model", parts=[chunk_part]),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Process chunk through aggregator
|
||||
async def process():
|
||||
results = []
|
||||
async for response in aggregator.process_response(chunk):
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(process())
|
||||
|
||||
# Get final aggregated response
|
||||
final_response = aggregator.close()
|
||||
|
||||
# Verify thought_signature was preserved in the Part
|
||||
assert final_response is not None
|
||||
assert len(final_response.content.parts) == 1
|
||||
|
||||
fc_part = final_response.content.parts[0]
|
||||
assert fc_part.function_call is not None
|
||||
assert fc_part.function_call.name == "add_5_numbers"
|
||||
|
||||
assert fc_part.thought_signature == test_thought_signature
|
||||
|
||||
|
||||
def test_progressive_sse_handles_empty_function_call():
|
||||
"""Test that empty function calls are skipped.
|
||||
|
||||
When using streamFunctionCallArguments, Gemini may send an empty
|
||||
functionCall: {} as the final chunk to signal streaming completion.
|
||||
This test verifies that such empty function calls are properly skipped
|
||||
and don't cause errors.
|
||||
"""
|
||||
|
||||
aggregator = StreamingResponseAggregator()
|
||||
|
||||
# Chunk 1: Streaming FC with partial args
|
||||
chunk1_fc = types.FunctionCall(
|
||||
name="concat_number_and_string",
|
||||
id="fc_001",
|
||||
partial_args=[
|
||||
types.PartialArg(json_path="$.num", number_value=100),
|
||||
types.PartialArg(json_path="$.s", string_value="ADK"),
|
||||
],
|
||||
will_continue=False,
|
||||
)
|
||||
chunk1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=chunk1_fc)]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Chunk 2: Empty function call (streaming end marker)
|
||||
chunk2_fc = types.FunctionCall() # Empty function call
|
||||
chunk2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=chunk2_fc)]
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Process all chunks through aggregator
|
||||
async def process():
|
||||
results = []
|
||||
for chunk in [chunk1, chunk2]:
|
||||
async for response in aggregator.process_response(chunk):
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(process())
|
||||
|
||||
# Get final aggregated response
|
||||
final_response = aggregator.close()
|
||||
|
||||
# Verify final response only has the real FC, not the empty one
|
||||
assert final_response is not None
|
||||
assert len(final_response.content.parts) == 1
|
||||
|
||||
fc_part = final_response.content.parts[0]
|
||||
assert fc_part.function_call is not None
|
||||
assert fc_part.function_call.name == "concat_number_and_string"
|
||||
assert fc_part.function_call.id == "fc_001"
|
||||
|
||||
# Verify arguments
|
||||
args = fc_part.function_call.args
|
||||
assert args["num"] == 100
|
||||
assert args["s"] == "ADK"
|
||||
|
||||
Reference in New Issue
Block a user