fix: Add support for injecting a custom google.genai.Client into Gemini models

This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings

Close #2560

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 874628604
This commit is contained in:
George Weale
2026-02-24 08:34:05 -08:00
committed by Copybara-Service
parent 8c0bd2034c
commit c615757ba1
2 changed files with 206 additions and 0 deletions
+56
View File
@@ -85,6 +85,23 @@ class Gemini(BaseLlm):
Attributes:
model: The name of the Gemini model.
client: An optional preconfigured ``google.genai.Client`` instance.
When provided, ADK uses this client for all API calls instead of
creating one internally from environment variables or ADC. This
allows fine-grained control over authentication, project, location,
and other client-level settings — and enables running agents that
target different Vertex AI regions within the same process.
Example::
from google import genai
from google.adk.models import Gemini
client = genai.Client(
vertexai=True, project="my-project", location="us-central1"
)
model = Gemini(model="gemini-2.5-flash", client=client)
use_interactions_api: Whether to use the interactions API for model
invocation.
"""
@@ -131,6 +148,35 @@ class Gemini(BaseLlm):
```
"""
def __init__(self, *, client: Optional[Client] = None, **kwargs: Any):
"""Initialises a Gemini model wrapper.
Args:
client: An optional preconfigured ``google.genai.Client``. When
provided, ADK uses this client for **all** Gemini API calls
(including the Live API) instead of creating one internally.
.. note::
When a custom client is supplied it is used as-is for Live API
connections. ADK will **not** override the client's
``api_version``; you are responsible for setting the correct
version (``v1beta1`` for Vertex AI, ``v1alpha`` for the
Gemini developer API) on the client yourself.
.. warning::
``google.genai.Client`` contains threading primitives that
cannot be pickled. If you are deploying to Agent Engine (or
any environment that serialises the model), do **not** pass a
custom client — let ADK create one from the environment
instead.
**kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor
(``model``, ``base_url``, ``retry_options``, etc.).
"""
super().__init__(**kwargs)
# Store after super().__init__ so Pydantic validation runs first.
object.__setattr__(self, '_client', client)
@classmethod
@override
def supported_models(cls) -> list[str]:
@@ -299,9 +345,16 @@ class Gemini(BaseLlm):
def api_client(self) -> Client:
"""Provides the api client.
If a preconfigured ``client`` was passed to the constructor it is
returned directly; otherwise a new ``Client`` is created using the
default environment/ADC configuration.
Returns:
The api client.
"""
if self._client is not None:
return self._client
from google.genai import Client
return Client(
@@ -334,6 +387,9 @@ class Gemini(BaseLlm):
@cached_property
def _live_api_client(self) -> Client:
if self._client is not None:
return self._client
from google.genai import Client
return Client(
+150
View File
@@ -2140,3 +2140,153 @@ async def test_connect_speech_config_remains_none_when_both_are_none(
# Verify the final speech_config is still None
assert config_arg.speech_config is None
assert isinstance(connection, GeminiLlmConnection)
# ---------------------------------------------------------------------------
# Tests for custom client injection (Issue #2560)
# ---------------------------------------------------------------------------
def test_custom_client_is_used_for_api_client():
"""When a custom client is provided, api_client returns it directly."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
assert gemini.api_client is custom_client
def test_custom_client_is_used_for_live_api_client():
"""When a custom client is provided, _live_api_client returns it directly."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
assert gemini._live_api_client is custom_client
def test_default_api_client_when_no_custom_client():
"""Without a custom client, api_client creates a default Client."""
gemini = Gemini(model="gemini-1.5-flash")
# api_client should construct a real Client (not None)
client = gemini.api_client
assert client is not None
# Verify it is not a mock — it's a real google.genai.Client
from google.genai import Client
assert isinstance(client, Client)
def test_default_live_api_client_when_no_custom_client():
"""Without a custom client, _live_api_client creates a default Client."""
gemini = Gemini(model="gemini-1.5-flash")
client = gemini._live_api_client
assert client is not None
from google.genai import Client
assert isinstance(client, Client)
def test_custom_client_api_backend_vertexai():
"""_api_backend reflects the custom client's vertexai setting."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
custom_client.vertexai = True
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI
def test_custom_client_api_backend_gemini_api():
"""_api_backend reflects non-vertexai custom client."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
custom_client.vertexai = False
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
assert gemini._api_backend == GoogleLLMVariant.GEMINI_API
@pytest.mark.asyncio
async def test_custom_client_used_for_generate_content():
"""Custom client is used when generate_content_async is called."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
custom_client.vertexai = False
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
generate_content_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model",
parts=[Part.from_text(text="Hello from custom client")],
),
finish_reason=types.FinishReason.STOP,
)
]
)
async def mock_coro():
return generate_content_response
custom_client.aio.models.generate_content.return_value = mock_coro()
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
system_instruction="You are a helpful assistant",
),
)
responses = [
resp
async for resp in gemini.generate_content_async(llm_request, stream=False)
]
assert len(responses) == 1
assert responses[0].content.parts[0].text == "Hello from custom client"
custom_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_custom_client_used_for_live_connect():
"""Custom client is used for live API streaming connections."""
from google.genai import Client
custom_client = mock.MagicMock(spec=Client)
custom_client.vertexai = False
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
mock_live_session = mock.AsyncMock()
class MockLiveConnect:
async def __aenter__(self):
return mock_live_session
async def __aexit__(self, *args):
pass
custom_client.aio.live.connect.return_value = MockLiveConnect()
llm_request = LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
system_instruction="You are a helpful assistant",
),
)
llm_request.live_connect_config = types.LiveConnectConfig()
async with gemini.connect(llm_request) as connection:
custom_client.aio.live.connect.assert_called_once()
assert isinstance(connection, GeminiLlmConnection)