feat: Add Gemma3Ollama model integration and a sample

This change introduces `Gemma3Ollama`, a new LLM model class for running Gemma 3 models locally via Ollama, leveraging LiteLLM. The function calling logic previously in the `Gemma` class has been refactored into a `GemmaFunctionCallingMixin` and is now used by both `Gemma` and `Gemma3Ollama`. A new sample application, `hello_world_gemma3_ollama`, is added to demonstrate using `Gemma3Ollama` with an agent. Unit tests for `Gemma3Ollama` are also included.

Merge: https://github.com/google/adk-python/pull/3120

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 839996879
This commit is contained in:
Doug Reid
2025-12-03 18:25:31 -08:00
committed by Copybara-Service
parent b0c3cc6e36
commit e9182e5eb4
6 changed files with 356 additions and 60 deletions
@@ -0,0 +1,16 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,93 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from google.adk.agents.llm_agent import Agent
from google.adk.models import Gemma3Ollama
litellm_logger = logging.getLogger("LiteLLM")
litellm_logger.setLevel(logging.WARNING)
def roll_die(sides: int) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
return random.randint(1, sides)
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
"No prime numbers found."
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model=Gemma3Ollama(),
name="data_processing_agent",
description=(
"hello world agent that can roll a dice of 8 sides and check prime"
" numbers."
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel (in one request and in one round).
It is ok to discuss previous dice rolls, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
)
@@ -0,0 +1,77 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import agent
from dotenv import load_dotenv
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.genai import types
load_dotenv(override=True)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_1, 'Hi, introduce yourself.')
await run_prompt(
session_1, 'Roll a die with 100 sides and check if it is prime'
)
await run_prompt(session_1, 'Roll it again.')
await run_prompt(session_1, 'What numbers did I get?')
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())
+10
View File
@@ -53,3 +53,13 @@ try:
except Exception:
# LiteLLM support requires: pip install google-adk[extensions]
pass
# Optionally register Gemma3Ollama if litellm package is installed
try:
from .gemma_llm import Gemma3Ollama
LLMRegistry.register(Gemma3Ollama)
__all__.append('Gemma3Ollama')
except Exception:
# Gemma3Ollama requires LiteLLM: pip install google-adk[extensions]
pass
+135 -60
View File
@@ -38,67 +38,21 @@ from typing_extensions import override
logger = logging.getLogger('google_adk.' + __name__)
class GemmaFunctionCallModel(BaseModel):
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
class GemmaFunctionCallingMixin:
"""Mixin providing function calling support for Gemma models.
name: str = Field(validation_alias=AliasChoices('name', 'function'))
parameters: dict[str, Any] = Field(
validation_alias=AliasChoices('parameters', 'args')
)
class Gemma(Gemini):
"""Integration for Gemma models exposed via the Gemini API.
Only Gemma 3 models are supported at this time. For agentic use cases,
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
For full documentation, see: https://ai.google.dev/gemma/docs/core/
NOTE: Gemma does **NOT** support system instructions. Any system instructions
will be replaced with an initial *user* prompt in the LLM request. If system
instructions change over the course of agent execution, the initial content
**SHOULD** be replaced. Special care is warranted here.
See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
NOTE: Gemma's function calling support is limited. It does not have full access to the
same built-in tools as Gemini. It also does not have special API support for tools and
functions. Rather, tools must be passed in via a `user` prompt, and extracted from model
responses based on approximate shape.
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports
usage via the Gemini API.
Gemma models don't have native function calling support, so this mixin
provides the logic to:
1. Convert function declarations to system instruction prompts
2. Convert function call/response parts to text in the conversation
3. Extract function calls from model text responses
"""
model: str = (
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
)
@classmethod
@override
def supported_models(cls) -> list[str]:
"""Provides the list of supported models.
Returns:
A list of supported models.
"""
return [
r'gemma-3.*',
]
@cached_property
def _api_backend(self) -> GoogleLLMVariant:
return GoogleLLMVariant.GEMINI_API
def _move_function_calls_into_system_instruction(
self, llm_request: LlmRequest
):
if llm_request.model is None or not llm_request.model.startswith('gemma-3'):
return
# Iterate through the existing contents to find and convert function calls and responses
# from text parts, as Gemma models don't directly support function calling.
) -> None:
"""Converts function declarations to system instructions for Gemma."""
# Convert function calls/responses in contents to text
new_contents: list[Content] = []
for content_item in llm_request.contents:
(
@@ -136,7 +90,10 @@ class Gemma(Gemini):
llm_request.config.tools = []
def _extract_function_calls_from_response(self, llm_response: LlmResponse):
def _extract_function_calls_from_response(
self, llm_response: LlmResponse
) -> None:
"""Extracts function calls from Gemma text responses."""
if llm_response.partial or (llm_response.turn_complete is True):
return
@@ -182,12 +139,78 @@ class Gemma(Gemini):
llm_response.content.parts = [function_call_part]
except (json.JSONDecodeError, ValidationError) as e:
logger.debug(
f'Error attempting to parse JSON into function call. Leaving as text'
f' response. %s',
'Error attempting to parse JSON into function call. Leaving as text'
' response. %s',
e,
)
except Exception as e:
logger.warning('Error processing Gemma function call response: %s', e)
logger.warning(
'Error processing Gemma function call response: %s',
e,
exc_info=True,
)
class GemmaFunctionCallModel(BaseModel):
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
name: str = Field(validation_alias=AliasChoices('name', 'function'))
parameters: dict[str, Any] = Field(
validation_alias=AliasChoices('parameters', 'args')
)
class Gemma(GemmaFunctionCallingMixin, Gemini):
"""Integration for Gemma models exposed via the Gemini API.
Only Gemma 3 models are supported at this time. For agentic use cases,
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
For full documentation, see: https://ai.google.dev/gemma/docs/core/
NOTE: Gemma does **NOT** support system instructions. Any system instructions
will be replaced with an initial *user* prompt in the LLM request. If system
instructions change over the course of agent execution, the initial content
**SHOULD** be replaced. Special care is warranted here.
See:
https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
NOTE: Gemma's function calling support is limited. It does not have full
access to the
same built-in tools as Gemini. It also does not have special API support for
tools and
functions. Rather, tools must be passed in via a `user` prompt, and extracted
from model
responses based on approximate shape.
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY**
supports
usage via the Gemini API.
"""
model: str = (
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model="{self.model}")'
@classmethod
@override
def supported_models(cls) -> list[str]:
"""Provides the list of supported models.
Returns:
A list of supported models.
"""
return [
r'gemma-3.*',
]
@cached_property
def _api_backend(self) -> GoogleLLMVariant:
return GoogleLLMVariant.GEMINI_API
@override
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
@@ -329,3 +352,55 @@ def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]:
if last_json_str:
return True, last_json_str
return False, None
try:
from google.adk.models.lite_llm import LiteLlm # noqa: F401
except Exception:
# LiteLLM not available, Gemma3Ollama will not be defined
LiteLlm = None
if LiteLlm is not None:
class Gemma3Ollama(GemmaFunctionCallingMixin, LiteLlm):
"""Integration for Gemma 3 models running locally via Ollama.
This enables fully local agent workflows using Gemma 3 models.
Requires Ollama to be running with a Gemma 3 model pulled.
Example:
ollama pull gemma3:12b
model = Gemma3Ollama(model="ollama/gemma3:12b")
"""
def __init__(self, model: str = 'ollama/gemma3:12b', **kwargs):
super().__init__(model=model, **kwargs)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model="{self.model}")'
@classmethod
@override
def supported_models(cls) -> list[str]:
return [
r'ollama/gemma3.*',
]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Sends a request to Gemma via Ollama/LiteLLM.
Args:
llm_request: LlmRequest, the request to send.
stream: bool = False, whether to do streaming call.
Yields:
LlmResponse: The model response.
"""
self._move_function_calls_into_system_instruction(llm_request)
async for response in super().generate_content_async(llm_request, stream):
self._extract_function_calls_from_response(response)
yield response
+25
View File
@@ -504,3 +504,28 @@ def test_process_response_last_json_object():
assert part.function_call.name == "second_call"
assert part.function_call.args == {"b": 2}
assert part.text is None
# Tests for Gemma3Ollama (only run when LiteLLM is installed)
try:
from google.adk.models.gemma_llm import Gemma3Ollama
def test_gemma3_ollama_supported_models():
assert Gemma3Ollama.supported_models() == [r"ollama/gemma3.*"]
@pytest.mark.parametrize(
"model_arg,expected_model",
[
(None, "ollama/gemma3:12b"),
("ollama/gemma3:27b", "ollama/gemma3:27b"),
],
)
def test_gemma3_ollama_model(model_arg, expected_model):
model = (
Gemma3Ollama() if model_arg is None else Gemma3Ollama(model=model_arg)
)
assert model.model == expected_model
except ImportError:
# LiteLLM not installed, skip Gemma3Ollama tests
pass