refactor: extract credentail key building logic to auth_config

PiperOrigin-RevId: 768124459
This commit is contained in:
Xiang (Sean) Zhou
2025-06-06 10:18:46 -07:00
committed by Copybara-Service
parent 9abb8414da
commit 309a656f49
4 changed files with 131 additions and 56 deletions
+2 -25
View File
@@ -112,7 +112,7 @@ class AuthHandler:
def parse_and_store_auth_response(self, state: State) -> None:
credential_key = self.get_credential_key()
credential_key = "temp:" + self.auth_config.get_credential_key()
state[credential_key] = self.auth_config.exchanged_auth_credential
if not isinstance(
@@ -130,7 +130,7 @@ class AuthHandler:
raise ValueError("auth_scheme is empty.")
def get_auth_response(self, state: State) -> AuthCredential:
credential_key = self.get_credential_key()
credential_key = "temp:" + self.auth_config.get_credential_key()
return state.get(credential_key, None)
def generate_auth_request(self) -> AuthConfig:
@@ -192,29 +192,6 @@ class AuthHandler:
exchanged_auth_credential=exchanged_credential,
)
def get_credential_key(self) -> str:
"""Generates a unique key for the given auth scheme and credential."""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if auth_scheme.model_extra:
auth_scheme = auth_scheme.model_copy(deep=True)
auth_scheme.model_extra.clear()
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
if auth_credential.model_extra:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
return f"temp:adk_{scheme_name}_{credential_name}"
def generate_auth_uri(
self,
) -> AuthCredential:
+30
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from .auth_credential import AuthCredential
from .auth_credential import BaseModelWithConfig
from .auth_schemes import AuthScheme
@@ -43,6 +45,34 @@ class AuthConfig(BaseModelWithConfig):
this field to guide the user through the OAuth2 flow and fill auth response in
this field"""
def get_credential_key(self):
"""Generates a hash key based on auth_scheme and raw_auth_credential. This
hash key can be used to store / retrieve exchanged_auth_credential in a
credentials store.
"""
auth_scheme = self.auth_scheme
if auth_scheme.model_extra:
auth_scheme = auth_scheme.model_copy(deep=True)
auth_scheme.model_extra.clear()
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
auth_credential = self.raw_auth_credential
if auth_credential.model_extra:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
return f"adk_{scheme_name}_{credential_name}"
class AuthToolArguments(BaseModelWithConfig):
"""the arguments for the special long running function tool that is used to
+91
View File
@@ -0,0 +1,91 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_tool import AuthConfig
import pytest
class TestAuthConfig:
"""Tests for the AuthConfig method."""
@pytest.fixture
def oauth2_auth_scheme():
"""Create an OAuth2 auth scheme for testing."""
# Create the OAuthFlows object first
flows = OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://example.com/oauth2/authorize",
tokenUrl="https://example.com/oauth2/token",
scopes={"read": "Read access", "write": "Write access"},
)
)
# Then create the OAuth2 object with the flows
return OAuth2(flows=flows)
@pytest.fixture
def oauth2_credentials():
"""Create OAuth2 credentials for testing."""
return AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
),
)
@pytest.fixture
def auth_config(oauth2_auth_scheme, oauth2_credentials):
"""Create an AuthConfig for testing."""
# Create a copy of the credentials for the exchanged_auth_credential
exchanged_credential = oauth2_credentials.model_copy(deep=True)
return AuthConfig(
auth_scheme=oauth2_auth_scheme,
raw_auth_credential=oauth2_credentials,
exchanged_auth_credential=exchanged_credential,
)
def test_get_credential_key(auth_config):
"""Test generating a unique credential key."""
key = auth_config.get_credential_key()
assert key.startswith("adk_oauth2_")
assert "_oauth2_" in key
def test_get_credential_key_with_extras(auth_config):
"""Test generating a key when model_extra exists."""
# Add model_extra to test cleanup
original_key = auth_config.get_credential_key()
key = auth_config.get_credential_key()
auth_config.auth_scheme.model_extra["extra_field"] = "value"
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
assert original_key == key
assert "extra_field" in auth_config.auth_scheme.model_extra
assert "extra_field" in auth_config.raw_auth_credential.model_extra
+8 -31
View File
@@ -209,31 +209,6 @@ class TestAuthHandlerInit:
assert handler.auth_config == auth_config
class TestGetCredentialKey:
"""Tests for the get_credential_key method."""
def test_get_credential_key(self, auth_config):
"""Test generating a unique credential key."""
handler = AuthHandler(auth_config)
key = handler.get_credential_key()
assert key.startswith("temp:adk_oauth2_")
assert "_oauth2_" in key
def test_get_credential_key_with_extras(self, auth_config):
"""Test generating a key when model_extra exists."""
# Add model_extra to test cleanup
original_key = AuthHandler(auth_config).get_credential_key()
key = AuthHandler(auth_config).get_credential_key()
auth_config.auth_scheme.model_extra["extra_field"] = "value"
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
assert original_key == key
assert "extra_field" in auth_config.auth_scheme.model_extra
assert "extra_field" in auth_config.raw_auth_credential.model_extra
class TestGenerateAuthUri:
"""Tests for the generate_auth_uri method."""
@@ -412,8 +387,8 @@ class TestGetAuthResponse:
state = MockState()
# Store a credential in the state
credential_key = handler.get_credential_key()
state[credential_key] = oauth2_credentials_with_auth_uri
credential_key = auth_config.get_credential_key()
state["temp:" + credential_key] = oauth2_credentials_with_auth_uri
result = handler.get_auth_response(state)
assert result == oauth2_credentials_with_auth_uri
@@ -443,8 +418,10 @@ class TestParseAndStoreAuthResponse:
handler.parse_and_store_auth_response(state)
credential_key = handler.get_credential_key()
assert state[credential_key] == auth_config.exchanged_auth_credential
credential_key = auth_config.get_credential_key()
assert (
state["temp:" + credential_key] == auth_config.exchanged_auth_credential
)
@patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
@@ -459,8 +436,8 @@ class TestParseAndStoreAuthResponse:
handler.parse_and_store_auth_response(state)
credential_key = handler.get_credential_key()
assert state[credential_key] == mock_exchange_token.return_value
credential_key = auth_config_with_exchanged.get_credential_key()
assert state["temp:" + credential_key] == mock_exchange_token.return_value
assert mock_exchange_token.called