diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 23c9c278..b8c5117e 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -85,6 +85,23 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. + client: An optional preconfigured ``google.genai.Client`` instance. + When provided, ADK uses this client for all API calls instead of + creating one internally from environment variables or ADC. This + allows fine-grained control over authentication, project, location, + and other client-level settings — and enables running agents that + target different Vertex AI regions within the same process. + + Example:: + + from google import genai + from google.adk.models import Gemini + + client = genai.Client( + vertexai=True, project="my-project", location="us-central1" + ) + model = Gemini(model="gemini-2.5-flash", client=client) + use_interactions_api: Whether to use the interactions API for model invocation. """ @@ -131,6 +148,35 @@ class Gemini(BaseLlm): ``` """ + def __init__(self, *, client: Optional[Client] = None, **kwargs: Any): + """Initialises a Gemini model wrapper. + + Args: + client: An optional preconfigured ``google.genai.Client``. When + provided, ADK uses this client for **all** Gemini API calls + (including the Live API) instead of creating one internally. + + .. note:: + When a custom client is supplied it is used as-is for Live API + connections. ADK will **not** override the client's + ``api_version``; you are responsible for setting the correct + version (``v1beta1`` for Vertex AI, ``v1alpha`` for the + Gemini developer API) on the client yourself. + + .. warning:: + ``google.genai.Client`` contains threading primitives that + cannot be pickled. If you are deploying to Agent Engine (or + any environment that serialises the model), do **not** pass a + custom client — let ADK create one from the environment + instead. + + **kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor + (``model``, ``base_url``, ``retry_options``, etc.). + """ + super().__init__(**kwargs) + # Store after super().__init__ so Pydantic validation runs first. + object.__setattr__(self, '_client', client) + @classmethod @override def supported_models(cls) -> list[str]: @@ -299,9 +345,16 @@ class Gemini(BaseLlm): def api_client(self) -> Client: """Provides the api client. + If a preconfigured ``client`` was passed to the constructor it is + returned directly; otherwise a new ``Client`` is created using the + default environment/ADC configuration. + Returns: The api client. """ + if self._client is not None: + return self._client + from google.genai import Client return Client( @@ -334,6 +387,9 @@ class Gemini(BaseLlm): @cached_property def _live_api_client(self) -> Client: + if self._client is not None: + return self._client + from google.genai import Client return Client( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 70aa01b6..75d4c0fd 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2140,3 +2140,153 @@ async def test_connect_speech_config_remains_none_when_both_are_none( # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) + + +# --------------------------------------------------------------------------- +# Tests for custom client injection (Issue #2560) +# --------------------------------------------------------------------------- + + +def test_custom_client_is_used_for_api_client(): + """When a custom client is provided, api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini.api_client is custom_client + + +def test_custom_client_is_used_for_live_api_client(): + """When a custom client is provided, _live_api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._live_api_client is custom_client + + +def test_default_api_client_when_no_custom_client(): + """Without a custom client, api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + # api_client should construct a real Client (not None) + client = gemini.api_client + assert client is not None + # Verify it is not a mock — it's a real google.genai.Client + from google.genai import Client + + assert isinstance(client, Client) + + +def test_default_live_api_client_when_no_custom_client(): + """Without a custom client, _live_api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + client = gemini._live_api_client + assert client is not None + from google.genai import Client + + assert isinstance(client, Client) + + +def test_custom_client_api_backend_vertexai(): + """_api_backend reflects the custom client's vertexai setting.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = True + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI + + +def test_custom_client_api_backend_gemini_api(): + """_api_backend reflects non-vertexai custom client.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.GEMINI_API + + +@pytest.mark.asyncio +async def test_custom_client_used_for_generate_content(): + """Custom client is used when generate_content_async is called.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part.from_text(text="Hello from custom client")], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + async def mock_coro(): + return generate_content_response + + custom_client.aio.models.generate_content.return_value = mock_coro() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + + responses = [ + resp + async for resp in gemini.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + assert responses[0].content.parts[0].text == "Hello from custom client" + custom_client.aio.models.generate_content.assert_called_once() + + +@pytest.mark.asyncio +async def test_custom_client_used_for_live_connect(): + """Custom client is used for live API streaming connections.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + mock_live_session = mock.AsyncMock() + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + custom_client.aio.live.connect.return_value = MockLiveConnect() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + llm_request.live_connect_config = types.LiveConnectConfig() + + async with gemini.connect(llm_request) as connection: + custom_client.aio.live.connect.assert_called_once() + assert isinstance(connection, GeminiLlmConnection)