You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Use context_id as session_id and construct temp user_id from context_id
PiperOrigin-RevId: 776639713
This commit is contained in:
committed by
Copybara-Service
parent
51a559eb2a
commit
09e487df3c
@@ -33,37 +33,19 @@ from google.genai import types as genai_types
|
||||
from ...runners import RunConfig
|
||||
from ...utils.feature_decorator import working_in_progress
|
||||
from .part_converter import convert_a2a_part_to_genai_part
|
||||
from .utils import _from_a2a_context_id
|
||||
from .utils import _get_adk_metadata_key
|
||||
|
||||
|
||||
def _get_user_id(request: RequestContext, user_id_from_context: str) -> str:
|
||||
def _get_user_id(request: RequestContext) -> str:
|
||||
# Get user from call context if available (auth is enabled on a2a server)
|
||||
if request.call_context and request.call_context.user:
|
||||
if (
|
||||
request.call_context
|
||||
and request.call_context.user
|
||||
and request.call_context.user.user_name
|
||||
):
|
||||
return request.call_context.user.user_name
|
||||
|
||||
# Get user from context id if available
|
||||
if user_id_from_context:
|
||||
return user_id_from_context
|
||||
|
||||
# Get user from message metadata if available (client is an ADK agent)
|
||||
if request.message.metadata:
|
||||
user_id = request.message.metadata.get(_get_adk_metadata_key('user_id'))
|
||||
if user_id:
|
||||
return f'ADK_USER_{user_id}'
|
||||
|
||||
# Get user from task if available (client is a an ADK agent)
|
||||
if request.current_task:
|
||||
user_id = request.current_task.metadata.get(
|
||||
_get_adk_metadata_key('user_id')
|
||||
)
|
||||
if user_id:
|
||||
return f'ADK_USER_{user_id}'
|
||||
return (
|
||||
f'temp_user_{request.task_id}'
|
||||
if request.task_id
|
||||
else f'TEMP_USER_{request.message.messageId}'
|
||||
)
|
||||
# Get user from context id
|
||||
return f'A2A_USER_{request.context_id}'
|
||||
|
||||
|
||||
@working_in_progress
|
||||
@@ -74,11 +56,9 @@ def convert_a2a_request_to_adk_run_args(
|
||||
if not request.message:
|
||||
raise ValueError('Request message cannot be None')
|
||||
|
||||
_, user_id, session_id = _from_a2a_context_id(request.context_id)
|
||||
|
||||
return {
|
||||
'user_id': _get_user_id(request, user_id),
|
||||
'session_id': session_id,
|
||||
'user_id': _get_user_id(request),
|
||||
'session_id': request.context_id,
|
||||
'new_message': genai_types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
|
||||
@@ -41,7 +41,7 @@ except ImportError as e:
|
||||
genai_types = DummyTypes()
|
||||
RequestContext = DummyTypes()
|
||||
RunConfig = DummyTypes()
|
||||
_get_user_id = lambda x, y: None
|
||||
_get_user_id = lambda x: None
|
||||
convert_a2a_request_to_adk_run_args = lambda x: None
|
||||
else:
|
||||
raise e
|
||||
@@ -61,12 +61,10 @@ class TestGetUserId:
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = mock_call_context
|
||||
request.message = Mock()
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.context_id = "test_context"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "context_user")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "authenticated_user"
|
||||
@@ -76,15 +74,13 @@ class TestGetUserId:
|
||||
# Arrange
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = Mock()
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.context_id = "test_context"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "context_user")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "context_user"
|
||||
assert result == "A2A_USER_test_context"
|
||||
|
||||
def test_get_user_id_from_context_when_call_context_has_no_user(self):
|
||||
"""Test getting user ID from context when call context has no user."""
|
||||
@@ -94,133 +90,64 @@ class TestGetUserId:
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = mock_call_context
|
||||
request.message = Mock()
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.context_id = "test_context"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "context_user")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "context_user"
|
||||
assert result == "A2A_USER_test_context"
|
||||
|
||||
def test_get_user_id_from_message_metadata(self):
|
||||
"""Test getting user ID from message metadata when context user is not available."""
|
||||
def test_get_user_id_with_empty_user_name(self):
|
||||
"""Test getting user ID when user exists but user_name is empty."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = {"adk_user_id": "message_user"}
|
||||
mock_user = Mock()
|
||||
mock_user.user_name = ""
|
||||
|
||||
mock_call_context = Mock()
|
||||
mock_call_context.user = mock_user
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.call_context = mock_call_context
|
||||
request.context_id = "test_context"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "ADK_USER_message_user"
|
||||
assert result == "A2A_USER_test_context"
|
||||
|
||||
def test_get_user_id_from_task_metadata(self):
|
||||
"""Test getting user ID from task metadata when message metadata is not available."""
|
||||
def test_get_user_id_with_none_user_name(self):
|
||||
"""Test getting user ID when user exists but user_name is None."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = None
|
||||
mock_user = Mock()
|
||||
mock_user.user_name = None
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.metadata = {"adk_user_id": "task_user"}
|
||||
mock_call_context = Mock()
|
||||
mock_call_context.user = mock_user
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = mock_task
|
||||
request.task_id = "task123"
|
||||
request.call_context = mock_call_context
|
||||
request.context_id = "test_context"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "ADK_USER_task_user"
|
||||
assert result == "A2A_USER_test_context"
|
||||
|
||||
def test_get_user_id_fallback_to_task_id(self):
|
||||
"""Test fallback to task ID when no other user ID is available."""
|
||||
def test_get_user_id_with_none_context_id(self):
|
||||
"""Test getting user ID when context_id is None."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = None
|
||||
mock_message.messageId = "msg456"
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.context_id = None
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
result = _get_user_id(request)
|
||||
|
||||
# Assert
|
||||
assert result == "temp_user_task123"
|
||||
|
||||
def test_get_user_id_fallback_to_message_id(self):
|
||||
"""Test fallback to message ID when no task ID is available."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = None
|
||||
mock_message.messageId = "msg456"
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = None
|
||||
request.task_id = None
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
|
||||
# Assert
|
||||
assert result == "TEMP_USER_msg456"
|
||||
|
||||
def test_get_user_id_message_metadata_empty(self):
|
||||
"""Test getting user ID when message metadata exists but doesn't contain user_id."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = {"other_key": "other_value"}
|
||||
mock_message.messageId = "msg456"
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
|
||||
# Assert
|
||||
assert result == "temp_user_task123"
|
||||
|
||||
def test_get_user_id_task_metadata_empty(self):
|
||||
"""Test getting user ID when task metadata exists but doesn't contain user_id."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
mock_message.metadata = None
|
||||
mock_message.messageId = "msg456"
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.metadata = {"other_key": "other_value"}
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.current_task = mock_task
|
||||
request.task_id = "task123"
|
||||
|
||||
# Act
|
||||
result = _get_user_id(request, "")
|
||||
|
||||
# Assert
|
||||
assert result == "temp_user_task123"
|
||||
assert result == "A2A_USER_None"
|
||||
|
||||
|
||||
class TestConvertA2aRequestToAdkRunArgs:
|
||||
@@ -229,11 +156,7 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
@patch(
|
||||
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
|
||||
)
|
||||
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
|
||||
@patch("google.adk.a2a.converters.request_converter._get_user_id")
|
||||
def test_convert_a2a_request_basic(
|
||||
self, mock_get_user_id, mock_from_context_id, mock_convert_part
|
||||
):
|
||||
def test_convert_a2a_request_basic(self, mock_convert_part):
|
||||
"""Test basic conversion of A2A request to ADK run args."""
|
||||
# Arrange
|
||||
mock_part1 = Mock()
|
||||
@@ -242,16 +165,16 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
mock_message = Mock()
|
||||
mock_message.parts = [mock_part1, mock_part2]
|
||||
|
||||
mock_user = Mock()
|
||||
mock_user.user_name = "test_user"
|
||||
|
||||
mock_call_context = Mock()
|
||||
mock_call_context.user = mock_user
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.message = mock_message
|
||||
request.context_id = "ADK/app/user/session"
|
||||
|
||||
mock_from_context_id.return_value = (
|
||||
"app_name",
|
||||
"user_from_context",
|
||||
"session123",
|
||||
)
|
||||
mock_get_user_id.return_value = "final_user"
|
||||
request.context_id = "test_context_123"
|
||||
request.call_context = mock_call_context
|
||||
|
||||
# Create proper genai_types.Part objects instead of mocks
|
||||
mock_genai_part1 = genai_types.Part(text="test part 1")
|
||||
@@ -263,16 +186,14 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "final_user"
|
||||
assert result["session_id"] == "session123"
|
||||
assert result["user_id"] == "test_user"
|
||||
assert result["session_id"] == "test_context_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part1, mock_genai_part2]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
|
||||
# Verify calls
|
||||
mock_from_context_id.assert_called_once_with("ADK/app/user/session")
|
||||
mock_get_user_id.assert_called_once_with(request, "user_from_context")
|
||||
assert mock_convert_part.call_count == 2
|
||||
mock_convert_part.assert_any_call(mock_part1)
|
||||
mock_convert_part.assert_any_call(mock_part2)
|
||||
@@ -290,11 +211,7 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
@patch(
|
||||
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
|
||||
)
|
||||
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
|
||||
@patch("google.adk.a2a.converters.request_converter._get_user_id")
|
||||
def test_convert_a2a_request_empty_parts(
|
||||
self, mock_get_user_id, mock_from_context_id, mock_convert_part
|
||||
):
|
||||
def test_convert_a2a_request_empty_parts(self, mock_convert_part):
|
||||
"""Test conversion with empty parts list."""
|
||||
# Arrange
|
||||
mock_message = Mock()
|
||||
@@ -302,22 +219,16 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.message = mock_message
|
||||
request.context_id = "ADK/app/user/session"
|
||||
|
||||
mock_from_context_id.return_value = (
|
||||
"app_name",
|
||||
"user_from_context",
|
||||
"session123",
|
||||
)
|
||||
mock_get_user_id.return_value = "final_user"
|
||||
request.context_id = "test_context_123"
|
||||
request.call_context = None
|
||||
|
||||
# Act
|
||||
result = convert_a2a_request_to_adk_run_args(request)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "final_user"
|
||||
assert result["session_id"] == "session123"
|
||||
assert result["user_id"] == "A2A_USER_test_context_123"
|
||||
assert result["session_id"] == "test_context_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == []
|
||||
@@ -329,11 +240,7 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
@patch(
|
||||
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
|
||||
)
|
||||
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
|
||||
@patch("google.adk.a2a.converters.request_converter._get_user_id")
|
||||
def test_convert_a2a_request_none_context_id(
|
||||
self, mock_get_user_id, mock_from_context_id, mock_convert_part
|
||||
):
|
||||
def test_convert_a2a_request_none_context_id(self, mock_convert_part):
|
||||
"""Test conversion when context_id is None."""
|
||||
# Arrange
|
||||
mock_part = Mock()
|
||||
@@ -343,9 +250,7 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
request = Mock(spec=RequestContext)
|
||||
request.message = mock_message
|
||||
request.context_id = None
|
||||
|
||||
mock_from_context_id.return_value = (None, None, None)
|
||||
mock_get_user_id.return_value = "fallback_user"
|
||||
request.call_context = None
|
||||
|
||||
# Create proper genai_types.Part object instead of mock
|
||||
mock_genai_part = genai_types.Part(text="test part")
|
||||
@@ -356,26 +261,18 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "fallback_user"
|
||||
assert result["user_id"] == "A2A_USER_None"
|
||||
assert result["session_id"] is None
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
|
||||
# Verify calls
|
||||
mock_from_context_id.assert_called_once_with(None)
|
||||
mock_get_user_id.assert_called_once_with(request, None)
|
||||
|
||||
@patch(
|
||||
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
|
||||
)
|
||||
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
|
||||
@patch("google.adk.a2a.converters.request_converter._get_user_id")
|
||||
def test_convert_a2a_request_invalid_context_id(
|
||||
self, mock_get_user_id, mock_from_context_id, mock_convert_part
|
||||
):
|
||||
"""Test conversion when context_id is invalid format."""
|
||||
def test_convert_a2a_request_no_auth(self, mock_convert_part):
|
||||
"""Test conversion when no authentication is available."""
|
||||
# Arrange
|
||||
mock_part = Mock()
|
||||
mock_message = Mock()
|
||||
@@ -383,10 +280,8 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.message = mock_message
|
||||
request.context_id = "invalid_format"
|
||||
|
||||
mock_from_context_id.return_value = (None, None, None)
|
||||
mock_get_user_id.return_value = "fallback_user"
|
||||
request.context_id = "session_123"
|
||||
request.call_context = None
|
||||
|
||||
# Create proper genai_types.Part object instead of mock
|
||||
mock_genai_part = genai_types.Part(text="test part")
|
||||
@@ -397,17 +292,13 @@ class TestConvertA2aRequestToAdkRunArgs:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["user_id"] == "fallback_user"
|
||||
assert result["session_id"] is None
|
||||
assert result["user_id"] == "A2A_USER_session_123"
|
||||
assert result["session_id"] == "session_123"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
assert isinstance(result["run_config"], RunConfig)
|
||||
|
||||
# Verify calls
|
||||
mock_from_context_id.assert_called_once_with("invalid_format")
|
||||
mock_get_user_id.assert_called_once_with(request, None)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration test cases combining both functions."""
|
||||
@@ -431,9 +322,7 @@ class TestIntegration:
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = mock_call_context
|
||||
request.message = mock_message
|
||||
request.context_id = "ADK/myapp/context_user/mysession"
|
||||
request.current_task = None
|
||||
request.task_id = "task123"
|
||||
request.context_id = "mysession"
|
||||
|
||||
# Create proper genai_types.Part object instead of mock
|
||||
mock_genai_part = genai_types.Part(text="test part")
|
||||
@@ -444,9 +333,7 @@ class TestIntegration:
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert (
|
||||
result["user_id"] == "auth_user"
|
||||
) # Should use authenticated user, not context user
|
||||
assert result["user_id"] == "auth_user" # Should use authenticated user
|
||||
assert result["session_id"] == "mysession"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
@@ -456,27 +343,17 @@ class TestIntegration:
|
||||
@patch(
|
||||
"google.adk.a2a.converters.request_converter.convert_a2a_part_to_genai_part"
|
||||
)
|
||||
@patch("google.adk.a2a.converters.request_converter._from_a2a_context_id")
|
||||
def test_end_to_end_conversion_with_fallback_user(
|
||||
self, mock_from_context_id, mock_convert_part
|
||||
):
|
||||
def test_end_to_end_conversion_with_fallback_user(self, mock_convert_part):
|
||||
"""Test end-to-end conversion with fallback user ID."""
|
||||
# Arrange
|
||||
mock_part = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.parts = [mock_part]
|
||||
mock_message.messageId = "msg789"
|
||||
mock_message.metadata = None
|
||||
|
||||
request = Mock(spec=RequestContext)
|
||||
request.call_context = None
|
||||
request.message = mock_message
|
||||
request.context_id = "invalid_format"
|
||||
request.current_task = None
|
||||
request.task_id = None
|
||||
|
||||
# Mock the utils function to return None values for invalid context
|
||||
mock_from_context_id.return_value = (None, None, None)
|
||||
request.context_id = "test_session_456"
|
||||
|
||||
# Create proper genai_types.Part object instead of mock
|
||||
mock_genai_part = genai_types.Part(text="test part")
|
||||
@@ -488,9 +365,9 @@ class TestIntegration:
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert (
|
||||
result["user_id"] == "TEMP_USER_msg789"
|
||||
) # Should fallback to message ID
|
||||
assert result["session_id"] is None
|
||||
result["user_id"] == "A2A_USER_test_session_456"
|
||||
) # Should fallback to context ID
|
||||
assert result["session_id"] == "test_session_456"
|
||||
assert isinstance(result["new_message"], genai_types.Content)
|
||||
assert result["new_message"].role == "user"
|
||||
assert result["new_message"].parts == [mock_genai_part]
|
||||
|
||||
Reference in New Issue
Block a user