diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index a0cabc28..1607dcad 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -112,7 +112,7 @@ class AuthHandler: def parse_and_store_auth_response(self, state: State) -> None: - credential_key = self.get_credential_key() + credential_key = "temp:" + self.auth_config.get_credential_key() state[credential_key] = self.auth_config.exchanged_auth_credential if not isinstance( @@ -130,7 +130,7 @@ class AuthHandler: raise ValueError("auth_scheme is empty.") def get_auth_response(self, state: State) -> AuthCredential: - credential_key = self.get_credential_key() + credential_key = "temp:" + self.auth_config.get_credential_key() return state.get(credential_key, None) def generate_auth_request(self) -> AuthConfig: @@ -192,29 +192,6 @@ class AuthHandler: exchanged_auth_credential=exchanged_credential, ) - def get_credential_key(self) -> str: - """Generates a unique key for the given auth scheme and credential.""" - auth_scheme = self.auth_config.auth_scheme - auth_credential = self.auth_config.raw_auth_credential - if auth_scheme.model_extra: - auth_scheme = auth_scheme.model_copy(deep=True) - auth_scheme.model_extra.clear() - scheme_name = ( - f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}" - if auth_scheme - else "" - ) - if auth_credential.model_extra: - auth_credential = auth_credential.model_copy(deep=True) - auth_credential.model_extra.clear() - credential_name = ( - f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}" - if auth_credential - else "" - ) - - return f"temp:adk_{scheme_name}_{credential_name}" - def generate_auth_uri( self, ) -> AuthCredential: diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py index d5604243..a1a19ab8 100644 --- a/src/google/adk/auth/auth_tool.py +++ b/src/google/adk/auth/auth_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from .auth_credential import AuthCredential from .auth_credential import BaseModelWithConfig from .auth_schemes import AuthScheme @@ -43,6 +45,34 @@ class AuthConfig(BaseModelWithConfig): this field to guide the user through the OAuth2 flow and fill auth response in this field""" + def get_credential_key(self): + """Generates a hash key based on auth_scheme and raw_auth_credential. This + hash key can be used to store / retrieve exchanged_auth_credential in a + credentials store. + """ + auth_scheme = self.auth_scheme + + if auth_scheme.model_extra: + auth_scheme = auth_scheme.model_copy(deep=True) + auth_scheme.model_extra.clear() + scheme_name = ( + f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}" + if auth_scheme + else "" + ) + + auth_credential = self.raw_auth_credential + if auth_credential.model_extra: + auth_credential = auth_credential.model_copy(deep=True) + auth_credential.model_extra.clear() + credential_name = ( + f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}" + if auth_credential + else "" + ) + + return f"adk_{scheme_name}_{credential_name}" + class AuthToolArguments(BaseModelWithConfig): """the arguments for the special long running function tool that is used to diff --git a/tests/unittests/auth/test_auth_config.py b/tests/unittests/auth/test_auth_config.py new file mode 100644 index 00000000..7a20cfc6 --- /dev/null +++ b/tests/unittests/auth/test_auth_config.py @@ -0,0 +1,91 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +import pytest + + +class TestAuthConfig: + """Tests for the AuthConfig method.""" + + +@pytest.fixture +def oauth2_auth_scheme(): + """Create an OAuth2 auth scheme for testing.""" + # Create the OAuthFlows object first + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/oauth2/authorize", + tokenUrl="https://example.com/oauth2/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + + # Then create the OAuth2 object with the flows + return OAuth2(flows=flows) + + +@pytest.fixture +def oauth2_credentials(): + """Create OAuth2 credentials for testing.""" + return AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="mock_client_id", + client_secret="mock_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + +@pytest.fixture +def auth_config(oauth2_auth_scheme, oauth2_credentials): + """Create an AuthConfig for testing.""" + # Create a copy of the credentials for the exchanged_auth_credential + exchanged_credential = oauth2_credentials.model_copy(deep=True) + + return AuthConfig( + auth_scheme=oauth2_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged_credential, + ) + + +def test_get_credential_key(auth_config): + """Test generating a unique credential key.""" + + key = auth_config.get_credential_key() + assert key.startswith("adk_oauth2_") + assert "_oauth2_" in key + + +def test_get_credential_key_with_extras(auth_config): + """Test generating a key when model_extra exists.""" + # Add model_extra to test cleanup + + original_key = auth_config.get_credential_key() + key = auth_config.get_credential_key() + + auth_config.auth_scheme.model_extra["extra_field"] = "value" + auth_config.raw_auth_credential.model_extra["extra_field"] = "value" + + assert original_key == key + assert "extra_field" in auth_config.auth_scheme.model_extra + assert "extra_field" in auth_config.raw_auth_credential.model_extra diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index bdac9892..c717682a 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -209,31 +209,6 @@ class TestAuthHandlerInit: assert handler.auth_config == auth_config -class TestGetCredentialKey: - """Tests for the get_credential_key method.""" - - def test_get_credential_key(self, auth_config): - """Test generating a unique credential key.""" - handler = AuthHandler(auth_config) - key = handler.get_credential_key() - assert key.startswith("temp:adk_oauth2_") - assert "_oauth2_" in key - - def test_get_credential_key_with_extras(self, auth_config): - """Test generating a key when model_extra exists.""" - # Add model_extra to test cleanup - - original_key = AuthHandler(auth_config).get_credential_key() - key = AuthHandler(auth_config).get_credential_key() - - auth_config.auth_scheme.model_extra["extra_field"] = "value" - auth_config.raw_auth_credential.model_extra["extra_field"] = "value" - - assert original_key == key - assert "extra_field" in auth_config.auth_scheme.model_extra - assert "extra_field" in auth_config.raw_auth_credential.model_extra - - class TestGenerateAuthUri: """Tests for the generate_auth_uri method.""" @@ -412,8 +387,8 @@ class TestGetAuthResponse: state = MockState() # Store a credential in the state - credential_key = handler.get_credential_key() - state[credential_key] = oauth2_credentials_with_auth_uri + credential_key = auth_config.get_credential_key() + state["temp:" + credential_key] = oauth2_credentials_with_auth_uri result = handler.get_auth_response(state) assert result == oauth2_credentials_with_auth_uri @@ -443,8 +418,10 @@ class TestParseAndStoreAuthResponse: handler.parse_and_store_auth_response(state) - credential_key = handler.get_credential_key() - assert state[credential_key] == auth_config.exchanged_auth_credential + credential_key = auth_config.get_credential_key() + assert ( + state["temp:" + credential_key] == auth_config.exchanged_auth_credential + ) @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token") def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged): @@ -459,8 +436,8 @@ class TestParseAndStoreAuthResponse: handler.parse_and_store_auth_response(state) - credential_key = handler.get_credential_key() - assert state[credential_key] == mock_exchange_token.return_value + credential_key = auth_config_with_exchanged.get_credential_key() + assert state["temp:" + credential_key] == mock_exchange_token.return_value assert mock_exchange_token.called