feat: Support streaming function call arguments in progressive SSE streaming feature

Co-authored-by: Xuan Yang <xygoogle@google.com>
PiperOrigin-RevId: 837172244
This commit is contained in:
Xuan Yang
2025-11-26 10:14:25 -08:00
committed by Copybara-Service
parent 73e5687b9a
commit 786aaed335
5 changed files with 492 additions and 3 deletions
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
+55
View File
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import Agent
from google.genai import types
def concat_number_and_string(num: int, s: str) -> str:
"""Concatenate a number and a string.
Args:
num: The number to concatenate.
s: The string to concatenate.
Returns:
The concatenated string.
"""
return str(num) + ': ' + s
root_agent = Agent(
model='gemini-3-pro-preview',
name='hello_world_stream_fc_args',
description='Demo agent showcasing streaming function call arguments.',
instruction="""
You are a helpful assistant.
You can use the `concat_number_and_string` tool to concatenate a number and a string.
You should always call the concat_number_and_string tool to concatenate a number and a string.
You should never concatenate on your own.
""",
tools=[
concat_number_and_string,
],
generate_content_config=types.GenerateContentConfig(
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True,
),
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
stream_function_call_arguments=True,
),
),
),
)
+1 -1
View File
@@ -41,7 +41,7 @@ dependencies = [
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
"google-genai>=1.45.0, <2.0.0", # Google GenAI SDK
"google-genai>=1.51.0, <2.0.0", # Google GenAI SDK
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
"mcp>=1.10.0, <2.0.0", # For MCP Toolset
+179 -2
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
from typing import Any
from typing import AsyncGenerator
from typing import Optional
@@ -43,6 +44,12 @@ class StreamingResponseAggregator:
self._current_text_is_thought: Optional[bool] = None
self._finish_reason: Optional[types.FinishReason] = None
# For streaming function call arguments
self._current_fc_name: Optional[str] = None
self._current_fc_args: dict[str, Any] = {}
self._current_fc_id: Optional[str] = None
self._current_thought_signature: Optional[str] = None
def _flush_text_buffer_to_sequence(self):
"""Flush current text buffer to parts sequence.
@@ -61,6 +68,171 @@ class StreamingResponseAggregator:
self._current_text_buffer = ''
self._current_text_is_thought = None
def _get_value_from_partial_arg(
self, partial_arg: types.PartialArg, json_path: str
):
"""Extract value from a partial argument.
Args:
partial_arg: The partial argument object
json_path: JSONPath for this argument
Returns:
Tuple of (value, has_value) where has_value indicates if a value exists
"""
value = None
has_value = False
if partial_arg.string_value is not None:
# For streaming strings, append chunks to existing value
string_chunk = partial_arg.string_value
has_value = True
# Get current value for this path (if any)
path_without_prefix = (
json_path[2:] if json_path.startswith('$.') else json_path
)
path_parts = path_without_prefix.split('.')
# Try to get existing value
existing_value = self._current_fc_args
for part in path_parts:
if isinstance(existing_value, dict) and part in existing_value:
existing_value = existing_value[part]
else:
existing_value = None
break
# Append to existing string or set new value
if isinstance(existing_value, str):
value = existing_value + string_chunk
else:
value = string_chunk
elif partial_arg.number_value is not None:
value = partial_arg.number_value
has_value = True
elif partial_arg.bool_value is not None:
value = partial_arg.bool_value
has_value = True
elif partial_arg.null_value is not None:
value = None
has_value = True
return value, has_value
def _set_value_by_json_path(self, json_path: str, value: Any):
"""Set a value in _current_fc_args using JSONPath notation.
Args:
json_path: JSONPath string like "$.location" or "$.location.latitude"
value: The value to set
"""
# Remove leading "$." from jsonPath
if json_path.startswith('$.'):
path = json_path[2:]
else:
path = json_path
# Split path into components
path_parts = path.split('.')
# Navigate to the correct location and set the value
current = self._current_fc_args
for part in path_parts[:-1]:
if part not in current:
current[part] = {}
current = current[part]
# Set the final value
current[path_parts[-1]] = value
def _flush_function_call_to_sequence(self):
"""Flush current function call to parts sequence.
This creates a complete FunctionCall part from accumulated partial args.
"""
if self._current_fc_name:
# Create function call part with accumulated args
fc_part = types.Part.from_function_call(
name=self._current_fc_name,
args=self._current_fc_args.copy(),
)
# Set the ID if provided (directly on the function_call object)
if self._current_fc_id and fc_part.function_call:
fc_part.function_call.id = self._current_fc_id
# Set thought_signature if provided (on the Part, not FunctionCall)
if self._current_thought_signature:
fc_part.thought_signature = self._current_thought_signature
self._parts_sequence.append(fc_part)
# Reset FC state
self._current_fc_name = None
self._current_fc_args = {}
self._current_fc_id = None
self._current_thought_signature = None
def _process_streaming_function_call(self, fc: types.FunctionCall):
"""Process a streaming function call with partialArgs.
Args:
fc: The function call object with partial_args
"""
# Save function name if present (first chunk)
if fc.name:
self._current_fc_name = fc.name
if fc.id:
self._current_fc_id = fc.id
# Process each partial argument
for partial_arg in getattr(fc, 'partial_args', []):
json_path = partial_arg.json_path
if not json_path:
continue
# Extract value from partial arg
value, has_value = self._get_value_from_partial_arg(
partial_arg, json_path
)
# Set the value using JSONPath (only if a value was provided)
if has_value:
self._set_value_by_json_path(json_path, value)
# Check if function call is complete
fc_will_continue = getattr(fc, 'will_continue', False)
if not fc_will_continue:
# Function call complete, flush it
self._flush_text_buffer_to_sequence()
self._flush_function_call_to_sequence()
def _process_function_call_part(self, part: types.Part):
"""Process a function call part (streaming or non-streaming).
Args:
part: The part containing a function call
"""
fc = part.function_call
# Check if this is a streaming FC (has partialArgs)
if hasattr(fc, 'partial_args') and fc.partial_args:
# Streaming function call arguments
# Save thought_signature from the part (first chunk should have it)
if part.thought_signature and not self._current_thought_signature:
self._current_thought_signature = part.thought_signature
self._process_streaming_function_call(fc)
else:
# Non-streaming function call (standard format with args)
# Skip empty function calls (used as streaming end markers)
if fc.name:
# Flush any buffered text first, then add the FC part
self._flush_text_buffer_to_sequence()
self._parts_sequence.append(part)
async def process_response(
self, response: types.GenerateContentResponse
) -> AsyncGenerator[LlmResponse, None]:
@@ -101,8 +273,12 @@ class StreamingResponseAggregator:
if not self._current_text_buffer:
self._current_text_is_thought = part.thought
self._current_text_buffer += part.text
elif part.function_call:
# Process function call (handles both streaming Args and
# non-streaming Args)
self._process_function_call_part(part)
else:
# Non-text part (function_call, bytes, etc.)
# Other non-text parts (bytes, etc.)
# Flush any buffered text first, then add the non-text part
self._flush_text_buffer_to_sequence()
self._parts_sequence.append(part)
@@ -155,8 +331,9 @@ class StreamingResponseAggregator:
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
# Always generate final aggregated response in progressive mode
if self._response and self._response.candidates:
# Flush any remaining text buffer to complete the sequence
# Flush any remaining buffers to complete the sequence
self._flush_text_buffer_to_sequence()
self._flush_function_call_to_sequence()
# Use the parts sequence which preserves original ordering
final_parts = self._parts_sequence
@@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering():
# Part 4: Second function call (from chunk8)
assert parts[4].function_call.name == "get_weather"
assert parts[4].function_call.args["location"] == "New York"
def test_progressive_sse_streaming_function_call_arguments():
"""Test streaming function call arguments feature.
This test simulates the streamFunctionCallArguments feature where a function
call's arguments are streamed incrementally across multiple chunks:
Chunk 1: FC name + partial location argument ("New ")
Chunk 2: Continue location argument ("York") -> concatenated to "New York"
Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete
Expected result: FunctionCall(name="get_weather",
args={"location": "New York", "unit":
"celsius"},
id="fc_001")
"""
aggregator = StreamingResponseAggregator()
# Chunk 1: FC name + partial location argument
chunk1_fc = types.FunctionCall(
name="get_weather",
id="fc_001",
partial_args=[
types.PartialArg(json_path="$.location", string_value="New ")
],
will_continue=True,
)
chunk1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model", parts=[types.Part(function_call=chunk1_fc)]
)
)
]
)
# Chunk 2: Continue streaming location argument
chunk2_fc = types.FunctionCall(
partial_args=[
types.PartialArg(json_path="$.location", string_value="York")
],
will_continue=True,
)
chunk2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model", parts=[types.Part(function_call=chunk2_fc)]
)
)
]
)
# Chunk 3: Add unit argument, FC complete
chunk3_fc = types.FunctionCall(
partial_args=[
types.PartialArg(json_path="$.unit", string_value="celsius")
],
will_continue=False, # FC complete
)
chunk3 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model", parts=[types.Part(function_call=chunk3_fc)]
),
finish_reason=types.FinishReason.STOP,
)
]
)
# Process all chunks through aggregator
processed_chunks = []
for chunk in [chunk1, chunk2, chunk3]:
async def process():
results = []
async for response in aggregator.process_response(chunk):
results.append(response)
return results
import asyncio
chunk_results = asyncio.run(process())
processed_chunks.extend(chunk_results)
# Get final aggregated response
final_response = aggregator.close()
# Verify final aggregated response has complete FC
assert final_response is not None
assert len(final_response.content.parts) == 1
fc_part = final_response.content.parts[0]
assert fc_part.function_call is not None
assert fc_part.function_call.name == "get_weather"
assert fc_part.function_call.id == "fc_001"
# Verify arguments were correctly assembled from streaming chunks
args = fc_part.function_call.args
assert args["location"] == "New York" # "New " + "York" concatenated
assert args["unit"] == "celsius"
def test_progressive_sse_preserves_thought_signature():
"""Test that thought_signature is preserved when streaming FC arguments.
This test verifies that when a streaming function call has a thought_signature
in the Part, it is correctly preserved in the final aggregated FunctionCall.
"""
aggregator = StreamingResponseAggregator()
# Create a thought signature (simulating what Gemini returns)
# thought_signature is bytes (base64 encoded)
test_thought_signature = b"test_signature_abc123"
# Chunk with streaming FC args and thought_signature
chunk_fc = types.FunctionCall(
name="add_5_numbers",
id="fc_003",
partial_args=[
types.PartialArg(json_path="$.num1", number_value=10),
types.PartialArg(json_path="$.num2", number_value=20),
],
will_continue=False,
)
# Create Part with both function_call AND thought_signature
chunk_part = types.Part(
function_call=chunk_fc, thought_signature=test_thought_signature
)
chunk = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(role="model", parts=[chunk_part]),
finish_reason=types.FinishReason.STOP,
)
]
)
# Process chunk through aggregator
async def process():
results = []
async for response in aggregator.process_response(chunk):
results.append(response)
return results
import asyncio
asyncio.run(process())
# Get final aggregated response
final_response = aggregator.close()
# Verify thought_signature was preserved in the Part
assert final_response is not None
assert len(final_response.content.parts) == 1
fc_part = final_response.content.parts[0]
assert fc_part.function_call is not None
assert fc_part.function_call.name == "add_5_numbers"
assert fc_part.thought_signature == test_thought_signature
def test_progressive_sse_handles_empty_function_call():
"""Test that empty function calls are skipped.
When using streamFunctionCallArguments, Gemini may send an empty
functionCall: {} as the final chunk to signal streaming completion.
This test verifies that such empty function calls are properly skipped
and don't cause errors.
"""
aggregator = StreamingResponseAggregator()
# Chunk 1: Streaming FC with partial args
chunk1_fc = types.FunctionCall(
name="concat_number_and_string",
id="fc_001",
partial_args=[
types.PartialArg(json_path="$.num", number_value=100),
types.PartialArg(json_path="$.s", string_value="ADK"),
],
will_continue=False,
)
chunk1 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model", parts=[types.Part(function_call=chunk1_fc)]
)
)
]
)
# Chunk 2: Empty function call (streaming end marker)
chunk2_fc = types.FunctionCall() # Empty function call
chunk2 = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role="model", parts=[types.Part(function_call=chunk2_fc)]
),
finish_reason=types.FinishReason.STOP,
)
]
)
# Process all chunks through aggregator
async def process():
results = []
for chunk in [chunk1, chunk2]:
async for response in aggregator.process_response(chunk):
results.append(response)
return results
import asyncio
asyncio.run(process())
# Get final aggregated response
final_response = aggregator.close()
# Verify final response only has the real FC, not the empty one
assert final_response is not None
assert len(final_response.content.parts) == 1
fc_part = final_response.content.parts[0]
assert fc_part.function_call is not None
assert fc_part.function_call.name == "concat_number_and_string"
assert fc_part.function_call.id == "fc_001"
# Verify arguments
args = fc_part.function_call.args
assert args["num"] == 100
assert args["s"] == "ADK"