You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
refactor: extract credentail key building logic to auth_config
PiperOrigin-RevId: 768124459
This commit is contained in:
committed by
Copybara-Service
parent
9abb8414da
commit
309a656f49
@@ -112,7 +112,7 @@ class AuthHandler:
|
||||
|
||||
def parse_and_store_auth_response(self, state: State) -> None:
|
||||
|
||||
credential_key = self.get_credential_key()
|
||||
credential_key = "temp:" + self.auth_config.get_credential_key()
|
||||
|
||||
state[credential_key] = self.auth_config.exchanged_auth_credential
|
||||
if not isinstance(
|
||||
@@ -130,7 +130,7 @@ class AuthHandler:
|
||||
raise ValueError("auth_scheme is empty.")
|
||||
|
||||
def get_auth_response(self, state: State) -> AuthCredential:
|
||||
credential_key = self.get_credential_key()
|
||||
credential_key = "temp:" + self.auth_config.get_credential_key()
|
||||
return state.get(credential_key, None)
|
||||
|
||||
def generate_auth_request(self) -> AuthConfig:
|
||||
@@ -192,29 +192,6 @@ class AuthHandler:
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
def get_credential_key(self) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
if auth_scheme.model_extra:
|
||||
auth_scheme = auth_scheme.model_copy(deep=True)
|
||||
auth_scheme.model_extra.clear()
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
if auth_credential.model_extra:
|
||||
auth_credential = auth_credential.model_copy(deep=True)
|
||||
auth_credential.model_extra.clear()
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"temp:adk_{scheme_name}_{credential_name}"
|
||||
|
||||
def generate_auth_uri(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import BaseModelWithConfig
|
||||
from .auth_schemes import AuthScheme
|
||||
@@ -43,6 +45,34 @@ class AuthConfig(BaseModelWithConfig):
|
||||
this field to guide the user through the OAuth2 flow and fill auth response in
|
||||
this field"""
|
||||
|
||||
def get_credential_key(self):
|
||||
"""Generates a hash key based on auth_scheme and raw_auth_credential. This
|
||||
hash key can be used to store / retrieve exchanged_auth_credential in a
|
||||
credentials store.
|
||||
"""
|
||||
auth_scheme = self.auth_scheme
|
||||
|
||||
if auth_scheme.model_extra:
|
||||
auth_scheme = auth_scheme.model_copy(deep=True)
|
||||
auth_scheme.model_extra.clear()
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
|
||||
auth_credential = self.raw_auth_credential
|
||||
if auth_credential.model_extra:
|
||||
auth_credential = auth_credential.model_copy(deep=True)
|
||||
auth_credential.model_extra.clear()
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"adk_{scheme_name}_{credential_name}"
|
||||
|
||||
|
||||
class AuthToolArguments(BaseModelWithConfig):
|
||||
"""the arguments for the special long running function tool that is used to
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuthConfig:
|
||||
"""Tests for the AuthConfig method."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_auth_scheme():
|
||||
"""Create an OAuth2 auth scheme for testing."""
|
||||
# Create the OAuthFlows object first
|
||||
flows = OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl="https://example.com/oauth2/authorize",
|
||||
tokenUrl="https://example.com/oauth2/token",
|
||||
scopes={"read": "Read access", "write": "Write access"},
|
||||
)
|
||||
)
|
||||
|
||||
# Then create the OAuth2 object with the flows
|
||||
return OAuth2(flows=flows)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_credentials():
|
||||
"""Create OAuth2 credentials for testing."""
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_config(oauth2_auth_scheme, oauth2_credentials):
|
||||
"""Create an AuthConfig for testing."""
|
||||
# Create a copy of the credentials for the exchanged_auth_credential
|
||||
exchanged_credential = oauth2_credentials.model_copy(deep=True)
|
||||
|
||||
return AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials,
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
|
||||
def test_get_credential_key(auth_config):
|
||||
"""Test generating a unique credential key."""
|
||||
|
||||
key = auth_config.get_credential_key()
|
||||
assert key.startswith("adk_oauth2_")
|
||||
assert "_oauth2_" in key
|
||||
|
||||
|
||||
def test_get_credential_key_with_extras(auth_config):
|
||||
"""Test generating a key when model_extra exists."""
|
||||
# Add model_extra to test cleanup
|
||||
|
||||
original_key = auth_config.get_credential_key()
|
||||
key = auth_config.get_credential_key()
|
||||
|
||||
auth_config.auth_scheme.model_extra["extra_field"] = "value"
|
||||
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
|
||||
|
||||
assert original_key == key
|
||||
assert "extra_field" in auth_config.auth_scheme.model_extra
|
||||
assert "extra_field" in auth_config.raw_auth_credential.model_extra
|
||||
@@ -209,31 +209,6 @@ class TestAuthHandlerInit:
|
||||
assert handler.auth_config == auth_config
|
||||
|
||||
|
||||
class TestGetCredentialKey:
|
||||
"""Tests for the get_credential_key method."""
|
||||
|
||||
def test_get_credential_key(self, auth_config):
|
||||
"""Test generating a unique credential key."""
|
||||
handler = AuthHandler(auth_config)
|
||||
key = handler.get_credential_key()
|
||||
assert key.startswith("temp:adk_oauth2_")
|
||||
assert "_oauth2_" in key
|
||||
|
||||
def test_get_credential_key_with_extras(self, auth_config):
|
||||
"""Test generating a key when model_extra exists."""
|
||||
# Add model_extra to test cleanup
|
||||
|
||||
original_key = AuthHandler(auth_config).get_credential_key()
|
||||
key = AuthHandler(auth_config).get_credential_key()
|
||||
|
||||
auth_config.auth_scheme.model_extra["extra_field"] = "value"
|
||||
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
|
||||
|
||||
assert original_key == key
|
||||
assert "extra_field" in auth_config.auth_scheme.model_extra
|
||||
assert "extra_field" in auth_config.raw_auth_credential.model_extra
|
||||
|
||||
|
||||
class TestGenerateAuthUri:
|
||||
"""Tests for the generate_auth_uri method."""
|
||||
|
||||
@@ -412,8 +387,8 @@ class TestGetAuthResponse:
|
||||
state = MockState()
|
||||
|
||||
# Store a credential in the state
|
||||
credential_key = handler.get_credential_key()
|
||||
state[credential_key] = oauth2_credentials_with_auth_uri
|
||||
credential_key = auth_config.get_credential_key()
|
||||
state["temp:" + credential_key] = oauth2_credentials_with_auth_uri
|
||||
|
||||
result = handler.get_auth_response(state)
|
||||
assert result == oauth2_credentials_with_auth_uri
|
||||
@@ -443,8 +418,10 @@ class TestParseAndStoreAuthResponse:
|
||||
|
||||
handler.parse_and_store_auth_response(state)
|
||||
|
||||
credential_key = handler.get_credential_key()
|
||||
assert state[credential_key] == auth_config.exchanged_auth_credential
|
||||
credential_key = auth_config.get_credential_key()
|
||||
assert (
|
||||
state["temp:" + credential_key] == auth_config.exchanged_auth_credential
|
||||
)
|
||||
|
||||
@patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
|
||||
def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
|
||||
@@ -459,8 +436,8 @@ class TestParseAndStoreAuthResponse:
|
||||
|
||||
handler.parse_and_store_auth_response(state)
|
||||
|
||||
credential_key = handler.get_credential_key()
|
||||
assert state[credential_key] == mock_exchange_token.return_value
|
||||
credential_key = auth_config_with_exchanged.get_credential_key()
|
||||
assert state["temp:" + credential_key] == mock_exchange_token.return_value
|
||||
assert mock_exchange_token.called
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user