From 8dff85099d67623dd6f4a707fb932ea55b8aaf9b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 3 Nov 2025 10:31:52 -0800 Subject: [PATCH] fix: Support models slash prefix in model name extraction Support "models/" prefix in model name extraction PiperOrigin-RevId: 827557443 --- src/google/adk/utils/model_name_utils.py | 8 ++++++-- tests/unittests/utils/test_model_name_utils.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 6e8add25..641988d4 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -27,8 +27,8 @@ def extract_model_name(model_string: str) -> str: """Extract the actual model name from either simple or path-based format. Args: - model_string: Either a simple model name like "gemini-2.5-pro" or - a path-based model name like "projects/.../models/gemini-2.0-flash-001" + model_string: Either a simple model name like "gemini-2.5-pro" or a + path-based model name like "projects/.../models/gemini-2.0-flash-001" Returns: The extracted model name (e.g., "gemini-2.5-pro") @@ -41,6 +41,10 @@ def extract_model_name(model_string: str) -> str: if match: return match.group(1) + # Handle 'models/' prefixed names like "models/gemini-2.5-pro" + if model_string.startswith('models/'): + return model_string[len('models/') :] + # If it's not a path-based model, return as-is (simple model name) return model_string diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index c80380df..2e3b70a9 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -42,6 +42,11 @@ class TestExtractModelName: path_model_3 = 'projects/test-project/locations/europe-west1/publishers/google/models/claude-3-sonnet' assert extract_model_name(path_model_3) == 'claude-3-sonnet' + def test_extract_model_name_with_models_prefix(self): + """Test extraction of model names with 'models/' prefix.""" + assert extract_model_name('models/gemini-2.5-pro') == 'gemini-2.5-pro' + assert extract_model_name('models/gemini-1.5-flash') == 'gemini-1.5-flash' + def test_extract_model_name_invalid_path(self): """Test that invalid path formats return the original string.""" invalid_paths = [