You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: remove bare excepts
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 867666149
This commit is contained in:
committed by
Copybara-Service
parent
fd8a9e3962
commit
0758f877b1
@@ -570,19 +570,16 @@ async def _execute_single_function_call_async(
|
||||
return function_response_event
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
function_response_event = None
|
||||
try:
|
||||
function_response_event = await _run_with_trace()
|
||||
return function_response_event
|
||||
finally:
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
except:
|
||||
trace_tool_call(
|
||||
tool=tool, args=function_args, function_response_event=None
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def handle_function_calls_live(
|
||||
@@ -720,19 +717,16 @@ async def _execute_single_function_call_live(
|
||||
return function_response_event
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
function_response_event = None
|
||||
try:
|
||||
function_response_event = await _run_with_trace()
|
||||
return function_response_event
|
||||
finally:
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
except:
|
||||
trace_tool_call(
|
||||
tool=tool, args=function_args, function_response_event=None
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _process_function_live_helper(
|
||||
|
||||
@@ -27,6 +27,7 @@ from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
@@ -329,7 +330,7 @@ class APIHubClient(BaseAPIHubClient):
|
||||
credentials, _ = default_service_credential(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
except:
|
||||
except DefaultCredentialsError:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
|
||||
@@ -815,7 +815,7 @@ class ConnectionsClient:
|
||||
credentials, _ = default_service_credential(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
except:
|
||||
except google.auth.exceptions.DefaultCredentialsError:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
|
||||
@@ -253,7 +253,7 @@ class IntegrationClient:
|
||||
credentials, project_id = default_service_credential(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
except:
|
||||
except google.auth.exceptions.DefaultCredentialsError:
|
||||
credentials = None
|
||||
if credentials:
|
||||
quota_project_id = getattr(credentials, "quota_project_id", None)
|
||||
|
||||
@@ -176,7 +176,7 @@ def _execute_sql(
|
||||
try:
|
||||
# if the json serialization of the value succeeds, use it as is
|
||||
json.dumps(val)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
val = str(val)
|
||||
row_values[key] = val
|
||||
rows.append(row_values)
|
||||
|
||||
@@ -101,7 +101,7 @@ def execute_sql(
|
||||
try:
|
||||
# if the json serialization of the value succeeds, use it as is
|
||||
json.dumps(val)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
val = str(val)
|
||||
row_values[key] = val
|
||||
rows.append(row_values)
|
||||
|
||||
@@ -277,7 +277,7 @@ def get_table_schema(
|
||||
|
||||
try:
|
||||
json.dumps(results)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
results = str(results)
|
||||
|
||||
return {"status": "SUCCESS", "results": results}
|
||||
@@ -375,7 +375,7 @@ def list_table_indexes(
|
||||
|
||||
try:
|
||||
json.dumps(index_info)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
index_info = str(index_info)
|
||||
|
||||
indexes.append(index_info)
|
||||
@@ -479,7 +479,7 @@ def list_table_index_columns(
|
||||
|
||||
try:
|
||||
json.dumps(index_column_info)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
index_column_info = str(index_column_info)
|
||||
|
||||
index_columns.append(index_column_info)
|
||||
|
||||
@@ -515,7 +515,7 @@ def similarity_search(
|
||||
try:
|
||||
# if the json serialization of the row succeeds, use it as is
|
||||
json.dumps(row)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
row = str(row)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
@@ -107,7 +107,7 @@ def execute_sql(
|
||||
try:
|
||||
# if the json serialization of the row succeeds, use it as is
|
||||
json.dumps(row)
|
||||
except:
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
row = str(row)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
@@ -18,6 +18,7 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -398,6 +399,24 @@ class TestAPIHubClient:
|
||||
# no service account client
|
||||
APIHubClient()._get_access_token()
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_default_credentials_error(
|
||||
self, mock_default_service_credential
|
||||
):
|
||||
mock_default_service_credential.side_effect = DefaultCredentialsError(
|
||||
"ADC not found"
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
),
|
||||
):
|
||||
APIHubClient()._get_access_token()
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_api_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
|
||||
@@ -631,6 +631,25 @@ class TestConnectionsClient:
|
||||
):
|
||||
client._get_access_token()
|
||||
|
||||
def test_get_access_token_default_credentials_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
side_effect=google.auth.exceptions.DefaultCredentialsError(
|
||||
"ADC not found"
|
||||
),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account that has the required"
|
||||
" permissions"
|
||||
),
|
||||
):
|
||||
client._get_access_token()
|
||||
|
||||
def test_get_access_token_refreshes_expired_token(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
|
||||
@@ -595,6 +595,34 @@ class TestIntegrationClient:
|
||||
in str(e)
|
||||
)
|
||||
|
||||
def test_get_access_token_default_credentials_error(
|
||||
self, project, location, integration_name, triggers, connection_name
|
||||
):
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
side_effect=google.auth.exceptions.DefaultCredentialsError(
|
||||
"ADC not found"
|
||||
),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
triggers=triggers,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account that has the required"
|
||||
" permissions to access the connection."
|
||||
),
|
||||
):
|
||||
client._get_access_token()
|
||||
|
||||
def test_get_access_token_uses_cached_token(
|
||||
self,
|
||||
project,
|
||||
|
||||
@@ -1136,6 +1136,33 @@ def test_execute_sql_result_dtype(
|
||||
assert result == {"status": "SUCCESS", "rows": tool_result_rows}
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "query", autospec=True)
|
||||
def test_execute_sql_result_dtype_circular_reference(
|
||||
mock_query, mock_query_and_wait
|
||||
):
|
||||
"""Test execute_sql converts circular values to strings."""
|
||||
credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = BigQueryToolConfig()
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = "SELECT"
|
||||
mock_query.return_value = query_job
|
||||
circular_value = []
|
||||
circular_value.append(circular_value)
|
||||
mock_query_and_wait.return_value = [{"x": circular_value}]
|
||||
|
||||
result = query_tool.execute_sql(
|
||||
"my_project", "SELECT 1", credentials, tool_settings, tool_context
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "SUCCESS",
|
||||
"rows": [{"x": str(circular_value)}],
|
||||
}
|
||||
|
||||
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_execute_sql_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during execute_sql tool invocation."""
|
||||
|
||||
@@ -135,3 +135,37 @@ def test_execute_sql_error():
|
||||
tool_context=tool_context,
|
||||
)
|
||||
assert result == {"status": "ERROR", "error_details": "Test error"}
|
||||
|
||||
|
||||
def test_execute_sql_row_value_circular_reference_fallback():
|
||||
"""Test execute_sql converts circular row values to strings."""
|
||||
project = "my_project"
|
||||
instance_id = "my_instance"
|
||||
query = "SELECT * FROM my_table"
|
||||
credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.bigtable.client.get_bigtable_data_client"
|
||||
) as mock_get_client:
|
||||
mock_client = mock.MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True)
|
||||
mock_client.execute_query.return_value = mock_iterator
|
||||
circular_value = []
|
||||
circular_value.append(circular_value)
|
||||
mock_row = mock.MagicMock()
|
||||
mock_row.fields = {"col1": circular_value}
|
||||
mock_iterator.__iter__.return_value = [mock_row]
|
||||
|
||||
result = execute_sql(
|
||||
project_id=project,
|
||||
instance_id=instance_id,
|
||||
credentials=credentials,
|
||||
query=query,
|
||||
settings=BigtableToolSettings(),
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["rows"][0]["col1"] == str(circular_value)
|
||||
|
||||
@@ -192,6 +192,45 @@ def test_list_table_indexes_success(
|
||||
assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY"
|
||||
|
||||
|
||||
@patch("google.adk.tools.spanner.client.get_spanner_client")
|
||||
def test_list_table_indexes_circular_row_fallback_to_string(
|
||||
mock_get_spanner_client, mock_spanner_ids, mock_credentials
|
||||
):
|
||||
"""Test list_table_indexes stringifies rows with circular references."""
|
||||
mock_spanner_client = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_database = MagicMock()
|
||||
mock_snapshot = MagicMock()
|
||||
circular_value = []
|
||||
circular_value.append(circular_value)
|
||||
mock_result_set = MagicMock()
|
||||
mock_result_set.__iter__.return_value = iter([(
|
||||
circular_value,
|
||||
"",
|
||||
"PRIMARY_KEY",
|
||||
"",
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)])
|
||||
mock_snapshot.execute_sql.return_value = mock_result_set
|
||||
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
|
||||
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
|
||||
mock_instance.database.return_value = mock_database
|
||||
mock_spanner_client.instance.return_value = mock_instance
|
||||
mock_get_spanner_client.return_value = mock_spanner_client
|
||||
|
||||
result = metadata_tool.list_table_indexes(
|
||||
mock_spanner_ids["project_id"],
|
||||
mock_spanner_ids["instance_id"],
|
||||
mock_spanner_ids["database_id"],
|
||||
mock_spanner_ids["table_name"],
|
||||
mock_credentials,
|
||||
)
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert isinstance(result["results"][0], str)
|
||||
|
||||
|
||||
@patch("google.adk.tools.spanner.client.get_spanner_client")
|
||||
def test_list_table_index_columns_success(
|
||||
mock_get_spanner_client, mock_spanner_ids, mock_credentials
|
||||
|
||||
@@ -187,6 +187,47 @@ def test_similarity_search_error(
|
||||
assert "Test Exception" in result["error_details"]
|
||||
|
||||
|
||||
@mock.patch.object(utils, "embed_contents")
|
||||
@mock.patch.object(client, "get_spanner_client")
|
||||
def test_similarity_search_circular_row_fallback_to_string(
|
||||
mock_get_spanner_client,
|
||||
mock_embed_contents,
|
||||
mock_spanner_ids,
|
||||
mock_credentials,
|
||||
):
|
||||
"""Test similarity_search stringifies rows with circular references."""
|
||||
mock_spanner_client = MagicMock()
|
||||
mock_instance = MagicMock()
|
||||
mock_database = MagicMock()
|
||||
mock_snapshot = MagicMock()
|
||||
circular_row = []
|
||||
circular_row.append(circular_row)
|
||||
mock_embed_contents.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_snapshot.execute_sql.return_value = iter([circular_row])
|
||||
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
|
||||
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
|
||||
mock_instance.database.return_value = mock_database
|
||||
mock_spanner_client.instance.return_value = mock_instance
|
||||
mock_get_spanner_client.return_value = mock_spanner_client
|
||||
|
||||
result = search_tool.similarity_search(
|
||||
project_id=mock_spanner_ids["project_id"],
|
||||
instance_id=mock_spanner_ids["instance_id"],
|
||||
database_id=mock_spanner_ids["database_id"],
|
||||
table_name=mock_spanner_ids["table_name"],
|
||||
query="test query",
|
||||
embedding_column_to_search="embedding_col",
|
||||
columns=["col1"],
|
||||
embedding_options={
|
||||
"vertex_ai_embedding_model_name": "text-embedding-005"
|
||||
},
|
||||
credentials=mock_credentials,
|
||||
)
|
||||
|
||||
assert result["status"] == "SUCCESS", result
|
||||
assert result["rows"] == [str(circular_row)]
|
||||
|
||||
|
||||
@mock.patch.object(client, "get_spanner_client")
|
||||
def test_similarity_search_postgresql_knn_success(
|
||||
mock_get_spanner_client, mock_spanner_ids, mock_credentials
|
||||
|
||||
@@ -178,6 +178,35 @@ def test_add_contents_empty_contents(
|
||||
mock_spanner_database.batch.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils.client, "get_spanner_client", autospec=True)
|
||||
def test_execute_sql_circular_row_fallback_to_string(mock_get_spanner_client):
|
||||
"""Test execute_sql stringifies rows with circular references."""
|
||||
mock_spanner_client = mock.MagicMock()
|
||||
mock_instance = mock.MagicMock()
|
||||
mock_database = mock.MagicMock()
|
||||
mock_snapshot = mock.MagicMock()
|
||||
circular_row = []
|
||||
circular_row.append(circular_row)
|
||||
mock_snapshot.execute_sql.return_value = iter([circular_row])
|
||||
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
|
||||
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
|
||||
mock_instance.database.return_value = mock_database
|
||||
mock_spanner_client.instance.return_value = mock_instance
|
||||
mock_get_spanner_client.return_value = mock_spanner_client
|
||||
|
||||
result = spanner_utils.execute_sql(
|
||||
project_id="test-project",
|
||||
instance_id="test-instance",
|
||||
database_id="test-database",
|
||||
query="SELECT 1",
|
||||
credentials=mock.Mock(),
|
||||
settings=SpannerToolSettings(),
|
||||
tool_context=mock.Mock(),
|
||||
)
|
||||
|
||||
assert result == {"status": "SUCCESS", "rows": [str(circular_row)]}
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
|
||||
def test_add_contents_additional_columns_list_mismatch(
|
||||
mock_embed_contents, spanner_tool_settings, mock_spanner_client
|
||||
|
||||
Reference in New Issue
Block a user