You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add SSL certificate verification configuration to OpenAPI tools
This change introduces a `verify` parameter to `RestApiTool` and `OpenAPIToolset`. This parameter allows users to configure how SSL certificates are verified when making API calls using the `requests` library. Options include providing a path to a CA bundle, disabling verification, or using a custom `ssl.SSLContext`. New methods `configure_verify` and `configure_verify_all` are added to update this setting after initialization. This is useful for environments with TLS-intercepting proxies. Fixes: https://github.com/google/adk-python/issues/3720 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 840809727
This commit is contained in:
committed by
Copybara-Service
parent
711df01e73
commit
9d2388a46f
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
@@ -68,6 +69,7 @@ class OpenAPIToolset(BaseToolset):
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
@@ -102,10 +104,19 @@ class OpenAPIToolset(BaseToolset):
|
||||
``google.adk.tools.openapi_tool.auth.auth_helpers``
|
||||
tool_filter: The filter used to filter the tools in the toolset. It can be
|
||||
either a tool predicate or a list of tool names of the tools to expose.
|
||||
ssl_verify: SSL certificate verification option for all tools. Can be:
|
||||
- None: Use default verification (True)
|
||||
- True: Verify SSL certificates using system CA
|
||||
- False: Disable SSL verification (insecure, not recommended)
|
||||
- str: Path to a CA bundle file or directory for custom CA
|
||||
- ssl.SSLContext: Custom SSL context for advanced configuration
|
||||
This is useful for enterprise environments where requests go through
|
||||
a TLS-intercepting proxy with a custom CA certificate.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self._ssl_verify = ssl_verify
|
||||
self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
@@ -121,6 +132,26 @@ class OpenAPIToolset(BaseToolset):
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def configure_ssl_verify_all(
|
||||
self, ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None
|
||||
):
|
||||
"""Configure SSL certificate verification for all tools.
|
||||
|
||||
This is useful for enterprise environments where requests go through a
|
||||
TLS-intercepting proxy with a custom CA certificate.
|
||||
|
||||
Args:
|
||||
ssl_verify: SSL certificate verification option. Can be:
|
||||
- None: Use default verification (True)
|
||||
- True: Verify SSL certificates using system CA
|
||||
- False: Disable SSL verification (insecure, not recommended)
|
||||
- str: Path to a CA bundle file or directory for custom CA
|
||||
- ssl.SSLContext: Custom SSL context for advanced configuration
|
||||
"""
|
||||
self._ssl_verify = ssl_verify
|
||||
for tool in self._tools:
|
||||
tool.configure_ssl_verify(ssl_verify)
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
@@ -154,7 +185,7 @@ class OpenAPIToolset(BaseToolset):
|
||||
|
||||
tools = []
|
||||
for o in operations:
|
||||
tool = RestApiTool.from_parsed_operation(o)
|
||||
tool = RestApiTool.from_parsed_operation(o, ssl_verify=self._ssl_verify)
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
@@ -88,6 +89,7 @@ class RestApiTool(BaseTool):
|
||||
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||
should_parse_operation=True,
|
||||
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
|
||||
):
|
||||
"""Initializes the RestApiTool with the given parameters.
|
||||
|
||||
@@ -114,6 +116,12 @@ class RestApiTool(BaseTool):
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
|
||||
auth_credential: The authentication credential of the tool.
|
||||
should_parse_operation: Whether to parse the operation.
|
||||
ssl_verify: SSL certificate verification option. Can be:
|
||||
- None: Use default verification
|
||||
- True: Verify SSL certificates using system CA
|
||||
- False: Disable SSL verification (insecure, not recommended)
|
||||
- str: Path to a CA bundle file or directory for custom CA
|
||||
- ssl.SSLContext: Custom SSL context for advanced configuration
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
self.name = name[:60]
|
||||
@@ -136,15 +144,21 @@ class RestApiTool(BaseTool):
|
||||
# Private properties
|
||||
self.credential_exchanger = AutoAuthCredentialExchanger()
|
||||
self._default_headers: Dict[str, str] = {}
|
||||
self._ssl_verify = ssl_verify
|
||||
if should_parse_operation:
|
||||
self._operation_parser = OperationParser(self.operation)
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
|
||||
def from_parsed_operation(
|
||||
cls,
|
||||
parsed: ParsedOperation,
|
||||
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
|
||||
) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a ParsedOperation object.
|
||||
|
||||
Args:
|
||||
parsed: A ParsedOperation object.
|
||||
ssl_verify: SSL certificate verification option.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
@@ -163,6 +177,7 @@ class RestApiTool(BaseTool):
|
||||
operation=parsed.operation,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
auth_credential=parsed.auth_credential,
|
||||
ssl_verify=ssl_verify,
|
||||
)
|
||||
generated._operation_parser = operation_parser
|
||||
return generated
|
||||
@@ -218,6 +233,24 @@ class RestApiTool(BaseTool):
|
||||
auth_credential = AuthCredential.model_validate_json(auth_credential)
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
def configure_ssl_verify(
|
||||
self, ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None
|
||||
):
|
||||
"""Configures SSL certificate verification for the API call.
|
||||
|
||||
This is useful for enterprise environments where requests go through a
|
||||
TLS-intercepting proxy with a custom CA certificate.
|
||||
|
||||
Args:
|
||||
ssl_verify: SSL certificate verification option. Can be:
|
||||
- None: Use default verification (True)
|
||||
- True: Verify SSL certificates using system CA
|
||||
- False: Disable SSL verification (insecure, not recommended)
|
||||
- str: Path to a CA bundle file or directory for custom CA
|
||||
- ssl.SSLContext: Custom SSL context for advanced configuration
|
||||
"""
|
||||
self._ssl_verify = ssl_verify
|
||||
|
||||
def set_default_headers(self, headers: Dict[str, str]):
|
||||
"""Sets default headers that are merged into every request."""
|
||||
self._default_headers = headers
|
||||
@@ -415,6 +448,8 @@ class RestApiTool(BaseTool):
|
||||
|
||||
# Got all parameters. Call the API.
|
||||
request_params = self._prepare_request_params(api_params, api_args)
|
||||
if self._ssl_verify is not None:
|
||||
request_params["verify"] = self._ssl_verify
|
||||
response = requests.request(**request_params)
|
||||
|
||||
# Parse API response
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
@@ -137,3 +138,34 @@ def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
|
||||
for tool in toolset._tools:
|
||||
assert tool.auth_scheme == auth_scheme
|
||||
assert tool.auth_credential == auth_credential
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"verify_value", ["/path/to/enterprise-ca-bundle.crt", False]
|
||||
)
|
||||
def test_openapi_toolset_verify_on_init(
|
||||
openapi_spec: Dict[str, Any], verify_value: str | bool
|
||||
):
|
||||
"""Test configuring verify during initialization."""
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=openapi_spec,
|
||||
ssl_verify=verify_value,
|
||||
)
|
||||
for tool in toolset._tools:
|
||||
assert tool._ssl_verify == verify_value
|
||||
|
||||
|
||||
def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]):
|
||||
"""Test configure_verify_all method."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
|
||||
# Initially verify should be None
|
||||
for tool in toolset._tools:
|
||||
assert tool._ssl_verify is None
|
||||
|
||||
# Configure verify for all tools
|
||||
ca_bundle_path = "/path/to/custom-ca.crt"
|
||||
toolset.configure_ssl_verify_all(ca_bundle_path)
|
||||
|
||||
for tool in toolset._tools:
|
||||
assert tool._ssl_verify == ca_bundle_path
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
|
||||
import json
|
||||
import ssl
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
@@ -48,6 +50,11 @@ class TestRestApiTool:
|
||||
mock_context.request_credential.return_value = {}
|
||||
return mock_context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssl_context(self):
|
||||
"""Fixture for a mock ssl.SSLContext."""
|
||||
return mock.create_autospec(ssl.SSLContext)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_operation_parser(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
@@ -934,6 +941,101 @@ class TestRestApiTool:
|
||||
assert "param_name" in request_params["params"]
|
||||
assert "empty_param" not in request_params["params"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"verify_input, expected_verify_in_call",
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
(
|
||||
"/path/to/enterprise-ca-bundle.crt",
|
||||
"/path/to/enterprise-ca-bundle.crt",
|
||||
),
|
||||
(
|
||||
"USE_SSL_FIXTURE",
|
||||
"USE_SSL_FIXTURE",
|
||||
),
|
||||
(None, None), # None means 'verify' should not be in call_kwargs
|
||||
],
|
||||
)
|
||||
async def test_call_with_verify_options(
|
||||
self,
|
||||
mock_tool_context,
|
||||
sample_endpoint,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
mock_ssl_context,
|
||||
verify_input,
|
||||
expected_verify_in_call,
|
||||
):
|
||||
"""Test different values for the 'verify' parameter."""
|
||||
if verify_input == "USE_SSL_FIXTURE":
|
||||
verify_input = mock_ssl_context
|
||||
if expected_verify_in_call == "USE_SSL_FIXTURE":
|
||||
expected_verify_in_call = mock_ssl_context
|
||||
|
||||
mock_response = mock.create_autospec(
|
||||
requests.Response, instance=True, spec_set=True
|
||||
)
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpoint,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
ssl_verify=verify_input,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
requests, "request", return_value=mock_response, autospec=True
|
||||
) as mock_request:
|
||||
await tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
assert mock_request.called
|
||||
_, call_kwargs = mock_request.call_args
|
||||
if expected_verify_in_call is None:
|
||||
assert "verify" not in call_kwargs
|
||||
else:
|
||||
assert call_kwargs["verify"] == expected_verify_in_call
|
||||
|
||||
async def test_call_with_configure_verify(
|
||||
self,
|
||||
mock_tool_context,
|
||||
sample_endpoint,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
"""Test that configure_verify updates the verify setting."""
|
||||
mock_response = mock.create_autospec(
|
||||
requests.Response, instance=True, spec_set=True
|
||||
)
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpoint,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
|
||||
ca_bundle_path = "/path/to/custom-ca.crt"
|
||||
tool.configure_ssl_verify(ca_bundle_path)
|
||||
|
||||
with patch.object(
|
||||
requests, "request", return_value=mock_response
|
||||
) as mock_request:
|
||||
await tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
assert mock_request.called
|
||||
call_kwargs = mock_request.call_args[1]
|
||||
assert call_kwargs["verify"] == ca_bundle_path
|
||||
|
||||
|
||||
def test_snake_to_lower_camel():
|
||||
assert snake_to_lower_camel("single") == "single"
|
||||
|
||||
Reference in New Issue
Block a user