diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 4ef8ae6c..875c3cb7 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -276,7 +276,6 @@ def get_fast_api_app( memory_service = InMemoryMemoryService() # Build the Session service - agent_engine_id = "" if session_service_uri: if session_service_uri.startswith("agentengine://"): # Create vertex session service @@ -285,8 +284,9 @@ def get_fast_api_app( raise click.ClickException("Agent engine id can not be empty.") envs.load_dotenv_for_agent("", agents_dir) session_service = VertexAiSessionService( - os.environ["GOOGLE_CLOUD_PROJECT"], - os.environ["GOOGLE_CLOUD_LOCATION"], + project=os.environ["GOOGLE_CLOUD_PROJECT"], + location=os.environ["GOOGLE_CLOUD_LOCATION"], + agent_engine_id=agent_engine_id, ) else: session_service = DatabaseSessionService(db_url=session_service_uri) @@ -357,8 +357,6 @@ def get_fast_api_app( async def get_session( app_name: str, user_id: str, session_id: str ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -371,8 +369,6 @@ def get_fast_api_app( response_model_exclude_none=True, ) async def list_sessions(app_name: str, user_id: str) -> list[Session]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name list_sessions_response = await session_service.list_sessions( app_name=app_name, user_id=user_id ) @@ -393,8 +389,6 @@ def get_fast_api_app( session_id: str, state: Optional[dict[str, Any]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name if ( await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id @@ -419,8 +413,6 @@ def get_fast_api_app( user_id: str, state: Optional[dict[str, Any]] = None, ) -> Session: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name logger.info("New session created") return await session_service.create_session( app_name=app_name, user_id=user_id, state=state @@ -660,8 +652,6 @@ def get_fast_api_app( @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") async def delete_session(app_name: str, user_id: str, session_id: str): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name await session_service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -677,7 +667,6 @@ def get_fast_api_app( artifact_name: str, version: Optional[int] = Query(None), ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -700,7 +689,6 @@ def get_fast_api_app( artifact_name: str, version_id: int, ) -> Optional[types.Part]: - app_name = agent_engine_id if agent_engine_id else app_name artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, @@ -719,7 +707,6 @@ def get_fast_api_app( async def list_artifact_names( app_name: str, user_id: str, session_id: str ) -> list[str]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_artifact_keys( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -731,7 +718,6 @@ def get_fast_api_app( async def list_artifact_versions( app_name: str, user_id: str, session_id: str, artifact_name: str ) -> list[int]: - app_name = agent_engine_id if agent_engine_id else app_name return await artifact_service.list_versions( app_name=app_name, user_id=user_id, @@ -745,7 +731,6 @@ def get_fast_api_app( async def delete_artifact( app_name: str, user_id: str, session_id: str, artifact_name: str ): - app_name = agent_engine_id if agent_engine_id else app_name await artifact_service.delete_artifact( app_name=app_name, user_id=user_id, @@ -755,10 +740,8 @@ def get_fast_api_app( @app.post("/run", response_model_exclude_none=True) async def agent_run(req: AgentRunRequest) -> list[Event]: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -776,11 +759,9 @@ def get_fast_api_app( @app.post("/run_sse") async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else req.app_name # SSE endpoint session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -818,8 +799,6 @@ def get_fast_api_app( async def get_event_graph( app_name: str, user_id: str, session_id: str, event_id: str ): - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -875,8 +854,6 @@ def get_fast_api_app( ) -> None: await websocket.accept() - # Connect to managed session if agent_engine_id is set. - app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -940,7 +917,7 @@ def get_fast_api_app( return runner_dict[app_name] root_agent = agent_loader.load_agent(app_name) runner = Runner( - app_name=agent_engine_id if agent_engine_id else app_name, + app_name=app_name, agent=root_agent, artifact_service=artifact_service, session_service=session_service, diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 5d6bed2e..258dcd93 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -22,7 +22,6 @@ from typing import Optional import urllib.parse from dateutil import parser -from google.genai import types from typing_extensions import override from google import genai @@ -40,15 +39,27 @@ logger = logging.getLogger('google_adk.' + __name__) class VertexAiSessionService(BaseSessionService): - """Connects to the managed Vertex AI Session Service.""" + """Connects to the Vertex AI Agent Engine Session Service using GenAI API client. + + https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview + """ def __init__( self, - project: str = None, - location: str = None, + project: Optional[str] = None, + location: Optional[str] = None, + agent_engine_id: Optional[str] = None, ): - self.project = project - self.location = location + """Initializes the VertexAiSessionService. + + Args: + project: The project id of the project to use. + location: The location of the project to use. + agent_engine_id: The resource ID of the agent engine to use. + """ + self._project = project + self._location = location + self._agent_engine_id = agent_engine_id @override async def create_session( @@ -64,14 +75,13 @@ class VertexAiSessionService(BaseSessionService): 'User-provided Session id is not supported for' ' VertexAISessionService.' ) - - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() session_json_dict = {'user_id': user_id} if state: session_json_dict['session_state'] = state - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', @@ -130,10 +140,10 @@ class VertexAiSessionService(BaseSessionService): session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() # Get session resource - api_client = _get_api_client(self.project, self.location) get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', @@ -203,14 +213,14 @@ class VertexAiSessionService(BaseSessionService): async def list_sessions( self, *, app_name: str, user_id: str ) -> ListSessionsResponse: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() path = f'reasoningEngines/{reasoning_engine_id}/sessions' if user_id: parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='') path = path + f'?filter=user_id={parsed_user_id}' - api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='GET', path=path, @@ -236,8 +246,9 @@ class VertexAiSessionService(BaseSessionService): async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: - reasoning_engine_id = _parse_reasoning_engine_id(app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(app_name) + api_client = self._get_api_client() + try: await api_client.async_request( http_method='DELETE', @@ -253,8 +264,8 @@ class VertexAiSessionService(BaseSessionService): # Update the in-memory session. await super().append_event(session=session, event=event) - reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - api_client = _get_api_client(self.project, self.location) + reasoning_engine_id = self._get_reasoning_engine_id(session.app_name) + api_client = self._get_api_client() await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', @@ -262,15 +273,34 @@ class VertexAiSessionService(BaseSessionService): ) return event + def _get_reasoning_engine_id(self, app_name: str): + if self._agent_engine_id: + return self._agent_engine_id -def _get_api_client(project: str, location: str): - """Instantiates an API client for the given project and location. + if app_name.isdigit(): + return app_name - It needs to be instantiated inside each request so that the event loop - management. - """ - client = genai.Client(vertexai=True, project=project, location=location) - return client._api_client + pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' + match = re.fullmatch(pattern, app_name) + + if not bool(match): + raise ValueError( + f'App name {app_name} is not valid. It should either be the full' + ' ReasoningEngine resource name, or the reasoning engine id.' + ) + + return match.groups()[-1] + + def _get_api_client(self): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management can be properly propagated. + """ + client = genai.Client( + vertexai=True, project=self._project, location=self._location + ) + return client._api_client def _convert_event_to_json(event: Event) -> Dict[str, Any]: @@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event: ) return event - - -def _parse_reasoning_engine_id(app_name: str): - if app_name.isdigit(): - return app_name - - pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' - match = re.fullmatch(pattern, app_name) - - if not bool(match): - raise ValueError( - f'App name {app_name} is not valid. It should either be the full' - ' ReasoningEngine resource name, or the reasoning engine id.' - ) - - return match.groups()[-1] diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 92f6a29d..6a9e0b46 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -245,8 +245,14 @@ class MockApiClient: raise ValueError(f'Unsupported http method: {http_method}') -def mock_vertex_ai_session_service(): +def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): """Creates a mock Vertex AI Session service for testing.""" + if agent_engine_id: + return VertexAiSessionService( + project='test-project', + location='test-location', + agent_engine_id=agent_engine_id, + ) return VertexAiSessionService( project='test-project', location='test-location' ) @@ -265,7 +271,7 @@ def mock_get_api_client(): '2': (MOCK_EVENT_JSON_2, 'my_token'), } with mock.patch( - 'google.adk.sessions.vertex_ai_session_service._get_api_client', + 'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client', return_value=api_client, ): yield @@ -273,8 +279,12 @@ def mock_get_api_client(): @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') -async def test_get_empty_session(): - session_service = mock_vertex_ai_session_service() +@pytest.mark.parametrize('agent_engine_id', [None, '123']) +async def test_get_empty_session(agent_engine_id): + if agent_engine_id: + session_service = mock_vertex_ai_session_service(agent_engine_id) + else: + session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: await session_service.get_session( app_name='123', user_id='user', session_id='0'