feat: Add api key argument to Vertex Session and Memory services for Express Mode support

We also change VertexAiSessionService and VertexAiMemoryBankService to both use keyword arguments for project, location, agent engine id, and express mode api key

PiperOrigin-RevId: 825719331
This commit is contained in:
Google Team Member
2025-10-29 14:53:37 -07:00
committed by Copybara-Service
parent d45b31fb45
commit 9014a849ea
6 changed files with 258 additions and 26 deletions
@@ -22,6 +22,7 @@ from google.genai import types
from typing_extensions import override
import vertexai
from ..utils.vertex_ai_utils import get_express_mode_api_key
from .base_memory_service import BaseMemoryService
from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry
@@ -40,6 +41,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
project: Optional[str] = None,
location: Optional[str] = None,
agent_engine_id: Optional[str] = None,
*,
express_mode_api_key: Optional[str] = None,
):
"""Initializes a VertexAiMemoryBankService.
@@ -49,10 +52,19 @@ class VertexAiMemoryBankService(BaseMemoryService):
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
e.g. '456' in
'projects/my-project/locations/us-central1/reasoningEngines/456'.
express_mode_api_key: The API key to use for Express Mode. If not
provided, the API key from the GOOGLE_API_KEY environment variable will
be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true.
Do not use Google AI Studio API key for this field. For more details,
visit
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
"""
self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
self._express_mode_api_key = get_express_mode_api_key(
project, location, express_mode_api_key
)
@override
async def add_session_to_memory(self, session: Session):
@@ -123,11 +135,14 @@ class VertexAiMemoryBankService(BaseMemoryService):
It needs to be instantiated inside each request so that the event loop
management can be properly propagated.
Returns:
An API client for the given project and location.
An API client for the given project and location or express mode api key.
"""
return vertexai.Client(project=self._project, location=self._location)
return vertexai.Client(
project=self._project,
location=self._location,
api_key=self._express_mode_api_key,
)
def _should_filter_out_event(content: types.Content) -> bool:
@@ -17,7 +17,6 @@ import asyncio
import datetime
import json
import logging
import os
import re
from typing import Any
from typing import Optional
@@ -35,7 +34,8 @@ import vertexai
from . import _session_util
from ..events.event import Event
from ..events.event_actions import EventActions
from ..utils.env_utils import is_env_enabled
from ..utils.vertex_ai_utils import get_express_mode_api_key
from ..utils.vertex_ai_utils import is_vertex_express_mode
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
@@ -55,6 +55,8 @@ class VertexAiSessionService(BaseSessionService):
project: Optional[str] = None,
location: Optional[str] = None,
agent_engine_id: Optional[str] = None,
*,
express_mode_api_key: Optional[str] = None,
):
"""Initializes the VertexAiSessionService.
@@ -62,10 +64,19 @@ class VertexAiSessionService(BaseSessionService):
project: The project id of the project to use.
location: The location of the project to use.
agent_engine_id: The resource ID of the agent engine to use.
express_mode_api_key: The API key to use for Express Mode. If not
provided, the API key from the GOOGLE_API_KEY environment variable will
be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true.
Do not use Google AI Studio API key for this field. For more details,
visit
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
"""
self._project = project
self._location = location
self._agent_engine_id = agent_engine_id
self._express_mode_api_key = get_express_mode_api_key(
project, location, express_mode_api_key
)
@override
async def create_session(
@@ -104,7 +115,9 @@ class VertexAiSessionService(BaseSessionService):
config = {'session_state': state} if state else {}
config.update(kwargs)
if _is_vertex_express_mode(self._project, self._location):
if is_vertex_express_mode(
self._project, self._location, self._express_mode_api_key
):
config['wait_for_completion'] = False
api_response = await api_client.aio.agent_engines.sessions.create(
name=f'reasoningEngines/{reasoning_engine_id}',
@@ -351,27 +364,16 @@ class VertexAiSessionService(BaseSessionService):
"""Instantiates an API client for the given project and location.
Returns:
An API client for the given project and location.
An API client for the given project and location or express mode api key.
"""
return vertexai.Client(
project=self._project,
location=self._location,
http_options=self._api_client_http_options_override(),
api_key=self._express_mode_api_key,
)
def _is_vertex_express_mode(
project: Optional[str], location: Optional[str]
) -> bool:
"""Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode."""
return (
is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI')
and os.environ.get('GOOGLE_API_KEY', None) is not None
and project is None
and location is None
)
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
"""Converts an API event object to an Event object."""
actions = getattr(api_event_obj, 'actions', None)
+55
View File
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Vertex AI. Includes helper functions for Express Mode.
This module is for ADK internal use only.
Please do not rely on the implementation details.
"""
from __future__ import annotations
import os
from typing import Optional
from ..utils.env_utils import is_env_enabled
def is_vertex_express_mode(
project: Optional[str], location: Optional[str], api_key: Optional[str]
) -> bool:
"""Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode."""
return (
is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI')
and api_key is not None
and project is None
and location is None
)
def get_express_mode_api_key(
project: Optional[str],
location: Optional[str],
express_mode_api_key: Optional[str],
) -> Optional[str]:
"""Validates and returns the API key for Express Mode."""
if (project or location) and express_mode_api_key:
raise ValueError(
'Cannot specify project or location and express_mode_api_key. '
'Either use project and location, or just the express_mode_api_key.'
)
if is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI'):
return express_mode_api_key or os.environ.get('GOOGLE_API_KEY', None)
else:
return None
@@ -13,6 +13,7 @@
# limitations under the License.
from datetime import datetime
from typing import Optional
from unittest import mock
from google.adk.events.event import Event
@@ -69,12 +70,18 @@ MOCK_SESSION_WITH_EMPTY_EVENTS = Session(
)
def mock_vertex_ai_memory_bank_service():
def mock_vertex_ai_memory_bank_service(
project: Optional[str] = 'test-project',
location: Optional[str] = 'test-location',
agent_engine_id: Optional[str] = '123',
express_mode_api_key: Optional[str] = None,
):
"""Creates a mock Vertex AI Memory Bank service for testing."""
return VertexAiMemoryBankService(
project='test-project',
location='test-location',
agent_engine_id='123',
project=project,
location=location,
agent_engine_id=agent_engine_id,
express_mode_api_key=express_mode_api_key,
)
@@ -90,6 +97,21 @@ def mock_vertexai_client():
yield mock_client
@pytest.mark.asyncio
async def test_initialize_with_project_location_and_api_key_error():
with pytest.raises(ValueError) as excinfo:
mock_vertex_ai_memory_bank_service(
project='test-project',
location='test-location',
express_mode_api_key='test-api-key',
)
assert (
'Cannot specify project or location and express_mode_api_key. Either use'
' project and location, or just the express_mode_api_key.'
in str(excinfo.value)
)
@pytest.mark.asyncio
async def test_add_session_to_memory(mock_vertexai_client):
memory_service = mock_vertex_ai_memory_bank_service()
@@ -356,12 +356,18 @@ class MockApiClient:
self.event_dict[session_id] = ([event_json], None)
def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
def mock_vertex_ai_session_service(
project: Optional[str] = 'test-project',
location: Optional[str] = 'test-location',
agent_engine_id: Optional[str] = None,
express_mode_api_key: Optional[str] = None,
):
"""Creates a mock Vertex AI Session service for testing."""
return VertexAiSessionService(
project='test-project',
location='test-location',
project=project,
location=location,
agent_engine_id=agent_engine_id,
express_mode_api_key=express_mode_api_key,
)
@@ -393,6 +399,21 @@ def mock_get_api_client(mock_api_client_instance):
yield
@pytest.mark.asyncio
async def test_initialize_with_project_location_and_api_key_error():
with pytest.raises(ValueError) as excinfo:
mock_vertex_ai_session_service(
project='test-project',
location='test-location',
express_mode_api_key='test-api-key',
)
assert (
'Cannot specify project or location and express_mode_api_key. Either use'
' project and location, or just the express_mode_api_key.'
in str(excinfo.value)
)
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
@@ -0,0 +1,117 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for vertex_utils."""
from unittest import mock
from google.adk.utils import vertex_ai_utils
import pytest
@pytest.mark.parametrize(
('use_vertexai_env', 'project', 'location', 'api_key', 'expected'),
[
('true', None, None, 'test-key', True),
('1', None, None, 'test-key', True),
('false', None, None, 'test-key', False),
('0', None, None, 'test-key', False),
(None, None, None, 'test-key', False),
('true', 'test-project', None, 'test-key', False),
('true', None, 'test-location', 'test-key', False),
('true', None, None, None, False),
],
)
def test_is_vertex_express_mode(
use_vertexai_env, project, location, api_key, expected
):
env_vars = {}
if use_vertexai_env:
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
with mock.patch.dict('os.environ', env_vars, clear=True):
assert (
vertex_ai_utils.is_vertex_express_mode(project, location, api_key)
== expected
)
def test_get_express_mode_api_key_value_error():
with pytest.raises(ValueError) as excinfo:
vertex_ai_utils.get_express_mode_api_key(
project='test-project', location=None, express_mode_api_key='key'
)
assert (
'Cannot specify project or location and express_mode_api_key. Either use'
' project and location, or just the express_mode_api_key.'
in str(excinfo.value)
)
with pytest.raises(ValueError) as excinfo:
vertex_ai_utils.get_express_mode_api_key(
project=None, location='test-location', express_mode_api_key='key'
)
assert (
'Cannot specify project or location and express_mode_api_key. Either use'
' project and location, or just the express_mode_api_key.'
in str(excinfo.value)
)
with pytest.raises(ValueError) as excinfo:
vertex_ai_utils.get_express_mode_api_key(
project='test-project',
location='test-location',
express_mode_api_key='key',
)
assert (
'Cannot specify project or location and express_mode_api_key. Either use'
' project and location, or just the express_mode_api_key.'
in str(excinfo.value)
)
@pytest.mark.parametrize(
(
'use_vertexai_env',
'google_api_key_env',
'express_mode_api_key',
'expected',
),
[
('true', None, 'express_key', 'express_key'),
('1', 'google_key', 'express_key', 'express_key'),
('true', 'google_key', None, 'google_key'),
('1', None, None, None),
('false', 'google_key', 'express_key', None),
('0', 'google_key', None, None),
(None, 'google_key', 'express_key', None),
],
)
def test_get_express_mode_api_key(
use_vertexai_env,
google_api_key_env,
express_mode_api_key,
expected,
):
env_vars = {}
if use_vertexai_env:
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
if google_api_key_env:
env_vars['GOOGLE_API_KEY'] = google_api_key_env
with mock.patch.dict('os.environ', env_vars, clear=True):
assert (
vertex_ai_utils.get_express_mode_api_key(
project=None,
location=None,
express_mode_api_key=express_mode_api_key,
)
== expected
)