You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Add support for injecting a custom google.genai.Client into Gemini models
This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 874628604
This commit is contained in:
committed by
Copybara-Service
parent
8c0bd2034c
commit
c615757ba1
@@ -85,6 +85,23 @@ class Gemini(BaseLlm):
|
||||
|
||||
Attributes:
|
||||
model: The name of the Gemini model.
|
||||
client: An optional preconfigured ``google.genai.Client`` instance.
|
||||
When provided, ADK uses this client for all API calls instead of
|
||||
creating one internally from environment variables or ADC. This
|
||||
allows fine-grained control over authentication, project, location,
|
||||
and other client-level settings — and enables running agents that
|
||||
target different Vertex AI regions within the same process.
|
||||
|
||||
Example::
|
||||
|
||||
from google import genai
|
||||
from google.adk.models import Gemini
|
||||
|
||||
client = genai.Client(
|
||||
vertexai=True, project="my-project", location="us-central1"
|
||||
)
|
||||
model = Gemini(model="gemini-2.5-flash", client=client)
|
||||
|
||||
use_interactions_api: Whether to use the interactions API for model
|
||||
invocation.
|
||||
"""
|
||||
@@ -131,6 +148,35 @@ class Gemini(BaseLlm):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *, client: Optional[Client] = None, **kwargs: Any):
|
||||
"""Initialises a Gemini model wrapper.
|
||||
|
||||
Args:
|
||||
client: An optional preconfigured ``google.genai.Client``. When
|
||||
provided, ADK uses this client for **all** Gemini API calls
|
||||
(including the Live API) instead of creating one internally.
|
||||
|
||||
.. note::
|
||||
When a custom client is supplied it is used as-is for Live API
|
||||
connections. ADK will **not** override the client's
|
||||
``api_version``; you are responsible for setting the correct
|
||||
version (``v1beta1`` for Vertex AI, ``v1alpha`` for the
|
||||
Gemini developer API) on the client yourself.
|
||||
|
||||
.. warning::
|
||||
``google.genai.Client`` contains threading primitives that
|
||||
cannot be pickled. If you are deploying to Agent Engine (or
|
||||
any environment that serialises the model), do **not** pass a
|
||||
custom client — let ADK create one from the environment
|
||||
instead.
|
||||
|
||||
**kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor
|
||||
(``model``, ``base_url``, ``retry_options``, etc.).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
# Store after super().__init__ so Pydantic validation runs first.
|
||||
object.__setattr__(self, '_client', client)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_models(cls) -> list[str]:
|
||||
@@ -299,9 +345,16 @@ class Gemini(BaseLlm):
|
||||
def api_client(self) -> Client:
|
||||
"""Provides the api client.
|
||||
|
||||
If a preconfigured ``client`` was passed to the constructor it is
|
||||
returned directly; otherwise a new ``Client`` is created using the
|
||||
default environment/ADC configuration.
|
||||
|
||||
Returns:
|
||||
The api client.
|
||||
"""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
from google.genai import Client
|
||||
|
||||
return Client(
|
||||
@@ -334,6 +387,9 @@ class Gemini(BaseLlm):
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
from google.genai import Client
|
||||
|
||||
return Client(
|
||||
|
||||
@@ -2140,3 +2140,153 @@ async def test_connect_speech_config_remains_none_when_both_are_none(
|
||||
# Verify the final speech_config is still None
|
||||
assert config_arg.speech_config is None
|
||||
assert isinstance(connection, GeminiLlmConnection)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for custom client injection (Issue #2560)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_custom_client_is_used_for_api_client():
|
||||
"""When a custom client is provided, api_client returns it directly."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
assert gemini.api_client is custom_client
|
||||
|
||||
|
||||
def test_custom_client_is_used_for_live_api_client():
|
||||
"""When a custom client is provided, _live_api_client returns it directly."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
assert gemini._live_api_client is custom_client
|
||||
|
||||
|
||||
def test_default_api_client_when_no_custom_client():
|
||||
"""Without a custom client, api_client creates a default Client."""
|
||||
gemini = Gemini(model="gemini-1.5-flash")
|
||||
|
||||
# api_client should construct a real Client (not None)
|
||||
client = gemini.api_client
|
||||
assert client is not None
|
||||
# Verify it is not a mock — it's a real google.genai.Client
|
||||
from google.genai import Client
|
||||
|
||||
assert isinstance(client, Client)
|
||||
|
||||
|
||||
def test_default_live_api_client_when_no_custom_client():
|
||||
"""Without a custom client, _live_api_client creates a default Client."""
|
||||
gemini = Gemini(model="gemini-1.5-flash")
|
||||
|
||||
client = gemini._live_api_client
|
||||
assert client is not None
|
||||
from google.genai import Client
|
||||
|
||||
assert isinstance(client, Client)
|
||||
|
||||
|
||||
def test_custom_client_api_backend_vertexai():
|
||||
"""_api_backend reflects the custom client's vertexai setting."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
custom_client.vertexai = True
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI
|
||||
|
||||
|
||||
def test_custom_client_api_backend_gemini_api():
|
||||
"""_api_backend reflects non-vertexai custom client."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
custom_client.vertexai = False
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
assert gemini._api_backend == GoogleLLMVariant.GEMINI_API
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_client_used_for_generate_content():
|
||||
"""Custom client is used when generate_content_async is called."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
custom_client.vertexai = False
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
generate_content_response = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part.from_text(text="Hello from custom client")],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def mock_coro():
|
||||
return generate_content_response
|
||||
|
||||
custom_client.aio.models.generate_content.return_value = mock_coro()
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="You are a helpful assistant",
|
||||
),
|
||||
)
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini.generate_content_async(llm_request, stream=False)
|
||||
]
|
||||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].content.parts[0].text == "Hello from custom client"
|
||||
custom_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_client_used_for_live_connect():
|
||||
"""Custom client is used for live API streaming connections."""
|
||||
from google.genai import Client
|
||||
|
||||
custom_client = mock.MagicMock(spec=Client)
|
||||
custom_client.vertexai = False
|
||||
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
|
||||
|
||||
mock_live_session = mock.AsyncMock()
|
||||
|
||||
class MockLiveConnect:
|
||||
|
||||
async def __aenter__(self):
|
||||
return mock_live_session
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
custom_client.aio.live.connect.return_value = MockLiveConnect()
|
||||
|
||||
llm_request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="You are a helpful assistant",
|
||||
),
|
||||
)
|
||||
llm_request.live_connect_config = types.LiveConnectConfig()
|
||||
|
||||
async with gemini.connect(llm_request) as connection:
|
||||
custom_client.aio.live.connect.assert_called_once()
|
||||
assert isinstance(connection, GeminiLlmConnection)
|
||||
|
||||
Reference in New Issue
Block a user