From bb89466623531ce91ef10a3260bac91902e27bb5 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 5 Feb 2026 11:04:01 -0800 Subject: [PATCH] chore: Improve type hints and handle None values in ADK utils Co-authored-by: George Weale PiperOrigin-RevId: 866025998 --- pyproject.toml | 1 - src/google/adk/a2a/converters/utils.py | 6 ++-- src/google/adk/models/cache_metadata.py | 6 ++++ src/google/adk/utils/_client_labels_utils.py | 5 ++-- src/google/adk/utils/content_utils.py | 8 +++--- src/google/adk/utils/feature_decorator.py | 26 +++++++++-------- src/google/adk/utils/output_schema_utils.py | 2 +- src/google/adk/utils/streaming_utils.py | 30 +++++++++++--------- 8 files changed, 48 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 55efe71f..2564e8ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,7 +219,6 @@ asyncio_mode = "auto" python_version = "3.10" exclude = ["tests/", "contributing/samples/"] plugins = ["pydantic.mypy"] -# Start with non-strict mode, and swtich to strict mode later. strict = true disable_error_code = ["import-not-found", "import-untyped", "unused-ignore"] follow_imports = "skip" diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py index ba971560..00111f83 100644 --- a/src/google/adk/a2a/converters/utils.py +++ b/src/google/adk/a2a/converters/utils.py @@ -59,7 +59,9 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: ) -def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: +def _from_a2a_context_id( + context_id: str | None, +) -> tuple[str, str, str] | tuple[None, None, None]: """Converts an A2A context id to app name, user id and session id. if context_id is None, return None, None, None if context_id is not None, but not in the format of @@ -69,7 +71,7 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: context_id: The A2A context id. Returns: - The app name, user id and session id. + The app name, user id and session id, or (None, None, None) if invalid. """ if not context_id: return None, None, None diff --git a/src/google/adk/models/cache_metadata.py b/src/google/adk/models/cache_metadata.py index 95b64611..c6b1dc91 100644 --- a/src/google/adk/models/cache_metadata.py +++ b/src/google/adk/models/cache_metadata.py @@ -113,6 +113,12 @@ class CacheMetadata(BaseModel): f"fingerprint={self.fingerprint[:8]}..." ) cache_id = self.cache_name.split("/")[-1] + if self.expire_time is None: + return ( + f"Cache {cache_id}: used {self.invocations_used} invocations, " + f"cached {self.contents_count} contents, " + "expires unknown" + ) time_until_expiry_minutes = (self.expire_time - time.time()) / 60 return ( f"Cache {cache_id}: used {self.invocations_used} invocations, " diff --git a/src/google/adk/utils/_client_labels_utils.py b/src/google/adk/utils/_client_labels_utils.py index 7fbeda11..f4243e61 100644 --- a/src/google/adk/utils/_client_labels_utils.py +++ b/src/google/adk/utils/_client_labels_utils.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Iterator from contextlib import contextmanager import contextvars import os @@ -32,7 +33,7 @@ EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}" """Label used to denote calls emerging to external system as a part of Evals.""" # The ContextVar holds client label collected for the current request. -_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar( +_LABEL_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar( "_LABEL_CONTEXT", default=None ) @@ -49,7 +50,7 @@ def _get_default_labels() -> List[str]: @contextmanager -def client_label_context(client_label: str): +def client_label_context(client_label: str) -> Iterator[None]: """Runs the operation within the context of the given client label.""" current_client_label = _LABEL_CONTEXT.get() diff --git a/src/google/adk/utils/content_utils.py b/src/google/adk/utils/content_utils.py index 379c31ec..011269ae 100644 --- a/src/google/adk/utils/content_utils.py +++ b/src/google/adk/utils/content_utils.py @@ -19,12 +19,12 @@ from google.genai import types def is_audio_part(part: types.Part) -> bool: return ( - part.inline_data - and part.inline_data.mime_type + part.inline_data is not None + and part.inline_data.mime_type is not None and part.inline_data.mime_type.startswith('audio/') ) or ( - part.file_data - and part.file_data.mime_type + part.file_data is not None + and part.file_data.mime_type is not None and part.file_data.mime_type.startswith('audio/') ) diff --git a/src/google/adk/utils/feature_decorator.py b/src/google/adk/utils/feature_decorator.py index 38b79d9f..7dbbc3bd 100644 --- a/src/google/adk/utils/feature_decorator.py +++ b/src/google/adk/utils/feature_decorator.py @@ -14,16 +14,16 @@ from __future__ import annotations +from collections.abc import Callable import functools import os -from typing import Callable +from typing import Any from typing import cast from typing import Optional from typing import TypeVar -from typing import Union import warnings -T = TypeVar("T", bound=Union[Callable, type]) +T = TypeVar("T") def _is_truthy_env(var_name: str) -> bool: @@ -39,8 +39,8 @@ def _make_feature_decorator( default_message: str, block_usage: bool = False, bypass_env_var: Optional[str] = None, -) -> Callable: - def decorator_factory(message_or_obj=None): +) -> Callable[..., Any]: + def decorator_factory(message_or_obj: Any = None) -> Any: # Case 1: Used as @decorator without parentheses # message_or_obj is the decorated class/function if message_or_obj is not None and ( @@ -68,10 +68,11 @@ def _create_decorator( msg = f"[{label.upper()}] {obj_name}: {message}" if isinstance(obj, type): # decorating a class - orig_init = obj.__init__ + cls = cast(type[Any], obj) + orig_init = cast(Any, cls).__init__ @functools.wraps(orig_init) - def new_init(self, *args, **kwargs): + def new_init(self: Any, *args: Any, **kwargs: Any) -> Any: # Check if usage should be bypassed via environment variable at call time should_bypass = bypass_env_var is not None and _is_truthy_env( bypass_env_var @@ -86,13 +87,14 @@ def _create_decorator( warnings.warn(msg, category=UserWarning, stacklevel=2) return orig_init(self, *args, **kwargs) - obj.__init__ = new_init # type: ignore[attr-defined] - return cast(T, obj) + cast(Any, cls).__init__ = new_init + return cast(T, cls) elif callable(obj): # decorating a function or method + func = cast(Callable[..., Any], obj) - @functools.wraps(obj) - def wrapper(*args, **kwargs): + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: # Check if usage should be bypassed via environment variable at call time should_bypass = bypass_env_var is not None and _is_truthy_env( bypass_env_var @@ -105,7 +107,7 @@ def _create_decorator( raise RuntimeError(msg) else: warnings.warn(msg, category=UserWarning, stacklevel=2) - return obj(*args, **kwargs) + return func(*args, **kwargs) return cast(T, wrapper) diff --git a/src/google/adk/utils/output_schema_utils.py b/src/google/adk/utils/output_schema_utils.py index ab97dd6d..7c494f92 100644 --- a/src/google/adk/utils/output_schema_utils.py +++ b/src/google/adk/utils/output_schema_utils.py @@ -28,7 +28,7 @@ from .variant_utils import get_google_llm_variant from .variant_utils import GoogleLLMVariant -def can_use_output_schema_with_tools(model: Union[str, BaseLlm]): +def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool: """Returns True if output schema with tools is supported.""" model_string = model if isinstance(model, str) else model.model diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index d3751df9..09689a4d 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -32,7 +32,7 @@ class StreamingResponseAggregator: individual (partial) model responses, as well as for aggregated content. """ - def __init__(self): + def __init__(self) -> None: self._text = '' self._thought_text = '' self._usage_metadata = None @@ -48,9 +48,9 @@ class StreamingResponseAggregator: self._current_fc_name: Optional[str] = None self._current_fc_args: dict[str, Any] = {} self._current_fc_id: Optional[str] = None - self._current_thought_signature: Optional[str] = None + self._current_thought_signature: Optional[bytes] = None - def _flush_text_buffer_to_sequence(self): + def _flush_text_buffer_to_sequence(self) -> None: """Flush current text buffer to parts sequence. This helper is used in progressive SSE mode to maintain part ordering. @@ -70,7 +70,7 @@ class StreamingResponseAggregator: def _get_value_from_partial_arg( self, partial_arg: types.PartialArg, json_path: str - ): + ) -> tuple[Any, bool]: """Extract value from a partial argument. Args: @@ -80,7 +80,7 @@ class StreamingResponseAggregator: Returns: Tuple of (value, has_value) where has_value indicates if a value exists """ - value = None + value: Any = None has_value = False if partial_arg.string_value is not None: @@ -95,12 +95,11 @@ class StreamingResponseAggregator: path_parts = path_without_prefix.split('.') # Try to get existing value - existing_value = self._current_fc_args + existing_value: Any = self._current_fc_args for part in path_parts: if isinstance(existing_value, dict) and part in existing_value: existing_value = existing_value[part] else: - existing_value = None break # Append to existing string or set new value @@ -121,7 +120,7 @@ class StreamingResponseAggregator: return value, has_value - def _set_value_by_json_path(self, json_path: str, value: Any): + def _set_value_by_json_path(self, json_path: str, value: Any) -> None: """Set a value in _current_fc_args using JSONPath notation. Args: @@ -147,7 +146,7 @@ class StreamingResponseAggregator: # Set the final value current[path_parts[-1]] = value - def _flush_function_call_to_sequence(self): + def _flush_function_call_to_sequence(self) -> None: """Flush current function call to parts sequence. This creates a complete FunctionCall part from accumulated partial args. @@ -175,7 +174,7 @@ class StreamingResponseAggregator: self._current_fc_id = None self._current_thought_signature = None - def _process_streaming_function_call(self, fc: types.FunctionCall): + def _process_streaming_function_call(self, fc: types.FunctionCall) -> None: """Process a streaming function call with partialArgs. Args: @@ -208,14 +207,14 @@ class StreamingResponseAggregator: self._flush_text_buffer_to_sequence() self._flush_function_call_to_sequence() - def _process_function_call_part(self, part: types.Part): + def _process_function_call_part(self, part: types.Part) -> None: """Process a function call part (streaming or non-streaming). Args: part: The part containing a function call """ fc = part.function_call - if not fc: + if fc is None: return # Check if this is a streaming FC (has partialArgs or will_continue=True) @@ -298,10 +297,11 @@ class StreamingResponseAggregator: and llm_response.content.parts[0].text ): part0 = llm_response.content.parts[0] + part_text = part0.text or '' if part0.thought: - self._thought_text += part0.text + self._thought_text += part_text else: - self._text += part0.text + self._text += part_text llm_response.partial = True elif (self._thought_text or self._text) and ( not llm_response.content @@ -382,3 +382,5 @@ class StreamingResponseAggregator: else candidate.finish_message, usage_metadata=self._usage_metadata, ) + + return None