diff --git a/src/google/adk/tools/bigquery/data_insights_tool.py b/src/google/adk/tools/bigquery/data_insights_tool.py index b4fe8450..98476c5b 100644 --- a/src/google/adk/tools/bigquery/data_insights_tool.py +++ b/src/google/adk/tools/bigquery/data_insights_tool.py @@ -25,6 +25,8 @@ import requests from . import client from .config import BigQueryToolConfig +_GDA_CLIENT_ID = "GOOGLE_ADK" + def ask_data_insights( project_id: str, @@ -129,6 +131,7 @@ def ask_data_insights( headers = { "Authorization": f"Bearer {credentials.token}", "Content-Type": "application/json", + "X-Goog-API-Client": _GDA_CLIENT_ID, } ca_url = f"https://geminidataanalytics.googleapis.com/v1alpha/projects/{project_id}/locations/{location}:chat" @@ -149,7 +152,7 @@ def ask_data_insights( "systemInstruction": instructions, "options": {"chart": {"image": {"noImage": {}}}}, }, - "clientIdEnum": "GOOGLE_ADK", + "clientIdEnum": _GDA_CLIENT_ID, } resp = _get_stream( diff --git a/src/google/adk/tools/data_agent/data_agent_tool.py b/src/google/adk/tools/data_agent/data_agent_tool.py index 8b5a8882..ca58eb7c 100644 --- a/src/google/adk/tools/data_agent/data_agent_tool.py +++ b/src/google/adk/tools/data_agent/data_agent_tool.py @@ -23,6 +23,7 @@ from ..tool_context import ToolContext from .config import DataAgentToolConfig BASE_URL = "https://geminidataanalytics.googleapis.com/v1beta" +_GDA_CLIENT_ID = "GOOGLE_ADK" def _get_http_headers( @@ -41,6 +42,7 @@ def _get_http_headers( return { "Authorization": f"Bearer {credentials.token}", "Content-Type": "application/json", + "X-Goog-API-Client": _GDA_CLIENT_ID, } @@ -294,7 +296,7 @@ def ask_data_agent( "dataAgentContext": { "dataAgent": data_agent_name, }, - "clientIdEnum": "GOOGLE_ADK", + "clientIdEnum": _GDA_CLIENT_ID, } resp = _get_stream( chat_url, diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index c2bbe271..b62c6835 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -89,6 +89,12 @@ def test_ask_data_insights_success(mock_get_stream): assert result["response"] == "Final formatted string from stream" mock_get_stream.assert_called_once() + # Verify that the correct headers and client ID were passed to _get_stream + args, _ = mock_get_stream.call_args + headers = args[2] + assert headers["X-Goog-API-Client"] == "GOOGLE_ADK" + assert headers["Authorization"] == "Bearer fake-token" + @mock.patch.object(data_insights_tool, "_get_stream") def test_ask_data_insights_handles_exception(mock_get_stream): diff --git a/tests/unittests/tools/data_agent/test_data_agent_tool.py b/tests/unittests/tools/data_agent/test_data_agent_tool.py index 6aa57e65..54b3e8d3 100644 --- a/tests/unittests/tools/data_agent/test_data_agent_tool.py +++ b/tests/unittests/tools/data_agent/test_data_agent_tool.py @@ -196,3 +196,16 @@ def test_get_stream_from_file(mock_post, case_file_path): # 6. Assert that the final list of dicts matches the expected output assert result == expected_final_list + + +def test_get_http_headers_includes_client_id(): + """Tests _get_http_headers includes the correct GDA client ID.""" + mock_creds = mock.Mock() + mock_creds.token = "fake-token" + + # pylint: disable=protected-access + headers = data_agent_tool._get_http_headers(mock_creds) + + assert headers["X-Goog-API-Client"] == "GOOGLE_ADK" + assert headers["Content-Type"] == "application/json" + assert headers["Authorization"] == "Bearer fake-token"