feat(conformance): Supports content and state_delta in TestCase.user_messages and initial_state for session creation

PiperOrigin-RevId: 808827170
This commit is contained in:
Wei Sun (Jack)
2025-09-18 18:55:04 -07:00
committed by Copybara-Service
parent 1a91bb2a59
commit f39df4155e
3 changed files with 51 additions and 7 deletions
+15 -4
View File
@@ -46,22 +46,33 @@ async def _create_conformance_test_files(
async with AdkWebServerClient() as client:
# Create a new session for the test
session = await client.create_session(
app_name=test_case.test_spec.agent, user_id=user_id, state={}
app_name=test_case.test_spec.agent,
user_id=user_id,
state=test_case.test_spec.initial_state,
)
# Run the agent with the user messages
for user_message_index, user_message in enumerate(
test_case.test_spec.user_messages
):
content = types.Content(
parts=[types.Part(text=user_message)], role="user"
)
# Create content from UserMessage object
if user_message.content is not None:
content = user_message.content
elif user_message.text is not None:
content = types.UserContent(parts=[types.Part(text=user_message.text)])
else:
raise ValueError(
f"UserMessage at index {user_message_index} has neither text nor"
" content"
)
async for _ in client.run_agent(
RunAgentRequest(
app_name=test_case.test_spec.agent,
user_id=user_id,
session_id=session.id,
new_message=content,
state_delta=user_message.state_delta,
),
mode="record",
test_case_dir=str(test_case_dir),
+14 -2
View File
@@ -120,7 +120,16 @@ class ConformanceTestRunner:
for user_message_index, user_message in enumerate(
test_case.test_spec.user_messages
):
content = types.UserContent(parts=[types.Part(text=user_message)])
# Create content from UserMessage object
if user_message.content is not None:
content = user_message.content
elif user_message.text is not None:
content = types.UserContent(parts=[types.Part(text=user_message.text)])
else:
raise ValueError(
f"UserMessage at index {user_message_index} has neither text nor"
" content"
)
request = RunAgentRequest(
app_name=test_case.test_spec.agent,
@@ -128,6 +137,7 @@ class ConformanceTestRunner:
session_id=session_id,
new_message=content,
streaming=False,
state_delta=user_message.state_delta,
)
# Run the agent but don't collect events here
@@ -193,7 +203,9 @@ class ConformanceTestRunner:
try:
# Create session
session = await self.client.create_session(
app_name=test_case.test_spec.agent, user_id=self.user_id, state={}
app_name=test_case.test_spec.agent,
user_id=self.user_id,
state=test_case.test_spec.initial_state,
)
# Run each user message
+22 -1
View File
@@ -16,9 +16,27 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
class UserMessage(BaseModel):
# oneof fields - start
text: Optional[str] = None
"""The user message in text."""
content: Optional[types.UserContent] = None
"""The user message in types.Content."""
# oneof fields - end
state_delta: Optional[dict[str, Any]] = None
"""The state changes when running this user message."""
class TestSpec(BaseModel):
@@ -38,7 +56,10 @@ class TestSpec(BaseModel):
agent: str
"""Name of the ADK agent to test against."""
user_messages: list[str]
initial_state: dict[str, Any] = Field(default_factory=dict)
"""The initial state key-value pairs in the creation_session request."""
user_messages: list[UserMessage] = Field(default_factory=list)
"""Sequence of user messages to send to the agent during test execution."""