You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Allow more credentials types for BigQuery tools
This change accepts the `google.auth.credentials.Credentials` type for `BigQueryCredentialsConfig`, so any subclass of that, including `google.oauth2.credentials.Credentials` would work to integrate with BigQuery service. This opens up a whole range of possibilities, such as using service account credentials to deploy an agent using these tools. PiperOrigin-RevId: 773190440
This commit is contained in:
committed by
Copybara-Service
parent
17beb32880
commit
2f716ada7f
@@ -40,13 +40,28 @@ would set:
|
||||
### With Application Default Credentials
|
||||
|
||||
This mode is useful for quick development when the agent builder is the only
|
||||
user interacting with the agent. The tools are initialized with the default
|
||||
credentials present on the machine running the agent.
|
||||
user interacting with the agent. The tools are run with these credentials.
|
||||
|
||||
1. Create application default credentials on the machine where the agent would
|
||||
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
|
||||
|
||||
1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent
|
||||
1. Set `CREDENTIALS_TYPE=None` in `agent.py`
|
||||
|
||||
1. Run the agent
|
||||
|
||||
### With Service Account Keys
|
||||
|
||||
This mode is useful for quick development when the agent builder wants to run
|
||||
the agent with service account credentials. The tools are run with these
|
||||
credentials.
|
||||
|
||||
1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
|
||||
|
||||
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py`
|
||||
|
||||
1. Download the key file and replace `"service_account_key.json"` with the path
|
||||
|
||||
1. Run the agent
|
||||
|
||||
### With Interactive OAuth
|
||||
|
||||
@@ -72,7 +87,7 @@ type.
|
||||
Note: don't create a separate .env, instead put it to the same .env file that
|
||||
stores your Vertex AI or Dev ML credentials
|
||||
|
||||
1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent
|
||||
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent
|
||||
|
||||
## Sample prompts
|
||||
|
||||
|
||||
@@ -15,24 +15,21 @@
|
||||
import os
|
||||
|
||||
from google.adk.agents import llm_agent
|
||||
from google.adk.auth import AuthCredentialTypes
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
import google.auth
|
||||
|
||||
RUN_WITH_ADC = False
|
||||
# Define an appropriate credential type
|
||||
CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2
|
||||
|
||||
|
||||
# Define BigQuery tool config
|
||||
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
|
||||
|
||||
if RUN_WITH_ADC:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
credentials=application_default_credentials
|
||||
)
|
||||
else:
|
||||
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
|
||||
# Initiaze the tools to do interactive OAuth
|
||||
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
|
||||
# must be set
|
||||
@@ -40,6 +37,20 @@ else:
|
||||
client_id=os.getenv("OAUTH_CLIENT_ID"),
|
||||
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
|
||||
)
|
||||
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
|
||||
# Initialize the tools to use the credentials in the service account key.
|
||||
# If this flow is enabled, make sure to replace the file path with your own
|
||||
# service account key file
|
||||
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
|
||||
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
|
||||
credentials_config = BigQueryCredentialsConfig(credentials=creds)
|
||||
else:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
# https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
credentials=application_default_credentials
|
||||
)
|
||||
|
||||
bigquery_toolset = BigQueryToolset(
|
||||
credentials_config=credentials_config, bigquery_tool_config=tool_config
|
||||
|
||||
@@ -21,9 +21,10 @@ from typing import Optional
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
import google.auth.credentials
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
import google.oauth2.credentials
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
@@ -40,26 +41,35 @@ BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"]
|
||||
|
||||
@experimental
|
||||
class BigQueryCredentialsConfig(BaseModel):
|
||||
"""Configuration for Google API tools. (Experimental)"""
|
||||
"""Configuration for Google API tools (Experimental).
|
||||
|
||||
Please do not use this in production, as it may be deprecated later.
|
||||
"""
|
||||
|
||||
# Configure the model to allow arbitrary types like Credentials
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
credentials: Optional[Credentials] = None
|
||||
"""the existing oauth credentials to use. If set,this credential will be used
|
||||
credentials: Optional[google.auth.credentials.Credentials] = None
|
||||
"""The existing auth credentials to use. If set, this credential will be used
|
||||
for every end user, end users don't need to be involved in the oauthflow. This
|
||||
field is mutually exclusive with client_id, client_secret and scopes.
|
||||
Don't set this field unless you are sure this credential has the permission to
|
||||
access every end user's data.
|
||||
|
||||
Example usage: when the agent is deployed in Google Cloud environment and
|
||||
Example usage 1: When the agent is deployed in Google Cloud environment and
|
||||
the service account (used as application default credentials) has access to
|
||||
all the required BigQuery resource. Setting this credential to allow user to
|
||||
access the BigQuery resource without end users going through oauth flow.
|
||||
|
||||
To get application default credential: `google.auth.default(...)`. See more
|
||||
To get application default credential, use: `google.auth.default(...)`. See more
|
||||
details in https://cloud.google.com/docs/authentication/application-default-credentials.
|
||||
|
||||
Example usage 2: When the agent wants to access the user's BigQuery resources
|
||||
using the service account key credentials.
|
||||
|
||||
To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`.
|
||||
See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
|
||||
|
||||
When the deployed environment cannot provide a pre-existing credential,
|
||||
consider setting below client_id, client_secret and scope for end users to go
|
||||
through oauth flow, so that agent can access the user data.
|
||||
@@ -86,7 +96,9 @@ class BigQueryCredentialsConfig(BaseModel):
|
||||
" client_id/client_secret/scopes."
|
||||
)
|
||||
|
||||
if self.credentials:
|
||||
if self.credentials and isinstance(
|
||||
self.credentials, google.oauth2.credentials.Credentials
|
||||
):
|
||||
self.client_id = self.credentials.client_id
|
||||
self.client_secret = self.credentials.client_secret
|
||||
self.scopes = self.credentials.scopes
|
||||
@@ -115,7 +127,7 @@ class BigQueryCredentialsManager:
|
||||
|
||||
async def get_valid_credentials(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
) -> Optional[google.auth.credentials.Credentials]:
|
||||
"""Get valid credentials, handling refresh and OAuth flow as needed.
|
||||
|
||||
Args:
|
||||
@@ -127,7 +139,7 @@ class BigQueryCredentialsManager:
|
||||
# First, try to get credentials from the tool context
|
||||
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
||||
creds = (
|
||||
Credentials.from_authorized_user_info(
|
||||
google.oauth2.credentials.Credentials.from_authorized_user_info(
|
||||
json.loads(creds_json), self.credentials_config.scopes
|
||||
)
|
||||
if creds_json
|
||||
@@ -138,6 +150,11 @@ class BigQueryCredentialsManager:
|
||||
if not creds:
|
||||
creds = self.credentials_config.credentials
|
||||
|
||||
# If non-oauth credentials are provided then use them as is. This helps
|
||||
# in flows such as service account keys
|
||||
if creds and not isinstance(creds, google.oauth2.credentials.Credentials):
|
||||
return creds
|
||||
|
||||
# Check if we have valid credentials
|
||||
if creds and creds.valid:
|
||||
return creds
|
||||
@@ -159,7 +176,7 @@ class BigQueryCredentialsManager:
|
||||
|
||||
async def _perform_oauth_flow(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
) -> Optional[google.oauth2.credentials.Credentials]:
|
||||
"""Perform OAuth flow to get new credentials.
|
||||
|
||||
Args:
|
||||
@@ -199,7 +216,7 @@ class BigQueryCredentialsManager:
|
||||
|
||||
if auth_response:
|
||||
# OAuth flow completed, create credentials
|
||||
creds = Credentials(
|
||||
creds = google.oauth2.credentials.Credentials(
|
||||
token=auth_response.oauth2.access_token,
|
||||
refresh_token=auth_response.oauth2.refresh_token,
|
||||
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.auth.credentials import Credentials
|
||||
from typing_extensions import override
|
||||
|
||||
from ...utils.feature_decorator import experimental
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import google.api_core.client_info
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ... import version
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from . import client
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import functools
|
||||
import types
|
||||
from typing import Callable
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from . import client
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
@@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||
# Mock the Google OAuth and API dependencies
|
||||
from google.oauth2.credentials import Credentials
|
||||
import google.auth.credentials
|
||||
import google.oauth2.credentials
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -27,22 +28,46 @@ class TestBigQueryCredentials:
|
||||
either existing credentials or client ID/secret pairs are provided.
|
||||
"""
|
||||
|
||||
def test_valid_credentials_object(self):
|
||||
"""Test that providing valid Credentials object works correctly.
|
||||
def test_valid_credentials_object_auth_credentials(self):
|
||||
"""Test that providing valid Credentials object works correctly with
|
||||
google.auth.credentials.Credentials.
|
||||
|
||||
When a user already has valid OAuth credentials, they should be able
|
||||
to pass them directly without needing to provide client ID/secret.
|
||||
"""
|
||||
# Create a mock credentials object with the expected attributes
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds.client_id = "test_client_id"
|
||||
mock_creds.client_secret = "test_client_secret"
|
||||
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
|
||||
# Create a mock auth credentials object
|
||||
# auth_creds = google.auth.credentials.Credentials()
|
||||
auth_creds = mock.create_autospec(
|
||||
google.auth.credentials.Credentials, instance=True
|
||||
)
|
||||
|
||||
config = BigQueryCredentialsConfig(credentials=mock_creds)
|
||||
config = BigQueryCredentialsConfig(credentials=auth_creds)
|
||||
|
||||
# Verify that the credentials are properly stored and attributes are extracted
|
||||
assert config.credentials == mock_creds
|
||||
assert config.credentials == auth_creds
|
||||
assert config.client_id is None
|
||||
assert config.client_secret is None
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
|
||||
|
||||
def test_valid_credentials_object_oauth2_credentials(self):
|
||||
"""Test that providing valid Credentials object works correctly with
|
||||
google.oauth2.credentials.Credentials.
|
||||
|
||||
When a user already has valid OAuth credentials, they should be able
|
||||
to pass them directly without needing to provide client ID/secret.
|
||||
"""
|
||||
# Create a mock oauth2 credentials object
|
||||
oauth2_creds = google.oauth2.credentials.Credentials(
|
||||
"test_token",
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
scopes=["https://www.googleapis.com/auth/calendar"],
|
||||
)
|
||||
|
||||
config = BigQueryCredentialsConfig(credentials=oauth2_creds)
|
||||
|
||||
# Verify that the credentials are properly stored and attributes are extracted
|
||||
assert config.credentials == oauth2_creds
|
||||
assert config.client_id == "test_client_id"
|
||||
assert config.client_secret == "test_client_secret"
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/calendar"]
|
||||
|
||||
@@ -22,9 +22,10 @@ from google.adk.tools import ToolContext
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
||||
from google.auth.credentials import Credentials as AuthCredentials
|
||||
from google.auth.exceptions import RefreshError
|
||||
# Mock the Google OAuth and API dependencies
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -64,9 +65,16 @@ class TestBigQueryCredentialsManager:
|
||||
"""Create a credentials manager instance for testing."""
|
||||
return BigQueryCredentialsManager(credentials_config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials_class",),
|
||||
[
|
||||
pytest.param(OAuthCredentials, id="oauth"),
|
||||
pytest.param(AuthCredentials, id="auth"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_valid_credentials_with_valid_existing_creds(
|
||||
self, manager, mock_tool_context
|
||||
self, manager, mock_tool_context, credentials_class
|
||||
):
|
||||
"""Test that valid existing credentials are returned immediately.
|
||||
|
||||
@@ -74,7 +82,7 @@ class TestBigQueryCredentialsManager:
|
||||
should be needed. This is the optimal happy path scenario.
|
||||
"""
|
||||
# Create mock credentials that are already valid
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=credentials_class)
|
||||
mock_creds.valid = True
|
||||
manager.credentials_config.credentials = mock_creds
|
||||
|
||||
@@ -85,6 +93,34 @@ class TestBigQueryCredentialsManager:
|
||||
mock_tool_context.get_auth_response.assert_not_called()
|
||||
mock_tool_context.request_credential.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("valid",),
|
||||
[
|
||||
pytest.param(False, id="invalid"),
|
||||
pytest.param(True, id="valid"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_valid_credentials_with_existing_non_oauth_creds(
|
||||
self, manager, mock_tool_context, valid
|
||||
):
|
||||
"""Test that existing non-oauth credentials are returned immediately.
|
||||
|
||||
When credentials are of non-oauth type, no refresh or OAuth flow
|
||||
is triggered irrespective of whether it is valid or not.
|
||||
"""
|
||||
# Create mock credentials that are already valid
|
||||
mock_creds = Mock(spec=AuthCredentials)
|
||||
mock_creds.valid = valid
|
||||
manager.credentials_config.credentials = mock_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
assert result == mock_creds
|
||||
# Verify no OAuth flow was triggered
|
||||
mock_tool_context.get_auth_response.assert_not_called()
|
||||
mock_tool_context.request_credential.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credentials_from_cache_when_none_in_manager(
|
||||
self, manager, mock_tool_context
|
||||
@@ -113,7 +149,7 @@ class TestBigQueryCredentialsManager:
|
||||
with patch(
|
||||
"google.oauth2.credentials.Credentials.from_authorized_user_info"
|
||||
) as mock_from_json:
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=OAuthCredentials)
|
||||
mock_creds.valid = True
|
||||
mock_from_json.return_value = mock_creds
|
||||
|
||||
@@ -179,7 +215,7 @@ class TestBigQueryCredentialsManager:
|
||||
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
|
||||
|
||||
# Create expired cached credentials with refresh token
|
||||
mock_cached_creds = Mock(spec=Credentials)
|
||||
mock_cached_creds = Mock(spec=OAuthCredentials)
|
||||
mock_cached_creds.valid = False
|
||||
mock_cached_creds.expired = True
|
||||
mock_cached_creds.refresh_token = "valid_refresh_token"
|
||||
@@ -227,7 +263,7 @@ class TestBigQueryCredentialsManager:
|
||||
users from having to re-authenticate for every expired token.
|
||||
"""
|
||||
# Create expired credentials with refresh token
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=OAuthCredentials)
|
||||
mock_creds.valid = False
|
||||
mock_creds.expired = True
|
||||
mock_creds.refresh_token = "refresh_token"
|
||||
@@ -257,7 +293,7 @@ class TestBigQueryCredentialsManager:
|
||||
gracefully fall back to requesting a new OAuth flow.
|
||||
"""
|
||||
# Create expired credentials that fail to refresh
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=OAuthCredentials)
|
||||
mock_creds.valid = False
|
||||
mock_creds.expired = True
|
||||
mock_creds.refresh_token = "expired_refresh_token"
|
||||
@@ -287,7 +323,7 @@ class TestBigQueryCredentialsManager:
|
||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||
|
||||
# Create a mock credentials instance that will represent our created credentials
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=OAuthCredentials)
|
||||
# Make the JSON match what a real Credentials object would produce
|
||||
mock_creds_json = (
|
||||
'{"token": "new_access_token", "refresh_token": "new_refresh_token",'
|
||||
@@ -300,7 +336,7 @@ class TestBigQueryCredentialsManager:
|
||||
|
||||
# Use the full module path as it appears in the project structure
|
||||
with patch(
|
||||
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
|
||||
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials",
|
||||
return_value=mock_creds,
|
||||
) as mock_credentials_class:
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
@@ -361,7 +397,7 @@ class TestBigQueryCredentialsManager:
|
||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||
|
||||
# Create the mock credentials instance that will be returned by the constructor
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds = Mock(spec=OAuthCredentials)
|
||||
# Make sure our mock JSON matches the structure that real Credentials objects produce
|
||||
mock_creds_json = (
|
||||
'{"token": "cached_access_token", "refresh_token":'
|
||||
@@ -376,7 +412,7 @@ class TestBigQueryCredentialsManager:
|
||||
|
||||
# Use the correct module path - without the 'src.' prefix
|
||||
with patch(
|
||||
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
|
||||
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials",
|
||||
return_value=mock_creds,
|
||||
) as mock_credentials_class:
|
||||
# Complete OAuth flow with first manager
|
||||
@@ -396,9 +432,9 @@ class TestBigQueryCredentialsManager:
|
||||
|
||||
# Mock the from_authorized_user_info method for the second manager
|
||||
with patch(
|
||||
"google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info"
|
||||
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info"
|
||||
) as mock_from_json:
|
||||
mock_cached_creds = Mock(spec=Credentials)
|
||||
mock_cached_creds = Mock(spec=OAuthCredentials)
|
||||
mock_cached_creds.valid = True
|
||||
mock_from_json.return_value = mock_cached_creds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user