chore: Use context_id as session_id and construct temp user_id from context_id

PiperOrigin-RevId: 776639713
This commit is contained in:
Xiang (Sean) Zhou
2025-06-27 10:41:34 -07:00
committed by Copybara-Service
parent 51a559eb2a
commit 09e487df3c
2 changed files with 77 additions and 220 deletions
@@ -33,37 +33,19 @@ from google.genai import types as genai_types
from ...runners import RunConfig
from ...utils.feature_decorator import working_in_progress
from .part_converter import convert_a2a_part_to_genai_part
from .utils import _from_a2a_context_id
from .utils import _get_adk_metadata_key
def _get_user_id(request: RequestContext, user_id_from_context: str) -> str:
def _get_user_id(request: RequestContext) -> str:
# Get user from call context if available (auth is enabled on a2a server)
if request.call_context and request.call_context.user:
if (
request.call_context
and request.call_context.user
and request.call_context.user.user_name
):
return request.call_context.user.user_name
# Get user from context id if available
if user_id_from_context:
return user_id_from_context
# Get user from message metadata if available (client is an ADK agent)
if request.message.metadata:
user_id = request.message.metadata.get(_get_adk_metadata_key('user_id'))
if user_id:
return f'ADK_USER_{user_id}'
# Get user from task if available (client is a an ADK agent)
if request.current_task:
user_id = request.current_task.metadata.get(
_get_adk_metadata_key('user_id')
)
if user_id:
return f'ADK_USER_{user_id}'
return (
f'temp_user_{request.task_id}'
if request.task_id
else f'TEMP_USER_{request.message.messageId}'
)
# Get user from context id
return f'A2A_USER_{request.context_id}'
@working_in_progress
@@ -74,11 +56,9 @@ def convert_a2a_request_to_adk_run_args(
if not request.message:
raise ValueError('Request message cannot be None')
_, user_id, session_id = _from_a2a_context_id(request.context_id)
return {
'user_id': _get_user_id(request, user_id),
'session_id': session_id,
'user_id': _get_user_id(request),
'session_id': request.context_id,
'new_message': genai_types.Content(
role='user',
parts=[
@@ -41,7 +41,7 @@ except ImportError as e:
genai_types = DummyTypes()
RequestContext = DummyTypes()
RunConfig = DummyTypes()
_get_user_id = lambda x, y: None
_get_user_id = lambda x: None
convert_a2a_request_to_adk_run_args = lambda x: None
else:
raise e
@@ -61,12 +61,10 @@ class TestGetUserId:
request = Mock(spec=RequestContext)
request.call_context = mock_call_context
request.message = Mock()
request.current_task = None
request.task_id = "task123"
request.context_id = "test_context"
# Act
result = _get_user_id(request, "context_user")
result = _get_user_id(request)
# Assert
assert result == "authenticated_user"
@@ -76,15 +74,13 @@ class TestGetUserId:
# Arrange
request = Mock(spec=RequestContext)
request.call_context = None
request.message = Mock()
request.current_task = None
request.task_id = "task123"
request.context_id = "test_context"
# Act
result = _get_user_id(request, "context_user")
result = _get_user_id(request)
# Assert
assert result == "context_user"
assert result == "A2A_USER_test_context"
def test_get_user_id_from_context_when_call_context_has_no_user(self):
"""Test getting user ID from context when call context has no user."""
@@ -94,133 +90,64 @@ class TestGetUserId:
request = Mock(spec=RequestContext)
request.call_context = mock_call_context
request.message = Mock()
request.current_task = None
request.task_id = "task123"
request.context_id = "test_context"
# Act
result = _get_user_id(request, "context_user")
result = _get_user_id(request)
# Assert
assert result == "context_user"
assert result == "A2A_USER_test_context"
def test_get_user_id_from_message_metadata(self):
"""Test getting user ID from message metadata when context user is not available."""
def test_get_user_id_with_empty_user_name(self):
"""Test getting user ID when user exists but user_name is empty."""
# Arrange
mock_message = Mock()
mock_message.metadata = {"adk_user_id": "message_user"}
mock_user = Mock()
mock_user.user_name = ""
mock_call_context = Mock()
mock_call_context.user = mock_user
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = None
request.task_id = "task123"
request.call_context = mock_call_context
request.context_id = "test_context"
# Act
result = _get_user_id(request, "")
result = _get_user_id(request)
# Assert
assert result == "ADK_USER_message_user"
assert result == "A2A_USER_test_context"
def test_get_user_id_from_task_metadata(self):
"""Test getting user ID from task metadata when message metadata is not available."""
def test_get_user_id_with_none_user_name(self):
"""Test getting user ID when user exists but user_name is None."""
# Arrange
mock_message = Mock()
mock_message.metadata = None
mock_user = Mock()
mock_user.user_name = None
mock_task = Mock()
mock_task.metadata = {"adk_user_id": "task_user"}
mock_call_context = Mock()
mock_call_context.user = mock_user
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = mock_task
request.task_id = "task123"
request.call_context = mock_call_context
request.context_id = "test_context"
# Act
result = _get_user_id(request, "")
result = _get_user_id(request)
# Assert
assert result == "ADK_USER_task_user"
assert result == "A2A_USER_test_context"
def test_get_user_id_fallback_to_task_id(self):
"""Test fallback to task ID when no other user ID is available."""
def test_get_user_id_with_none_context_id(self):
"""Test getting user ID when context_id is None."""
# Arrange
mock_message = Mock()
mock_message.metadata = None
mock_message.messageId = "msg456"
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = None
request.task_id = "task123"
request.context_id = None
# Act
result = _get_user_id(request, "")
result = _get_user_id(request)
# Assert
assert result == "temp_user_task123"
def test_get_user_id_fallback_to_message_id(self):
"""Test fallback to message ID when no task ID is available."""
# Arrange
mock_message = Mock()
mock_message.metadata = None
mock_message.messageId = "msg456"
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = None
request.task_id = None
# Act
result = _get_user_id(request, "")
# Assert
assert result == "TEMP_USER_msg456"
def test_get_user_id_message_metadata_empty(self):
"""Test getting user ID when message metadata exists but doesn't contain user_id."""
# Arrange
mock_message = Mock()
mock_message.metadata = {"other_key": "other_value"}
mock_message.messageId = "msg456"
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = None
request.task_id = "task123"
# Act
result = _get_user_id(request, "")
# Assert
assert result == "temp_user_task123"
def test_get_user_id_task_metadata_empty(self):
"""Test getting user ID when task metadata exists but doesn't contain user_id."""
# Arrange
mock_message = Mock()
mock_message.metadata = None
mock_message.messageId = "msg456"
mock_task = Mock()
mock_task.metadata = {"other_key": "other_value"}
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.current_task = mock_task
request.task_id = "task123"
# Act
result = _get_user_id(request, "")
# Assert
assert result == "temp_user_task123"
assert result == "A2A_USER_None"
class TestConvertA2aRequestToAdkRunArgs:
@@ -229,11 +156,7 @@ class TestConvertA2aRequestToAdkRunArgs:
@patch(
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
)
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
@patch("google.adk.a2a.converters.request_converter._get_user_id")
def test_convert_a2a_request_basic(
self, mock_get_user_id, mock_from_context_id, mock_convert_part
):
def test_convert_a2a_request_basic(self, mock_convert_part):
"""Test basic conversion of A2A request to ADK run args."""
# Arrange
mock_part1 = Mock()
@@ -242,16 +165,16 @@ class TestConvertA2aRequestToAdkRunArgs:
mock_message = Mock()
mock_message.parts = [mock_part1, mock_part2]
mock_user = Mock()
mock_user.user_name = "test_user"
mock_call_context = Mock()
mock_call_context.user = mock_user
request = Mock(spec=RequestContext)
request.message = mock_message
request.context_id = "ADK/app/user/session"
mock_from_context_id.return_value = (
"app_name",
"user_from_context",
"session123",
)
mock_get_user_id.return_value = "final_user"
request.context_id = "test_context_123"
request.call_context = mock_call_context
# Create proper genai_types.Part objects instead of mocks
mock_genai_part1 = genai_types.Part(text="test part 1")
@@ -263,16 +186,14 @@ class TestConvertA2aRequestToAdkRunArgs:
# Assert
assert result is not None
assert result["user_id"] == "final_user"
assert result["session_id"] == "session123"
assert result["user_id"] == "test_user"
assert result["session_id"] == "test_context_123"
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2]
assert isinstance(result["run_config"], RunConfig)
# Verify calls
mock_from_context_id.assert_called_once_with("ADK/app/user/session")
mock_get_user_id.assert_called_once_with(request, "user_from_context")
assert mock_convert_part.call_count == 2
mock_convert_part.assert_any_call(mock_part1)
mock_convert_part.assert_any_call(mock_part2)
@@ -290,11 +211,7 @@ class TestConvertA2aRequestToAdkRunArgs:
@patch(
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
)
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
@patch("google.adk.a2a.converters.request_converter._get_user_id")
def test_convert_a2a_request_empty_parts(
self, mock_get_user_id, mock_from_context_id, mock_convert_part
):
def test_convert_a2a_request_empty_parts(self, mock_convert_part):
"""Test conversion with empty parts list."""
# Arrange
mock_message = Mock()
@@ -302,22 +219,16 @@ class TestConvertA2aRequestToAdkRunArgs:
request = Mock(spec=RequestContext)
request.message = mock_message
request.context_id = "ADK/app/user/session"
mock_from_context_id.return_value = (
"app_name",
"user_from_context",
"session123",
)
mock_get_user_id.return_value = "final_user"
request.context_id = "test_context_123"
request.call_context = None
# Act
result = convert_a2a_request_to_adk_run_args(request)
# Assert
assert result is not None
assert result["user_id"] == "final_user"
assert result["session_id"] == "session123"
assert result["user_id"] == "A2A_USER_test_context_123"
assert result["session_id"] == "test_context_123"
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
assert result["new_message"].parts == []
@@ -329,11 +240,7 @@ class TestConvertA2aRequestToAdkRunArgs:
@patch(
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
)
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
@patch("google.adk.a2a.converters.request_converter._get_user_id")
def test_convert_a2a_request_none_context_id(
self, mock_get_user_id, mock_from_context_id, mock_convert_part
):
def test_convert_a2a_request_none_context_id(self, mock_convert_part):
"""Test conversion when context_id is None."""
# Arrange
mock_part = Mock()
@@ -343,9 +250,7 @@ class TestConvertA2aRequestToAdkRunArgs:
request = Mock(spec=RequestContext)
request.message = mock_message
request.context_id = None
mock_from_context_id.return_value = (None, None, None)
mock_get_user_id.return_value = "fallback_user"
request.call_context = None
# Create proper genai_types.Part object instead of mock
mock_genai_part = genai_types.Part(text="test part")
@@ -356,26 +261,18 @@ class TestConvertA2aRequestToAdkRunArgs:
# Assert
assert result is not None
assert result["user_id"] == "fallback_user"
assert result["user_id"] == "A2A_USER_None"
assert result["session_id"] is None
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
assert result["new_message"].parts == [mock_genai_part]
assert isinstance(result["run_config"], RunConfig)
# Verify calls
mock_from_context_id.assert_called_once_with(None)
mock_get_user_id.assert_called_once_with(request, None)
@patch(
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
)
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
@patch("google.adk.a2a.converters.request_converter._get_user_id")
def test_convert_a2a_request_invalid_context_id(
self, mock_get_user_id, mock_from_context_id, mock_convert_part
):
"""Test conversion when context_id is invalid format."""
def test_convert_a2a_request_no_auth(self, mock_convert_part):
"""Test conversion when no authentication is available."""
# Arrange
mock_part = Mock()
mock_message = Mock()
@@ -383,10 +280,8 @@ class TestConvertA2aRequestToAdkRunArgs:
request = Mock(spec=RequestContext)
request.message = mock_message
request.context_id = "invalid_format"
mock_from_context_id.return_value = (None, None, None)
mock_get_user_id.return_value = "fallback_user"
request.context_id = "session_123"
request.call_context = None
# Create proper genai_types.Part object instead of mock
mock_genai_part = genai_types.Part(text="test part")
@@ -397,17 +292,13 @@ class TestConvertA2aRequestToAdkRunArgs:
# Assert
assert result is not None
assert result["user_id"] == "fallback_user"
assert result["session_id"] is None
assert result["user_id"] == "A2A_USER_session_123"
assert result["session_id"] == "session_123"
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
assert result["new_message"].parts == [mock_genai_part]
assert isinstance(result["run_config"], RunConfig)
# Verify calls
mock_from_context_id.assert_called_once_with("invalid_format")
mock_get_user_id.assert_called_once_with(request, None)
class TestIntegration:
"""Integration test cases combining both functions."""
@@ -431,9 +322,7 @@ class TestIntegration:
request = Mock(spec=RequestContext)
request.call_context = mock_call_context
request.message = mock_message
request.context_id = "ADK/myapp/context_user/mysession"
request.current_task = None
request.task_id = "task123"
request.context_id = "mysession"
# Create proper genai_types.Part object instead of mock
mock_genai_part = genai_types.Part(text="test part")
@@ -444,9 +333,7 @@ class TestIntegration:
# Assert
assert result is not None
assert (
result["user_id"] == "auth_user"
) # Should use authenticated user, not context user
assert result["user_id"] == "auth_user" # Should use authenticated user
assert result["session_id"] == "mysession"
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
@@ -456,27 +343,17 @@ class TestIntegration:
@patch(
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
)
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
def test_end_to_end_conversion_with_fallback_user(
self, mock_from_context_id, mock_convert_part
):
def test_end_to_end_conversion_with_fallback_user(self, mock_convert_part):
"""Test end-to-end conversion with fallback user ID."""
# Arrange
mock_part = Mock()
mock_message = Mock()
mock_message.parts = [mock_part]
mock_message.messageId = "msg789"
mock_message.metadata = None
request = Mock(spec=RequestContext)
request.call_context = None
request.message = mock_message
request.context_id = "invalid_format"
request.current_task = None
request.task_id = None
# Mock the utils function to return None values for invalid context
mock_from_context_id.return_value = (None, None, None)
request.context_id = "test_session_456"
# Create proper genai_types.Part object instead of mock
mock_genai_part = genai_types.Part(text="test part")
@@ -488,9 +365,9 @@ class TestIntegration:
# Assert
assert result is not None
assert (
result["user_id"] == "TEMP_USER_msg789"
) # Should fallback to message ID
assert result["session_id"] is None
result["user_id"] == "A2A_USER_test_session_456"
) # Should fallback to context ID
assert result["session_id"] == "test_session_456"
assert isinstance(result["new_message"], genai_types.Content)
assert result["new_message"].role == "user"
assert result["new_message"].parts == [mock_genai_part]