chore: Improve type hints and handle None values in ADK utils

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 866025998
This commit is contained in:
George Weale
2026-02-05 11:04:01 -08:00
committed by Copybara-Service
parent adbc37fea1
commit bb89466623
8 changed files with 48 additions and 36 deletions
-1
View File
@@ -219,7 +219,6 @@ asyncio_mode = "auto"
python_version = "3.10"
exclude = ["tests/", "contributing/samples/"]
plugins = ["pydantic.mypy"]
# Start with non-strict mode, and swtich to strict mode later.
strict = true
disable_error_code = ["import-not-found", "import-untyped", "unused-ignore"]
follow_imports = "skip"
+4 -2
View File
@@ -59,7 +59,9 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str:
)
def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
def _from_a2a_context_id(
context_id: str | None,
) -> tuple[str, str, str] | tuple[None, None, None]:
"""Converts an A2A context id to app name, user id and session id.
if context_id is None, return None, None, None
if context_id is not None, but not in the format of
@@ -69,7 +71,7 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
context_id: The A2A context id.
Returns:
The app name, user id and session id.
The app name, user id and session id, or (None, None, None) if invalid.
"""
if not context_id:
return None, None, None
+6
View File
@@ -113,6 +113,12 @@ class CacheMetadata(BaseModel):
f"fingerprint={self.fingerprint[:8]}..."
)
cache_id = self.cache_name.split("/")[-1]
if self.expire_time is None:
return (
f"Cache {cache_id}: used {self.invocations_used} invocations, "
f"cached {self.contents_count} contents, "
"expires unknown"
)
time_until_expiry_minutes = (self.expire_time - time.time()) / 60
return (
f"Cache {cache_id}: used {self.invocations_used} invocations, "
+3 -2
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
import contextvars
import os
@@ -32,7 +33,7 @@ EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}"
"""Label used to denote calls emerging to external system as a part of Evals."""
# The ContextVar holds client label collected for the current request.
_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
_LABEL_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_LABEL_CONTEXT", default=None
)
@@ -49,7 +50,7 @@ def _get_default_labels() -> List[str]:
@contextmanager
def client_label_context(client_label: str):
def client_label_context(client_label: str) -> Iterator[None]:
"""Runs the operation within the context of the given client label."""
current_client_label = _LABEL_CONTEXT.get()
+4 -4
View File
@@ -19,12 +19,12 @@ from google.genai import types
def is_audio_part(part: types.Part) -> bool:
return (
part.inline_data
and part.inline_data.mime_type
part.inline_data is not None
and part.inline_data.mime_type is not None
and part.inline_data.mime_type.startswith('audio/')
) or (
part.file_data
and part.file_data.mime_type
part.file_data is not None
and part.file_data.mime_type is not None
and part.file_data.mime_type.startswith('audio/')
)
+14 -12
View File
@@ -14,16 +14,16 @@
from __future__ import annotations
from collections.abc import Callable
import functools
import os
from typing import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from typing import Union
import warnings
T = TypeVar("T", bound=Union[Callable, type])
T = TypeVar("T")
def _is_truthy_env(var_name: str) -> bool:
@@ -39,8 +39,8 @@ def _make_feature_decorator(
default_message: str,
block_usage: bool = False,
bypass_env_var: Optional[str] = None,
) -> Callable:
def decorator_factory(message_or_obj=None):
) -> Callable[..., Any]:
def decorator_factory(message_or_obj: Any = None) -> Any:
# Case 1: Used as @decorator without parentheses
# message_or_obj is the decorated class/function
if message_or_obj is not None and (
@@ -68,10 +68,11 @@ def _create_decorator(
msg = f"[{label.upper()}] {obj_name}: {message}"
if isinstance(obj, type): # decorating a class
orig_init = obj.__init__
cls = cast(type[Any], obj)
orig_init = cast(Any, cls).__init__
@functools.wraps(orig_init)
def new_init(self, *args, **kwargs):
def new_init(self: Any, *args: Any, **kwargs: Any) -> Any:
# Check if usage should be bypassed via environment variable at call time
should_bypass = bypass_env_var is not None and _is_truthy_env(
bypass_env_var
@@ -86,13 +87,14 @@ def _create_decorator(
warnings.warn(msg, category=UserWarning, stacklevel=2)
return orig_init(self, *args, **kwargs)
obj.__init__ = new_init # type: ignore[attr-defined]
return cast(T, obj)
cast(Any, cls).__init__ = new_init
return cast(T, cls)
elif callable(obj): # decorating a function or method
func = cast(Callable[..., Any], obj)
@functools.wraps(obj)
def wrapper(*args, **kwargs):
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Check if usage should be bypassed via environment variable at call time
should_bypass = bypass_env_var is not None and _is_truthy_env(
bypass_env_var
@@ -105,7 +107,7 @@ def _create_decorator(
raise RuntimeError(msg)
else:
warnings.warn(msg, category=UserWarning, stacklevel=2)
return obj(*args, **kwargs)
return func(*args, **kwargs)
return cast(T, wrapper)
+1 -1
View File
@@ -28,7 +28,7 @@ from .variant_utils import get_google_llm_variant
from .variant_utils import GoogleLLMVariant
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]):
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool:
"""Returns True if output schema with tools is supported."""
model_string = model if isinstance(model, str) else model.model
+16 -14
View File
@@ -32,7 +32,7 @@ class StreamingResponseAggregator:
individual (partial) model responses, as well as for aggregated content.
"""
def __init__(self):
def __init__(self) -> None:
self._text = ''
self._thought_text = ''
self._usage_metadata = None
@@ -48,9 +48,9 @@ class StreamingResponseAggregator:
self._current_fc_name: Optional[str] = None
self._current_fc_args: dict[str, Any] = {}
self._current_fc_id: Optional[str] = None
self._current_thought_signature: Optional[str] = None
self._current_thought_signature: Optional[bytes] = None
def _flush_text_buffer_to_sequence(self):
def _flush_text_buffer_to_sequence(self) -> None:
"""Flush current text buffer to parts sequence.
This helper is used in progressive SSE mode to maintain part ordering.
@@ -70,7 +70,7 @@ class StreamingResponseAggregator:
def _get_value_from_partial_arg(
self, partial_arg: types.PartialArg, json_path: str
):
) -> tuple[Any, bool]:
"""Extract value from a partial argument.
Args:
@@ -80,7 +80,7 @@ class StreamingResponseAggregator:
Returns:
Tuple of (value, has_value) where has_value indicates if a value exists
"""
value = None
value: Any = None
has_value = False
if partial_arg.string_value is not None:
@@ -95,12 +95,11 @@ class StreamingResponseAggregator:
path_parts = path_without_prefix.split('.')
# Try to get existing value
existing_value = self._current_fc_args
existing_value: Any = self._current_fc_args
for part in path_parts:
if isinstance(existing_value, dict) and part in existing_value:
existing_value = existing_value[part]
else:
existing_value = None
break
# Append to existing string or set new value
@@ -121,7 +120,7 @@ class StreamingResponseAggregator:
return value, has_value
def _set_value_by_json_path(self, json_path: str, value: Any):
def _set_value_by_json_path(self, json_path: str, value: Any) -> None:
"""Set a value in _current_fc_args using JSONPath notation.
Args:
@@ -147,7 +146,7 @@ class StreamingResponseAggregator:
# Set the final value
current[path_parts[-1]] = value
def _flush_function_call_to_sequence(self):
def _flush_function_call_to_sequence(self) -> None:
"""Flush current function call to parts sequence.
This creates a complete FunctionCall part from accumulated partial args.
@@ -175,7 +174,7 @@ class StreamingResponseAggregator:
self._current_fc_id = None
self._current_thought_signature = None
def _process_streaming_function_call(self, fc: types.FunctionCall):
def _process_streaming_function_call(self, fc: types.FunctionCall) -> None:
"""Process a streaming function call with partialArgs.
Args:
@@ -208,14 +207,14 @@ class StreamingResponseAggregator:
self._flush_text_buffer_to_sequence()
self._flush_function_call_to_sequence()
def _process_function_call_part(self, part: types.Part):
def _process_function_call_part(self, part: types.Part) -> None:
"""Process a function call part (streaming or non-streaming).
Args:
part: The part containing a function call
"""
fc = part.function_call
if not fc:
if fc is None:
return
# Check if this is a streaming FC (has partialArgs or will_continue=True)
@@ -298,10 +297,11 @@ class StreamingResponseAggregator:
and llm_response.content.parts[0].text
):
part0 = llm_response.content.parts[0]
part_text = part0.text or ''
if part0.thought:
self._thought_text += part0.text
self._thought_text += part_text
else:
self._text += part0.text
self._text += part_text
llm_response.partial = True
elif (self._thought_text or self._text) and (
not llm_response.content
@@ -382,3 +382,5 @@ class StreamingResponseAggregator:
else candidate.finish_message,
usage_metadata=self._usage_metadata,
)
return None