From 2780ae2892adfbebc7580c843d2eaad29f86c335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=8D=E5=81=9A=E4=BA=86=E7=9D=A1=E5=A4=A7=E8=A7=89?= <64798754+stakeswky@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:15:15 -0800 Subject: [PATCH] fix: temp-scoped state now visible to subsequent agents in same invocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/4618 ## Summary Fixes #4564 When using `output_key` with a `temp:` prefix (e.g. `output_key='temp:result'`) in a `SequentialAgent`, the output was silently lost. Agent-2 could never read the temp state written by agent-1. ## Root Cause Two issues in `append_event`: 1. `_trim_temp_delta_state()` removed temp keys from the event delta **before** `_update_session_state()` could apply them to the in-memory session 2. `_update_session_state()` also explicitly skipped `temp:`-prefixed keys ```python # Before (broken ordering): async def append_event(self, session, event): event = self._trim_temp_delta_state(event) # temp keys gone! self._update_session_state(session, event) # nothing to apply ``` ## Fix Introduce `_apply_temp_state()` which writes temp-scoped keys to the in-memory `session.state` **before** the event delta is trimmed: ```python # After: async def append_event(self, session, event): self._apply_temp_state(session, event) # temp keys → session.state event = self._trim_temp_delta_state(event) # temp keys removed from delta self._update_session_state(session, event) # non-temp keys applied ``` This ensures: - ✅ Temp state is available to subsequent agents within the same invocation - ✅ Temp state is still stripped from event deltas (not persisted to storage) - ✅ All three session services (InMemory, Database, SQLite) behave consistently ## Files Changed - `src/google/adk/sessions/base_session_service.py`: Added `_apply_temp_state()`, reordered `append_event` logic, removed temp-skip in `_update_session_state` - `src/google/adk/sessions/database_session_service.py`: Added `_apply_temp_state()` call before trim - `src/google/adk/sessions/sqlite_session_service.py`: Added `_apply_temp_state()` call before trim - `tests/unittests/sessions/test_session_service.py`: Updated existing test + added new test for sequential agent scenario ## Testing All 67 session service tests pass across InMemory, Database, and SQLite backends. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4618 from stakeswky:fix/temp-state-output-key b9fc737e7a6dc07e06e99af3271a8fc026acae4a PiperOrigin-RevId: 878499263 --- .../adk/sessions/base_session_service.py | 26 +++++++++-- .../adk/sessions/database_session_service.py | 3 ++ .../adk/sessions/sqlite_session_service.py | 3 ++ .../sessions/test_session_service.py | 43 +++++++++++++++---- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83..eb22a83b 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,13 +106,35 @@ class BaseSessionService(abc.ABC): """Appends an event to a session object.""" if event.partial: return event + # Apply temp-scoped state to the in-memory session BEFORE trimming the + # event delta, so that subsequent agents within the same invocation can + # read temp values (e.g. output_key='temp:my_key' in SequentialAgent). + self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) self._update_session_state(session, event) session.events.append(event) return event + def _apply_temp_state(self, session: Session, event: Event) -> None: + """Applies temp-scoped state delta to the in-memory session state. + + Temp state is ephemeral: it lives in the session's in-memory state for + the duration of the current invocation but is NOT persisted to storage + (the event delta is trimmed separately by _trim_temp_delta_state). + """ + if not event.actions or not event.actions.state_delta: + return + for key, value in event.actions.state_delta.items(): + if key.startswith(State.TEMP_PREFIX): + session.state[key] = value + def _trim_temp_delta_state(self, event: Event) -> Event: - """Removes temporary state delta keys from the event.""" + """Removes temporary state delta keys from the event. + + This prevents temp-scoped state from being persisted, while the + in-memory session state (updated by _apply_temp_state) retains the + values for the duration of the current invocation. + """ if not event.actions or not event.actions.state_delta: return event @@ -128,6 +150,4 @@ class BaseSessionService(abc.ABC): if not event.actions or not event.actions.state_delta: return for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue session.state.update({key: value}) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 6b19464e..321a5cc6 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -522,6 +522,9 @@ class DatabaseSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d23c8278..600f89c4 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -361,6 +361,9 @@ class SqliteSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) event_timestamp = event.timestamp diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 4e277195..5c5aa83e 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service): ) await session_service.append_event(session=session, event=event) - # Refetch session and check state and event - session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' - ) - # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + # Temp state IS available in the in-memory session (same invocation) + assert session.state.get('temp:k1') == 'v1' + assert session.state.get('sk') == 'v2' + # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert 'temp:k1' not in event.actions.state_delta + assert event.actions.state_delta.get('sk') == 'v2' + + +@pytest.mark.asyncio +async def test_temp_state_visible_across_sequential_events(session_service): + """Temp state set by one event should be readable before the next event. + + This simulates a SequentialAgent where agent-1 writes output_key='temp:out' + and agent-2 needs to read it from session.state within the same invocation. + """ + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s_seq' + ) + + # Agent-1 writes temp state + event1 = Event( + invocation_id='inv1', + author='agent1', + actions=EventActions(state_delta={'temp:output': 'result_from_a1'}), + ) + await session_service.append_event(session=session, event=event1) + + # Agent-2 should be able to read temp state from the same session object + assert session.state.get('temp:output') == 'result_from_a1' + + # But the event delta should NOT contain the temp key (not persisted) + assert 'temp:output' not in event1.actions.state_delta @pytest.mark.asyncio