chore: Set agent_engine_id in the service constructor, also use the agent_engine_id field instead of overriding app_name in FastAPI endpoint

PiperOrigin-RevId: 770427903
This commit is contained in:
Shangjie Chen
2025-06-11 19:45:46 -07:00
committed by Copybara-Service
parent a7ea374dfb
commit fc65873d7c
3 changed files with 74 additions and 73 deletions
+6 -29
View File
@@ -276,7 +276,6 @@ def get_fast_api_app(
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
if session_service_uri:
if session_service_uri.startswith("agentengine://"):
# Create vertex session service
@@ -285,8 +284,9 @@ def get_fast_api_app(
raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agents_dir)
session_service = VertexAiSessionService(
os.environ["GOOGLE_CLOUD_PROJECT"],
os.environ["GOOGLE_CLOUD_LOCATION"],
project=os.environ["GOOGLE_CLOUD_PROJECT"],
location=os.environ["GOOGLE_CLOUD_LOCATION"],
agent_engine_id=agent_engine_id,
)
else:
session_service = DatabaseSessionService(db_url=session_service_uri)
@@ -357,8 +357,6 @@ def get_fast_api_app(
async def get_session(
app_name: str, user_id: str, session_id: str
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -371,8 +369,6 @@ def get_fast_api_app(
response_model_exclude_none=True,
)
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
@@ -393,8 +389,6 @@ def get_fast_api_app(
session_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
@@ -419,8 +413,6 @@ def get_fast_api_app(
user_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
@@ -660,8 +652,6 @@ def get_fast_api_app(
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
async def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -677,7 +667,6 @@ def get_fast_api_app(
artifact_name: str,
version: Optional[int] = Query(None),
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = await artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
@@ -700,7 +689,6 @@ def get_fast_api_app(
artifact_name: str,
version_id: int,
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = await artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
@@ -719,7 +707,6 @@ def get_fast_api_app(
async def list_artifact_names(
app_name: str, user_id: str, session_id: str
) -> list[str]:
app_name = agent_engine_id if agent_engine_id else app_name
return await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -731,7 +718,6 @@ def get_fast_api_app(
async def list_artifact_versions(
app_name: str, user_id: str, session_id: str, artifact_name: str
) -> list[int]:
app_name = agent_engine_id if agent_engine_id else app_name
return await artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
@@ -745,7 +731,6 @@ def get_fast_api_app(
async def delete_artifact(
app_name: str, user_id: str, session_id: str, artifact_name: str
):
app_name = agent_engine_id if agent_engine_id else app_name
await artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,
@@ -755,10 +740,8 @@ def get_fast_api_app(
@app.post("/run", response_model_exclude_none=True)
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else req.app_name
session = await session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
@@ -776,11 +759,9 @@ def get_fast_api_app(
@app.post("/run_sse")
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = await session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
@@ -818,8 +799,6 @@ def get_fast_api_app(
async def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -875,8 +854,6 @@ def get_fast_api_app(
) -> None:
await websocket.accept()
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -940,7 +917,7 @@ def get_fast_api_app(
return runner_dict[app_name]
root_agent = agent_loader.load_agent(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
@@ -22,7 +22,6 @@ from typing import Optional
import urllib.parse
from dateutil import parser
from google.genai import types
from typing_extensions import override
from google import genai
@@ -40,15 +39,27 @@ logger = logging.getLogger('google_adk.' + __name__)
class VertexAiSessionService(BaseSessionService):
"""Connects to the managed Vertex AI Session Service."""
"""Connects to the Vertex AI Agent Engine Session Service using GenAI API client.
https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview
"""
def __init__(
self,
project: str = None,
location: str = None,
project: Optional[str] = None,
location: Optional[str] = None,
agent_engine_id: Optional[str] = None,
):
self.project = project
self.location = location
"""Initializes the VertexAiSessionService.
Args:
project: The project id of the project to use.
location: The location of the project to use.
agent_engine_id: The resource ID of the agent engine to use.
"""
self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
@override
async def create_session(
@@ -64,14 +75,13 @@ class VertexAiSessionService(BaseSessionService):
'User-provided Session id is not supported for'
' VertexAISessionService.'
)
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
session_json_dict = {'user_id': user_id}
if state:
session_json_dict['session_state'] = state
api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
@@ -130,10 +140,10 @@ class VertexAiSessionService(BaseSessionService):
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
# Get session resource
api_client = _get_api_client(self.project, self.location)
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
@@ -203,14 +213,14 @@ class VertexAiSessionService(BaseSessionService):
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
if user_id:
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
path = path + f'?filter=user_id={parsed_user_id}'
api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='GET',
path=path,
@@ -236,8 +246,9 @@ class VertexAiSessionService(BaseSessionService):
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_client = _get_api_client(self.project, self.location)
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
api_client = self._get_api_client()
try:
await api_client.async_request(
http_method='DELETE',
@@ -253,8 +264,8 @@ class VertexAiSessionService(BaseSessionService):
# Update the in-memory session.
await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
api_client = _get_api_client(self.project, self.location)
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)
api_client = self._get_api_client()
await api_client.async_request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
@@ -262,15 +273,34 @@ class VertexAiSessionService(BaseSessionService):
)
return event
def _get_reasoning_engine_id(self, app_name: str):
if self._agent_engine_id:
return self._agent_engine_id
def _get_api_client(project: str, location: str):
"""Instantiates an API client for the given project and location.
if app_name.isdigit():
return app_name
It needs to be instantiated inside each request so that the event loop
management.
"""
client = genai.Client(vertexai=True, project=project, location=location)
return client._api_client
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
match = re.fullmatch(pattern, app_name)
if not bool(match):
raise ValueError(
f'App name {app_name} is not valid. It should either be the full'
' ReasoningEngine resource name, or the reasoning engine id.'
)
return match.groups()[-1]
def _get_api_client(self):
"""Instantiates an API client for the given project and location.
It needs to be instantiated inside each request so that the event loop
management can be properly propagated.
"""
client = genai.Client(
vertexai=True, project=self._project, location=self._location
)
return client._api_client
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
@@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event:
)
return event
def _parse_reasoning_engine_id(app_name: str):
if app_name.isdigit():
return app_name
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
match = re.fullmatch(pattern, app_name)
if not bool(match):
raise ValueError(
f'App name {app_name} is not valid. It should either be the full'
' ReasoningEngine resource name, or the reasoning engine id.'
)
return match.groups()[-1]
@@ -245,8 +245,14 @@ class MockApiClient:
raise ValueError(f'Unsupported http method: {http_method}')
def mock_vertex_ai_session_service():
def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
"""Creates a mock Vertex AI Session service for testing."""
if agent_engine_id:
return VertexAiSessionService(
project='test-project',
location='test-location',
agent_engine_id=agent_engine_id,
)
return VertexAiSessionService(
project='test-project', location='test-location'
)
@@ -265,7 +271,7 @@ def mock_get_api_client():
'2': (MOCK_EVENT_JSON_2, 'my_token'),
}
with mock.patch(
'google.adk.sessions.vertex_ai_session_service._get_api_client',
'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client',
return_value=api_client,
):
yield
@@ -273,8 +279,12 @@ def mock_get_api_client():
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_empty_session():
session_service = mock_vertex_ai_session_service()
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
async def test_get_empty_session(agent_engine_id):
if agent_engine_id:
session_service = mock_vertex_ai_session_service(agent_engine_id)
else:
session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo:
await session_service.get_session(
app_name='123', user_id='user', session_id='0'