From 485fcb84e3ca351f83416c012edcafcec479c1db Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 20 Feb 2026 15:19:11 -0800 Subject: [PATCH] feat: Add intra-invocation compaction and token compaction pre-request Compact session events before LLM calls when token threshold is exceeded Co-authored-by: George Weale PiperOrigin-RevId: 873095899 --- src/google/adk/agents/invocation_context.py | 7 + src/google/adk/apps/compaction.py | 397 ++++++++++++------ src/google/adk/flows/llm_flows/compaction.py | 58 +++ src/google/adk/flows/llm_flows/single_flow.py | 4 + src/google/adk/runners.py | 8 +- tests/unittests/apps/test_compaction.py | 296 ++++++++++++- .../llm_flows/test_compaction_processor.py | 346 +++++++++++++++ 7 files changed, 983 insertions(+), 133 deletions(-) create mode 100644 src/google/adk/flows/llm_flows/compaction.py create mode 100644 tests/unittests/flows/llm_flows/test_compaction_processor.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7a23a6cc..4c75e1c4 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -24,6 +24,7 @@ from pydantic import ConfigDict from pydantic import Field from pydantic import PrivateAttr +from ..apps.app import EventsCompactionConfig from ..apps.app import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService @@ -200,6 +201,12 @@ class InvocationContext(BaseModel): resumability_config: Optional[ResumabilityConfig] = None """The resumability config that applies to all agents under this invocation.""" + events_compaction_config: Optional[EventsCompactionConfig] = None + """The compaction config for this invocation.""" + + token_compaction_checked: bool = False + """Whether token-threshold compaction ran during this invocation.""" + plugin_manager: PluginManager = Field(default_factory=PluginManager) """The manager for keeping track of plugins in this invocation.""" diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index 4af7b512..61941bff 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -16,25 +16,53 @@ from __future__ import annotations import logging +from google.genai import types + +from ..agents.base_agent import BaseAgent from ..events.event import Event from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .app import App +from .app import EventsCompactionConfig from .llm_event_summarizer import LlmEventSummarizer logger = logging.getLogger('google_adk.' + __name__) -def _count_text_chars_in_event(event: Event) -> int: - """Returns the number of text characters in an event's content.""" +def _count_text_chars_in_content(content: types.Content | None) -> int: + """Returns the number of text characters in a content object.""" total_chars = 0 - if event.content and event.content.parts: - for part in event.content.parts: + if content and content.parts: + for part in content.parts: if part.text: total_chars += len(part.text) return total_chars +def _valid_compactions( + events: list[Event], +) -> list[tuple[int, float, float, Event]]: + """Returns compaction events with fully-defined compaction ranges.""" + compactions: list[tuple[int, float, float, Event]] = [] + for i, event in enumerate(events): + if not (event.actions and event.actions.compaction): + continue + compaction = event.actions.compaction + if ( + compaction.start_timestamp is None + or compaction.end_timestamp is None + or compaction.compacted_content is None + ): + continue + compactions.append(( + i, + compaction.start_timestamp, + compaction.end_timestamp, + event, + )) + return compactions + + def _is_compaction_subsumed( *, start_timestamp: float, @@ -60,67 +88,29 @@ def _is_compaction_subsumed( return False -def _estimate_prompt_token_count(events: list[Event]) -> int | None: +def _estimate_prompt_token_count( + *, + events: list[Event], + current_branch: str | None, + agent_name: str, +) -> int | None: """Returns an approximate prompt token count from session events. - This estimate is compaction-aware: it counts compaction summaries and only - counts raw events that would remain visible after applying compaction ranges. + This estimate mirrors the effective content-building path used by the + contents request processor. """ - compactions: list[tuple[int, float, float, Event]] = [] - for i, event in enumerate(events): - if not (event.actions and event.actions.compaction): - continue - compaction = event.actions.compaction - if ( - compaction.start_timestamp is None - or compaction.end_timestamp is None - or compaction.compacted_content is None - ): - continue - compactions.append(( - i, - compaction.start_timestamp, - compaction.end_timestamp, - Event( - timestamp=compaction.end_timestamp, - author='model', - content=compaction.compacted_content, - branch=event.branch, - invocation_id=event.invocation_id, - actions=event.actions, - ), - )) - - effective_compactions = [ - (i, start, end, summary_event) - for i, start, end, summary_event in compactions - if not _is_compaction_subsumed( - start_timestamp=start, - end_timestamp=end, - event_index=i, - compactions=compactions, - ) - ] - compaction_ranges = [ - (start, end) for _, start, end, _ in effective_compactions - ] - - def _is_timestamp_compacted(ts: float) -> bool: - for start_ts, end_ts in compaction_ranges: - if start_ts <= ts <= end_ts: - return True - return False + # Deferred import: contents depends on agents.invocation_context which + # imports from apps, so a top-level import would create a circular dependency. + from ..flows.llm_flows import contents + effective_contents = contents._get_contents( + current_branch=current_branch, + events=events, + agent_name=agent_name, + ) total_chars = 0 - for _, _, _, summary_event in effective_compactions: - total_chars += _count_text_chars_in_event(summary_event) - - for event in events: - if event.actions and event.actions.compaction: - continue - if _is_timestamp_compacted(event.timestamp): - continue - total_chars += _count_text_chars_in_event(event) + for content in effective_contents: + total_chars += _count_text_chars_in_content(content) if total_chars <= 0: return None @@ -129,7 +119,12 @@ def _estimate_prompt_token_count(events: list[Event]) -> int | None: return total_chars // 4 -def _latest_prompt_token_count(events: list[Event]) -> int | None: +def _latest_prompt_token_count( + events: list[Event], + *, + current_branch: str | None = None, + agent_name: str = '', +) -> int | None: """Returns the most recently observed prompt token count, if available.""" for event in reversed(events): if ( @@ -137,23 +132,29 @@ def _latest_prompt_token_count(events: list[Event]) -> int | None: and event.usage_metadata.prompt_token_count is not None ): return event.usage_metadata.prompt_token_count - return _estimate_prompt_token_count(events) + return _estimate_prompt_token_count( + events=events, + current_branch=current_branch, + agent_name=agent_name, + ) def _latest_compaction_event(events: list[Event]) -> Event | None: - """Returns the compaction event with the greatest covered end timestamp.""" + """Returns the latest non-subsumed compaction event by stream order.""" + compactions = _valid_compactions(events) latest_event = None - latest_end = 0.0 - for event in events: - if ( - event.actions - and event.actions.compaction - and event.actions.compaction.end_timestamp is not None + latest_index = -1 + for event_index, start_ts, end_ts, event in compactions: + if _is_compaction_subsumed( + start_timestamp=start_ts, + end_timestamp=end_ts, + event_index=event_index, + compactions=compactions, ): - end_ts = event.actions.compaction.end_timestamp - if end_ts is not None and end_ts >= latest_end: - latest_end = end_ts - latest_event = event + continue + if event_index > latest_index: + latest_index = event_index + latest_event = event return latest_event @@ -167,55 +168,73 @@ def _latest_compaction_end_timestamp(events: list[Event]) -> float: return latest_event.actions.compaction.end_timestamp -async def _run_compaction_for_token_threshold( - app: App, session: Session, session_service: BaseSessionService -): - """Runs post-invocation compaction based on a token threshold. +def _has_token_threshold_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether token-threshold compaction is fully configured.""" + return bool( + config + and config.token_threshold is not None + and config.event_retention_size is not None + ) - If triggered, this compacts older raw events and keeps the last - `event_retention_size` raw events un-compacted. - """ - config = app.events_compaction_config - if not config: - return False - if config.token_threshold is None or config.event_retention_size is None: - return False - prompt_token_count = _latest_prompt_token_count(session.events) - if prompt_token_count is None or prompt_token_count < config.token_threshold: - return False +def _has_sliding_window_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether sliding-window compaction is fully configured.""" + return bool( + config + and config.compaction_interval is not None + and config.overlap_size is not None + ) - latest_compaction_event = _latest_compaction_event(session.events) - last_compacted_end_timestamp = 0.0 - if ( - latest_compaction_event - and latest_compaction_event.actions - and latest_compaction_event.actions.compaction - and latest_compaction_event.actions.compaction.end_timestamp is not None - ): - last_compacted_end_timestamp = ( - latest_compaction_event.actions.compaction.end_timestamp + +def _ensure_compaction_summarizer( + *, config: EventsCompactionConfig, agent: BaseAgent +) -> None: + """Ensures compaction config has a summarizer initialized.""" + if config.summarizer is not None: + return + + from ..agents.llm_agent import LlmAgent + + if not isinstance(agent, LlmAgent): + raise ValueError( + 'No LlmAgent model available for event compaction summarizer.' ) + config.summarizer = LlmEventSummarizer(llm=agent.canonical_model) + + +def _events_to_compact_for_token_threshold( + *, + events: list[Event], + event_retention_size: int, +) -> list[Event]: + """Collects token-threshold compaction candidates with rolling-summary seed. + + If a previous compaction exists, include its summary as the first event so + the next summary can supersede it. + """ + latest_compaction_event = _latest_compaction_event(events) + last_compacted_end_timestamp = _latest_compaction_end_timestamp(events) + candidate_events = [ - e - for e in session.events - if not (e.actions and e.actions.compaction) - and e.timestamp > last_compacted_end_timestamp + event + for event in events + if not (event.actions and event.actions.compaction) + and event.timestamp > last_compacted_end_timestamp ] + if len(candidate_events) <= event_retention_size: + return [] - if len(candidate_events) <= config.event_retention_size: - return False - - if config.event_retention_size == 0: + if event_retention_size == 0: events_to_compact = candidate_events else: - events_to_compact = candidate_events[: -config.event_retention_size] + split_index = _safe_token_compaction_split_index( + candidate_events=candidate_events, + event_retention_size=event_retention_size, + ) + events_to_compact = candidate_events[:split_index] if not events_to_compact: - return False + return [] - # Rolling summary: if a previous compaction exists, seed the next summary with - # the previous compaction summary content so new compactions can subsume older - # ones while still keeping `event_retention_size` raw events visible. if ( latest_compaction_event and latest_compaction_event.actions @@ -231,10 +250,101 @@ async def _run_compaction_for_token_threshold( branch=latest_compaction_event.branch, invocation_id=Event.new_id(), ) - events_to_compact = [seed_event] + events_to_compact + return [seed_event] + events_to_compact - if not config.summarizer: - config.summarizer = LlmEventSummarizer(llm=app.root_agent.canonical_model) + return events_to_compact + + +def _event_function_call_ids(event: Event) -> set[str]: + """Returns function call ids found in an event.""" + function_call_ids: set[str] = set() + for function_call in event.get_function_calls(): + if function_call.id: + function_call_ids.add(function_call.id) + return function_call_ids + + +def _event_function_response_ids(event: Event) -> set[str]: + """Returns function response ids found in an event.""" + function_response_ids: set[str] = set() + for function_response in event.get_function_responses(): + if function_response.id: + function_response_ids.add(function_response.id) + return function_response_ids + + +def _safe_token_compaction_split_index( + *, + candidate_events: list[Event], + event_retention_size: int, +) -> int: + """Returns a split index that avoids orphaning retained tool responses. + + Retained events (tail of candidate events) may contain function responses. + If their matching function call events are in the compacted prefix, contents + assembly can fail. This method shifts the split earlier so matching function + call events are retained together with their responses. + + Iterates backwards through candidate_events once, maintaining a running set + of unmatched response IDs. The latest valid split point where no unmatched + responses remain is returned. + """ + initial_split = len(candidate_events) - event_retention_size + if initial_split <= 0: + return 0 + + unmatched_response_ids: set[str] = set() + best_split = 0 + + for i in range(len(candidate_events) - 1, -1, -1): + event = candidate_events[i] + unmatched_response_ids.update(_event_function_response_ids(event)) + call_ids = _event_function_call_ids(event) + unmatched_response_ids -= call_ids + + if not unmatched_response_ids and i <= initial_split: + best_split = i + break + + return best_split + + +async def _run_compaction_for_token_threshold_config( + *, + config: EventsCompactionConfig | None, + session: Session, + session_service: BaseSessionService, + agent: BaseAgent, + agent_name: str = '', + current_branch: str | None = None, +) -> bool: + """Runs token-threshold compaction for a provided compaction config.""" + if not _has_token_threshold_config(config): + return False + if config is None: + return False + + if config.token_threshold is None or config.event_retention_size is None: + return False + + prompt_token_count = _latest_prompt_token_count( + session.events, + current_branch=current_branch, + agent_name=agent_name, + ) + if prompt_token_count is None or prompt_token_count < config.token_threshold: + return False + + events_to_compact = _events_to_compact_for_token_threshold( + events=session.events, + event_retention_size=config.event_retention_size, + ) + if not events_to_compact: + return False + + _ensure_compaction_summarizer(config=config, agent=agent) + if config.summarizer is None: + return False compaction_event = await config.summarizer.maybe_summarize_events( events=events_to_compact @@ -246,8 +356,30 @@ async def _run_compaction_for_token_threshold( return False -async def _run_compaction_for_sliding_window( +async def _run_compaction_for_token_threshold( app: App, session: Session, session_service: BaseSessionService +): + """Runs post-invocation compaction based on a token threshold. + + If triggered, this compacts older raw events and keeps the last + `event_retention_size` raw events un-compacted. + """ + return await _run_compaction_for_token_threshold_config( + config=app.events_compaction_config, + session=session, + session_service=session_service, + agent=app.root_agent, + agent_name='', + current_branch=None, + ) + + +async def _run_compaction_for_sliding_window( + app: App, + session: Session, + session_service: BaseSessionService, + *, + skip_token_compaction: bool = False, ): """Runs compaction for SlidingWindowCompactor. @@ -327,22 +459,30 @@ async def _run_compaction_for_sliding_window( app: The application instance. session: The session containing events to compact. session_service: The session service for appending events. + skip_token_compaction: Whether to skip token-threshold compaction. """ events = session.events if not events: return None + config = app.events_compaction_config + if config is None: + return None + # Prefer token-threshold compaction if configured and triggered. - if ( - app.events_compaction_config - and app.events_compaction_config.token_threshold is not None - ): + if not skip_token_compaction and _has_token_threshold_config(config): token_compacted = await _run_compaction_for_token_threshold( app, session, session_service ) if token_compacted: return None + if not _has_sliding_window_config(config): + return None + + if config.compaction_interval is None or config.overlap_size is None: + return None + # Find the last compaction event and its range. last_compacted_end_timestamp = 0.0 for event in reversed(events): @@ -373,7 +513,7 @@ async def _run_compaction_for_sliding_window( if invocation_latest_timestamps[inv_id] > last_compacted_end_timestamp ] - if len(new_invocation_ids) < app.events_compaction_config.compaction_interval: + if len(new_invocation_ids) < config.compaction_interval: return None # Not enough new invocations to trigger compaction. # Determine the range of invocations to compact. @@ -385,9 +525,7 @@ async def _run_compaction_for_sliding_window( first_new_inv_id = new_invocation_ids[0] first_new_inv_idx = unique_invocation_ids.index(first_new_inv_id) - start_idx = max( - 0, first_new_inv_idx - app.events_compaction_config.overlap_size - ) + start_idx = max(0, first_new_inv_idx - config.overlap_size) start_inv_id = unique_invocation_ids[start_idx] # Find the index of the last event with end_inv_id. @@ -419,15 +557,12 @@ async def _run_compaction_for_sliding_window( if not events_to_compact: return None - if not app.events_compaction_config.summarizer: - app.events_compaction_config.summarizer = LlmEventSummarizer( - llm=app.root_agent.canonical_model - ) + _ensure_compaction_summarizer(config=config, agent=app.root_agent) + if config.summarizer is None: + return None - compaction_event = ( - await app.events_compaction_config.summarizer.maybe_summarize_events( - events=events_to_compact - ) + compaction_event = await config.summarizer.maybe_summarize_events( + events=events_to_compact ) if compaction_event: await session_service.append_event(session=session, event=compaction_event) diff --git a/src/google/adk/flows/llm_flows/compaction.py b/src/google/adk/flows/llm_flows/compaction.py new file mode 100644 index 00000000..f4b60ba9 --- /dev/null +++ b/src/google/adk/flows/llm_flows/compaction.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Request processor that runs token-threshold event compaction.""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import TYPE_CHECKING + +from ...apps.compaction import _has_token_threshold_config +from ...apps.compaction import _run_compaction_for_token_threshold_config +from ...events.event import Event +from ._base_llm_processor import BaseLlmRequestProcessor + +if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext + from ...models.llm_request import LlmRequest + + +class CompactionRequestProcessor(BaseLlmRequestProcessor): + """Compacts session events before contents are prepared for model calls.""" + + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + del llm_request + config = invocation_context.events_compaction_config + if not _has_token_threshold_config(config): + return + yield # Required for AsyncGenerator. + + token_compacted = await _run_compaction_for_token_threshold_config( + config=config, + session=invocation_context.session, + session_service=invocation_context.session_service, + agent=invocation_context.agent, + agent_name=invocation_context.agent.name, + current_branch=invocation_context.branch, + ) + if token_compacted: + invocation_context.token_compaction_checked = True + return + yield # Required for AsyncGenerator. + + +request_processor = CompactionRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 0a26cdce..e0bd00ff 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,6 +22,7 @@ from . import _code_execution from . import _nl_planning from . import _output_schema_processor from . import basic +from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -42,6 +43,9 @@ def _create_request_processors(): request_confirmation.request_processor, instructions.request_processor, identity.request_processor, + # Compaction should run before contents so compacted events are reflected + # in the model request context. + compaction.request_processor, contents.request_processor, # Context cache processor sets up cache config and finds # existing cache metadata. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a8..cdb878cf 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -553,7 +553,10 @@ class Runner: if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') await _run_compaction_for_sliding_window( - self.app, session, self.session_service + self.app, + session, + self.session_service, + skip_token_compaction=invocation_context.token_compaction_checked, ) async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: @@ -1362,6 +1365,9 @@ class Runner: credential_service=self.credential_service, plugin_manager=self.plugin_manager, context_cache_config=self.context_cache_config, + events_compaction_config=( + self.app.events_compaction_config if self.app else None + ), invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/tests/unittests/apps/test_compaction.py b/tests/unittests/apps/test_compaction.py index fadcd39d..6960c8d4 100644 --- a/tests/unittests/apps/test_compaction.py +++ b/tests/unittests/apps/test_compaction.py @@ -50,6 +50,7 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): invocation_id: str, text: str, prompt_token_count: int | None = None, + thought: bool = False, ) -> Event: usage_metadata = None if prompt_token_count is not None: @@ -60,7 +61,60 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): timestamp=timestamp, invocation_id=invocation_id, author='user', - content=Content(role='user', parts=[Part(text=text)]), + content=Content(role='user', parts=[Part(text=text, thought=thought)]), + usage_metadata=usage_metadata, + ) + + def _create_function_call_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id=function_call_id, name='tool', args={} + ) + ) + ], + ), + ) + + def _create_function_response_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + prompt_token_count: int | None = None, + ) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), usage_metadata=usage_metadata, ) @@ -249,9 +303,21 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, event_retention_size=5, ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) self.assertEqual(config.token_threshold, 50_000) self.assertEqual(config.event_retention_size, 5) + def test_events_compaction_config_accepts_sliding_window_fields(self): + config = EventsCompactionConfig( + compaction_interval=2, + overlap_size=1, + ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) + self.assertIsNone(config.token_threshold) + self.assertIsNone(config.event_retention_size) + def test_events_compaction_config_rejects_partial_token_fields( self, ): @@ -262,6 +328,23 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, ) + def test_events_compaction_config_rejects_partial_sliding_fields( + self, + ): + with pytest.raises(ValidationError): + EventsCompactionConfig( + compaction_interval=2, + ) + + with pytest.raises(ValidationError): + EventsCompactionConfig( + overlap_size=0, + ) + + def test_events_compaction_config_rejects_missing_modes(self): + with pytest.raises(ValidationError): + EventsCompactionConfig() + def test_latest_prompt_token_count_fallback_applies_compaction(self): events = [ self._create_event(1.0, 'inv1', 'a' * 40), @@ -275,6 +358,25 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): # Visible text after compaction is: 'S' + ('c' * 20) = 21 chars. self.assertEqual(estimated_token_count, 21 // 4) + def test_latest_prompt_token_count_fallback_uses_effective_contents(self): + events = [ + self._create_event(1.0, 'inv1', 'visible'), + Event( + timestamp=2.0, + invocation_id='inv2', + author='model', + content=Content( + role='model', + parts=[Part(text='hidden-thought', thought=True)], + ), + ), + ] + + estimated_token_count = compaction_module._latest_prompt_token_count(events) + + # Thought-only events are filtered by contents processing. + self.assertEqual(estimated_token_count, len('visible') // 4) + async def test_run_compaction_for_token_threshold_keeps_retention_events( self, ): @@ -324,6 +426,136 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_keeps_tool_call_pair( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'tool-call-1'), + self._create_function_response_event( + 3.0, + 'inv2', + 'tool-call-1', + prompt_token_count=100, + ), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_for_token_threshold_equal_threshold_compacts( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=100, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_skip_token_compaction(self): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + await _run_compaction_for_sliding_window( + app, + session, + self.mock_session_service, + skip_token_compaction=True, + ) + + self.mock_compactor.maybe_summarize_events.assert_not_called() + self.mock_session_service.append_event.assert_not_called() + async def test_run_compaction_for_token_threshold_seeds_previous_compaction( self, ): @@ -482,6 +714,68 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_uses_latest_ordered_seed( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_event(3.0, 'inv3', 'e3'), + self._create_event(4.0, 'inv4', 'e4'), + self._create_event(5.0, 'inv5', 'e5'), + self._create_event(15.0, 'inv6', 'e6'), + self._create_event(20.0, 'inv7', 'e7'), + self._create_compacted_event( + 15.0, 20.0, 'Summary 15-20', appended_ts=21.0 + ), + self._create_compacted_event( + 1.0, 5.0, 'Summary 1-5', appended_ts=22.0 + ), + self._create_event(23.0, 'inv8', 'e8'), + self._create_event(24.0, 'inv9', 'e9', prompt_token_count=120), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 23.0, 'Summary 1-23' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + compacted_events_arg[0].content.parts[0].text, 'Summary 1-5' + ) + self.assertEqual( + [e.invocation_id for e in compacted_events_arg[1:]], + ['inv6', 'inv7', 'inv8'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + def test_get_contents_with_multiple_compactions(self): # Event timestamps: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 diff --git a/tests/unittests/flows/llm_flows/test_compaction_processor.py b/tests/unittests/flows/llm_flows/test_compaction_processor.py new file mode 100644 index 00000000..9f747c4b --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_compaction_processor.py @@ -0,0 +1,346 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for request-phase token compaction processor.""" + +from unittest.mock import AsyncMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.apps.app import EventsCompactionConfig +from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.events.event import Event +from google.adk.flows.llm_flows import compaction +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.single_flow import SingleFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.session import Session +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +def _create_event( + *, + timestamp: float, + invocation_id: str, + text: str, + prompt_token_count: int | None = None, +) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='user', + content=Content(role='user', parts=[Part(text=text)]), + usage_metadata=usage_metadata, + ) + + +def test_single_flow_includes_compaction_before_contents(): + flow = SingleFlow() + + compaction_index = flow.request_processors.index(compaction.request_processor) + contents_index = flow.request_processors.index(contents.request_processor) + + assert compaction_index < contents_index + + +@pytest.mark.asyncio +async def test_compaction_request_processor_no_token_config(): + session = Session(app_name='app', user_id='user', id='session', events=[]) + session_service = AsyncMock(spec=BaseSessionService) + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + compaction_interval=2, + overlap_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + session_service.append_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_compaction_request_processor_runs_token_compaction(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + _create_event( + timestamp=3.0, + invocation_id='inv3', + text='e3', + prompt_token_count=100, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_compacts_with_latest_tool_response(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + Event( + timestamp=3.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id='call-1', name='tool', args={} + ) + ) + ], + ), + ), + Event( + timestamp=4.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id='call-1', + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_can_compact_current_user_event(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + Event( + timestamp=2.0, + invocation_id='current-inv', + author='user', + content=Content( + role='user', + parts=[Part(text='latest user message')], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'current-inv', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_not_marked_when_not_compacted(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event( + timestamp=2.0, + invocation_id='inv2', + text='e2', + prompt_token_count=40, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + mock_summarizer.maybe_summarize_events.return_value = Event( + author='compactor', + invocation_id=Event.new_id(), + ) + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + mock_summarizer.maybe_summarize_events.assert_not_called() + session_service.append_event.assert_not_called()