You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add ApigeeLlm as a model that let's ADK Agent developers to connect with an Apigee proxy
PiperOrigin-RevId: 824712152
This commit is contained in:
committed by
Copybara-Service
parent
00d147d62f
commit
87dcb3f7ba
@@ -0,0 +1,8 @@
|
||||
# This is a sample .env file.
|
||||
# Copy this file to .env and replace the placeholder values with your actual credentials.
|
||||
|
||||
# Your Google API key for accessing Gemini models.
|
||||
GOOGLE_API_KEY="your-google-api-key"
|
||||
|
||||
# The URL of your Apigee proxy.
|
||||
APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath"
|
||||
@@ -0,0 +1,84 @@
|
||||
# Hello World with Apigee LLM
|
||||
|
||||
This sample demonstrates how to use the Agent Development Kit (ADK) with an LLM fronted by an Apigee proxy. It showcases the flexibility of the `ApigeeLlm` class in configuring the target LLM provider (Gemini or Vertex AI) and API version through the model string.
|
||||
|
||||
## Setup
|
||||
|
||||
Before running the sample, you need to configure your environment with the necessary credentials.
|
||||
|
||||
1. **Create a `.env` file:**
|
||||
Copy the sample environment file to a new file named `.env` in the same directory.
|
||||
```bash
|
||||
cp .env-sample .env
|
||||
```
|
||||
|
||||
2. **Set Environment Variables:**
|
||||
Open the `.env` file and provide values for the following variables:
|
||||
|
||||
- `GOOGLE_API_KEY`: Your API key for the Google AI services (Gemini).
|
||||
- `APIGEE_PROXY_URL`: The full URL of your Apigee proxy endpoint.
|
||||
|
||||
Example `.env` file:
|
||||
```
|
||||
GOOGLE_API_KEY="your-google-api-key"
|
||||
APIGEE_PROXY_URL="https://your-apigee-proxy.net/basepath"
|
||||
```
|
||||
|
||||
The `main.py` script will automatically load these variables when it runs.
|
||||
|
||||
## Run the Sample
|
||||
|
||||
Once your `.env` file is configured, you can run the sample with the following command:
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
## Configuring the Apigee LLM
|
||||
|
||||
The `ApigeeLlm` class is configured using a special model string format in `agent.py`. This string determines which backend provider (Vertex AI or Gemini) and which API version to use.
|
||||
|
||||
### Model String Format
|
||||
|
||||
The supported format is:
|
||||
|
||||
`apigee/[<provider>/][<version>/]<model_id>`
|
||||
|
||||
- **`provider`** (optional): Can be `vertex_ai` or `gemini`.
|
||||
- If specified, it forces the use of that provider.
|
||||
- If omitted, the provider is determined by the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If this variable is set to `true` or `1`, Vertex AI is used; otherwise, `gemini` is used by default.
|
||||
|
||||
- **`version`** (optional): The API version to use (e.g., `v1`, `v1beta`).
|
||||
- If omitted, the default version for the selected provider is used.
|
||||
|
||||
- **`model_id`** (required): The identifier for the model you want to use (e.g., `gemini-2.5-flash`).
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
Here are some examples of how to configure the model string in `agent.py` to achieve different behaviors:
|
||||
|
||||
1. **Implicit Provider (determined by environment variable):**
|
||||
|
||||
- `model="apigee/gemini-2.5-flash"`
|
||||
- Uses the default API version.
|
||||
- Provider is Vertex AI if `GOOGLE_GENAI_USE_VERTEXAI` is true, otherwise Gemini.
|
||||
|
||||
- `model="apigee/v1/gemini-2.5-flash"`
|
||||
- Uses API version `v1`.
|
||||
- Provider is determined by the environment variable.
|
||||
|
||||
2. **Explicit Provider (ignores environment variable):**
|
||||
|
||||
- `model="apigee/vertex_ai/gemini-2.5-flash"`
|
||||
- Uses Vertex AI with the default API version.
|
||||
|
||||
- `model="apigee/gemini/gemini-2.5-flash"`
|
||||
- Uses Gemini with the default API version.
|
||||
|
||||
- `model="apigee/gemini/v1/gemini-2.5-flash"`
|
||||
- Uses Gemini with API version `v1`.
|
||||
|
||||
- `model="apigee/vertex_ai/v1beta/gemini-2.5-flash"`
|
||||
- Uses Vertex AI with API version `v1beta`.
|
||||
|
||||
By modifying the `model` string in `agent.py`, you can test various configurations without changing the core logic of the agent.
|
||||
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def roll_die(sides: int, tool_context: ToolContext) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
result = random.randint(1, sides)
|
||||
if "rolls" not in tool_context.state:
|
||||
tool_context.state["rolls"] = []
|
||||
|
||||
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
|
||||
return result
|
||||
|
||||
|
||||
async def check_prime(nums: list[int]) -> str:
|
||||
"""Check if a given list of numbers are prime.
|
||||
|
||||
Args:
|
||||
nums: The list of numbers to check.
|
||||
|
||||
Returns:
|
||||
A str indicating which number is prime.
|
||||
"""
|
||||
primes = set()
|
||||
for number in nums:
|
||||
number = int(number)
|
||||
if number <= 1:
|
||||
continue
|
||||
is_prime = True
|
||||
for i in range(2, int(number**0.5) + 1):
|
||||
if number % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.add(number)
|
||||
return (
|
||||
"No prime numbers found."
|
||||
if not primes
|
||||
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||
)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model="apigee/gemini-2.5-flash",
|
||||
name="hello_world_agent",
|
||||
description=(
|
||||
"hello world agent that can roll a dice of 8 sides and check prime"
|
||||
" numbers."
|
||||
),
|
||||
instruction="""
|
||||
You roll dice and answer questions about the outcome of the dice rolls.
|
||||
You can roll dice of different sizes.
|
||||
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
|
||||
It is ok to discuss previous dice roles, and comment on the dice rolls.
|
||||
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||
You should never roll a die on your own.
|
||||
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||
You should not check prime numbers before calling the tool.
|
||||
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||
3. When you respond, you must include the roll_die result from step 1.
|
||||
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||
You should not rely on the previous history on prime results.
|
||||
""",
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
],
|
||||
# planner=BuiltInPlanner(
|
||||
# thinking_config=types.ThinkingConfig(
|
||||
# include_thoughts=True,
|
||||
# ),
|
||||
# ),
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
safety_settings=[
|
||||
types.SafetySetting( # avoid false alarm about rolling dice.
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=types.HarmBlockThreshold.OFF,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
import agent
|
||||
from dotenv import load_dotenv
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.cli.utils import logs
|
||||
from google.adk.runners import InMemoryRunner
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
|
||||
load_dotenv(override=True)
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
|
||||
async def main():
|
||||
app_name = "my_app"
|
||||
user_id_1 = "user1"
|
||||
runner = InMemoryRunner(
|
||||
agent=agent.root_agent,
|
||||
app_name=app_name,
|
||||
)
|
||||
session_11 = await runner.session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
|
||||
async def run_prompt(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=new_message)]
|
||||
)
|
||||
print("** User says:", content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f"** {event.author}: {event.content.parts[0].text}")
|
||||
|
||||
async def run_prompt_bytes(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=str.encode(new_message), mime_type="text/plain"
|
||||
)
|
||||
],
|
||||
)
|
||||
print("** User says:", content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
run_config=RunConfig(save_input_blobs_as_artifacts=True),
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f"** {event.author}: {event.content.parts[0].text}")
|
||||
|
||||
async def check_rolls_in_state(rolls_size: int):
|
||||
session = await runner.session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||
)
|
||||
assert len(session.state["rolls"]) == rolls_size
|
||||
for roll in session.state["rolls"]:
|
||||
assert roll > 0 and roll <= 100
|
||||
|
||||
start_time = time.time()
|
||||
print("Start time:", start_time)
|
||||
print("------------------------------------")
|
||||
await run_prompt(session_11, "Hi")
|
||||
await run_prompt(session_11, "Roll a die with 100 sides")
|
||||
await check_rolls_in_state(1)
|
||||
await run_prompt(session_11, "Roll a die again with 100 sides.")
|
||||
await check_rolls_in_state(2)
|
||||
await run_prompt(session_11, "What numbers did I got?")
|
||||
await run_prompt_bytes(session_11, "Hi bytes")
|
||||
print(
|
||||
await runner.artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||
)
|
||||
)
|
||||
end_time = time.time()
|
||||
print("------------------------------------")
|
||||
print("End time:", end_time)
|
||||
print("Total time:", end_time - start_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The API key can be set in a .env file.
|
||||
# For example, create a .env file with the following content:
|
||||
# GOOGLE_API_KEY="your-api-key"
|
||||
# APIGEE_PROXY_URL="your-proxy-url"
|
||||
if not os.getenv("GOOGLE_API_KEY"):
|
||||
raise ValueError("GOOGLE_API_KEY environment variable is not set.")
|
||||
if not os.getenv("APIGEE_PROXY_URL"):
|
||||
raise ValueError("APIGEE_PROXY_URL environment variable is not set.")
|
||||
asyncio.run(main())
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Defines the interface to support a model."""
|
||||
|
||||
from .apigee_llm import ApigeeLlm
|
||||
from .base_llm import BaseLlm
|
||||
from .gemma_llm import Gemma
|
||||
from .google_llm import Gemini
|
||||
@@ -31,3 +32,4 @@ __all__ = [
|
||||
|
||||
LLMRegistry.register(Gemini)
|
||||
LLMRegistry.register(Gemma)
|
||||
LLMRegistry.register(ApigeeLlm)
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.adk import version as adk_version
|
||||
from google.genai import Client
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .google_llm import Gemini
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
_APIGEE_PROXY_URL_ENV_VARIABLE_NAME = 'APIGEE_PROXY_URL'
|
||||
_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_USE_VERTEXAI'
|
||||
|
||||
|
||||
class ApigeeLlm(Gemini):
|
||||
"""A BaseLlm implementation for calling Apigee proxy.
|
||||
|
||||
Attributes:
|
||||
model: The name of the Gemini model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
proxy_url: str | None = None,
|
||||
custom_headers: dict[str, str] | None = None,
|
||||
retry_options: Optional[types.HttpRetryOptions] = None,
|
||||
):
|
||||
"""Initializes the Apigee LLM backend.
|
||||
|
||||
Args:
|
||||
model: The model string specifies the LLM provider (e.g., Vertex AI,
|
||||
Gemini), API version, and the model ID. Supported format:
|
||||
`apigee/[<provider>/][<version>/]<model_id>`
|
||||
|
||||
Components
|
||||
`provider` (optional): `vertex_ai` or `gemini`. If omitted, behavior
|
||||
depends on the `GOOGLE_GENAI_USE_VERTEXAI` environment variable. If
|
||||
that is not set to TRUE or 1, it defaults to `gemini`. `provider`
|
||||
takes precedence over `GOOGLE_GENAI_USE_VERTEXAI`.
|
||||
`version` (optional): The API version (e.g., `v1`, `v1beta`). If
|
||||
omitted, the default version for the provider is used.
|
||||
`model_id` (required): The model identifier (e.g.,
|
||||
`gemini-2.5-flash`).
|
||||
|
||||
Examples
|
||||
- `apigee/gemini-2.5-flash`
|
||||
- `apigee/v1/gemini-2.5-flash`
|
||||
- `apigee/vertex_ai/gemini-2.5-flash`
|
||||
- `apigee/gemini/v1/gemini-2.5-flash`
|
||||
- `apigee/vertex_ai/v1beta/gemini-2.5-flash`
|
||||
|
||||
proxy_url: The URL of the Apigee proxy.
|
||||
custom_headers: A dictionary of headers to be sent with the request.
|
||||
retry_options: Allow google-genai to retry failed responses.
|
||||
"""
|
||||
|
||||
super().__init__(model=model, retry_options=retry_options)
|
||||
# Validate the model string. Create a helper method to validate the model
|
||||
# string.
|
||||
if not _validate_model_string(model):
|
||||
raise ValueError(f'Invalid model string: {model}')
|
||||
|
||||
self._isvertexai = _identify_vertexai(model)
|
||||
self._api_version = _identify_api_version(model)
|
||||
self._proxy_url = proxy_url or os.environ.get(
|
||||
_APIGEE_PROXY_URL_ENV_VARIABLE_NAME
|
||||
)
|
||||
self._custom_headers = custom_headers or {}
|
||||
self._user_agent = f'google-adk/{adk_version.__version__}'
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'apigee\/.*',
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def api_client(self) -> Client:
|
||||
"""Provides the api client.
|
||||
|
||||
Returns:
|
||||
The api client.
|
||||
"""
|
||||
|
||||
kwargs_for_http_options = {}
|
||||
if self._api_version:
|
||||
kwargs_for_http_options['api_version'] = self._api_version
|
||||
http_options = types.HttpOptions(
|
||||
base_url=self._proxy_url,
|
||||
headers=self._merge_tracking_headers(self._custom_headers),
|
||||
retry_options=self.retry_options,
|
||||
**kwargs_for_http_options,
|
||||
)
|
||||
|
||||
return Client(
|
||||
vertexai=self._isvertexai,
|
||||
http_options=http_options,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
|
||||
llm_request.model = _get_model_id(llm_request.model)
|
||||
await super()._preprocess_request(llm_request)
|
||||
|
||||
|
||||
def _identify_vertexai(model: str) -> bool:
|
||||
"""Returns True if the model spec starts with apigee/vertex_ai."""
|
||||
return not model.startswith('apigee/gemini/') and (
|
||||
model.startswith('apigee/vertex_ai/')
|
||||
or os.environ.get(
|
||||
_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME, '0'
|
||||
).lower()
|
||||
in ['true', '1']
|
||||
)
|
||||
|
||||
|
||||
def _identify_api_version(model: str) -> str:
|
||||
"""Returns the api version for the model spec."""
|
||||
model = model.removeprefix('apigee/')
|
||||
components = model.split('/')
|
||||
|
||||
if len(components) == 3:
|
||||
# Format: <provider>/<version>/<model_id>
|
||||
return components[1]
|
||||
if len(components) == 2:
|
||||
# Format: <version>/<model_id> or <provider>/<model_id>
|
||||
# _validate_model_string ensures that if the first component is not a
|
||||
# provider, it can be a version.
|
||||
if components[0] not in ('vertex_ai', 'gemini') and components[
|
||||
0
|
||||
].startswith('v'):
|
||||
return components[0]
|
||||
return ''
|
||||
|
||||
|
||||
def _get_model_id(model: str) -> str:
|
||||
"""Returns the model ID for the model spec."""
|
||||
model = model.removeprefix('apigee/')
|
||||
components = model.split('/')
|
||||
|
||||
# Model_id is the last component in the model string.
|
||||
return components[-1]
|
||||
|
||||
|
||||
def _validate_model_string(model: str) -> bool:
|
||||
"""Validates the model string for Apigee LLM.
|
||||
|
||||
The model string specifies the LLM provider (e.g., Vertex AI, Gemini), API
|
||||
version, and the model ID.
|
||||
|
||||
Args:
|
||||
model: The model string. Supported format:
|
||||
`apigee/[<provider>/][<version>/]<model_id>`
|
||||
|
||||
Returns:
|
||||
True if the model string is valid, False otherwise.
|
||||
"""
|
||||
if not model.startswith('apigee/'):
|
||||
return False
|
||||
|
||||
# Remove leading "apigee/" from the model string.
|
||||
model = model.removeprefix('apigee/')
|
||||
|
||||
# The string has to be non-empty. i.e. the model_id cannot be empty.
|
||||
if not model:
|
||||
return False
|
||||
|
||||
components = model.split('/')
|
||||
# If the model string has exactly 1 component, it means only the model_id is
|
||||
# present. This is a valid format.
|
||||
if len(components) == 1:
|
||||
return True
|
||||
|
||||
# If the model string has more than 3 components, it is invalid.
|
||||
if len(components) > 3:
|
||||
return False
|
||||
|
||||
# If the model string has 3 components, it means only the provider, version,
|
||||
# and model_id are present. This is a valid format.
|
||||
if len(components) == 3:
|
||||
# Format: <provider>/<version>/<model_id>
|
||||
if components[0] not in ('vertex_ai', 'gemini'):
|
||||
return False
|
||||
if not components[1].startswith('v'):
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the model string has 2 components, it means either the provider or the
|
||||
# version (but not both), and model_id are present.
|
||||
if len(components) == 2:
|
||||
if components[0] in ['vertex_ai', 'gemini']:
|
||||
return True
|
||||
if components[0].startswith('v'):
|
||||
return True
|
||||
return False
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from google.adk.models.apigee_llm import ApigeeLlm
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.genai import types
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
BASE_MODEL_ID = 'gemini-2.5-flash'
|
||||
APIGEE_GEMINI_MODEL_ID = 'apigee/gemini/v1/' + BASE_MODEL_ID
|
||||
APIGEE_VERTEX_MODEL_ID = 'apigee/vertex_ai/v1beta/gemini-pro'
|
||||
VERTEX_BASE_MODEL_ID = 'gemini-pro'
|
||||
PROXY_URL = 'https://test.apigee.net'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_request():
|
||||
"""Provides a sample LlmRequest for testing."""
|
||||
return LlmRequest(
|
||||
model=APIGEE_GEMINI_MODEL_ID,
|
||||
contents=[
|
||||
types.Content(
|
||||
role='user', parts=[types.Part.from_text(text='Test prompt')]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_generate_content_async_non_streaming(
|
||||
mock_client_constructor, llm_request
|
||||
):
|
||||
"""Tests the generate_content_async method for non-streaming responses."""
|
||||
apigee_llm_instance = ApigeeLlm(
|
||||
model=APIGEE_GEMINI_MODEL_ID,
|
||||
proxy_url=PROXY_URL,
|
||||
)
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text='Test response')],
|
||||
role='model',
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_client_instance.aio.models.generate_content = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
response_generator = apigee_llm_instance.generate_content_async(llm_request)
|
||||
responses = [resp async for resp in response_generator]
|
||||
|
||||
assert len(responses) == 1
|
||||
llm_response = responses[0]
|
||||
assert llm_response.content.parts[0].text == 'Test response'
|
||||
assert llm_response.content.role == 'model'
|
||||
|
||||
mock_client_constructor.assert_called_once()
|
||||
_, kwargs = mock_client_constructor.call_args
|
||||
assert not kwargs['vertexai']
|
||||
http_options = kwargs['http_options']
|
||||
assert http_options.base_url == PROXY_URL
|
||||
assert http_options.api_version == 'v1'
|
||||
assert 'user-agent' in http_options.headers
|
||||
assert 'x-goog-api-client' in http_options.headers
|
||||
|
||||
mock_client_instance.aio.models.generate_content.assert_called_once_with(
|
||||
model=BASE_MODEL_ID,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_generate_content_async_streaming(
|
||||
mock_client_constructor, llm_request
|
||||
):
|
||||
"""Tests the generate_content_async method for streaming responses."""
|
||||
apigee_llm_instance = ApigeeLlm(
|
||||
model=APIGEE_GEMINI_MODEL_ID,
|
||||
proxy_url=PROXY_URL,
|
||||
)
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_responses = [
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text='Hello')],
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text=',')],
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text=' world!')],
|
||||
)
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
async def mock_stream_generator():
|
||||
for r in mock_responses:
|
||||
yield r
|
||||
|
||||
mock_client_instance.aio.models.generate_content_stream = AsyncMock(
|
||||
return_value=mock_stream_generator()
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
response_generator = apigee_llm_instance.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
responses = [resp async for resp in response_generator]
|
||||
|
||||
assert responses
|
||||
full_text_parts = []
|
||||
for r in responses:
|
||||
for p in r.content.parts:
|
||||
if p.text:
|
||||
full_text_parts.append(p.text)
|
||||
full_text = ''.join(full_text_parts)
|
||||
assert 'Hello, world!' in full_text
|
||||
|
||||
mock_client_instance.aio.models.generate_content_stream.assert_called_once_with(
|
||||
model=BASE_MODEL_ID,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_generate_content_async_with_custom_headers(
|
||||
mock_client_constructor, llm_request
|
||||
):
|
||||
"""Tests that custom headers are passed in the request."""
|
||||
custom_headers = {
|
||||
'X-Custom-Header': 'custom-value',
|
||||
}
|
||||
apigee_llm = ApigeeLlm(
|
||||
model=APIGEE_GEMINI_MODEL_ID,
|
||||
proxy_url=PROXY_URL,
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text='Test response')],
|
||||
role='model',
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_client_instance.aio.models.generate_content = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
response_generator = apigee_llm.generate_content_async(llm_request)
|
||||
_ = [resp async for resp in response_generator] # Consume generator
|
||||
|
||||
mock_client_constructor.assert_called_once()
|
||||
_, kwargs = mock_client_constructor.call_args
|
||||
http_options = kwargs['http_options']
|
||||
assert http_options.headers['X-Custom-Header'] == 'custom-value'
|
||||
assert 'user-agent' in http_options.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_vertex_model_path_parsing(mock_client_constructor):
|
||||
"""Tests that Vertex AI model paths are parsed correctly."""
|
||||
apigee_llm = ApigeeLlm(model=APIGEE_VERTEX_MODEL_ID, proxy_url=PROXY_URL)
|
||||
llm_request = LlmRequest(
|
||||
model=APIGEE_VERTEX_MODEL_ID,
|
||||
contents=[
|
||||
types.Content(
|
||||
role='user', parts=[types.Part.from_text(text='Test prompt')]
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_client_instance.aio.models.generate_content = AsyncMock(
|
||||
return_value=types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text='Test response')],
|
||||
role='model',
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
|
||||
|
||||
mock_client_constructor.assert_called_once()
|
||||
_, kwargs = mock_client_constructor.call_args
|
||||
assert kwargs['vertexai']
|
||||
assert kwargs['http_options'].api_version == 'v1beta'
|
||||
|
||||
mock_client_instance.aio.models.generate_content.assert_called_once()
|
||||
call_kwargs = (
|
||||
mock_client_instance.aio.models.generate_content.call_args.kwargs
|
||||
)
|
||||
assert call_kwargs['model'] == VERTEX_BASE_MODEL_ID
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_proxy_url_from_env_variable(mock_client_constructor):
|
||||
"""Tests that proxy_url is read from environment variable."""
|
||||
with mock.patch.dict(
|
||||
os.environ, {'APIGEE_PROXY_URL': 'https://env.proxy.url'}
|
||||
):
|
||||
apigee_llm = ApigeeLlm(model=APIGEE_GEMINI_MODEL_ID)
|
||||
llm_request = LlmRequest(
|
||||
model=APIGEE_GEMINI_MODEL_ID,
|
||||
contents=[
|
||||
types.Content(
|
||||
role='user', parts=[types.Part.from_text(text='Test prompt')]
|
||||
)
|
||||
],
|
||||
)
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_client_instance.aio.models.generate_content = AsyncMock(
|
||||
return_value=types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
parts=[Part.from_text(text='Test response')],
|
||||
role='model',
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
_ = [resp async for resp in apigee_llm.generate_content_async(llm_request)]
|
||||
|
||||
mock_client_constructor.assert_called_once()
|
||||
_, kwargs = mock_client_constructor.call_args
|
||||
assert kwargs['http_options'].base_url == 'https://env.proxy.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
'model_string',
|
||||
'use_vertexai_env',
|
||||
'expected_is_vertexai',
|
||||
'expected_api_version',
|
||||
'expected_model_id',
|
||||
),
|
||||
[
|
||||
('apigee/gemini-2.5-flash', None, False, None, 'gemini-2.5-flash'),
|
||||
('apigee/gemini-2.5-flash', 'true', True, None, 'gemini-2.5-flash'),
|
||||
('apigee/gemini-2.5-flash', '1', True, None, 'gemini-2.5-flash'),
|
||||
('apigee/gemini-2.5-flash', 'false', False, None, 'gemini-2.5-flash'),
|
||||
('apigee/gemini-2.5-flash', '0', False, None, 'gemini-2.5-flash'),
|
||||
(
|
||||
'apigee/v1/gemini-2.5-flash',
|
||||
None,
|
||||
False,
|
||||
'v1',
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
(
|
||||
'apigee/v1/gemini-2.5-flash',
|
||||
'true',
|
||||
True,
|
||||
'v1',
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
(
|
||||
'apigee/vertex_ai/gemini-2.5-flash',
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
(
|
||||
'apigee/vertex_ai/gemini-2.5-flash',
|
||||
'false',
|
||||
True,
|
||||
None,
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
(
|
||||
'apigee/gemini/v1/gemini-2.5-flash',
|
||||
'true',
|
||||
False,
|
||||
'v1',
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
(
|
||||
'apigee/vertex_ai/v1beta/gemini-2.5-flash',
|
||||
'false',
|
||||
True,
|
||||
'v1beta',
|
||||
'gemini-2.5-flash',
|
||||
),
|
||||
],
|
||||
)
|
||||
@mock.patch('google.adk.models.apigee_llm.Client')
|
||||
async def test_model_string_parsing_and_client_initialization(
|
||||
mock_client_constructor,
|
||||
model_string,
|
||||
use_vertexai_env,
|
||||
expected_is_vertexai,
|
||||
expected_api_version,
|
||||
expected_model_id,
|
||||
):
|
||||
"""Tests model string parsing and genai.Client initialization."""
|
||||
env_vars = {}
|
||||
if use_vertexai_env is not None:
|
||||
env_vars['GOOGLE_GENAI_USE_VERTEXAI'] = use_vertexai_env
|
||||
|
||||
# The ApigeeLlm is initialized in the 'with' block to make sure that the mock
|
||||
# of the environment variable is active.
|
||||
with mock.patch.dict(os.environ, env_vars, clear=True):
|
||||
apigee_llm = ApigeeLlm(model=model_string, proxy_url=PROXY_URL)
|
||||
request = LlmRequest(model=model_string, contents=[])
|
||||
|
||||
mock_client_instance = mock.Mock()
|
||||
mock_client_instance.aio.models.generate_content = AsyncMock(
|
||||
return_value=types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(parts=[Part.from_text(text='')])
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
mock_client_constructor.return_value = mock_client_instance
|
||||
|
||||
_ = [resp async for resp in apigee_llm.generate_content_async(request)]
|
||||
|
||||
mock_client_constructor.assert_called_once()
|
||||
_, kwargs = mock_client_constructor.call_args
|
||||
assert kwargs['vertexai'] == expected_is_vertexai
|
||||
http_options = kwargs['http_options']
|
||||
assert http_options.api_version == expected_api_version
|
||||
|
||||
(
|
||||
mock_client_instance.aio.models.generate_content.assert_called_once_with(
|
||||
model=expected_model_id,
|
||||
contents=request.contents,
|
||||
config=request.config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'invalid_model_string',
|
||||
[
|
||||
'apigee/openai/v1/gpt',
|
||||
'apigee/', # Missing model_id
|
||||
'apigee', # Invalid format
|
||||
'gemini-pro', # Invalid format
|
||||
'apigee/vertex_ai/v1/model/extra', # Too many components
|
||||
'apigee/unknown/model',
|
||||
],
|
||||
)
|
||||
async def test_invalid_model_strings_raise_value_error(invalid_model_string):
|
||||
"""Tests that invalid model strings raise a ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError, match=f'Invalid model string: {invalid_model_string}'
|
||||
):
|
||||
ApigeeLlm(model=invalid_model_string, proxy_url=PROXY_URL)
|
||||
Reference in New Issue
Block a user