You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add api key argument to Vertex Session and Memory services for Express Mode support
We also change VertexAiSessionService and VertexAiMemoryBankService to both use keyword arguments for project, location, agent engine id, and express mode api key PiperOrigin-RevId: 825719331
This commit is contained in:
committed by
Copybara-Service
parent
d45b31fb45
commit
9014a849ea
@@ -22,6 +22,7 @@ from google.genai import types
|
||||
from typing_extensions import override
|
||||
import vertexai
|
||||
|
||||
from ..utils.vertex_ai_utils import get_express_mode_api_key
|
||||
from .base_memory_service import BaseMemoryService
|
||||
from .base_memory_service import SearchMemoryResponse
|
||||
from .memory_entry import MemoryEntry
|
||||
@@ -40,6 +41,8 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
agent_engine_id: Optional[str] = None,
|
||||
*,
|
||||
express_mode_api_key: Optional[str] = None,
|
||||
):
|
||||
"""Initializes a VertexAiMemoryBankService.
|
||||
|
||||
@@ -49,10 +52,19 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
|
||||
e.g. '456' in
|
||||
'projects/my-project/locations/us-central1/reasoningEngines/456'.
|
||||
express_mode_api_key: The API key to use for Express Mode. If not
|
||||
provided, the API key from the GOOGLE_API_KEY environment variable will
|
||||
be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true.
|
||||
Do not use Google AI Studio API key for this field. For more details,
|
||||
visit
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
|
||||
"""
|
||||
self._project = project
|
||||
self._location = location
|
||||
self._agent_engine_id = agent_engine_id
|
||||
self._express_mode_api_key = get_express_mode_api_key(
|
||||
project, location, express_mode_api_key
|
||||
)
|
||||
|
||||
@override
|
||||
async def add_session_to_memory(self, session: Session):
|
||||
@@ -123,11 +135,14 @@ class VertexAiMemoryBankService(BaseMemoryService):
|
||||
|
||||
It needs to be instantiated inside each request so that the event loop
|
||||
management can be properly propagated.
|
||||
|
||||
Returns:
|
||||
An API client for the given project and location.
|
||||
An API client for the given project and location or express mode api key.
|
||||
"""
|
||||
return vertexai.Client(project=self._project, location=self._location)
|
||||
return vertexai.Client(
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
api_key=self._express_mode_api_key,
|
||||
)
|
||||
|
||||
|
||||
def _should_filter_out_event(content: types.Content) -> bool:
|
||||
|
||||
@@ -17,7 +17,6 @@ import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
@@ -35,7 +34,8 @@ import vertexai
|
||||
from . import _session_util
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from ..utils.env_utils import is_env_enabled
|
||||
from ..utils.vertex_ai_utils import get_express_mode_api_key
|
||||
from ..utils.vertex_ai_utils import is_vertex_express_mode
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListSessionsResponse
|
||||
@@ -55,6 +55,8 @@ class VertexAiSessionService(BaseSessionService):
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
agent_engine_id: Optional[str] = None,
|
||||
*,
|
||||
express_mode_api_key: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the VertexAiSessionService.
|
||||
|
||||
@@ -62,10 +64,19 @@ class VertexAiSessionService(BaseSessionService):
|
||||
project: The project id of the project to use.
|
||||
location: The location of the project to use.
|
||||
agent_engine_id: The resource ID of the agent engine to use.
|
||||
express_mode_api_key: The API key to use for Express Mode. If not
|
||||
provided, the API key from the GOOGLE_API_KEY environment variable will
|
||||
be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true.
|
||||
Do not use Google AI Studio API key for this field. For more details,
|
||||
visit
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
|
||||
"""
|
||||
self._project = project
|
||||
self._location = location
|
||||
self._agent_engine_id = agent_engine_id
|
||||
self._express_mode_api_key = get_express_mode_api_key(
|
||||
project, location, express_mode_api_key
|
||||
)
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
@@ -104,7 +115,9 @@ class VertexAiSessionService(BaseSessionService):
|
||||
config = {'session_state': state} if state else {}
|
||||
config.update(kwargs)
|
||||
|
||||
if _is_vertex_express_mode(self._project, self._location):
|
||||
if is_vertex_express_mode(
|
||||
self._project, self._location, self._express_mode_api_key
|
||||
):
|
||||
config['wait_for_completion'] = False
|
||||
api_response = await api_client.aio.agent_engines.sessions.create(
|
||||
name=f'reasoningEngines/{reasoning_engine_id}',
|
||||
@@ -351,27 +364,16 @@ class VertexAiSessionService(BaseSessionService):
|
||||
"""Instantiates an API client for the given project and location.
|
||||
|
||||
Returns:
|
||||
An API client for the given project and location.
|
||||
An API client for the given project and location or express mode api key.
|
||||
"""
|
||||
return vertexai.Client(
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
http_options=self._api_client_http_options_override(),
|
||||
api_key=self._express_mode_api_key,
|
||||
)
|
||||
|
||||
|
||||
def _is_vertex_express_mode(
|
||||
project: Optional[str], location: Optional[str]
|
||||
) -> bool:
|
||||
"""Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode."""
|
||||
return (
|
||||
is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI')
|
||||
and os.environ.get('GOOGLE_API_KEY', None) is not None
|
||||
and project is None
|
||||
and location is None
|
||||
)
|
||||
|
||||
|
||||
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
|
||||
"""Converts an API event object to an Event object."""
|
||||
actions = getattr(api_event_obj, 'actions', None)
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for Vertex AI. Includes helper functions for Express Mode.
|
||||
|
||||
This module is for ADK internal use only.
|
||||
Please do not rely on the implementation details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..utils.env_utils import is_env_enabled
|
||||
|
||||
|
||||
def is_vertex_express_mode(
|
||||
project: Optional[str], location: Optional[str], api_key: Optional[str]
|
||||
) -> bool:
|
||||
"""Check if Vertex AI and API key are both enabled replacing project and location, meaning the user is using the Vertex Express Mode."""
|
||||
return (
|
||||
is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI')
|
||||
and api_key is not None
|
||||
and project is None
|
||||
and location is None
|
||||
)
|
||||
|
||||
|
||||
def get_express_mode_api_key(
|
||||
project: Optional[str],
|
||||
location: Optional[str],
|
||||
express_mode_api_key: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""Validates and returns the API key for Express Mode."""
|
||||
if (project or location) and express_mode_api_key:
|
||||
raise ValueError(
|
||||
'Cannot specify project or location and express_mode_api_key. '
|
||||
'Either use project and location, or just the express_mode_api_key.'
|
||||
)
|
||||
if is_env_enabled('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
return express_mode_api_key or os.environ.get('GOOGLE_API_KEY', None)
|
||||
else:
|
||||
return None
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.events.event import Event
|
||||
@@ -69,12 +70,18 @@ MOCK_SESSION_WITH_EMPTY_EVENTS = Session(
|
||||
)
|
||||
|
||||
|
||||
def mock_vertex_ai_memory_bank_service():
|
||||
def mock_vertex_ai_memory_bank_service(
|
||||
project: Optional[str] = 'test-project',
|
||||
location: Optional[str] = 'test-location',
|
||||
agent_engine_id: Optional[str] = '123',
|
||||
express_mode_api_key: Optional[str] = None,
|
||||
):
|
||||
"""Creates a mock Vertex AI Memory Bank service for testing."""
|
||||
return VertexAiMemoryBankService(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
agent_engine_id='123',
|
||||
project=project,
|
||||
location=location,
|
||||
agent_engine_id=agent_engine_id,
|
||||
express_mode_api_key=express_mode_api_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,6 +97,21 @@ def mock_vertexai_client():
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_with_project_location_and_api_key_error():
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
mock_vertex_ai_memory_bank_service(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
express_mode_api_key='test-api-key',
|
||||
)
|
||||
assert (
|
||||
'Cannot specify project or location and express_mode_api_key. Either use'
|
||||
' project and location, or just the express_mode_api_key.'
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_session_to_memory(mock_vertexai_client):
|
||||
memory_service = mock_vertex_ai_memory_bank_service()
|
||||
|
||||
@@ -356,12 +356,18 @@ class MockApiClient:
|
||||
self.event_dict[session_id] = ([event_json], None)
|
||||
|
||||
|
||||
def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
|
||||
def mock_vertex_ai_session_service(
|
||||
project: Optional[str] = 'test-project',
|
||||
location: Optional[str] = 'test-location',
|
||||
agent_engine_id: Optional[str] = None,
|
||||
express_mode_api_key: Optional[str] = None,
|
||||
):
|
||||
"""Creates a mock Vertex AI Session service for testing."""
|
||||
return VertexAiSessionService(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
project=project,
|
||||
location=location,
|
||||
agent_engine_id=agent_engine_id,
|
||||
express_mode_api_key=express_mode_api_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -393,6 +399,21 @@ def mock_get_api_client(mock_api_client_instance):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_with_project_location_and_api_key_error():
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
mock_vertex_ai_session_service(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
express_mode_api_key='test-api-key',
|
||||
)
|
||||
assert (
|
||||
'Cannot specify project or location and express_mode_api_key. Either use'
|
||||
' project and location, or just the express_mode_api_key.'
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for vertex_utils."""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.utils import vertex_ai_utils
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('use_vertexai_env', 'project', 'location', 'api_key', 'expected'),
|
||||
[
|
||||
('true', None, None, 'test-key', True),
|
||||
('1', None, None, 'test-key', True),
|
||||
('false', None, None, 'test-key', False),
|
||||
('0', None, None, 'test-key', False),
|
||||
(None, None, None, 'test-key', False),
|
||||
('true', 'test-project', None, 'test-key', False),
|
||||
('true', None, 'test-location', 'test-key', False),
|
||||
('true', None, None, None, False),
|
||||
],
|
||||
)
|
||||
def test_is_vertex_express_mode(
|
||||
use_vertexai_env, project, location, api_key, expected
|
||||
):
|
||||
env_vars = {}
|
||||
if use_vertexai_env:
|
||||
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
|
||||
with mock.patch.dict('os.environ', env_vars, clear=True):
|
||||
assert (
|
||||
vertex_ai_utils.is_vertex_express_mode(project, location, api_key)
|
||||
== expected
|
||||
)
|
||||
|
||||
|
||||
def test_get_express_mode_api_key_value_error():
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
vertex_ai_utils.get_express_mode_api_key(
|
||||
project='test-project', location=None, express_mode_api_key='key'
|
||||
)
|
||||
assert (
|
||||
'Cannot specify project or location and express_mode_api_key. Either use'
|
||||
' project and location, or just the express_mode_api_key.'
|
||||
in str(excinfo.value)
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
vertex_ai_utils.get_express_mode_api_key(
|
||||
project=None, location='test-location', express_mode_api_key='key'
|
||||
)
|
||||
assert (
|
||||
'Cannot specify project or location and express_mode_api_key. Either use'
|
||||
' project and location, or just the express_mode_api_key.'
|
||||
in str(excinfo.value)
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
vertex_ai_utils.get_express_mode_api_key(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
express_mode_api_key='key',
|
||||
)
|
||||
assert (
|
||||
'Cannot specify project or location and express_mode_api_key. Either use'
|
||||
' project and location, or just the express_mode_api_key.'
|
||||
in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
'use_vertexai_env',
|
||||
'google_api_key_env',
|
||||
'express_mode_api_key',
|
||||
'expected',
|
||||
),
|
||||
[
|
||||
('true', None, 'express_key', 'express_key'),
|
||||
('1', 'google_key', 'express_key', 'express_key'),
|
||||
('true', 'google_key', None, 'google_key'),
|
||||
('1', None, None, None),
|
||||
('false', 'google_key', 'express_key', None),
|
||||
('0', 'google_key', None, None),
|
||||
(None, 'google_key', 'express_key', None),
|
||||
],
|
||||
)
|
||||
def test_get_express_mode_api_key(
|
||||
use_vertexai_env,
|
||||
google_api_key_env,
|
||||
express_mode_api_key,
|
||||
expected,
|
||||
):
|
||||
env_vars = {}
|
||||
if use_vertexai_env:
|
||||
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
|
||||
if google_api_key_env:
|
||||
env_vars['GOOGLE_API_KEY'] = google_api_key_env
|
||||
with mock.patch.dict('os.environ', env_vars, clear=True):
|
||||
assert (
|
||||
vertex_ai_utils.get_express_mode_api_key(
|
||||
project=None,
|
||||
location=None,
|
||||
express_mode_api_key=express_mode_api_key,
|
||||
)
|
||||
== expected
|
||||
)
|
||||
Reference in New Issue
Block a user