You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat(conformance): Supports content and state_delta in TestCase.user_messages and initial_state for session creation
PiperOrigin-RevId: 808827170
This commit is contained in:
committed by
Copybara-Service
parent
1a91bb2a59
commit
f39df4155e
@@ -46,22 +46,33 @@ async def _create_conformance_test_files(
|
||||
async with AdkWebServerClient() as client:
|
||||
# Create a new session for the test
|
||||
session = await client.create_session(
|
||||
app_name=test_case.test_spec.agent, user_id=user_id, state={}
|
||||
app_name=test_case.test_spec.agent,
|
||||
user_id=user_id,
|
||||
state=test_case.test_spec.initial_state,
|
||||
)
|
||||
|
||||
# Run the agent with the user messages
|
||||
for user_message_index, user_message in enumerate(
|
||||
test_case.test_spec.user_messages
|
||||
):
|
||||
content = types.Content(
|
||||
parts=[types.Part(text=user_message)], role="user"
|
||||
)
|
||||
# Create content from UserMessage object
|
||||
if user_message.content is not None:
|
||||
content = user_message.content
|
||||
elif user_message.text is not None:
|
||||
content = types.UserContent(parts=[types.Part(text=user_message.text)])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"UserMessage at index {user_message_index} has neither text nor"
|
||||
" content"
|
||||
)
|
||||
|
||||
async for _ in client.run_agent(
|
||||
RunAgentRequest(
|
||||
app_name=test_case.test_spec.agent,
|
||||
user_id=user_id,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
state_delta=user_message.state_delta,
|
||||
),
|
||||
mode="record",
|
||||
test_case_dir=str(test_case_dir),
|
||||
|
||||
@@ -120,7 +120,16 @@ class ConformanceTestRunner:
|
||||
for user_message_index, user_message in enumerate(
|
||||
test_case.test_spec.user_messages
|
||||
):
|
||||
content = types.UserContent(parts=[types.Part(text=user_message)])
|
||||
# Create content from UserMessage object
|
||||
if user_message.content is not None:
|
||||
content = user_message.content
|
||||
elif user_message.text is not None:
|
||||
content = types.UserContent(parts=[types.Part(text=user_message.text)])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"UserMessage at index {user_message_index} has neither text nor"
|
||||
" content"
|
||||
)
|
||||
|
||||
request = RunAgentRequest(
|
||||
app_name=test_case.test_spec.agent,
|
||||
@@ -128,6 +137,7 @@ class ConformanceTestRunner:
|
||||
session_id=session_id,
|
||||
new_message=content,
|
||||
streaming=False,
|
||||
state_delta=user_message.state_delta,
|
||||
)
|
||||
|
||||
# Run the agent but don't collect events here
|
||||
@@ -193,7 +203,9 @@ class ConformanceTestRunner:
|
||||
try:
|
||||
# Create session
|
||||
session = await self.client.create_session(
|
||||
app_name=test_case.test_spec.agent, user_id=self.user_id, state={}
|
||||
app_name=test_case.test_spec.agent,
|
||||
user_id=self.user_id,
|
||||
state=test_case.test_spec.initial_state,
|
||||
)
|
||||
|
||||
# Run each user message
|
||||
|
||||
@@ -16,9 +16,27 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
|
||||
# oneof fields - start
|
||||
text: Optional[str] = None
|
||||
"""The user message in text."""
|
||||
|
||||
content: Optional[types.UserContent] = None
|
||||
"""The user message in types.Content."""
|
||||
# oneof fields - end
|
||||
|
||||
state_delta: Optional[dict[str, Any]] = None
|
||||
"""The state changes when running this user message."""
|
||||
|
||||
|
||||
class TestSpec(BaseModel):
|
||||
@@ -38,7 +56,10 @@ class TestSpec(BaseModel):
|
||||
agent: str
|
||||
"""Name of the ADK agent to test against."""
|
||||
|
||||
user_messages: list[str]
|
||||
initial_state: dict[str, Any] = Field(default_factory=dict)
|
||||
"""The initial state key-value pairs in the creation_session request."""
|
||||
|
||||
user_messages: list[UserMessage] = Field(default_factory=list)
|
||||
"""Sequence of user messages to send to the agent during test execution."""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user