feat: Add ApigeeLlm as a model that let's ADK Agent developers to connect with an Apigee proxy

PiperOrigin-RevId: 824712152
This commit is contained in:
Google Team Member
2025-10-27 15:53:02 -07:00
committed by Copybara-Service
parent 00d147d62f
commit 87dcb3f7ba
7 changed files with 960 additions and 0 deletions
@@ -0,0 +1,8 @@
# This is a sample .env file.
# Copy this file to .env and replace the placeholder values with your actual credentials.
# Your Google API key for accessing Gemini models.
GOOGLE_API_KEY="your-google-api-key"
# The URL of your Apigee proxy.
APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath"
@@ -0,0 +1,84 @@
# Hello World with Apigee LLM
This sample demonstrates how to use the Agent Development Kit (ADK) with an LLM fronted by an Apigee proxy. It showcases the flexibility of the `ApigeeLlm` class in configuring the target LLM provider (Gemini or Vertex AI) and API version through the model string.
## Setup
Before running the sample, you need to configure your environment with the necessary credentials.
1. **Create a `.env` file:**
Copy the sample environment file to a new file named `.env` in the same directory.
```bash
cp .env-sample .env
```
2. **Set Environment Variables:**
Open the `.env` file and provide values for the following variables:
- `GOOGLE_API_KEY`: Your API key for the Google AI services (Gemini).
- `APIGEE_PROXY_URL`: The full URL of your Apigee proxy endpoint.
Example `.env` file:
```
GOOGLE_API_KEY="your-google-api-key"
APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath"
```
The `main.py` script will automatically load these variables when it runs.
## Run the Sample
Once your `.env` file is configured, you can run the sample with the following command:
```bash
python main.py
```
## Configuring the Apigee LLM
The `ApigeeLlm` class is configured using a special model string format in `agent.py`. This string determines which backend provider (Vertex AI or Gemini) and which API version to use.
### Model String Format
The supported format is:
`apigee/[<provider>/][<version>/]<model_id>`
- **`provider`** (optional): Can be `vertex_ai` or `gemini`.
- If specified, it forces the use of that provider.
- If omitted, the provider is determined by the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If this variable is set to `true` or `1`, Vertex AI is used; otherwise, `gemini` is used by default.
- **`version`** (optional): The API version to use (e.g., `v1`, `v1beta`).
- If omitted, the default version for the selected provider is used.
- **`model_id`** (required): The identifier for the model you want to use (e.g., `gemini-2.5-flash`).
### Configuration Examples
Here are some examples of how to configure the model string in `agent.py` to achieve different behaviors:
1. **Implicit Provider (determined by environment variable):**
- `model="apigee/gemini-2.5-flash"`
- Uses the default API version.
- Provider is Vertex AI if `GOOGLE_GENAI_USE_VERTEXAI` is true, otherwise Gemini.
- `model="apigee/v1/gemini-2.5-flash"`
- Uses API version `v1`.
- Provider is determined by the environment variable.
2. **Explicit Provider (ignores environment variable):**
- `model="apigee/vertex_ai/gemini-2.5-flash"`
- Uses Vertex AI with the default API version.
- `model="apigee/gemini/gemini-2.5-flash"`
- Uses Gemini with the default API version.
- `model="apigee/gemini/v1/gemini-2.5-flash"`
- Uses Gemini with API version `v1`.
- `model="apigee/vertex_ai/v1beta/gemini-2.5-flash"`
- Uses Vertex AI with API version `v1beta`.
By modifying the `model` string in `agent.py`, you can test various configurations without changing the core logic of the agent.
@@ -0,0 +1,108 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if "rolls" not in tool_context.state:
tool_context.state["rolls"] = []
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
return result
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
"No prime numbers found."
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model="apigee/gemini-2.5-flash",
name="hello_world_agent",
description=(
"hello world agent that can roll a dice of 8 sides and check prime"
" numbers."
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
# planner=BuiltInPlanner(
# thinking_config=types.ThinkingConfig(
# include_thoughts=True,
# ),
# ),
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import time
import agent
from dotenv import load_dotenv
from google.adk.agents.run_config import RunConfig
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.sessions.session import Session
from google.genai import types
load_dotenv(override=True)
logs.log_to_tmp_folder()
async def main():
app_name = "my_app"
user_id_1 = "user1"
runner = InMemoryRunner(
agent=agent.root_agent,
app_name=app_name,
)
session_11 = await runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role="user", parts=[types.Part.from_text(text=new_message)]
)
print("** User says:", content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f"** {event.author}: {event.content.parts[0].text}")
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
role="user",
parts=[
types.Part.from_bytes(
data=str.encode(new_message), mime_type="text/plain"
)
],
)
print("** User says:", content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f"** {event.author}: {event.content.parts[0].text}")
async def check_rolls_in_state(rolls_size: int):
session = await runner.session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
assert len(session.state["rolls"]) == rolls_size
for roll in session.state["rolls"]:
assert roll > 0 and roll <= 100
start_time = time.time()
print("Start time:", start_time)
print("------------------------------------")
await run_prompt(session_11, "Hi")
await run_prompt(session_11, "Roll a die with 100 sides")
await check_rolls_in_state(1)
await run_prompt(session_11, "Roll a die again with 100 sides.")
await check_rolls_in_state(2)
await run_prompt(session_11, "What numbers did I got?")
await run_prompt_bytes(session_11, "Hi bytes")
print(
await runner.artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print("------------------------------------")
print("End time:", end_time)
print("Total time:", end_time - start_time)
if __name__ == "__main__":
# The API key can be set in a .env file.
# For example, create a .env file with the following content:
# GOOGLE_API_KEY="your-api-key"
# APIGEE_PROXY_URL="your-proxy-url"
if not os.getenv("GOOGLE_API_KEY"):
raise ValueError("GOOGLE_API_KEY environment variable is not set.")
if not os.getenv("APIGEE_PROXY_URL"):
raise ValueError("APIGEE_PROXY_URL environment variable is not set.")
asyncio.run(main())
+2
View File
@@ -14,6 +14,7 @@
"""Defines the interface to support a model."""
from .apigee_llm import ApigeeLlm
from .base_llm import BaseLlm
from .gemma_llm import Gemma
from .google_llm import Gemini
@@ -31,3 +32,4 @@ __all__ = [
LLMRegistry.register(Gemini)
LLMRegistry.register(Gemma)
LLMRegistry.register(ApigeeLlm)
+232
View File
@@ -0,0 +1,232 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import cached_property
import logging
import os
import re
from typing import Optional
from typing import TYPE_CHECKING
from google.adk import version as adk_version
from google.genai import Client
from google.genai import types
from typing_extensions import override
from .google_llm import Gemini
if TYPE_CHECKING:
from .llm_request import LlmRequest
logger = logging.getLogger('google_adk.' + __name__)
_APIGEE_PROXY_URL_ENV_VARIABLE_NAME = 'APIGEE_PROXY_URL'
_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_USE_VERTEXAI'
class ApigeeLlm(Gemini):
"""A BaseLlm implementation for calling Apigee proxy.
Attributes:
model: The name of the Gemini model.
"""
def __init__(
self,
*,
model: str,
proxy_url: str | None = None,
custom_headers: dict[str, str] | None = None,
retry_options: Optional[types.HttpRetryOptions] = None,
):
"""Initializes the Apigee LLM backend.
Args:
model: The model string specifies the LLM provider (e.g., Vertex AI,
Gemini), API version, and the model ID. Supported format:
`apigee/[<provider>/][<version>/]<model_id>`
Components
`provider` (optional): `vertex_ai` or `gemini`. If omitted, behavior
depends on the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If
that is not set to TRUE or 1, it defaults to `gemini`. `provider`
takes precedence over `GOOGLE_GENAI_USE_VERTEXAI`.
`version` (optional): The API version (e.g., `v1`, `v1beta`). If
omitted, the default version for the provider is used.
`model_id` (required): The model identifier (e.g.,
`gemini-2.5-flash`).
Examples
- `apigee/gemini-2.5-flash`
- `apigee/v1/gemini-2.5-flash`
- `apigee/vertex_ai/gemini-2.5-flash`
- `apigee/gemini/v1/gemini-2.5-flash`
- `apigee/vertex_ai/v1beta/gemini-2.5-flash`
proxy_url: The URL of the Apigee proxy.
custom_headers: A dictionary of headers to be sent with the request.
retry_options: Allow google-genai to retry failed responses.
"""
super().__init__(model=model, retry_options=retry_options)
# Validate the model string. Create a helper method to validate the model
# string.
if not _validate_model_string(model):
raise ValueError(f'Invalid model string: {model}')
self._isvertexai = _identify_vertexai(model)
self._api_version = _identify_api_version(model)
self._proxy_url = proxy_url or os.environ.get(
_APIGEE_PROXY_URL_ENV_VARIABLE_NAME
)
self._custom_headers = custom_headers or {}
self._user_agent = f'google-adk/{adk_version.__version__}'
@classmethod
@override
def supported_models(cls) -> list[str]:
"""Provides the list of supported models.
Returns:
A list of supported models.
"""
return [
r'apigee\/.*',
]
@cached_property
def api_client(self) -> Client:
"""Provides the api client.
Returns:
The api client.
"""
kwargs_for_http_options = {}
if self._api_version:
kwargs_for_http_options['api_version'] = self._api_version
http_options = types.HttpOptions(
base_url=self._proxy_url,
headers=self._merge_tracking_headers(self._custom_headers),
retry_options=self.retry_options,
**kwargs_for_http_options,
)
return Client(
vertexai=self._isvertexai,
http_options=http_options,
)
@override
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
llm_request.model = _get_model_id(llm_request.model)
await super()._preprocess_request(llm_request)
def _identify_vertexai(model: str) -> bool:
"""Returns True if the model spec starts with apigee/vertex_ai."""
return not model.startswith('apigee/gemini/') and (
model.startswith('apigee/vertex_ai/')
or os.environ.get(
_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME, '0'
).lower()
in ['true', '1']
)
def _identify_api_version(model: str) -> str:
"""Returns the api version for the model spec."""
model = model.removeprefix('apigee/')
components = model.split('/')
if len(components) == 3:
# Format: <provider>/<version>/<model_id>
return components[1]
if len(components) == 2:
# Format: <version>/<model_id> or <provider>/<model_id>
# _validate_model_string ensures that if the first component is not a
# provider, it can be a version.
if components[0] not in ('vertex_ai', 'gemini') and components[
0
].startswith('v'):
return components[0]
return ''
def _get_model_id(model: str) -> str:
"""Returns the model ID for the model spec."""
model = model.removeprefix('apigee/')
components = model.split('/')
# Model_id is the last component in the model string.
return components[-1]
def _validate_model_string(model: str) -> bool:
"""Validates the model string for Apigee LLM.
The model string specifies the LLM provider (e.g., Vertex AI, Gemini), API
version, and the model ID.
Args:
model: The model string. Supported format:
`apigee/[<provider>/][<version>/]<model_id>`
Returns:
True if the model string is valid, False otherwise.
"""
if not model.startswith('apigee/'):
return False
# Remove leading "apigee/" from the model string.
model = model.removeprefix('apigee/')
# The string has to be non-empty. i.e. the model_id cannot be empty.
if not model:
return False
components = model.split('/')
# If the model string has exactly 1 component, it means only the model_id is
# present. This is a valid format.
if len(components) == 1:
return True
# If the model string has more than 3 components, it is invalid.
if len(components) > 3:
return False
# If the model string has 3 components, it means only the provider, version,
# and model_id are present. This is a valid format.
if len(components) == 3:
# Format: <provider>/<version>/<model_id>
if components[0] not in ('vertex_ai', 'gemini'):
return False
if not components[1].startswith('v'):
return False
return True
# If the model string has 2 components, it means either the provider or the
# version (but not both), and model_id are present.
if len(components) == 2:
if components[0] in ['vertex_ai', 'gemini']:
return True
if components[0].startswith('v'):
return True
return False
return False
+414
View File
@@ -0,0 +1,414 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from unittest import mock
from unittest.mock import AsyncMock
from google.adk.models.apigee_llm import ApigeeLlm
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from google.genai.types import Content
from google.genai.types import Part
import pytest
BASE_MODEL_ID = 'gemini-2.5-flash'
APIGEE_GEMINI_MODEL_ID = 'apigee/gemini/v1/' + BASE_MODEL_ID
APIGEE_VERTEX_MODEL_ID = 'apigee/vertex_ai/v1beta/gemini-pro'
VERTEX_BASE_MODEL_ID = 'gemini-pro'
PROXY_URL = 'https://test.apigee.net'
@pytest.fixture
def llm_request():
"""Provides a sample LlmRequest for testing."""
return LlmRequest(
model=APIGEE_GEMINI_MODEL_ID,
contents=[
types.Content(
role='user', parts=[types.Part.from_text(text='Test prompt')]
)
],
)
@pytest.mark.asyncio
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_generate_content_async_non_streaming(
mock_client_constructor, llm_request
):
"""Tests the generate_content_async method for non-streaming responses."""
apigee_llm_instance = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
)
mock_client_instance = mock.Mock()
mock_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_client_constructor.return_value = mock_client_instance
response_generator = apigee_llm_instance.generate_content_async(llm_request)
responses = [resp async for resp in response_generator]
assert len(responses) == 1
llm_response = responses[0]
assert llm_response.content.parts[0].text == 'Test response'
assert llm_response.content.role == 'model'
mock_client_constructor.assert_called_once()
_, kwargs = mock_client_constructor.call_args
assert not kwargs['vertexai']
http_options = kwargs['http_options']
assert http_options.base_url == PROXY_URL
assert http_options.api_version == 'v1'
assert 'user-agent' in http_options.headers
assert 'x-goog-api-client' in http_options.headers
mock_client_instance.aio.models.generate_content.assert_called_once_with(
model=BASE_MODEL_ID,
contents=llm_request.contents,
config=llm_request.config,
)
@pytest.mark.asyncio
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_generate_content_async_streaming(
mock_client_constructor, llm_request
):
"""Tests the generate_content_async method for streaming responses."""
apigee_llm_instance = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
)
mock_client_instance = mock.Mock()
mock_responses = [
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Hello')],
)
)
]
),
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text=',')],
)
)
]
),
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text=' world!')],
)
)
]
),
]
async def mock_stream_generator():
for r in mock_responses:
yield r
mock_client_instance.aio.models.generate_content_stream = AsyncMock(
return_value=mock_stream_generator()
)
mock_client_constructor.return_value = mock_client_instance
response_generator = apigee_llm_instance.generate_content_async(
llm_request, stream=True
)
responses = [resp async for resp in response_generator]
assert responses
full_text_parts = []
for r in responses:
for p in r.content.parts:
if p.text:
full_text_parts.append(p.text)
full_text = ''.join(full_text_parts)
assert 'Hello, world!' in full_text
mock_client_instance.aio.models.generate_content_stream.assert_called_once_with(
model=BASE_MODEL_ID,
contents=llm_request.contents,
config=llm_request.config,
)
@pytest.mark.asyncio
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_generate_content_async_with_custom_headers(
mock_client_constructor, llm_request
):
"""Tests that custom headers are passed in the request."""
custom_headers = {
'X-Custom-Header': 'custom-value',
}
apigee_llm = ApigeeLlm(
model=APIGEE_GEMINI_MODEL_ID,
proxy_url=PROXY_URL,
custom_headers=custom_headers,
)
mock_client_instance = mock.Mock()
mock_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=mock_response
)
mock_client_constructor.return_value = mock_client_instance
response_generator = apigee_llm.generate_content_async(llm_request)
_ = [resp async for resp in response_generator] # Consume generator
mock_client_constructor.assert_called_once()
_, kwargs = mock_client_constructor.call_args
http_options = kwargs['http_options']
assert http_options.headers['X-Custom-Header'] == 'custom-value'
assert 'user-agent' in http_options.headers
@pytest.mark.asyncio
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_vertex_model_path_parsing(mock_client_constructor):
"""Tests that Vertex AI model paths are parsed correctly."""
apigee_llm = ApigeeLlm(model=APIGEE_VERTEX_MODEL_ID, proxy_url=PROXY_URL)
llm_request = LlmRequest(
model=APIGEE_VERTEX_MODEL_ID,
contents=[
types.Content(
role='user', parts=[types.Part.from_text(text='Test prompt')]
)
],
)
mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
mock_client_constructor.assert_called_once()
_, kwargs = mock_client_constructor.call_args
assert kwargs['vertexai']
assert kwargs['http_options'].api_version == 'v1beta'
mock_client_instance.aio.models.generate_content.assert_called_once()
call_kwargs = (
mock_client_instance.aio.models.generate_content.call_args.kwargs
)
assert call_kwargs['model'] == VERTEX_BASE_MODEL_ID
@pytest.mark.asyncio
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_proxy_url_from_env_variable(mock_client_constructor):
"""Tests that proxy_url is read from environment variable."""
with mock.patch.dict(
os.environ, {'APIGEE_PROXY_URL': 'https://env.proxy.url'}
):
apigee_llm = ApigeeLlm(model=APIGEE_GEMINI_MODEL_ID)
llm_request = LlmRequest(
model=APIGEE_GEMINI_MODEL_ID,
contents=[
types.Content(
role='user', parts=[types.Part.from_text(text='Test prompt')]
)
],
)
mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
parts=[Part.from_text(text='Test response')],
role='model',
)
)
]
)
)
mock_client_constructor.return_value = mock_client_instance
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
mock_client_constructor.assert_called_once()
_, kwargs = mock_client_constructor.call_args
assert kwargs['http_options'].base_url == 'https://env.proxy.url'
@pytest.mark.asyncio
@pytest.mark.parametrize(
(
'model_string',
'use_vertexai_env',
'expected_is_vertexai',
'expected_api_version',
'expected_model_id',
),
[
('apigee/gemini-2.5-flash', None, False, None, 'gemini-2.5-flash'),
('apigee/gemini-2.5-flash', 'true', True, None, 'gemini-2.5-flash'),
('apigee/gemini-2.5-flash', '1', True, None, 'gemini-2.5-flash'),
('apigee/gemini-2.5-flash', 'false', False, None, 'gemini-2.5-flash'),
('apigee/gemini-2.5-flash', '0', False, None, 'gemini-2.5-flash'),
(
'apigee/v1/gemini-2.5-flash',
None,
False,
'v1',
'gemini-2.5-flash',
),
(
'apigee/v1/gemini-2.5-flash',
'true',
True,
'v1',
'gemini-2.5-flash',
),
(
'apigee/vertex_ai/gemini-2.5-flash',
None,
True,
None,
'gemini-2.5-flash',
),
(
'apigee/vertex_ai/gemini-2.5-flash',
'false',
True,
None,
'gemini-2.5-flash',
),
(
'apigee/gemini/v1/gemini-2.5-flash',
'true',
False,
'v1',
'gemini-2.5-flash',
),
(
'apigee/vertex_ai/v1beta/gemini-2.5-flash',
'false',
True,
'v1beta',
'gemini-2.5-flash',
),
],
)
@mock.patch('google.adk.models.apigee_llm.Client')
async def test_model_string_parsing_and_client_initialization(
mock_client_constructor,
model_string,
use_vertexai_env,
expected_is_vertexai,
expected_api_version,
expected_model_id,
):
"""Tests model string parsing and genai.Client initialization."""
env_vars = {}
if use_vertexai_env is not None:
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
# The ApigeeLlm is initialized in the 'with' block to make sure that the mock
# of the environment variable is active.
with mock.patch.dict(os.environ, env_vars, clear=True):
apigee_llm = ApigeeLlm(model=model_string, proxy_url=PROXY_URL)
request = LlmRequest(model=model_string, contents=[])
mock_client_instance = mock.Mock()
mock_client_instance.aio.models.generate_content = AsyncMock(
return_value=types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(parts=[Part.from_text(text='')])
)
]
)
)
mock_client_constructor.return_value = mock_client_instance
_ = [resp async for resp in apigee_llm.generate_content_async(request)]
mock_client_constructor.assert_called_once()
_, kwargs = mock_client_constructor.call_args
assert kwargs['vertexai'] == expected_is_vertexai
http_options = kwargs['http_options']
assert http_options.api_version == expected_api_version
(
mock_client_instance.aio.models.generate_content.assert_called_once_with(
model=expected_model_id,
contents=request.contents,
config=request.config,
)
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'invalid_model_string',
[
'apigee/openai/v1/gpt',
'apigee/', # Missing model_id
'apigee', # Invalid format
'gemini-pro', # Invalid format
'apigee/vertex_ai/v1/model/extra', # Too many components
'apigee/unknown/model',
],
)
async def test_invalid_model_strings_raise_value_error(invalid_model_string):
"""Tests that invalid model strings raise a ValueError."""
with pytest.raises(
ValueError, match=f'Invalid model string: {invalid_model_string}'
):
ApigeeLlm(model=invalid_model_string, proxy_url=PROXY_URL)