diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index cd4583c7..050ce133 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -40,13 +40,28 @@ would set: ### With Application Default Credentials This mode is useful for quick development when the agent builder is the only -user interacting with the agent. The tools are initialized with the default -credentials present on the machine running the agent. +user interacting with the agent. The tools are run with these credentials. 1. Create application default credentials on the machine where the agent would be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. -1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=None` in `agent.py` + +1. Run the agent + +### With Service Account Keys + +This mode is useful for quick development when the agent builder wants to run +the agent with service account credentials. The tools are run with these +credentials. + +1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py` + +1. Download the key file and replace `"service_account_key.json"` with the path + +1. Run the agent ### With Interactive OAuth @@ -72,7 +87,7 @@ type. Note: don't create a separate .env, instead put it to the same .env file that stores your Vertex AI or Dev ML credentials -1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent +1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent ## Sample prompts diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index 0999ca12..3cd1eb99 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -15,24 +15,21 @@ import os from google.adk.agents import llm_agent +from google.adk.auth import AuthCredentialTypes from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode import google.auth -RUN_WITH_ADC = False +# Define an appropriate credential type +CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 +# Define BigQuery tool config tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) -if RUN_WITH_ADC: - # Initialize the tools to use the application default credentials. - application_default_credentials, _ = google.auth.default() - credentials_config = BigQueryCredentialsConfig( - credentials=application_default_credentials - ) -else: +if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initiaze the tools to do interactive OAuth # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET # must be set @@ -40,6 +37,20 @@ else: client_id=os.getenv("OAUTH_CLIENT_ID"), client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) +elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: + # Initialize the tools to use the credentials in the service account key. + # If this flow is enabled, make sure to replace the file path with your own + # service account key file + # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys + creds, _ = google.auth.load_credentials_from_file("service_account_key.json") + credentials_config = BigQueryCredentialsConfig(credentials=creds) +else: + # Initialize the tools to use the application default credentials. + # https://cloud.google.com/docs/authentication/provide-credentials-adc + application_default_credentials, _ = google.auth.default() + credentials_config = BigQueryCredentialsConfig( + credentials=application_default_credentials + ) bigquery_toolset = BigQueryToolset( credentials_config=credentials_config, bigquery_tool_config=tool_config diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index 0a99136c..d0f3abe0 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -21,9 +21,10 @@ from typing import Optional from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlowAuthorizationCode from fastapi.openapi.models import OAuthFlows +import google.auth.credentials from google.auth.exceptions import RefreshError from google.auth.transport.requests import Request -from google.oauth2.credentials import Credentials +import google.oauth2.credentials from pydantic import BaseModel from pydantic import model_validator @@ -40,26 +41,35 @@ BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @experimental class BigQueryCredentialsConfig(BaseModel): - """Configuration for Google API tools. (Experimental)""" + """Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ # Configure the model to allow arbitrary types like Credentials model_config = {"arbitrary_types_allowed": True} - credentials: Optional[Credentials] = None - """the existing oauth credentials to use. If set,this credential will be used + credentials: Optional[google.auth.credentials.Credentials] = None + """The existing auth credentials to use. If set, this credential will be used for every end user, end users don't need to be involved in the oauthflow. This field is mutually exclusive with client_id, client_secret and scopes. Don't set this field unless you are sure this credential has the permission to access every end user's data. - Example usage: when the agent is deployed in Google Cloud environment and + Example usage 1: When the agent is deployed in Google Cloud environment and the service account (used as application default credentials) has access to all the required BigQuery resource. Setting this credential to allow user to access the BigQuery resource without end users going through oauth flow. - To get application default credential: `google.auth.default(...)`. See more + To get application default credential, use: `google.auth.default(...)`. See more details in https://cloud.google.com/docs/authentication/application-default-credentials. + Example usage 2: When the agent wants to access the user's BigQuery resources + using the service account key credentials. + + To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`. + See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + When the deployed environment cannot provide a pre-existing credential, consider setting below client_id, client_secret and scope for end users to go through oauth flow, so that agent can access the user data. @@ -86,7 +96,9 @@ class BigQueryCredentialsConfig(BaseModel): " client_id/client_secret/scopes." ) - if self.credentials: + if self.credentials and isinstance( + self.credentials, google.oauth2.credentials.Credentials + ): self.client_id = self.credentials.client_id self.client_secret = self.credentials.client_secret self.scopes = self.credentials.scopes @@ -115,7 +127,7 @@ class BigQueryCredentialsManager: async def get_valid_credentials( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.auth.credentials.Credentials]: """Get valid credentials, handling refresh and OAuth flow as needed. Args: @@ -127,7 +139,7 @@ class BigQueryCredentialsManager: # First, try to get credentials from the tool context creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds = ( - Credentials.from_authorized_user_info( + google.oauth2.credentials.Credentials.from_authorized_user_info( json.loads(creds_json), self.credentials_config.scopes ) if creds_json @@ -138,6 +150,11 @@ class BigQueryCredentialsManager: if not creds: creds = self.credentials_config.credentials + # If non-oauth credentials are provided then use them as is. This helps + # in flows such as service account keys + if creds and not isinstance(creds, google.oauth2.credentials.Credentials): + return creds + # Check if we have valid credentials if creds and creds.valid: return creds @@ -159,7 +176,7 @@ class BigQueryCredentialsManager: async def _perform_oauth_flow( self, tool_context: ToolContext - ) -> Optional[Credentials]: + ) -> Optional[google.oauth2.credentials.Credentials]: """Perform OAuth flow to get new credentials. Args: @@ -199,7 +216,7 @@ class BigQueryCredentialsManager: if auth_response: # OAuth flow completed, create credentials - creds = Credentials( + creds = google.oauth2.credentials.Credentials( token=auth_response.oauth2.access_token, refresh_token=auth_response.oauth2.refresh_token, token_uri=auth_scheme.flows.authorizationCode.tokenUrl, diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py index 18273418..50d49ff7 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -19,7 +19,7 @@ from typing import Any from typing import Callable from typing import Optional -from google.oauth2.credentials import Credentials +from google.auth.credentials import Credentials from typing_extensions import override from ...utils.feature_decorator import experimental diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index 23f1befc..8b2816eb 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -15,8 +15,8 @@ from __future__ import annotations import google.api_core.client_info +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from ... import version diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py index 4f540061..64f23d07 100644 --- a/src/google/adk/tools/bigquery/metadata_tool.py +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index d3a94fda..147d0b4d 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -16,8 +16,8 @@ import functools import types from typing import Callable +from google.auth.credentials import Credentials from google.cloud import bigquery -from google.oauth2.credentials import Credentials from . import client from .config import BigQueryToolConfig diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9fa152fc..05af3aaf 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest import mock from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +import google.auth.credentials +import google.oauth2.credentials import pytest @@ -27,22 +28,46 @@ class TestBigQueryCredentials: either existing credentials or client ID/secret pairs are provided. """ - def test_valid_credentials_object(self): - """Test that providing valid Credentials object works correctly. + def test_valid_credentials_object_auth_credentials(self): + """Test that providing valid Credentials object works correctly with + google.auth.credentials.Credentials. When a user already has valid OAuth credentials, they should be able to pass them directly without needing to provide client ID/secret. """ - # Create a mock credentials object with the expected attributes - mock_creds = Mock(spec=Credentials) - mock_creds.client_id = "test_client_id" - mock_creds.client_secret = "test_client_secret" - mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"] + # Create a mock auth credentials object + # auth_creds = google.auth.credentials.Credentials() + auth_creds = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) - config = BigQueryCredentialsConfig(credentials=mock_creds) + config = BigQueryCredentialsConfig(credentials=auth_creds) # Verify that the credentials are properly stored and attributes are extracted - assert config.credentials == mock_creds + assert config.credentials == auth_creds + assert config.client_id is None + assert config.client_secret is None + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + + def test_valid_credentials_object_oauth2_credentials(self): + """Test that providing valid Credentials object works correctly with + google.oauth2.credentials.Credentials. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock oauth2 credentials object + oauth2_creds = google.oauth2.credentials.Credentials( + "test_token", + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/calendar"], + ) + + config = BigQueryCredentialsConfig(credentials=oauth2_creds) + + # Verify that the credentials are properly stored and attributes are extracted + assert config.credentials == oauth2_creds assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" assert config.scopes == ["https://www.googleapis.com/auth/calendar"] diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py index 95d8b00d..47d95590 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py @@ -22,9 +22,10 @@ from google.adk.tools import ToolContext from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager +from google.auth.credentials import Credentials as AuthCredentials from google.auth.exceptions import RefreshError # Mock the Google OAuth and API dependencies -from google.oauth2.credentials import Credentials +from google.oauth2.credentials import Credentials as OAuthCredentials import pytest @@ -64,9 +65,16 @@ class TestBigQueryCredentialsManager: """Create a credentials manager instance for testing.""" return BigQueryCredentialsManager(credentials_config) + @pytest.mark.parametrize( + ("credentials_class",), + [ + pytest.param(OAuthCredentials, id="oauth"), + pytest.param(AuthCredentials, id="auth"), + ], + ) @pytest.mark.asyncio async def test_get_valid_credentials_with_valid_existing_creds( - self, manager, mock_tool_context + self, manager, mock_tool_context, credentials_class ): """Test that valid existing credentials are returned immediately. @@ -74,7 +82,7 @@ class TestBigQueryCredentialsManager: should be needed. This is the optimal happy path scenario. """ # Create mock credentials that are already valid - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=credentials_class) mock_creds.valid = True manager.credentials_config.credentials = mock_creds @@ -85,6 +93,34 @@ class TestBigQueryCredentialsManager: mock_tool_context.get_auth_response.assert_not_called() mock_tool_context.request_credential.assert_not_called() + @pytest.mark.parametrize( + ("valid",), + [ + pytest.param(False, id="invalid"), + pytest.param(True, id="valid"), + ], + ) + @pytest.mark.asyncio + async def test_get_valid_credentials_with_existing_non_oauth_creds( + self, manager, mock_tool_context, valid + ): + """Test that existing non-oauth credentials are returned immediately. + + When credentials are of non-oauth type, no refresh or OAuth flow + is triggered irrespective of whether it is valid or not. + """ + # Create mock credentials that are already valid + mock_creds = Mock(spec=AuthCredentials) + mock_creds.valid = valid + manager.credentials_config.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + assert result == mock_creds + # Verify no OAuth flow was triggered + mock_tool_context.get_auth_response.assert_not_called() + mock_tool_context.request_credential.assert_not_called() + @pytest.mark.asyncio async def test_get_credentials_from_cache_when_none_in_manager( self, manager, mock_tool_context @@ -113,7 +149,7 @@ class TestBigQueryCredentialsManager: with patch( "google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = True mock_from_json.return_value = mock_creds @@ -179,7 +215,7 @@ class TestBigQueryCredentialsManager: mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json # Create expired cached credentials with refresh token - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = False mock_cached_creds.expired = True mock_cached_creds.refresh_token = "valid_refresh_token" @@ -227,7 +263,7 @@ class TestBigQueryCredentialsManager: users from having to re-authenticate for every expired token. """ # Create expired credentials with refresh token - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "refresh_token" @@ -257,7 +293,7 @@ class TestBigQueryCredentialsManager: gracefully fall back to requesting a new OAuth flow. """ # Create expired credentials that fail to refresh - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) mock_creds.valid = False mock_creds.expired = True mock_creds.refresh_token = "expired_refresh_token" @@ -287,7 +323,7 @@ class TestBigQueryCredentialsManager: mock_tool_context.get_auth_response.return_value = mock_auth_response # Create a mock credentials instance that will represent our created credentials - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make the JSON match what a real Credentials object would produce mock_creds_json = ( '{"token": "new_access_token", "refresh_token": "new_refresh_token",' @@ -300,7 +336,7 @@ class TestBigQueryCredentialsManager: # Use the full module path as it appears in the project structure with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: result = await manager.get_valid_credentials(mock_tool_context) @@ -361,7 +397,7 @@ class TestBigQueryCredentialsManager: mock_tool_context.get_auth_response.return_value = mock_auth_response # Create the mock credentials instance that will be returned by the constructor - mock_creds = Mock(spec=Credentials) + mock_creds = Mock(spec=OAuthCredentials) # Make sure our mock JSON matches the structure that real Credentials objects produce mock_creds_json = ( '{"token": "cached_access_token", "refresh_token":' @@ -376,7 +412,7 @@ class TestBigQueryCredentialsManager: # Use the correct module path - without the 'src.' prefix with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials", + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: # Complete OAuth flow with first manager @@ -396,9 +432,9 @@ class TestBigQueryCredentialsManager: # Mock the from_authorized_user_info method for the second manager with patch( - "google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info" + "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: - mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = True mock_from_json.return_value = mock_cached_creds