You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat(conformance): Adds a minimal AdkWebServer http client for conformance tests to interact with
PiperOrigin-RevId: 803208215
This commit is contained in:
committed by
Copybara-Service
parent
7b077ac351
commit
ebf2c98e41
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,211 @@
|
||||
"""HTTP client for interacting with the ADK web server."""
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...events.event import Event
|
||||
from ...sessions.session import Session
|
||||
from ..adk_web_server import RunAgentRequest
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class AdkWebServerClient:
|
||||
"""HTTP client for interacting with the ADK web server for conformance tests.
|
||||
|
||||
Usage patterns:
|
||||
|
||||
# Pattern 1: Manual lifecycle management
|
||||
client = AdkWebServerClient()
|
||||
session = await client.create_session(app_name="app", user_id="user")
|
||||
async for event in client.run_agent(request):
|
||||
# Process events...
|
||||
await client.close() # Optional explicit cleanup
|
||||
|
||||
# Pattern 2: Automatic cleanup with context manager (recommended)
|
||||
async with AdkWebServerClient() as client:
|
||||
session = await client.create_session(app_name="app", user_id="user")
|
||||
async for event in client.run_agent(request):
|
||||
# Process events...
|
||||
# Client automatically closed here
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_url: str = "http://127.0.0.1:8000", timeout: float = 30.0
|
||||
):
|
||||
"""Initialize the ADK web server client for conformance testing.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the ADK web server (default: http://127.0.0.1:8000)
|
||||
timeout: Request timeout in seconds (default: 30.0)
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_client(self) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""Get or create an HTTP client with proper lifecycle management.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator yielding the HTTP client instance.
|
||||
"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
)
|
||||
try:
|
||||
yield self._client
|
||||
finally:
|
||||
pass # Keep client alive for reuse
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client and clean up resources."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def __aenter__(self) -> "AdkWebServerClient":
|
||||
"""Async context manager entry.
|
||||
|
||||
Returns:
|
||||
The client instance for use in the async context.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # pylint: disable=unused-argument
|
||||
"""Async context manager exit that closes the HTTP client."""
|
||||
await self.close()
|
||||
|
||||
async def get_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> Session:
|
||||
"""Retrieve a specific session from the ADK web server.
|
||||
|
||||
Args:
|
||||
app_name: Name of the application
|
||||
user_id: User identifier
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
The requested Session object
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails or session not found
|
||||
"""
|
||||
async with self._get_client() as client:
|
||||
response = await client.get(
|
||||
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
return Session.model_validate(response.json())
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
"""Create a new session in the ADK web server.
|
||||
|
||||
Args:
|
||||
app_name: Name of the application
|
||||
user_id: User identifier
|
||||
state: Optional initial state for the session
|
||||
|
||||
Returns:
|
||||
The newly created Session object
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails
|
||||
"""
|
||||
async with self._get_client() as client:
|
||||
payload = {}
|
||||
if state is not None:
|
||||
payload["state"] = state
|
||||
|
||||
response = await client.post(
|
||||
f"/apps/{app_name}/users/{user_id}/sessions",
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return Session.model_validate(response.json())
|
||||
|
||||
async def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
"""Delete a session from the ADK web server.
|
||||
|
||||
Args:
|
||||
app_name: Name of the application
|
||||
user_id: User identifier
|
||||
session_id: Session identifier to delete
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails or session not found
|
||||
"""
|
||||
async with self._get_client() as client:
|
||||
response = await client.delete(
|
||||
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
request: RunAgentRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Run an agent with streaming Server-Sent Events response.
|
||||
|
||||
Args:
|
||||
request: The RunAgentRequest containing agent execution parameters
|
||||
|
||||
Yields:
|
||||
Event objects streamed from the agent execution
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails
|
||||
json.JSONDecodeError: If event data cannot be parsed
|
||||
"""
|
||||
# TODO: Prepare headers for conformance tracking
|
||||
headers = {}
|
||||
|
||||
async with self._get_client() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/run_sse",
|
||||
json=request.model_dump(by_alias=True, exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:") and (data := line[5:].strip()):
|
||||
try:
|
||||
event_data = json.loads(data)
|
||||
yield Event.model_validate(event_data)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
logger.warning("Failed to parse event data: %s", exc)
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,211 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.cli.adk_web_server import RunAgentRequest
|
||||
from google.adk.cli.conformance.adk_web_server_client import AdkWebServerClient
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
def test_init_default_values():
|
||||
client = AdkWebServerClient()
|
||||
assert client.base_url == "http://127.0.0.1:8000"
|
||||
assert client.timeout == 30.0
|
||||
|
||||
|
||||
def test_init_custom_values():
|
||||
client = AdkWebServerClient(
|
||||
base_url="https://custom.example.com/", timeout=60.0
|
||||
)
|
||||
assert client.base_url == "https://custom.example.com"
|
||||
assert client.timeout == 60.0
|
||||
|
||||
|
||||
def test_init_strips_trailing_slash():
|
||||
client = AdkWebServerClient(base_url="http://test.com/")
|
||||
assert client.base_url == "http://test.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "test_session",
|
||||
"app_name": "test_app",
|
||||
"user_id": "test_user",
|
||||
"events": [],
|
||||
"state": {},
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
session = await client.get_session(
|
||||
app_name="test_app", user_id="test_user", session_id="test_session"
|
||||
)
|
||||
|
||||
assert isinstance(session, Session)
|
||||
assert session.id == "test_session"
|
||||
mock_client.get.assert_called_once_with(
|
||||
"/apps/test_app/users/test_user/sessions/test_session"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "new_session",
|
||||
"app_name": "test_app",
|
||||
"user_id": "test_user",
|
||||
"events": [],
|
||||
"state": {"key": "value"},
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
session = await client.create_session(
|
||||
app_name="test_app", user_id="test_user", state={"key": "value"}
|
||||
)
|
||||
|
||||
assert isinstance(session, Session)
|
||||
assert session.id == "new_session"
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/apps/test_app/users/test_user/sessions",
|
||||
json={"state": {"key": "value"}},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.delete.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
await client.delete_session(
|
||||
app_name="test_app", user_id="test_user", session_id="test_session"
|
||||
)
|
||||
|
||||
mock_client.delete.assert_called_once_with(
|
||||
"/apps/test_app/users/test_user/sessions/test_session"
|
||||
)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Create sample events
|
||||
event1 = Event(
|
||||
author="test_agent",
|
||||
invocation_id="test_invocation_1",
|
||||
content=types.Content(role="model", parts=[types.Part(text="Hello")]),
|
||||
)
|
||||
event2 = Event(
|
||||
author="test_agent",
|
||||
invocation_id="test_invocation_2",
|
||||
content=types.Content(role="model", parts=[types.Part(text="World")]),
|
||||
)
|
||||
|
||||
# Mock streaming response
|
||||
class MockStreamResponse:
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def aiter_lines(self):
|
||||
yield f"data:{json.dumps(event1.model_dump())}"
|
||||
yield "data:" # Empty line should be ignored
|
||||
yield f"data:{json.dumps(event2.model_dump())}"
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def mock_stream(*_args, **_kwargs):
|
||||
return MockStreamResponse()
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.stream = mock_stream
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
request = RunAgentRequest(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
session_id="test_session",
|
||||
new_message=types.Content(
|
||||
role="user", parts=[types.Part(text="Hello")]
|
||||
),
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in client.run_agent(request):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 2
|
||||
assert all(isinstance(event, Event) for event in events)
|
||||
assert events[0].invocation_id == "test_invocation_1"
|
||||
assert events[1].invocation_id == "test_invocation_2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Create a mock client to close
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Force client creation
|
||||
async with client._get_client():
|
||||
pass
|
||||
|
||||
# Now close should work
|
||||
await client.close()
|
||||
mock_client.aclose.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager():
|
||||
async with AdkWebServerClient() as client:
|
||||
assert isinstance(client, AdkWebServerClient)
|
||||
Reference in New Issue
Block a user