You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: temp-scoped state now visible to subsequent agents in same invocation
Merge https://github.com/google/adk-python/pull/4618 ## Summary Fixes #4564 When using `output_key` with a `temp:` prefix (e.g. `output_key='temp:result'`) in a `SequentialAgent`, the output was silently lost. Agent-2 could never read the temp state written by agent-1. ## Root Cause Two issues in `append_event`: 1. `_trim_temp_delta_state()` removed temp keys from the event delta **before** `_update_session_state()` could apply them to the in-memory session 2. `_update_session_state()` also explicitly skipped `temp:`-prefixed keys ```python # Before (broken ordering): async def append_event(self, session, event): event = self._trim_temp_delta_state(event) # temp keys gone! self._update_session_state(session, event) # nothing to apply ``` ## Fix Introduce `_apply_temp_state()` which writes temp-scoped keys to the in-memory `session.state` **before** the event delta is trimmed: ```python # After: async def append_event(self, session, event): self._apply_temp_state(session, event) # temp keys → session.state event = self._trim_temp_delta_state(event) # temp keys removed from delta self._update_session_state(session, event) # non-temp keys applied ``` This ensures: - ✅ Temp state is available to subsequent agents within the same invocation - ✅ Temp state is still stripped from event deltas (not persisted to storage) - ✅ All three session services (InMemory, Database, SQLite) behave consistently ## Files Changed - `src/google/adk/sessions/base_session_service.py`: Added `_apply_temp_state()`, reordered `append_event` logic, removed temp-skip in `_update_session_state` - `src/google/adk/sessions/database_session_service.py`: Added `_apply_temp_state()` call before trim - `src/google/adk/sessions/sqlite_session_service.py`: Added `_apply_temp_state()` call before trim - `tests/unittests/sessions/test_session_service.py`: Updated existing test + added new test for sequential agent scenario ## Testing All 67 session service tests pass across InMemory, Database, and SQLite backends. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4618 from stakeswky:fix/temp-state-output-key b9fc737e7a6dc07e06e99af3271a8fc026acae4a PiperOrigin-RevId: 878499263
This commit is contained in:
@@ -106,13 +106,35 @@ class BaseSessionService(abc.ABC):
|
||||
"""Appends an event to a session object."""
|
||||
if event.partial:
|
||||
return event
|
||||
# Apply temp-scoped state to the in-memory session BEFORE trimming the
|
||||
# event delta, so that subsequent agents within the same invocation can
|
||||
# read temp values (e.g. output_key='temp:my_key' in SequentialAgent).
|
||||
self._apply_temp_state(session, event)
|
||||
event = self._trim_temp_delta_state(event)
|
||||
self._update_session_state(session, event)
|
||||
session.events.append(event)
|
||||
return event
|
||||
|
||||
def _apply_temp_state(self, session: Session, event: Event) -> None:
|
||||
"""Applies temp-scoped state delta to the in-memory session state.
|
||||
|
||||
Temp state is ephemeral: it lives in the session's in-memory state for
|
||||
the duration of the current invocation but is NOT persisted to storage
|
||||
(the event delta is trimmed separately by _trim_temp_delta_state).
|
||||
"""
|
||||
if not event.actions or not event.actions.state_delta:
|
||||
return
|
||||
for key, value in event.actions.state_delta.items():
|
||||
if key.startswith(State.TEMP_PREFIX):
|
||||
session.state[key] = value
|
||||
|
||||
def _trim_temp_delta_state(self, event: Event) -> Event:
|
||||
"""Removes temporary state delta keys from the event."""
|
||||
"""Removes temporary state delta keys from the event.
|
||||
|
||||
This prevents temp-scoped state from being persisted, while the
|
||||
in-memory session state (updated by _apply_temp_state) retains the
|
||||
values for the duration of the current invocation.
|
||||
"""
|
||||
if not event.actions or not event.actions.state_delta:
|
||||
return event
|
||||
|
||||
@@ -128,6 +150,4 @@ class BaseSessionService(abc.ABC):
|
||||
if not event.actions or not event.actions.state_delta:
|
||||
return
|
||||
for key, value in event.actions.state_delta.items():
|
||||
if key.startswith(State.TEMP_PREFIX):
|
||||
continue
|
||||
session.state.update({key: value})
|
||||
|
||||
@@ -522,6 +522,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
# Apply temp state to in-memory session before trimming, so that
|
||||
# subsequent agents within the same invocation can read temp values.
|
||||
self._apply_temp_state(session, event)
|
||||
# Trim temp state before persisting
|
||||
event = self._trim_temp_delta_state(event)
|
||||
|
||||
|
||||
@@ -361,6 +361,9 @@ class SqliteSessionService(BaseSessionService):
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
# Apply temp state to in-memory session before trimming, so that
|
||||
# subsequent agents within the same invocation can read temp values.
|
||||
self._apply_temp_state(session, event)
|
||||
# Trim temp state before persisting
|
||||
event = self._trim_temp_delta_state(event)
|
||||
event_timestamp = event.timestamp
|
||||
|
||||
@@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service):
|
||||
)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
# Refetch session and check state and event
|
||||
session_got = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id='s1'
|
||||
)
|
||||
# Check session state does not contain temp keys
|
||||
assert session_got.state.get('sk') == 'v2'
|
||||
assert 'temp:k1' not in session_got.state
|
||||
# Temp state IS available in the in-memory session (same invocation)
|
||||
assert session.state.get('temp:k1') == 'v1'
|
||||
assert session.state.get('sk') == 'v2'
|
||||
|
||||
# Check event as stored in session does not contain temp keys in state_delta
|
||||
assert 'temp:k1' not in session_got.events[0].actions.state_delta
|
||||
assert session_got.events[0].actions.state_delta.get('sk') == 'v2'
|
||||
assert 'temp:k1' not in event.actions.state_delta
|
||||
assert event.actions.state_delta.get('sk') == 'v2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temp_state_visible_across_sequential_events(session_service):
|
||||
"""Temp state set by one event should be readable before the next event.
|
||||
|
||||
This simulates a SequentialAgent where agent-1 writes output_key='temp:out'
|
||||
and agent-2 needs to read it from session.state within the same invocation.
|
||||
"""
|
||||
app_name = 'my_app'
|
||||
user_id = 'u1'
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id='s_seq'
|
||||
)
|
||||
|
||||
# Agent-1 writes temp state
|
||||
event1 = Event(
|
||||
invocation_id='inv1',
|
||||
author='agent1',
|
||||
actions=EventActions(state_delta={'temp:output': 'result_from_a1'}),
|
||||
)
|
||||
await session_service.append_event(session=session, event=event1)
|
||||
|
||||
# Agent-2 should be able to read temp state from the same session object
|
||||
assert session.state.get('temp:output') == 'result_from_a1'
|
||||
|
||||
# But the event delta should NOT contain the temp key (not persisted)
|
||||
assert 'temp:output' not in event1.actions.state_delta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user