diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 664d0d6c..a619aac4 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -570,19 +570,16 @@ async def _execute_single_function_call_async( return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): + function_response_event = None try: function_response_event = await _run_with_trace() + return function_response_event + finally: trace_tool_call( tool=tool, args=function_args, function_response_event=function_response_event, ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise async def handle_function_calls_live( @@ -720,19 +717,16 @@ async def _execute_single_function_call_live( return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): + function_response_event = None try: function_response_event = await _run_with_trace() + return function_response_event + finally: trace_tool_call( tool=tool, args=function_args, function_response_event=function_response_event, ) - return function_response_event - except: - trace_tool_call( - tool=tool, args=function_args, function_response_event=None - ) - raise async def _process_function_live_helper( diff --git a/src/google/adk/tools/apihub_tool/clients/apihub_client.py b/src/google/adk/tools/apihub_tool/clients/apihub_client.py index b9197350..ac566c84 100644 --- a/src/google/adk/tools/apihub_tool/clients/apihub_client.py +++ b/src/google/adk/tools/apihub_tool/clients/apihub_client.py @@ -27,6 +27,7 @@ from urllib.parse import parse_qs from urllib.parse import urlparse from google.auth import default as default_service_credential +from google.auth.exceptions import DefaultCredentialsError from google.auth.transport.requests import Request from google.oauth2 import service_account import requests @@ -329,7 +330,7 @@ class APIHubClient(BaseAPIHubClient): credentials, _ = default_service_credential( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) - except: + except DefaultCredentialsError: credentials = None if not credentials: diff --git a/src/google/adk/tools/application_integration_tool/clients/connections_client.py b/src/google/adk/tools/application_integration_tool/clients/connections_client.py index 1756d5b0..fdec2d22 100644 --- a/src/google/adk/tools/application_integration_tool/clients/connections_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/connections_client.py @@ -815,7 +815,7 @@ class ConnectionsClient: credentials, _ = default_service_credential( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) - except: + except google.auth.exceptions.DefaultCredentialsError: credentials = None if not credentials: diff --git a/src/google/adk/tools/application_integration_tool/clients/integration_client.py b/src/google/adk/tools/application_integration_tool/clients/integration_client.py index a7b5e44c..bb704ee4 100644 --- a/src/google/adk/tools/application_integration_tool/clients/integration_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/integration_client.py @@ -253,7 +253,7 @@ class IntegrationClient: credentials, project_id = default_service_credential( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) - except: + except google.auth.exceptions.DefaultCredentialsError: credentials = None if credentials: quota_project_id = getattr(credentials, "quota_project_id", None) diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 66cfb26b..d5b1264f 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -176,7 +176,7 @@ def _execute_sql( try: # if the json serialization of the value succeeds, use it as is json.dumps(val) - except: + except (TypeError, ValueError, OverflowError): val = str(val) row_values[key] = val rows.append(row_values) diff --git a/src/google/adk/tools/bigtable/query_tool.py b/src/google/adk/tools/bigtable/query_tool.py index 34830cc6..a7a785a2 100644 --- a/src/google/adk/tools/bigtable/query_tool.py +++ b/src/google/adk/tools/bigtable/query_tool.py @@ -101,7 +101,7 @@ def execute_sql( try: # if the json serialization of the value succeeds, use it as is json.dumps(val) - except: + except (TypeError, ValueError, OverflowError): val = str(val) row_values[key] = val rows.append(row_values) diff --git a/src/google/adk/tools/spanner/metadata_tool.py b/src/google/adk/tools/spanner/metadata_tool.py index 4b71f40c..51d8ac1a 100644 --- a/src/google/adk/tools/spanner/metadata_tool.py +++ b/src/google/adk/tools/spanner/metadata_tool.py @@ -277,7 +277,7 @@ def get_table_schema( try: json.dumps(results) - except: + except (TypeError, ValueError, OverflowError): results = str(results) return {"status": "SUCCESS", "results": results} @@ -375,7 +375,7 @@ def list_table_indexes( try: json.dumps(index_info) - except: + except (TypeError, ValueError, OverflowError): index_info = str(index_info) indexes.append(index_info) @@ -479,7 +479,7 @@ def list_table_index_columns( try: json.dumps(index_column_info) - except: + except (TypeError, ValueError, OverflowError): index_column_info = str(index_column_info) index_columns.append(index_column_info) diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index 93944dec..03f695b8 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -515,7 +515,7 @@ def similarity_search( try: # if the json serialization of the row succeeds, use it as is json.dumps(row) - except: + except (TypeError, ValueError, OverflowError): row = str(row) rows.append(row) diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index e2c2ce14..9f5efdb7 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -107,7 +107,7 @@ def execute_sql( try: # if the json serialization of the row succeeds, use it as is json.dumps(row) - except: + except (TypeError, ValueError, OverflowError): row = str(row) rows.append(row) diff --git a/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py b/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py index 621ebc1b..06244671 100644 --- a/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py +++ b/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock from unittest.mock import patch from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient +from google.auth.exceptions import DefaultCredentialsError import pytest from requests.exceptions import HTTPError @@ -398,6 +399,24 @@ class TestAPIHubClient: # no service account client APIHubClient()._get_access_token() + @patch( + "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential" + ) + def test_get_access_token_default_credentials_error( + self, mock_default_service_credential + ): + mock_default_service_credential.side_effect = DefaultCredentialsError( + "ADC not found" + ) + with pytest.raises( + ValueError, + match=( + "Please provide a service account or an access token to API Hub" + " client." + ), + ): + APIHubClient()._get_access_token() + @patch("requests.get") def test_get_spec_content_api_level(self, mock_get, client): mock_get.side_effect = [ diff --git a/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py b/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py index c2334973..5c94d5ba 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py @@ -631,6 +631,25 @@ class TestConnectionsClient: ): client._get_access_token() + def test_get_access_token_default_credentials_error( + self, project, location, connection_name + ): + client = ConnectionsClient(project, location, connection_name, None) + with mock.patch( + "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential", + side_effect=google.auth.exceptions.DefaultCredentialsError( + "ADC not found" + ), + ): + with pytest.raises( + ValueError, + match=( + "Please provide a service account that has the required" + " permissions" + ), + ): + client._get_access_token() + def test_get_access_token_refreshes_expired_token( self, project, location, connection_name, mock_credentials ): diff --git a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py index 77e74653..eea0cbc3 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py @@ -595,6 +595,34 @@ class TestIntegrationClient: in str(e) ) + def test_get_access_token_default_credentials_error( + self, project, location, integration_name, triggers, connection_name + ): + with mock.patch( + "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", + side_effect=google.auth.exceptions.DefaultCredentialsError( + "ADC not found" + ), + ): + client = IntegrationClient( + project=project, + location=location, + integration=integration_name, + triggers=triggers, + connection=connection_name, + entity_operations=None, + actions=None, + service_account_json=None, + ) + with pytest.raises( + ValueError, + match=( + "Please provide a service account that has the required" + " permissions to access the connection." + ), + ): + client._get_access_token() + def test_get_access_token_uses_cached_token( self, project, diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 8c182c89..150cdb75 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -1136,6 +1136,33 @@ def test_execute_sql_result_dtype( assert result == {"status": "SUCCESS", "rows": tool_result_rows} +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True) +@mock.patch.object(bigquery.Client, "query", autospec=True) +def test_execute_sql_result_dtype_circular_reference( + mock_query, mock_query_and_wait +): + """Test execute_sql converts circular values to strings.""" + credentials = mock.create_autospec(Credentials, instance=True) + tool_settings = BigQueryToolConfig() + tool_context = mock.create_autospec(ToolContext, instance=True) + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = "SELECT" + mock_query.return_value = query_job + circular_value = [] + circular_value.append(circular_value) + mock_query_and_wait.return_value = [{"x": circular_value}] + + result = query_tool.execute_sql( + "my_project", "SELECT 1", credentials, tool_settings, tool_context + ) + + assert result == { + "status": "SUCCESS", + "rows": [{"x": str(circular_value)}], + } + + @mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_execute_sql_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during execute_sql tool invocation.""" diff --git a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py index 01eaf8d1..0bd0fedc 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py @@ -135,3 +135,37 @@ def test_execute_sql_error(): tool_context=tool_context, ) assert result == {"status": "ERROR", "error_details": "Test error"} + + +def test_execute_sql_row_value_circular_reference_fallback(): + """Test execute_sql converts circular row values to strings.""" + project = "my_project" + instance_id = "my_instance" + query = "SELECT * FROM my_table" + credentials = mock.create_autospec(Credentials, instance=True) + tool_context = mock.create_autospec(ToolContext, instance=True) + + with mock.patch( + "google.adk.tools.bigtable.client.get_bigtable_data_client" + ) as mock_get_client: + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) + mock_client.execute_query.return_value = mock_iterator + circular_value = [] + circular_value.append(circular_value) + mock_row = mock.MagicMock() + mock_row.fields = {"col1": circular_value} + mock_iterator.__iter__.return_value = [mock_row] + + result = execute_sql( + project_id=project, + instance_id=instance_id, + credentials=credentials, + query=query, + settings=BigtableToolSettings(), + tool_context=tool_context, + ) + + assert result["status"] == "SUCCESS" + assert result["rows"][0]["col1"] == str(circular_value) diff --git a/tests/unittests/tools/spanner/test_metadata_tool.py b/tests/unittests/tools/spanner/test_metadata_tool.py index 6ea7dd16..fcfcd4bd 100644 --- a/tests/unittests/tools/spanner/test_metadata_tool.py +++ b/tests/unittests/tools/spanner/test_metadata_tool.py @@ -192,6 +192,45 @@ def test_list_table_indexes_success( assert result["results"][0]["INDEX_NAME"] == "PRIMARY_KEY" +@patch("google.adk.tools.spanner.client.get_spanner_client") +def test_list_table_indexes_circular_row_fallback_to_string( + mock_get_spanner_client, mock_spanner_ids, mock_credentials +): + """Test list_table_indexes stringifies rows with circular references.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_snapshot = MagicMock() + circular_value = [] + circular_value.append(circular_value) + mock_result_set = MagicMock() + mock_result_set.__iter__.return_value = iter([( + circular_value, + "", + "PRIMARY_KEY", + "", + True, + False, + None, + )]) + mock_snapshot.execute_sql.return_value = mock_result_set + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = metadata_tool.list_table_indexes( + mock_spanner_ids["project_id"], + mock_spanner_ids["instance_id"], + mock_spanner_ids["database_id"], + mock_spanner_ids["table_name"], + mock_credentials, + ) + assert result["status"] == "SUCCESS" + assert isinstance(result["results"][0], str) + + @patch("google.adk.tools.spanner.client.get_spanner_client") def test_list_table_index_columns_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py index 2f0976d0..4532dd56 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -187,6 +187,47 @@ def test_similarity_search_error( assert "Test Exception" in result["error_details"] +@mock.patch.object(utils, "embed_contents") +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_circular_row_fallback_to_string( + mock_get_spanner_client, + mock_embed_contents, + mock_spanner_ids, + mock_credentials, +): + """Test similarity_search stringifies rows with circular references.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_snapshot = MagicMock() + circular_row = [] + circular_row.append(circular_row) + mock_embed_contents.return_value = [[0.1, 0.2, 0.3]] + mock_snapshot.execute_sql.return_value = iter([circular_row]) + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options={ + "vertex_ai_embedding_model_name": "text-embedding-005" + }, + credentials=mock_credentials, + ) + + assert result["status"] == "SUCCESS", result + assert result["rows"] == [str(circular_row)] + + @mock.patch.object(client, "get_spanner_client") def test_similarity_search_postgresql_knn_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials diff --git a/tests/unittests/tools/spanner/test_utils.py b/tests/unittests/tools/spanner/test_utils.py index 0dd58f24..fe8d7db4 100644 --- a/tests/unittests/tools/spanner/test_utils.py +++ b/tests/unittests/tools/spanner/test_utils.py @@ -178,6 +178,35 @@ def test_add_contents_empty_contents( mock_spanner_database.batch.assert_not_called() +@mock.patch.object(spanner_utils.client, "get_spanner_client", autospec=True) +def test_execute_sql_circular_row_fallback_to_string(mock_get_spanner_client): + """Test execute_sql stringifies rows with circular references.""" + mock_spanner_client = mock.MagicMock() + mock_instance = mock.MagicMock() + mock_database = mock.MagicMock() + mock_snapshot = mock.MagicMock() + circular_row = [] + circular_row.append(circular_row) + mock_snapshot.execute_sql.return_value = iter([circular_row]) + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = spanner_utils.execute_sql( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + query="SELECT 1", + credentials=mock.Mock(), + settings=SpannerToolSettings(), + tool_context=mock.Mock(), + ) + + assert result == {"status": "SUCCESS", "rows": [str(circular_row)]} + + @mock.patch.object(spanner_utils, "embed_contents", autospec=True) def test_add_contents_additional_columns_list_mismatch( mock_embed_contents, spanner_tool_settings, mock_spanner_client