fix: Allow more credentials types for BigQuery tools

This change accepts the `google.auth.credentials.Credentials` type for `BigQueryCredentialsConfig`, so any subclass of that, including `google.oauth2.credentials.Credentials` would work to integrate with BigQuery service. This opens up a whole range of possibilities, such as using service account credentials to deploy an agent using these tools.

PiperOrigin-RevId: 773190440
This commit is contained in:
Google Team Member
2025-06-18 22:01:07 -07:00
committed by Copybara-Service
parent 17beb32880
commit 2f716ada7f
9 changed files with 155 additions and 51 deletions
+19 -4
View File
@@ -40,13 +40,28 @@ would set:
### With Application Default Credentials
This mode is useful for quick development when the agent builder is the only
user interacting with the agent. The tools are initialized with the default
credentials present on the machine running the agent.
user interacting with the agent. The tools are run with these credentials.
1. Create application default credentials on the machine where the agent would
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent
1. Set `CREDENTIALS_TYPE=None` in `agent.py`
1. Run the agent
### With Service Account Keys
This mode is useful for quick development when the agent builder wants to run
the agent with service account credentials. The tools are run with these
credentials.
1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py`
1. Download the key file and replace `"service_account_key.json"` with the path
1. Run the agent
### With Interactive OAuth
@@ -72,7 +87,7 @@ type.
Note: don't create a separate .env, instead put it to the same .env file that
stores your Vertex AI or Dev ML credentials
1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent
## Sample prompts
+19 -8
View File
@@ -15,24 +15,21 @@
import os
from google.adk.agents import llm_agent
from google.adk.auth import AuthCredentialTypes
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
import google.auth
RUN_WITH_ADC = False
# Define an appropriate credential type
CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2
# Define BigQuery tool config
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
if RUN_WITH_ADC:
# Initialize the tools to use the application default credentials.
application_default_credentials, _ = google.auth.default()
credentials_config = BigQueryCredentialsConfig(
credentials=application_default_credentials
)
else:
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initiaze the tools to do interactive OAuth
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
# must be set
@@ -40,6 +37,20 @@ else:
client_id=os.getenv("OAUTH_CLIENT_ID"),
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
)
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
# Initialize the tools to use the credentials in the service account key.
# If this flow is enabled, make sure to replace the file path with your own
# service account key file
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
credentials_config = BigQueryCredentialsConfig(credentials=creds)
else:
# Initialize the tools to use the application default credentials.
# https://cloud.google.com/docs/authentication/provide-credentials-adc
application_default_credentials, _ = google.auth.default()
credentials_config = BigQueryCredentialsConfig(
credentials=application_default_credentials
)
bigquery_toolset = BigQueryToolset(
credentials_config=credentials_config, bigquery_tool_config=tool_config
@@ -21,9 +21,10 @@ from typing import Optional
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
import google.auth.credentials
from google.auth.exceptions import RefreshError
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
import google.oauth2.credentials
from pydantic import BaseModel
from pydantic import model_validator
@@ -40,26 +41,35 @@ BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"]
@experimental
class BigQueryCredentialsConfig(BaseModel):
"""Configuration for Google API tools. (Experimental)"""
"""Configuration for Google API tools (Experimental).
Please do not use this in production, as it may be deprecated later.
"""
# Configure the model to allow arbitrary types like Credentials
model_config = {"arbitrary_types_allowed": True}
credentials: Optional[Credentials] = None
"""the existing oauth credentials to use. If set,this credential will be used
credentials: Optional[google.auth.credentials.Credentials] = None
"""The existing auth credentials to use. If set, this credential will be used
for every end user, end users don't need to be involved in the oauthflow. This
field is mutually exclusive with client_id, client_secret and scopes.
Don't set this field unless you are sure this credential has the permission to
access every end user's data.
Example usage: when the agent is deployed in Google Cloud environment and
Example usage 1: When the agent is deployed in Google Cloud environment and
the service account (used as application default credentials) has access to
all the required BigQuery resource. Setting this credential to allow user to
access the BigQuery resource without end users going through oauth flow.
To get application default credential: `google.auth.default(...)`. See more
To get application default credential, use: `google.auth.default(...)`. See more
details in https://cloud.google.com/docs/authentication/application-default-credentials.
Example usage 2: When the agent wants to access the user's BigQuery resources
using the service account key credentials.
To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`.
See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
When the deployed environment cannot provide a pre-existing credential,
consider setting below client_id, client_secret and scope for end users to go
through oauth flow, so that agent can access the user data.
@@ -86,7 +96,9 @@ class BigQueryCredentialsConfig(BaseModel):
" client_id/client_secret/scopes."
)
if self.credentials:
if self.credentials and isinstance(
self.credentials, google.oauth2.credentials.Credentials
):
self.client_id = self.credentials.client_id
self.client_secret = self.credentials.client_secret
self.scopes = self.credentials.scopes
@@ -115,7 +127,7 @@ class BigQueryCredentialsManager:
async def get_valid_credentials(
self, tool_context: ToolContext
) -> Optional[Credentials]:
) -> Optional[google.auth.credentials.Credentials]:
"""Get valid credentials, handling refresh and OAuth flow as needed.
Args:
@@ -127,7 +139,7 @@ class BigQueryCredentialsManager:
# First, try to get credentials from the tool context
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = (
Credentials.from_authorized_user_info(
google.oauth2.credentials.Credentials.from_authorized_user_info(
json.loads(creds_json), self.credentials_config.scopes
)
if creds_json
@@ -138,6 +150,11 @@ class BigQueryCredentialsManager:
if not creds:
creds = self.credentials_config.credentials
# If non-oauth credentials are provided then use them as is. This helps
# in flows such as service account keys
if creds and not isinstance(creds, google.oauth2.credentials.Credentials):
return creds
# Check if we have valid credentials
if creds and creds.valid:
return creds
@@ -159,7 +176,7 @@ class BigQueryCredentialsManager:
async def _perform_oauth_flow(
self, tool_context: ToolContext
) -> Optional[Credentials]:
) -> Optional[google.oauth2.credentials.Credentials]:
"""Perform OAuth flow to get new credentials.
Args:
@@ -199,7 +216,7 @@ class BigQueryCredentialsManager:
if auth_response:
# OAuth flow completed, create credentials
creds = Credentials(
creds = google.oauth2.credentials.Credentials(
token=auth_response.oauth2.access_token,
refresh_token=auth_response.oauth2.refresh_token,
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
@@ -19,7 +19,7 @@ from typing import Any
from typing import Callable
from typing import Optional
from google.oauth2.credentials import Credentials
from google.auth.credentials import Credentials
from typing_extensions import override
from ...utils.feature_decorator import experimental
+1 -1
View File
@@ -15,8 +15,8 @@
from __future__ import annotations
import google.api_core.client_info
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from ... import version
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from . import client
+1 -1
View File
@@ -16,8 +16,8 @@ import functools
import types
from typing import Callable
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from . import client
from .config import BigQueryToolConfig
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest import mock
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
# Mock the Google OAuth and API dependencies
from google.oauth2.credentials import Credentials
import google.auth.credentials
import google.oauth2.credentials
import pytest
@@ -27,22 +28,46 @@ class TestBigQueryCredentials:
either existing credentials or client ID/secret pairs are provided.
"""
def test_valid_credentials_object(self):
"""Test that providing valid Credentials object works correctly.
def test_valid_credentials_object_auth_credentials(self):
"""Test that providing valid Credentials object works correctly with
google.auth.credentials.Credentials.
When a user already has valid OAuth credentials, they should be able
to pass them directly without needing to provide client ID/secret.
"""
# Create a mock credentials object with the expected attributes
mock_creds = Mock(spec=Credentials)
mock_creds.client_id = "test_client_id"
mock_creds.client_secret = "test_client_secret"
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
# Create a mock auth credentials object
# auth_creds = google.auth.credentials.Credentials()
auth_creds = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
config = BigQueryCredentialsConfig(credentials=mock_creds)
config = BigQueryCredentialsConfig(credentials=auth_creds)
# Verify that the credentials are properly stored and attributes are extracted
assert config.credentials == mock_creds
assert config.credentials == auth_creds
assert config.client_id is None
assert config.client_secret is None
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
def test_valid_credentials_object_oauth2_credentials(self):
"""Test that providing valid Credentials object works correctly with
google.oauth2.credentials.Credentials.
When a user already has valid OAuth credentials, they should be able
to pass them directly without needing to provide client ID/secret.
"""
# Create a mock oauth2 credentials object
oauth2_creds = google.oauth2.credentials.Credentials(
"test_token",
client_id="test_client_id",
client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/calendar"],
)
config = BigQueryCredentialsConfig(credentials=oauth2_creds)
# Verify that the credentials are properly stored and attributes are extracted
assert config.credentials == oauth2_creds
assert config.client_id == "test_client_id"
assert config.client_secret == "test_client_secret"
assert config.scopes == ["https://www.googleapis.com/auth/calendar"]
@@ -22,9 +22,10 @@ from google.adk.tools import ToolContext
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
from google.auth.credentials import Credentials as AuthCredentials
from google.auth.exceptions import RefreshError
# Mock the Google OAuth and API dependencies
from google.oauth2.credentials import Credentials
from google.oauth2.credentials import Credentials as OAuthCredentials
import pytest
@@ -64,9 +65,16 @@ class TestBigQueryCredentialsManager:
"""Create a credentials manager instance for testing."""
return BigQueryCredentialsManager(credentials_config)
@pytest.mark.parametrize(
("credentials_class",),
[
pytest.param(OAuthCredentials, id="oauth"),
pytest.param(AuthCredentials, id="auth"),
],
)
@pytest.mark.asyncio
async def test_get_valid_credentials_with_valid_existing_creds(
self, manager, mock_tool_context
self, manager, mock_tool_context, credentials_class
):
"""Test that valid existing credentials are returned immediately.
@@ -74,7 +82,7 @@ class TestBigQueryCredentialsManager:
should be needed. This is the optimal happy path scenario.
"""
# Create mock credentials that are already valid
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=credentials_class)
mock_creds.valid = True
manager.credentials_config.credentials = mock_creds
@@ -85,6 +93,34 @@ class TestBigQueryCredentialsManager:
mock_tool_context.get_auth_response.assert_not_called()
mock_tool_context.request_credential.assert_not_called()
@pytest.mark.parametrize(
("valid",),
[
pytest.param(False, id="invalid"),
pytest.param(True, id="valid"),
],
)
@pytest.mark.asyncio
async def test_get_valid_credentials_with_existing_non_oauth_creds(
self, manager, mock_tool_context, valid
):
"""Test that existing non-oauth credentials are returned immediately.
When credentials are of non-oauth type, no refresh or OAuth flow
is triggered irrespective of whether it is valid or not.
"""
# Create mock credentials that are already valid
mock_creds = Mock(spec=AuthCredentials)
mock_creds.valid = valid
manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context)
assert result == mock_creds
# Verify no OAuth flow was triggered
mock_tool_context.get_auth_response.assert_not_called()
mock_tool_context.request_credential.assert_not_called()
@pytest.mark.asyncio
async def test_get_credentials_from_cache_when_none_in_manager(
self, manager, mock_tool_context
@@ -113,7 +149,7 @@ class TestBigQueryCredentialsManager:
with patch(
"google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=OAuthCredentials)
mock_creds.valid = True
mock_from_json.return_value = mock_creds
@@ -179,7 +215,7 @@ class TestBigQueryCredentialsManager:
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
# Create expired cached credentials with refresh token
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds = Mock(spec=OAuthCredentials)
mock_cached_creds.valid = False
mock_cached_creds.expired = True
mock_cached_creds.refresh_token = "valid_refresh_token"
@@ -227,7 +263,7 @@ class TestBigQueryCredentialsManager:
users from having to re-authenticate for every expired token.
"""
# Create expired credentials with refresh token
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=OAuthCredentials)
mock_creds.valid = False
mock_creds.expired = True
mock_creds.refresh_token = "refresh_token"
@@ -257,7 +293,7 @@ class TestBigQueryCredentialsManager:
gracefully fall back to requesting a new OAuth flow.
"""
# Create expired credentials that fail to refresh
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=OAuthCredentials)
mock_creds.valid = False
mock_creds.expired = True
mock_creds.refresh_token = "expired_refresh_token"
@@ -287,7 +323,7 @@ class TestBigQueryCredentialsManager:
mock_tool_context.get_auth_response.return_value = mock_auth_response
# Create a mock credentials instance that will represent our created credentials
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=OAuthCredentials)
# Make the JSON match what a real Credentials object would produce
mock_creds_json = (
'{"token": "new_access_token", "refresh_token": "new_refresh_token",'
@@ -300,7 +336,7 @@ class TestBigQueryCredentialsManager:
# Use the full module path as it appears in the project structure
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
result = await manager.get_valid_credentials(mock_tool_context)
@@ -361,7 +397,7 @@ class TestBigQueryCredentialsManager:
mock_tool_context.get_auth_response.return_value = mock_auth_response
# Create the mock credentials instance that will be returned by the constructor
mock_creds = Mock(spec=Credentials)
mock_creds = Mock(spec=OAuthCredentials)
# Make sure our mock JSON matches the structure that real Credentials objects produce
mock_creds_json = (
'{"token": "cached_access_token", "refresh_token":'
@@ -376,7 +412,7 @@ class TestBigQueryCredentialsManager:
# Use the correct module path - without the 'src.' prefix
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
# Complete OAuth flow with first manager
@@ -396,9 +432,9 @@ class TestBigQueryCredentialsManager:
# Mock the from_authorized_user_info method for the second manager
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info"
"google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds = Mock(spec=OAuthCredentials)
mock_cached_creds.valid = True
mock_from_json.return_value = mock_cached_creds