diff --git a/contributing/samples/hello_world_stream_fc_args/__init__.py b/contributing/samples/hello_world_stream_fc_args/__init__.py new file mode 100755 index 00000000..c48963cd --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_stream_fc_args/agent.py b/contributing/samples/hello_world_stream_fc_args/agent.py new file mode 100755 index 00000000..f6138421 --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/agent.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk import Agent +from google.genai import types + + +def concat_number_and_string(num: int, s: str) -> str: + """Concatenate a number and a string. + + Args: + num: The number to concatenate. + s: The string to concatenate. + + Returns: + The concatenated string. + """ + return str(num) + ': ' + s + + +root_agent = Agent( + model='gemini-3-pro-preview', + name='hello_world_stream_fc_args', + description='Demo agent showcasing streaming function call arguments.', + instruction=""" + You are a helpful assistant. + You can use the `concat_number_and_string` tool to concatenate a number and a string. + You should always call the concat_number_and_string tool to concatenate a number and a string. + You should never concatenate on your own. + """, + tools=[ + concat_number_and_string, + ], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True, + ), + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + stream_function_call_arguments=True, + ), + ), + ), +) diff --git a/pyproject.toml b/pyproject.toml index 5c0515d6..06ddb04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service - "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK + "google-genai>=1.51.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation "mcp>=1.10.0, <2.0.0", # For MCP Toolset diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index eb753654..eae80aa7 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing import Optional @@ -43,6 +44,12 @@ class StreamingResponseAggregator: self._current_text_is_thought: Optional[bool] = None self._finish_reason: Optional[types.FinishReason] = None + # For streaming function call arguments + self._current_fc_name: Optional[str] = None + self._current_fc_args: dict[str, Any] = {} + self._current_fc_id: Optional[str] = None + self._current_thought_signature: Optional[str] = None + def _flush_text_buffer_to_sequence(self): """Flush current text buffer to parts sequence. @@ -61,6 +68,171 @@ class StreamingResponseAggregator: self._current_text_buffer = '' self._current_text_is_thought = None + def _get_value_from_partial_arg( + self, partial_arg: types.PartialArg, json_path: str + ): + """Extract value from a partial argument. + + Args: + partial_arg: The partial argument object + json_path: JSONPath for this argument + + Returns: + Tuple of (value, has_value) where has_value indicates if a value exists + """ + value = None + has_value = False + + if partial_arg.string_value is not None: + # For streaming strings, append chunks to existing value + string_chunk = partial_arg.string_value + has_value = True + + # Get current value for this path (if any) + path_without_prefix = ( + json_path[2:] if json_path.startswith('$.') else json_path + ) + path_parts = path_without_prefix.split('.') + + # Try to get existing value + existing_value = self._current_fc_args + for part in path_parts: + if isinstance(existing_value, dict) and part in existing_value: + existing_value = existing_value[part] + else: + existing_value = None + break + + # Append to existing string or set new value + if isinstance(existing_value, str): + value = existing_value + string_chunk + else: + value = string_chunk + + elif partial_arg.number_value is not None: + value = partial_arg.number_value + has_value = True + elif partial_arg.bool_value is not None: + value = partial_arg.bool_value + has_value = True + elif partial_arg.null_value is not None: + value = None + has_value = True + + return value, has_value + + def _set_value_by_json_path(self, json_path: str, value: Any): + """Set a value in _current_fc_args using JSONPath notation. + + Args: + json_path: JSONPath string like "$.location" or "$.location.latitude" + value: The value to set + """ + # Remove leading "$." from jsonPath + if json_path.startswith('$.'): + path = json_path[2:] + else: + path = json_path + + # Split path into components + path_parts = path.split('.') + + # Navigate to the correct location and set the value + current = self._current_fc_args + for part in path_parts[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + # Set the final value + current[path_parts[-1]] = value + + def _flush_function_call_to_sequence(self): + """Flush current function call to parts sequence. + + This creates a complete FunctionCall part from accumulated partial args. + """ + if self._current_fc_name: + # Create function call part with accumulated args + fc_part = types.Part.from_function_call( + name=self._current_fc_name, + args=self._current_fc_args.copy(), + ) + + # Set the ID if provided (directly on the function_call object) + if self._current_fc_id and fc_part.function_call: + fc_part.function_call.id = self._current_fc_id + + # Set thought_signature if provided (on the Part, not FunctionCall) + if self._current_thought_signature: + fc_part.thought_signature = self._current_thought_signature + + self._parts_sequence.append(fc_part) + + # Reset FC state + self._current_fc_name = None + self._current_fc_args = {} + self._current_fc_id = None + self._current_thought_signature = None + + def _process_streaming_function_call(self, fc: types.FunctionCall): + """Process a streaming function call with partialArgs. + + Args: + fc: The function call object with partial_args + """ + # Save function name if present (first chunk) + if fc.name: + self._current_fc_name = fc.name + if fc.id: + self._current_fc_id = fc.id + + # Process each partial argument + for partial_arg in getattr(fc, 'partial_args', []): + json_path = partial_arg.json_path + if not json_path: + continue + + # Extract value from partial arg + value, has_value = self._get_value_from_partial_arg( + partial_arg, json_path + ) + + # Set the value using JSONPath (only if a value was provided) + if has_value: + self._set_value_by_json_path(json_path, value) + + # Check if function call is complete + fc_will_continue = getattr(fc, 'will_continue', False) + if not fc_will_continue: + # Function call complete, flush it + self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() + + def _process_function_call_part(self, part: types.Part): + """Process a function call part (streaming or non-streaming). + + Args: + part: The part containing a function call + """ + fc = part.function_call + + # Check if this is a streaming FC (has partialArgs) + if hasattr(fc, 'partial_args') and fc.partial_args: + # Streaming function call arguments + + # Save thought_signature from the part (first chunk should have it) + if part.thought_signature and not self._current_thought_signature: + self._current_thought_signature = part.thought_signature + self._process_streaming_function_call(fc) + else: + # Non-streaming function call (standard format with args) + # Skip empty function calls (used as streaming end markers) + if fc.name: + # Flush any buffered text first, then add the FC part + self._flush_text_buffer_to_sequence() + self._parts_sequence.append(part) + async def process_response( self, response: types.GenerateContentResponse ) -> AsyncGenerator[LlmResponse, None]: @@ -101,8 +273,12 @@ class StreamingResponseAggregator: if not self._current_text_buffer: self._current_text_is_thought = part.thought self._current_text_buffer += part.text + elif part.function_call: + # Process function call (handles both streaming Args and + # non-streaming Args) + self._process_function_call_part(part) else: - # Non-text part (function_call, bytes, etc.) + # Other non-text parts (bytes, etc.) # Flush any buffered text first, then add the non-text part self._flush_text_buffer_to_sequence() self._parts_sequence.append(part) @@ -155,8 +331,9 @@ class StreamingResponseAggregator: if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): # Always generate final aggregated response in progressive mode if self._response and self._response.candidates: - # Flush any remaining text buffer to complete the sequence + # Flush any remaining buffers to complete the sequence self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() # Use the parts sequence which preserves original ordering final_parts = self._parts_sequence diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py index e64613ff..e589d51c 100644 --- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py +++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py @@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering(): # Part 4: Second function call (from chunk8) assert parts[4].function_call.name == "get_weather" assert parts[4].function_call.args["location"] == "New York" + + +def test_progressive_sse_streaming_function_call_arguments(): + """Test streaming function call arguments feature. + + This test simulates the streamFunctionCallArguments feature where a function + call's arguments are streamed incrementally across multiple chunks: + + Chunk 1: FC name + partial location argument ("New ") + Chunk 2: Continue location argument ("York") -> concatenated to "New York" + Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete + + Expected result: FunctionCall(name="get_weather", + args={"location": "New York", "unit": + "celsius"}, + id="fc_001") + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: FC name + partial location argument + chunk1_fc = types.FunctionCall( + name="get_weather", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.location", string_value="New ") + ], + will_continue=True, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Continue streaming location argument + chunk2_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.location", string_value="York") + ], + will_continue=True, + ) + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ) + ) + ] + ) + + # Chunk 3: Add unit argument, FC complete + chunk3_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.unit", string_value="celsius") + ], + will_continue=False, # FC complete + ) + chunk3 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk3_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + processed_chunks = [] + for chunk in [chunk1, chunk2, chunk3]: + + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + chunk_results = asyncio.run(process()) + processed_chunks.extend(chunk_results) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final aggregated response has complete FC + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "get_weather" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments were correctly assembled from streaming chunks + args = fc_part.function_call.args + assert args["location"] == "New York" # "New " + "York" concatenated + assert args["unit"] == "celsius" + + +def test_progressive_sse_preserves_thought_signature(): + """Test that thought_signature is preserved when streaming FC arguments. + + This test verifies that when a streaming function call has a thought_signature + in the Part, it is correctly preserved in the final aggregated FunctionCall. + """ + + aggregator = StreamingResponseAggregator() + + # Create a thought signature (simulating what Gemini returns) + # thought_signature is bytes (base64 encoded) + test_thought_signature = b"test_signature_abc123" + + # Chunk with streaming FC args and thought_signature + chunk_fc = types.FunctionCall( + name="add_5_numbers", + id="fc_003", + partial_args=[ + types.PartialArg(json_path="$.num1", number_value=10), + types.PartialArg(json_path="$.num2", number_value=20), + ], + will_continue=False, + ) + + # Create Part with both function_call AND thought_signature + chunk_part = types.Part( + function_call=chunk_fc, thought_signature=test_thought_signature + ) + + chunk = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(role="model", parts=[chunk_part]), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process chunk through aggregator + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify thought_signature was preserved in the Part + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "add_5_numbers" + + assert fc_part.thought_signature == test_thought_signature + + +def test_progressive_sse_handles_empty_function_call(): + """Test that empty function calls are skipped. + + When using streamFunctionCallArguments, Gemini may send an empty + functionCall: {} as the final chunk to signal streaming completion. + This test verifies that such empty function calls are properly skipped + and don't cause errors. + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: Streaming FC with partial args + chunk1_fc = types.FunctionCall( + name="concat_number_and_string", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.num", number_value=100), + types.PartialArg(json_path="$.s", string_value="ADK"), + ], + will_continue=False, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Empty function call (streaming end marker) + chunk2_fc = types.FunctionCall() # Empty function call + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + async def process(): + results = [] + for chunk in [chunk1, chunk2]: + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final response only has the real FC, not the empty one + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "concat_number_and_string" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments + args = fc_part.function_call.args + assert args["num"] == 100 + assert args["s"] == "ADK"