You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Improve type hints and handle None values in ADK utils
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 866025998
This commit is contained in:
committed by
Copybara-Service
parent
adbc37fea1
commit
bb89466623
@@ -219,7 +219,6 @@ asyncio_mode = "auto"
|
||||
python_version = "3.10"
|
||||
exclude = ["tests/", "contributing/samples/"]
|
||||
plugins = ["pydantic.mypy"]
|
||||
# Start with non-strict mode, and swtich to strict mode later.
|
||||
strict = true
|
||||
disable_error_code = ["import-not-found", "import-untyped", "unused-ignore"]
|
||||
follow_imports = "skip"
|
||||
|
||||
@@ -59,7 +59,9 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
|
||||
def _from_a2a_context_id(
|
||||
context_id: str | None,
|
||||
) -> tuple[str, str, str] | tuple[None, None, None]:
|
||||
"""Converts an A2A context id to app name, user id and session id.
|
||||
if context_id is None, return None, None, None
|
||||
if context_id is not None, but not in the format of
|
||||
@@ -69,7 +71,7 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
|
||||
context_id: The A2A context id.
|
||||
|
||||
Returns:
|
||||
The app name, user id and session id.
|
||||
The app name, user id and session id, or (None, None, None) if invalid.
|
||||
"""
|
||||
if not context_id:
|
||||
return None, None, None
|
||||
|
||||
@@ -113,6 +113,12 @@ class CacheMetadata(BaseModel):
|
||||
f"fingerprint={self.fingerprint[:8]}..."
|
||||
)
|
||||
cache_id = self.cache_name.split("/")[-1]
|
||||
if self.expire_time is None:
|
||||
return (
|
||||
f"Cache {cache_id}: used {self.invocations_used} invocations, "
|
||||
f"cached {self.contents_count} contents, "
|
||||
"expires unknown"
|
||||
)
|
||||
time_until_expiry_minutes = (self.expire_time - time.time()) / 60
|
||||
return (
|
||||
f"Cache {cache_id}: used {self.invocations_used} invocations, "
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
import os
|
||||
@@ -32,7 +33,7 @@ EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}"
|
||||
"""Label used to denote calls emerging to external system as a part of Evals."""
|
||||
|
||||
# The ContextVar holds client label collected for the current request.
|
||||
_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
_LABEL_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"_LABEL_CONTEXT", default=None
|
||||
)
|
||||
|
||||
@@ -49,7 +50,7 @@ def _get_default_labels() -> List[str]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def client_label_context(client_label: str):
|
||||
def client_label_context(client_label: str) -> Iterator[None]:
|
||||
"""Runs the operation within the context of the given client label."""
|
||||
current_client_label = _LABEL_CONTEXT.get()
|
||||
|
||||
|
||||
@@ -19,12 +19,12 @@ from google.genai import types
|
||||
|
||||
def is_audio_part(part: types.Part) -> bool:
|
||||
return (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
part.inline_data is not None
|
||||
and part.inline_data.mime_type is not None
|
||||
and part.inline_data.mime_type.startswith('audio/')
|
||||
) or (
|
||||
part.file_data
|
||||
and part.file_data.mime_type
|
||||
part.file_data is not None
|
||||
and part.file_data.mime_type is not None
|
||||
and part.file_data.mime_type.startswith('audio/')
|
||||
)
|
||||
|
||||
|
||||
@@ -14,16 +14,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
T = TypeVar("T", bound=Union[Callable, type])
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _is_truthy_env(var_name: str) -> bool:
|
||||
@@ -39,8 +39,8 @@ def _make_feature_decorator(
|
||||
default_message: str,
|
||||
block_usage: bool = False,
|
||||
bypass_env_var: Optional[str] = None,
|
||||
) -> Callable:
|
||||
def decorator_factory(message_or_obj=None):
|
||||
) -> Callable[..., Any]:
|
||||
def decorator_factory(message_or_obj: Any = None) -> Any:
|
||||
# Case 1: Used as @decorator without parentheses
|
||||
# message_or_obj is the decorated class/function
|
||||
if message_or_obj is not None and (
|
||||
@@ -68,10 +68,11 @@ def _create_decorator(
|
||||
msg = f"[{label.upper()}] {obj_name}: {message}"
|
||||
|
||||
if isinstance(obj, type): # decorating a class
|
||||
orig_init = obj.__init__
|
||||
cls = cast(type[Any], obj)
|
||||
orig_init = cast(Any, cls).__init__
|
||||
|
||||
@functools.wraps(orig_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Check if usage should be bypassed via environment variable at call time
|
||||
should_bypass = bypass_env_var is not None and _is_truthy_env(
|
||||
bypass_env_var
|
||||
@@ -86,13 +87,14 @@ def _create_decorator(
|
||||
warnings.warn(msg, category=UserWarning, stacklevel=2)
|
||||
return orig_init(self, *args, **kwargs)
|
||||
|
||||
obj.__init__ = new_init # type: ignore[attr-defined]
|
||||
return cast(T, obj)
|
||||
cast(Any, cls).__init__ = new_init
|
||||
return cast(T, cls)
|
||||
|
||||
elif callable(obj): # decorating a function or method
|
||||
func = cast(Callable[..., Any], obj)
|
||||
|
||||
@functools.wraps(obj)
|
||||
def wrapper(*args, **kwargs):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Check if usage should be bypassed via environment variable at call time
|
||||
should_bypass = bypass_env_var is not None and _is_truthy_env(
|
||||
bypass_env_var
|
||||
@@ -105,7 +107,7 @@ def _create_decorator(
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
warnings.warn(msg, category=UserWarning, stacklevel=2)
|
||||
return obj(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from .variant_utils import get_google_llm_variant
|
||||
from .variant_utils import GoogleLLMVariant
|
||||
|
||||
|
||||
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]):
|
||||
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool:
|
||||
"""Returns True if output schema with tools is supported."""
|
||||
model_string = model if isinstance(model, str) else model.model
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class StreamingResponseAggregator:
|
||||
individual (partial) model responses, as well as for aggregated content.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._text = ''
|
||||
self._thought_text = ''
|
||||
self._usage_metadata = None
|
||||
@@ -48,9 +48,9 @@ class StreamingResponseAggregator:
|
||||
self._current_fc_name: Optional[str] = None
|
||||
self._current_fc_args: dict[str, Any] = {}
|
||||
self._current_fc_id: Optional[str] = None
|
||||
self._current_thought_signature: Optional[str] = None
|
||||
self._current_thought_signature: Optional[bytes] = None
|
||||
|
||||
def _flush_text_buffer_to_sequence(self):
|
||||
def _flush_text_buffer_to_sequence(self) -> None:
|
||||
"""Flush current text buffer to parts sequence.
|
||||
|
||||
This helper is used in progressive SSE mode to maintain part ordering.
|
||||
@@ -70,7 +70,7 @@ class StreamingResponseAggregator:
|
||||
|
||||
def _get_value_from_partial_arg(
|
||||
self, partial_arg: types.PartialArg, json_path: str
|
||||
):
|
||||
) -> tuple[Any, bool]:
|
||||
"""Extract value from a partial argument.
|
||||
|
||||
Args:
|
||||
@@ -80,7 +80,7 @@ class StreamingResponseAggregator:
|
||||
Returns:
|
||||
Tuple of (value, has_value) where has_value indicates if a value exists
|
||||
"""
|
||||
value = None
|
||||
value: Any = None
|
||||
has_value = False
|
||||
|
||||
if partial_arg.string_value is not None:
|
||||
@@ -95,12 +95,11 @@ class StreamingResponseAggregator:
|
||||
path_parts = path_without_prefix.split('.')
|
||||
|
||||
# Try to get existing value
|
||||
existing_value = self._current_fc_args
|
||||
existing_value: Any = self._current_fc_args
|
||||
for part in path_parts:
|
||||
if isinstance(existing_value, dict) and part in existing_value:
|
||||
existing_value = existing_value[part]
|
||||
else:
|
||||
existing_value = None
|
||||
break
|
||||
|
||||
# Append to existing string or set new value
|
||||
@@ -121,7 +120,7 @@ class StreamingResponseAggregator:
|
||||
|
||||
return value, has_value
|
||||
|
||||
def _set_value_by_json_path(self, json_path: str, value: Any):
|
||||
def _set_value_by_json_path(self, json_path: str, value: Any) -> None:
|
||||
"""Set a value in _current_fc_args using JSONPath notation.
|
||||
|
||||
Args:
|
||||
@@ -147,7 +146,7 @@ class StreamingResponseAggregator:
|
||||
# Set the final value
|
||||
current[path_parts[-1]] = value
|
||||
|
||||
def _flush_function_call_to_sequence(self):
|
||||
def _flush_function_call_to_sequence(self) -> None:
|
||||
"""Flush current function call to parts sequence.
|
||||
|
||||
This creates a complete FunctionCall part from accumulated partial args.
|
||||
@@ -175,7 +174,7 @@ class StreamingResponseAggregator:
|
||||
self._current_fc_id = None
|
||||
self._current_thought_signature = None
|
||||
|
||||
def _process_streaming_function_call(self, fc: types.FunctionCall):
|
||||
def _process_streaming_function_call(self, fc: types.FunctionCall) -> None:
|
||||
"""Process a streaming function call with partialArgs.
|
||||
|
||||
Args:
|
||||
@@ -208,14 +207,14 @@ class StreamingResponseAggregator:
|
||||
self._flush_text_buffer_to_sequence()
|
||||
self._flush_function_call_to_sequence()
|
||||
|
||||
def _process_function_call_part(self, part: types.Part):
|
||||
def _process_function_call_part(self, part: types.Part) -> None:
|
||||
"""Process a function call part (streaming or non-streaming).
|
||||
|
||||
Args:
|
||||
part: The part containing a function call
|
||||
"""
|
||||
fc = part.function_call
|
||||
if not fc:
|
||||
if fc is None:
|
||||
return
|
||||
|
||||
# Check if this is a streaming FC (has partialArgs or will_continue=True)
|
||||
@@ -298,10 +297,11 @@ class StreamingResponseAggregator:
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
part0 = llm_response.content.parts[0]
|
||||
part_text = part0.text or ''
|
||||
if part0.thought:
|
||||
self._thought_text += part0.text
|
||||
self._thought_text += part_text
|
||||
else:
|
||||
self._text += part0.text
|
||||
self._text += part_text
|
||||
llm_response.partial = True
|
||||
elif (self._thought_text or self._text) and (
|
||||
not llm_response.content
|
||||
@@ -382,3 +382,5 @@ class StreamingResponseAggregator:
|
||||
else candidate.finish_message,
|
||||
usage_metadata=self._usage_metadata,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user