diff --git a/contributing/samples/hello_world_apigeellm/.env-sample b/contributing/samples/hello_world_apigeellm/.env-sample new file mode 100644 index 00000000..eeef7fad --- /dev/null +++ b/contributing/samples/hello_world_apigeellm/.env-sample @@ -0,0 +1,8 @@ +# This is a sample .env file. +# Copy this file to .env and replace the placeholder values with your actual credentials. + +# Your Google API key for accessing Gemini models. +GOOGLE_API_KEY="your-google-api-key" + +# The URL of your Apigee proxy. +APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath" diff --git a/contributing/samples/hello_world_apigeellm/README.md b/contributing/samples/hello_world_apigeellm/README.md new file mode 100644 index 00000000..41cfa50a --- /dev/null +++ b/contributing/samples/hello_world_apigeellm/README.md @@ -0,0 +1,84 @@ +# Hello World with Apigee LLM + +This sample demonstrates how to use the Agent Development Kit (ADK) with an LLM fronted by an Apigee proxy. It showcases the flexibility of the `ApigeeLlm` class in configuring the target LLM provider (Gemini or Vertex AI) and API version through the model string. + +## Setup + +Before running the sample, you need to configure your environment with the necessary credentials. + +1. **Create a `.env` file:** + Copy the sample environment file to a new file named `.env` in the same directory. + ```bash + cp .env-sample .env + ``` + +2. **Set Environment Variables:** + Open the `.env` file and provide values for the following variables: + + - `GOOGLE_API_KEY`: Your API key for the Google AI services (Gemini). + - `APIGEE_PROXY_URL`: The full URL of your Apigee proxy endpoint. + + Example `.env` file: + ``` + GOOGLE_API_KEY="your-google-api-key" + APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath" + ``` + + The `main.py` script will automatically load these variables when it runs. + +## Run the Sample + +Once your `.env` file is configured, you can run the sample with the following command: + +```bash +python main.py +``` + +## Configuring the Apigee LLM + +The `ApigeeLlm` class is configured using a special model string format in `agent.py`. This string determines which backend provider (Vertex AI or Gemini) and which API version to use. + +### Model String Format + +The supported format is: + +`apigee/[/][/]` + +- **`provider`** (optional): Can be `vertex_ai` or `gemini`. + - If specified, it forces the use of that provider. + - If omitted, the provider is determined by the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If this variable is set to `true` or `1`, Vertex AI is used; otherwise, `gemini` is used by default. + +- **`version`** (optional): The API version to use (e.g., `v1`, `v1beta`). + - If omitted, the default version for the selected provider is used. + +- **`model_id`** (required): The identifier for the model you want to use (e.g., `gemini-2.5-flash`). + +### Configuration Examples + +Here are some examples of how to configure the model string in `agent.py` to achieve different behaviors: + +1. **Implicit Provider (determined by environment variable):** + + - `model="apigee/gemini-2.5-flash"` + - Uses the default API version. + - Provider is Vertex AI if `GOOGLE_GENAI_USE_VERTEXAI` is true, otherwise Gemini. + + - `model="apigee/v1/gemini-2.5-flash"` + - Uses API version `v1`. + - Provider is determined by the environment variable. + +2. **Explicit Provider (ignores environment variable):** + + - `model="apigee/vertex_ai/gemini-2.5-flash"` + - Uses Vertex AI with the default API version. + + - `model="apigee/gemini/gemini-2.5-flash"` + - Uses Gemini with the default API version. + + - `model="apigee/gemini/v1/gemini-2.5-flash"` + - Uses Gemini with API version `v1`. + + - `model="apigee/vertex_ai/v1beta/gemini-2.5-flash"` + - Uses Vertex AI with API version `v1beta`. + +By modifying the `model` string in `agent.py`, you can test various configurations without changing the core logic of the agent. diff --git a/contributing/samples/hello_world_apigeellm/agent.py b/contributing/samples/hello_world_apigeellm/agent.py new file mode 100644 index 00000000..21bf0936 --- /dev/null +++ b/contributing/samples/hello_world_apigeellm/agent.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if "rolls" not in tool_context.state: + tool_context.state["rolls"] = [] + + tool_context.state["rolls"] = tool_context.state["rolls"] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model="apigee/gemini-2.5-flash", + name="hello_world_agent", + description=( + "hello world agent that can roll a dice of 8 sides and check prime" + " numbers." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + # planner=BuiltInPlanner( + # thinking_config=types.ThinkingConfig( + # include_thoughts=True, + # ), + # ), + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/hello_world_apigeellm/main.py b/contributing/samples/hello_world_apigeellm/main.py new file mode 100644 index 00000000..1e81097d --- /dev/null +++ b/contributing/samples/hello_world_apigeellm/main.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import time + +import agent +from dotenv import load_dotenv +from google.adk.agents.run_config import RunConfig +from google.adk.cli.utils import logs +from google.adk.runners import InMemoryRunner +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = "my_app" + user_id_1 = "user1" + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + session_11 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role="user", parts=[types.Part.from_text(text=new_message)] + ) + print("** User says:", content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f"** {event.author}: {event.content.parts[0].text}") + + async def run_prompt_bytes(session: Session, new_message: str): + content = types.Content( + role="user", + parts=[ + types.Part.from_bytes( + data=str.encode(new_message), mime_type="text/plain" + ) + ], + ) + print("** User says:", content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ): + if event.content.parts and event.content.parts[0].text: + print(f"** {event.author}: {event.content.parts[0].text}") + + async def check_rolls_in_state(rolls_size: int): + session = await runner.session_service.get_session( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + assert len(session.state["rolls"]) == rolls_size + for roll in session.state["rolls"]: + assert roll > 0 and roll <= 100 + + start_time = time.time() + print("Start time:", start_time) + print("------------------------------------") + await run_prompt(session_11, "Hi") + await run_prompt(session_11, "Roll a die with 100 sides") + await check_rolls_in_state(1) + await run_prompt(session_11, "Roll a die again with 100 sides.") + await check_rolls_in_state(2) + await run_prompt(session_11, "What numbers did I got?") + await run_prompt_bytes(session_11, "Hi bytes") + print( + await runner.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + ) + end_time = time.time() + print("------------------------------------") + print("End time:", end_time) + print("Total time:", end_time - start_time) + + +if __name__ == "__main__": + # The API key can be set in a .env file. + # For example, create a .env file with the following content: + # GOOGLE_API_KEY="your-api-key" + # APIGEE_PROXY_URL="your-proxy-url" + if not os.getenv("GOOGLE_API_KEY"): + raise ValueError("GOOGLE_API_KEY environment variable is not set.") + if not os.getenv("APIGEE_PROXY_URL"): + raise ValueError("APIGEE_PROXY_URL environment variable is not set.") + asyncio.run(main()) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index c08570a9..9f3c2a2c 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -14,6 +14,7 @@ """Defines the interface to support a model.""" +from .apigee_llm import ApigeeLlm from .base_llm import BaseLlm from .gemma_llm import Gemma from .google_llm import Gemini @@ -31,3 +32,4 @@ __all__ = [ LLMRegistry.register(Gemini) LLMRegistry.register(Gemma) +LLMRegistry.register(ApigeeLlm) diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py new file mode 100644 index 00000000..c93e5fe2 --- /dev/null +++ b/src/google/adk/models/apigee_llm.py @@ -0,0 +1,232 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from functools import cached_property +import logging +import os +import re +from typing import Optional +from typing import TYPE_CHECKING + +from google.adk import version as adk_version +from google.genai import Client +from google.genai import types +from typing_extensions import override + +from .google_llm import Gemini + +if TYPE_CHECKING: + from .llm_request import LlmRequest + +logger = logging.getLogger('google_adk.' + __name__) + +_APIGEE_PROXY_URL_ENV_VARIABLE_NAME = 'APIGEE_PROXY_URL' +_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_USE_VERTEXAI' + + +class ApigeeLlm(Gemini): + """A BaseLlm implementation for calling Apigee proxy. + + Attributes: + model: The name of the Gemini model. + """ + + def __init__( + self, + *, + model: str, + proxy_url: str | None = None, + custom_headers: dict[str, str] | None = None, + retry_options: Optional[types.HttpRetryOptions] = None, + ): + """Initializes the Apigee LLM backend. + + Args: + model: The model string specifies the LLM provider (e.g., Vertex AI, + Gemini), API version, and the model ID. Supported format: + `apigee/[/][/]` + + Components + `provider` (optional): `vertex_ai` or `gemini`. If omitted, behavior + depends on the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If + that is not set to TRUE or 1, it defaults to `gemini`. `provider` + takes precedence over `GOOGLE_GENAI_USE_VERTEXAI`. + `version` (optional): The API version (e.g., `v1`, `v1beta`). If + omitted, the default version for the provider is used. + `model_id` (required): The model identifier (e.g., + `gemini-2.5-flash`). + + Examples + - `apigee/gemini-2.5-flash` + - `apigee/v1/gemini-2.5-flash` + - `apigee/vertex_ai/gemini-2.5-flash` + - `apigee/gemini/v1/gemini-2.5-flash` + - `apigee/vertex_ai/v1beta/gemini-2.5-flash` + + proxy_url: The URL of the Apigee proxy. + custom_headers: A dictionary of headers to be sent with the request. + retry_options: Allow google-genai to retry failed responses. + """ + + super().__init__(model=model, retry_options=retry_options) + # Validate the model string. Create a helper method to validate the model + # string. + if not _validate_model_string(model): + raise ValueError(f'Invalid model string: {model}') + + self._isvertexai = _identify_vertexai(model) + self._api_version = _identify_api_version(model) + self._proxy_url = proxy_url or os.environ.get( + _APIGEE_PROXY_URL_ENV_VARIABLE_NAME + ) + self._custom_headers = custom_headers or {} + self._user_agent = f'google-adk/{adk_version.__version__}' + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r'apigee\/.*', + ] + + @cached_property + def api_client(self) -> Client: + """Provides the api client. + + Returns: + The api client. + """ + + kwargs_for_http_options = {} + if self._api_version: + kwargs_for_http_options['api_version'] = self._api_version + http_options = types.HttpOptions( + base_url=self._proxy_url, + headers=self._merge_tracking_headers(self._custom_headers), + retry_options=self.retry_options, + **kwargs_for_http_options, + ) + + return Client( + vertexai=self._isvertexai, + http_options=http_options, + ) + + @override + async def _preprocess_request(self, llm_request: LlmRequest) -> None: + llm_request.model = _get_model_id(llm_request.model) + await super()._preprocess_request(llm_request) + + +def _identify_vertexai(model: str) -> bool: + """Returns True if the model spec starts with apigee/vertex_ai.""" + return not model.startswith('apigee/gemini/') and ( + model.startswith('apigee/vertex_ai/') + or os.environ.get( + _GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME, '0' + ).lower() + in ['true', '1'] + ) + + +def _identify_api_version(model: str) -> str: + """Returns the api version for the model spec.""" + model = model.removeprefix('apigee/') + components = model.split('/') + + if len(components) == 3: + # Format: // + return components[1] + if len(components) == 2: + # Format: / or / + # _validate_model_string ensures that if the first component is not a + # provider, it can be a version. + if components[0] not in ('vertex_ai', 'gemini') and components[ + 0 + ].startswith('v'): + return components[0] + return '' + + +def _get_model_id(model: str) -> str: + """Returns the model ID for the model spec.""" + model = model.removeprefix('apigee/') + components = model.split('/') + + # Model_id is the last component in the model string. + return components[-1] + + +def _validate_model_string(model: str) -> bool: + """Validates the model string for Apigee LLM. + + The model string specifies the LLM provider (e.g., Vertex AI, Gemini), API + version, and the model ID. + + Args: + model: The model string. Supported format: + `apigee/[/][/]` + + Returns: + True if the model string is valid, False otherwise. + """ + if not model.startswith('apigee/'): + return False + + # Remove leading "apigee/" from the model string. + model = model.removeprefix('apigee/') + + # The string has to be non-empty. i.e. the model_id cannot be empty. + if not model: + return False + + components = model.split('/') + # If the model string has exactly 1 component, it means only the model_id is + # present. This is a valid format. + if len(components) == 1: + return True + + # If the model string has more than 3 components, it is invalid. + if len(components) > 3: + return False + + # If the model string has 3 components, it means only the provider, version, + # and model_id are present. This is a valid format. + if len(components) == 3: + # Format: // + if components[0] not in ('vertex_ai', 'gemini'): + return False + if not components[1].startswith('v'): + return False + return True + + # If the model string has 2 components, it means either the provider or the + # version (but not both), and model_id are present. + if len(components) == 2: + if components[0] in ['vertex_ai', 'gemini']: + return True + if components[0].startswith('v'): + return True + return False + + return False diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py new file mode 100644 index 00000000..eeafd5be --- /dev/null +++ b/tests/unittests/models/test_apigee_llm.py @@ -0,0 +1,414 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from unittest import mock +from unittest.mock import AsyncMock + +from google.adk.models.apigee_llm import ApigeeLlm +from google.adk.models.llm_request import LlmRequest +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + +BASE_MODEL_ID = 'gemini-2.5-flash' +APIGEE_GEMINI_MODEL_ID = 'apigee/gemini/v1/' + BASE_MODEL_ID +APIGEE_VERTEX_MODEL_ID = 'apigee/vertex_ai/v1beta/gemini-pro' +VERTEX_BASE_MODEL_ID = 'gemini-pro' +PROXY_URL = 'https://test.apigee.net' + + +@pytest.fixture +def llm_request(): + """Provides a sample LlmRequest for testing.""" + return LlmRequest( + model=APIGEE_GEMINI_MODEL_ID, + contents=[ + types.Content( + role='user', parts=[types.Part.from_text(text='Test prompt')] + ) + ], + ) + + +@pytest.mark.asyncio +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_generate_content_async_non_streaming( + mock_client_constructor, llm_request +): + """Tests the generate_content_async method for non-streaming responses.""" + apigee_llm_instance = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + mock_client_instance = mock.Mock() + mock_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + mock_client_constructor.return_value = mock_client_instance + + response_generator = apigee_llm_instance.generate_content_async(llm_request) + responses = [resp async for resp in response_generator] + + assert len(responses) == 1 + llm_response = responses[0] + assert llm_response.content.parts[0].text == 'Test response' + assert llm_response.content.role == 'model' + + mock_client_constructor.assert_called_once() + _, kwargs = mock_client_constructor.call_args + assert not kwargs['vertexai'] + http_options = kwargs['http_options'] + assert http_options.base_url == PROXY_URL + assert http_options.api_version == 'v1' + assert 'user-agent' in http_options.headers + assert 'x-goog-api-client' in http_options.headers + + mock_client_instance.aio.models.generate_content.assert_called_once_with( + model=BASE_MODEL_ID, + contents=llm_request.contents, + config=llm_request.config, + ) + + +@pytest.mark.asyncio +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_generate_content_async_streaming( + mock_client_constructor, llm_request +): + """Tests the generate_content_async method for streaming responses.""" + apigee_llm_instance = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + mock_client_instance = mock.Mock() + mock_responses = [ + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Hello')], + ) + ) + ] + ), + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text=',')], + ) + ) + ] + ), + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text=' world!')], + ) + ) + ] + ), + ] + + async def mock_stream_generator(): + for r in mock_responses: + yield r + + mock_client_instance.aio.models.generate_content_stream = AsyncMock( + return_value=mock_stream_generator() + ) + mock_client_constructor.return_value = mock_client_instance + + response_generator = apigee_llm_instance.generate_content_async( + llm_request, stream=True + ) + responses = [resp async for resp in response_generator] + + assert responses + full_text_parts = [] + for r in responses: + for p in r.content.parts: + if p.text: + full_text_parts.append(p.text) + full_text = ''.join(full_text_parts) + assert 'Hello, world!' in full_text + + mock_client_instance.aio.models.generate_content_stream.assert_called_once_with( + model=BASE_MODEL_ID, + contents=llm_request.contents, + config=llm_request.config, + ) + + +@pytest.mark.asyncio +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_generate_content_async_with_custom_headers( + mock_client_constructor, llm_request +): + """Tests that custom headers are passed in the request.""" + custom_headers = { + 'X-Custom-Header': 'custom-value', + } + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + custom_headers=custom_headers, + ) + mock_client_instance = mock.Mock() + mock_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=mock_response + ) + mock_client_constructor.return_value = mock_client_instance + + response_generator = apigee_llm.generate_content_async(llm_request) + _ = [resp async for resp in response_generator] # Consume generator + + mock_client_constructor.assert_called_once() + _, kwargs = mock_client_constructor.call_args + http_options = kwargs['http_options'] + assert http_options.headers['X-Custom-Header'] == 'custom-value' + assert 'user-agent' in http_options.headers + + +@pytest.mark.asyncio +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_vertex_model_path_parsing(mock_client_constructor): + """Tests that Vertex AI model paths are parsed correctly.""" + apigee_llm = ApigeeLlm(model=APIGEE_VERTEX_MODEL_ID, proxy_url=PROXY_URL) + llm_request = LlmRequest( + model=APIGEE_VERTEX_MODEL_ID, + contents=[ + types.Content( + role='user', parts=[types.Part.from_text(text='Test prompt')] + ) + ], + ) + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + mock_client_constructor.assert_called_once() + _, kwargs = mock_client_constructor.call_args + assert kwargs['vertexai'] + assert kwargs['http_options'].api_version == 'v1beta' + + mock_client_instance.aio.models.generate_content.assert_called_once() + call_kwargs = ( + mock_client_instance.aio.models.generate_content.call_args.kwargs + ) + assert call_kwargs['model'] == VERTEX_BASE_MODEL_ID + + +@pytest.mark.asyncio +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_proxy_url_from_env_variable(mock_client_constructor): + """Tests that proxy_url is read from environment variable.""" + with mock.patch.dict( + os.environ, {'APIGEE_PROXY_URL': 'https://env.proxy.url'} + ): + apigee_llm = ApigeeLlm(model=APIGEE_GEMINI_MODEL_ID) + llm_request = LlmRequest( + model=APIGEE_GEMINI_MODEL_ID, + contents=[ + types.Content( + role='user', parts=[types.Part.from_text(text='Test prompt')] + ) + ], + ) + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + mock_client_constructor.assert_called_once() + _, kwargs = mock_client_constructor.call_args + assert kwargs['http_options'].base_url == 'https://env.proxy.url' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ( + 'model_string', + 'use_vertexai_env', + 'expected_is_vertexai', + 'expected_api_version', + 'expected_model_id', + ), + [ + ('apigee/gemini-2.5-flash', None, False, None, 'gemini-2.5-flash'), + ('apigee/gemini-2.5-flash', 'true', True, None, 'gemini-2.5-flash'), + ('apigee/gemini-2.5-flash', '1', True, None, 'gemini-2.5-flash'), + ('apigee/gemini-2.5-flash', 'false', False, None, 'gemini-2.5-flash'), + ('apigee/gemini-2.5-flash', '0', False, None, 'gemini-2.5-flash'), + ( + 'apigee/v1/gemini-2.5-flash', + None, + False, + 'v1', + 'gemini-2.5-flash', + ), + ( + 'apigee/v1/gemini-2.5-flash', + 'true', + True, + 'v1', + 'gemini-2.5-flash', + ), + ( + 'apigee/vertex_ai/gemini-2.5-flash', + None, + True, + None, + 'gemini-2.5-flash', + ), + ( + 'apigee/vertex_ai/gemini-2.5-flash', + 'false', + True, + None, + 'gemini-2.5-flash', + ), + ( + 'apigee/gemini/v1/gemini-2.5-flash', + 'true', + False, + 'v1', + 'gemini-2.5-flash', + ), + ( + 'apigee/vertex_ai/v1beta/gemini-2.5-flash', + 'false', + True, + 'v1beta', + 'gemini-2.5-flash', + ), + ], +) +@mock.patch('google.adk.models.apigee_llm.Client') +async def test_model_string_parsing_and_client_initialization( + mock_client_constructor, + model_string, + use_vertexai_env, + expected_is_vertexai, + expected_api_version, + expected_model_id, +): + """Tests model string parsing and genai.Client initialization.""" + env_vars = {} + if use_vertexai_env is not None: + env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env + + # The ApigeeLlm is initialized in the 'with' block to make sure that the mock + # of the environment variable is active. + with mock.patch.dict(os.environ, env_vars, clear=True): + apigee_llm = ApigeeLlm(model=model_string, proxy_url=PROXY_URL) + request = LlmRequest(model=model_string, contents=[]) + + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content(parts=[Part.from_text(text='')]) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + _ = [resp async for resp in apigee_llm.generate_content_async(request)] + + mock_client_constructor.assert_called_once() + _, kwargs = mock_client_constructor.call_args + assert kwargs['vertexai'] == expected_is_vertexai + http_options = kwargs['http_options'] + assert http_options.api_version == expected_api_version + + ( + mock_client_instance.aio.models.generate_content.assert_called_once_with( + model=expected_model_id, + contents=request.contents, + config=request.config, + ) + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'invalid_model_string', + [ + 'apigee/openai/v1/gpt', + 'apigee/', # Missing model_id + 'apigee', # Invalid format + 'gemini-pro', # Invalid format + 'apigee/vertex_ai/v1/model/extra', # Too many components + 'apigee/unknown/model', + ], +) +async def test_invalid_model_strings_raise_value_error(invalid_model_string): + """Tests that invalid model strings raise a ValueError.""" + with pytest.raises( + ValueError, match=f'Invalid model string: {invalid_model_string}' + ): + ApigeeLlm(model=invalid_model_string, proxy_url=PROXY_URL)