diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 03f4c392..6667c2de 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -22,6 +22,7 @@ from google.genai import types from typing_extensions import override import vertexai +from ..utils.vertex_ai_utils import get_express_mode_api_key from .base_memory_service import BaseMemoryService from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry @@ -40,6 +41,8 @@ class VertexAiMemoryBankService(BaseMemoryService): project: Optional[str] = None, location: Optional[str] = None, agent_engine_id: Optional[str] = None, + *, + express_mode_api_key: Optional[str] = None, ): """Initializes a VertexAiMemoryBankService. @@ -49,10 +52,19 @@ class VertexAiMemoryBankService(BaseMemoryService): agent_engine_id: The ID of the agent engine to use for the Memory Bank. e.g. '456' in 'projects/my-project/locations/us-central1/reasoningEngines/456'. + express_mode_api_key: The API key to use for Express Mode. If not + provided, the API key from the GOOGLE_API_KEY environment variable will + be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. + Do not use Google AI Studio API key for this field. For more details, + visit + https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ self._project = project self._location = location self._agent_engine_id = agent_engine_id + self._express_mode_api_key = get_express_mode_api_key( + project, location, express_mode_api_key + ) @override async def add_session_to_memory(self, session: Session): @@ -123,11 +135,14 @@ class VertexAiMemoryBankService(BaseMemoryService): It needs to be instantiated inside each request so that the event loop management can be properly propagated. - Returns: - An API client for the given project and location. + An API client for the given project and location or express mode api key. """ - return vertexai.Client(project=self._project, location=self._location) + return vertexai.Client( + project=self._project, + location=self._location, + api_key=self._express_mode_api_key, + ) def _should_filter_out_event(content: types.Content) -> bool: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 48b112b6..8025e79e 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -17,7 +17,6 @@ import asyncio import datetime import json import logging -import os import re from typing import Any from typing import Optional @@ -35,7 +34,8 @@ import vertexai from . import _session_util from ..events.event import Event from ..events.event_actions import EventActions -from ..utils.env_utils import is_env_enabled +from ..utils.vertex_ai_utils import get_express_mode_api_key +from ..utils.vertex_ai_utils import is_vertex_express_mode from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse @@ -55,6 +55,8 @@ class VertexAiSessionService(BaseSessionService): project: Optional[str] = None, location: Optional[str] = None, agent_engine_id: Optional[str] = None, + *, + express_mode_api_key: Optional[str] = None, ): """Initializes the VertexAiSessionService. @@ -62,10 +64,19 @@ class VertexAiSessionService(BaseSessionService): project: The project id of the project to use. location: The location of the project to use. agent_engine_id: The resource ID of the agent engine to use. + express_mode_api_key: The API key to use for Express Mode. If not + provided, the API key from the GOOGLE_API_KEY environment variable will + be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. + Do not use Google AI Studio API key for this field. For more details, + visit + https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ self._project = project self._location = location self._agent_engine_id = agent_engine_id + self._express_mode_api_key = get_express_mode_api_key( + project, location, express_mode_api_key + ) @override async def create_session( @@ -104,7 +115,9 @@ class VertexAiSessionService(BaseSessionService): config = {'session_state': state} if state else {} config.update(kwargs) - if _is_vertex_express_mode(self._project, self._location): + if is_vertex_express_mode( + self._project, self._location, self._express_mode_api_key + ): config['wait_for_completion'] = False api_response = await api_client.aio.agent_engines.sessions.create( name=f'reasoningEngines/{reasoning_engine_id}', @@ -351,27 +364,16 @@ class VertexAiSessionService(BaseSessionService): """Instantiates an API client for the given project and location. Returns: - An API client for the given project and location. + An API client for the given project and location or express mode api key. """ return vertexai.Client( project=self._project, location=self._location, http_options=self._api_client_http_options_override(), + api_key=self._express_mode_api_key, ) -def _is_vertex_express_mode( - project: Optional[str], location: Optional[str] -) -> bool: - """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" - return ( - is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI') - and os.environ.get('GOOGLE_API_KEY', None) is not None - and project is None - and location is None - ) - - def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: """Converts an API event object to an Event object.""" actions = getattr(api_event_obj, 'actions', None) diff --git a/src/google/adk/utils/vertex_ai_utils.py b/src/google/adk/utils/vertex_ai_utils.py new file mode 100644 index 00000000..55969dd1 --- /dev/null +++ b/src/google/adk/utils/vertex_ai_utils.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Vertex AI. Includes helper functions for Express Mode. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from __future__ import annotations + +import os +from typing import Optional + +from ..utils.env_utils import is_env_enabled + + +def is_vertex_express_mode( + project: Optional[str], location: Optional[str], api_key: Optional[str] +) -> bool: + """Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode.""" + return ( + is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI') + and api_key is not None + and project is None + and location is None + ) + + +def get_express_mode_api_key( + project: Optional[str], + location: Optional[str], + express_mode_api_key: Optional[str], +) -> Optional[str]: + """Validates and returns the API key for Express Mode.""" + if (project or location) and express_mode_api_key: + raise ValueError( + 'Cannot specify project or location and express_mode_api_key. ' + 'Either use project and location, or just the express_mode_api_key.' + ) + if is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI'): + return express_mode_api_key or os.environ.get('GOOGLE_API_KEY', None) + else: + return None diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index c47023df..6a1f0ccb 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -13,6 +13,7 @@ # limitations under the License. from datetime import datetime +from typing import Optional from unittest import mock from google.adk.events.event import Event @@ -69,12 +70,18 @@ MOCK_SESSION_WITH_EMPTY_EVENTS = Session( ) -def mock_vertex_ai_memory_bank_service(): +def mock_vertex_ai_memory_bank_service( + project: Optional[str] = 'test-project', + location: Optional[str] = 'test-location', + agent_engine_id: Optional[str] = '123', + express_mode_api_key: Optional[str] = None, +): """Creates a mock Vertex AI Memory Bank service for testing.""" return VertexAiMemoryBankService( - project='test-project', - location='test-location', - agent_engine_id='123', + project=project, + location=location, + agent_engine_id=agent_engine_id, + express_mode_api_key=express_mode_api_key, ) @@ -90,6 +97,21 @@ def mock_vertexai_client(): yield mock_client +@pytest.mark.asyncio +async def test_initialize_with_project_location_and_api_key_error(): + with pytest.raises(ValueError) as excinfo: + mock_vertex_ai_memory_bank_service( + project='test-project', + location='test-location', + express_mode_api_key='test-api-key', + ) + assert ( + 'Cannot specify project or location and express_mode_api_key. Either use' + ' project and location, or just the express_mode_api_key.' + in str(excinfo.value) + ) + + @pytest.mark.asyncio async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 7e76b5ec..fd004199 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -356,12 +356,18 @@ class MockApiClient: self.event_dict[session_id] = ([event_json], None) -def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None): +def mock_vertex_ai_session_service( + project: Optional[str] = 'test-project', + location: Optional[str] = 'test-location', + agent_engine_id: Optional[str] = None, + express_mode_api_key: Optional[str] = None, +): """Creates a mock Vertex AI Session service for testing.""" return VertexAiSessionService( - project='test-project', - location='test-location', + project=project, + location=location, agent_engine_id=agent_engine_id, + express_mode_api_key=express_mode_api_key, ) @@ -393,6 +399,21 @@ def mock_get_api_client(mock_api_client_instance): yield +@pytest.mark.asyncio +async def test_initialize_with_project_location_and_api_key_error(): + with pytest.raises(ValueError) as excinfo: + mock_vertex_ai_session_service( + project='test-project', + location='test-location', + express_mode_api_key='test-api-key', + ) + assert ( + 'Cannot specify project or location and express_mode_api_key. Either use' + ' project and location, or just the express_mode_api_key.' + in str(excinfo.value) + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') @pytest.mark.parametrize('agent_engine_id', [None, '123']) diff --git a/tests/unittests/utils/test_vertex_ai_utils.py b/tests/unittests/utils/test_vertex_ai_utils.py new file mode 100644 index 00000000..644a36c7 --- /dev/null +++ b/tests/unittests/utils/test_vertex_ai_utils.py @@ -0,0 +1,117 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for vertex_utils.""" + +from unittest import mock + +from google.adk.utils import vertex_ai_utils +import pytest + + +@pytest.mark.parametrize( + ('use_vertexai_env', 'project', 'location', 'api_key', 'expected'), + [ + ('true', None, None, 'test-key', True), + ('1', None, None, 'test-key', True), + ('false', None, None, 'test-key', False), + ('0', None, None, 'test-key', False), + (None, None, None, 'test-key', False), + ('true', 'test-project', None, 'test-key', False), + ('true', None, 'test-location', 'test-key', False), + ('true', None, None, None, False), + ], +) +def test_is_vertex_express_mode( + use_vertexai_env, project, location, api_key, expected +): + env_vars = {} + if use_vertexai_env: + env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env + with mock.patch.dict('os.environ', env_vars, clear=True): + assert ( + vertex_ai_utils.is_vertex_express_mode(project, location, api_key) + == expected + ) + + +def test_get_express_mode_api_key_value_error(): + with pytest.raises(ValueError) as excinfo: + vertex_ai_utils.get_express_mode_api_key( + project='test-project', location=None, express_mode_api_key='key' + ) + assert ( + 'Cannot specify project or location and express_mode_api_key. Either use' + ' project and location, or just the express_mode_api_key.' + in str(excinfo.value) + ) + with pytest.raises(ValueError) as excinfo: + vertex_ai_utils.get_express_mode_api_key( + project=None, location='test-location', express_mode_api_key='key' + ) + assert ( + 'Cannot specify project or location and express_mode_api_key. Either use' + ' project and location, or just the express_mode_api_key.' + in str(excinfo.value) + ) + with pytest.raises(ValueError) as excinfo: + vertex_ai_utils.get_express_mode_api_key( + project='test-project', + location='test-location', + express_mode_api_key='key', + ) + assert ( + 'Cannot specify project or location and express_mode_api_key. Either use' + ' project and location, or just the express_mode_api_key.' + in str(excinfo.value) + ) + + +@pytest.mark.parametrize( + ( + 'use_vertexai_env', + 'google_api_key_env', + 'express_mode_api_key', + 'expected', + ), + [ + ('true', None, 'express_key', 'express_key'), + ('1', 'google_key', 'express_key', 'express_key'), + ('true', 'google_key', None, 'google_key'), + ('1', None, None, None), + ('false', 'google_key', 'express_key', None), + ('0', 'google_key', None, None), + (None, 'google_key', 'express_key', None), + ], +) +def test_get_express_mode_api_key( + use_vertexai_env, + google_api_key_env, + express_mode_api_key, + expected, +): + env_vars = {} + if use_vertexai_env: + env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env + if google_api_key_env: + env_vars['GOOGLE_API_KEY'] = google_api_key_env + with mock.patch.dict('os.environ', env_vars, clear=True): + assert ( + vertex_ai_utils.get_express_mode_api_key( + project=None, + location=None, + express_mode_api_key=express_mode_api_key, + ) + == expected + )