From f39df4155ea88a6023461b600b0e300089307562 Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Thu, 18 Sep 2025 18:55:04 -0700 Subject: [PATCH] feat(conformance): Supports content and state_delta in `TestCase.user_messages` and `initial_state` for session creation PiperOrigin-RevId: 808827170 --- src/google/adk/cli/conformance/cli_create.py | 19 ++++++++++++---- src/google/adk/cli/conformance/cli_test.py | 16 ++++++++++++-- src/google/adk/cli/conformance/test_case.py | 23 +++++++++++++++++++- 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/google/adk/cli/conformance/cli_create.py b/src/google/adk/cli/conformance/cli_create.py index 8db39fd4..865b223d 100644 --- a/src/google/adk/cli/conformance/cli_create.py +++ b/src/google/adk/cli/conformance/cli_create.py @@ -46,22 +46,33 @@ async def _create_conformance_test_files( async with AdkWebServerClient() as client: # Create a new session for the test session = await client.create_session( - app_name=test_case.test_spec.agent, user_id=user_id, state={} + app_name=test_case.test_spec.agent, + user_id=user_id, + state=test_case.test_spec.initial_state, ) # Run the agent with the user messages for user_message_index, user_message in enumerate( test_case.test_spec.user_messages ): - content = types.Content( - parts=[types.Part(text=user_message)], role="user" - ) + # Create content from UserMessage object + if user_message.content is not None: + content = user_message.content + elif user_message.text is not None: + content = types.UserContent(parts=[types.Part(text=user_message.text)]) + else: + raise ValueError( + f"UserMessage at index {user_message_index} has neither text nor" + " content" + ) + async for _ in client.run_agent( RunAgentRequest( app_name=test_case.test_spec.agent, user_id=user_id, session_id=session.id, new_message=content, + state_delta=user_message.state_delta, ), mode="record", test_case_dir=str(test_case_dir), diff --git a/src/google/adk/cli/conformance/cli_test.py b/src/google/adk/cli/conformance/cli_test.py index b5a86ea9..552a120c 100644 --- a/src/google/adk/cli/conformance/cli_test.py +++ b/src/google/adk/cli/conformance/cli_test.py @@ -120,7 +120,16 @@ class ConformanceTestRunner: for user_message_index, user_message in enumerate( test_case.test_spec.user_messages ): - content = types.UserContent(parts=[types.Part(text=user_message)]) + # Create content from UserMessage object + if user_message.content is not None: + content = user_message.content + elif user_message.text is not None: + content = types.UserContent(parts=[types.Part(text=user_message.text)]) + else: + raise ValueError( + f"UserMessage at index {user_message_index} has neither text nor" + " content" + ) request = RunAgentRequest( app_name=test_case.test_spec.agent, @@ -128,6 +137,7 @@ class ConformanceTestRunner: session_id=session_id, new_message=content, streaming=False, + state_delta=user_message.state_delta, ) # Run the agent but don't collect events here @@ -193,7 +203,9 @@ class ConformanceTestRunner: try: # Create session session = await self.client.create_session( - app_name=test_case.test_spec.agent, user_id=self.user_id, state={} + app_name=test_case.test_spec.agent, + user_id=self.user_id, + state=test_case.test_spec.initial_state, ) # Run each user message diff --git a/src/google/adk/cli/conformance/test_case.py b/src/google/adk/cli/conformance/test_case.py index 06012166..30aa9366 100644 --- a/src/google/adk/cli/conformance/test_case.py +++ b/src/google/adk/cli/conformance/test_case.py @@ -16,9 +16,27 @@ from __future__ import annotations from dataclasses import dataclass from pathlib import Path +from typing import Any +from typing import Optional +from google.genai import types from pydantic import BaseModel from pydantic import ConfigDict +from pydantic import Field + + +class UserMessage(BaseModel): + + # oneof fields - start + text: Optional[str] = None + """The user message in text.""" + + content: Optional[types.UserContent] = None + """The user message in types.Content.""" + # oneof fields - end + + state_delta: Optional[dict[str, Any]] = None + """The state changes when running this user message.""" class TestSpec(BaseModel): @@ -38,7 +56,10 @@ class TestSpec(BaseModel): agent: str """Name of the ADK agent to test against.""" - user_messages: list[str] + initial_state: dict[str, Any] = Field(default_factory=dict) + """The initial state key-value pairs in the creation_session request.""" + + user_messages: list[UserMessage] = Field(default_factory=list) """Sequence of user messages to send to the agent during test execution."""