You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add Gemma3Ollama model integration and a sample
This change introduces `Gemma3Ollama`, a new LLM model class for running Gemma 3 models locally via Ollama, leveraging LiteLLM. The function calling logic previously in the `Gemma` class has been refactored into a `GemmaFunctionCallingMixin` and is now used by both `Gemma` and `Gemma3Ollama`. A new sample application, `hello_world_gemma3_ollama`, is added to demonstrate using `Gemma3Ollama` with an agent. Unit tests for `Gemma3Ollama` are also included. Merge: https://github.com/google/adk-python/pull/3120 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 839996879
This commit is contained in:
committed by
Copybara-Service
parent
b0c3cc6e36
commit
e9182e5eb4
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models import Gemma3Ollama
|
||||
|
||||
litellm_logger = logging.getLogger("LiteLLM")
|
||||
litellm_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def roll_die(sides: int) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
return random.randint(1, sides)
|
||||
|
||||
|
||||
async def check_prime(nums: list[int]) -> str:
|
||||
"""Check if a given list of numbers are prime.
|
||||
|
||||
Args:
|
||||
nums: The list of numbers to check.
|
||||
|
||||
Returns:
|
||||
A str indicating which number is prime.
|
||||
"""
|
||||
primes = set()
|
||||
for number in nums:
|
||||
number = int(number)
|
||||
if number <= 1:
|
||||
continue
|
||||
is_prime = True
|
||||
for i in range(2, int(number**0.5) + 1):
|
||||
if number % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.add(number)
|
||||
return (
|
||||
"No prime numbers found."
|
||||
if not primes
|
||||
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||
)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model=Gemma3Ollama(),
|
||||
name="data_processing_agent",
|
||||
description=(
|
||||
"hello world agent that can roll a dice of 8 sides and check prime"
|
||||
" numbers."
|
||||
),
|
||||
instruction="""
|
||||
You roll dice and answer questions about the outcome of the dice rolls.
|
||||
You can roll dice of different sizes.
|
||||
You can use multiple tools in parallel by calling functions in parallel (in one request and in one round).
|
||||
It is ok to discuss previous dice rolls, and comment on the dice rolls.
|
||||
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||
You should never roll a die on your own.
|
||||
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||
You should not check prime numbers before calling the tool.
|
||||
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||
3. When you respond, you must include the roll_die result from step 1.
|
||||
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||
You should not rely on the previous history on prime results.
|
||||
""",
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import agent
|
||||
from dotenv import load_dotenv
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.cli.utils import logs
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
|
||||
load_dotenv(override=True)
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
session_service = InMemorySessionService()
|
||||
artifact_service = InMemoryArtifactService()
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=agent.root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
session_1 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
|
||||
async def run_prompt(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
|
||||
start_time = time.time()
|
||||
print('Start time:', start_time)
|
||||
print('------------------------------------')
|
||||
await run_prompt(session_1, 'Hi, introduce yourself.')
|
||||
await run_prompt(
|
||||
session_1, 'Roll a die with 100 sides and check if it is prime'
|
||||
)
|
||||
await run_prompt(session_1, 'Roll it again.')
|
||||
await run_prompt(session_1, 'What numbers did I get?')
|
||||
end_time = time.time()
|
||||
print('------------------------------------')
|
||||
print('End time:', end_time)
|
||||
print('Total time:', end_time - start_time)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -53,3 +53,13 @@ try:
|
||||
except Exception:
|
||||
# LiteLLM support requires: pip install google-adk[extensions]
|
||||
pass
|
||||
|
||||
# Optionally register Gemma3Ollama if litellm package is installed
|
||||
try:
|
||||
from .gemma_llm import Gemma3Ollama
|
||||
|
||||
LLMRegistry.register(Gemma3Ollama)
|
||||
__all__.append('Gemma3Ollama')
|
||||
except Exception:
|
||||
# Gemma3Ollama requires LiteLLM: pip install google-adk[extensions]
|
||||
pass
|
||||
|
||||
@@ -38,67 +38,21 @@ from typing_extensions import override
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class GemmaFunctionCallModel(BaseModel):
|
||||
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
|
||||
class GemmaFunctionCallingMixin:
|
||||
"""Mixin providing function calling support for Gemma models.
|
||||
|
||||
name: str = Field(validation_alias=AliasChoices('name', 'function'))
|
||||
parameters: dict[str, Any] = Field(
|
||||
validation_alias=AliasChoices('parameters', 'args')
|
||||
)
|
||||
|
||||
|
||||
class Gemma(Gemini):
|
||||
"""Integration for Gemma models exposed via the Gemini API.
|
||||
|
||||
Only Gemma 3 models are supported at this time. For agentic use cases,
|
||||
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
|
||||
|
||||
For full documentation, see: https://ai.google.dev/gemma/docs/core/
|
||||
|
||||
NOTE: Gemma does **NOT** support system instructions. Any system instructions
|
||||
will be replaced with an initial *user* prompt in the LLM request. If system
|
||||
instructions change over the course of agent execution, the initial content
|
||||
**SHOULD** be replaced. Special care is warranted here.
|
||||
See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
|
||||
|
||||
NOTE: Gemma's function calling support is limited. It does not have full access to the
|
||||
same built-in tools as Gemini. It also does not have special API support for tools and
|
||||
functions. Rather, tools must be passed in via a `user` prompt, and extracted from model
|
||||
responses based on approximate shape.
|
||||
|
||||
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports
|
||||
usage via the Gemini API.
|
||||
Gemma models don't have native function calling support, so this mixin
|
||||
provides the logic to:
|
||||
1. Convert function declarations to system instruction prompts
|
||||
2. Convert function call/response parts to text in the conversation
|
||||
3. Extract function calls from model text responses
|
||||
"""
|
||||
|
||||
model: str = (
|
||||
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'gemma-3.*',
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def _api_backend(self) -> GoogleLLMVariant:
|
||||
return GoogleLLMVariant.GEMINI_API
|
||||
|
||||
def _move_function_calls_into_system_instruction(
|
||||
self, llm_request: LlmRequest
|
||||
):
|
||||
if llm_request.model is None or not llm_request.model.startswith('gemma-3'):
|
||||
return
|
||||
|
||||
# Iterate through the existing contents to find and convert function calls and responses
|
||||
# from text parts, as Gemma models don't directly support function calling.
|
||||
) -> None:
|
||||
"""Converts function declarations to system instructions for Gemma."""
|
||||
# Convert function calls/responses in contents to text
|
||||
new_contents: list[Content] = []
|
||||
for content_item in llm_request.contents:
|
||||
(
|
||||
@@ -136,7 +90,10 @@ class Gemma(Gemini):
|
||||
|
||||
llm_request.config.tools = []
|
||||
|
||||
def _extract_function_calls_from_response(self, llm_response: LlmResponse):
|
||||
def _extract_function_calls_from_response(
|
||||
self, llm_response: LlmResponse
|
||||
) -> None:
|
||||
"""Extracts function calls from Gemma text responses."""
|
||||
if llm_response.partial or (llm_response.turn_complete is True):
|
||||
return
|
||||
|
||||
@@ -182,12 +139,78 @@ class Gemma(Gemini):
|
||||
llm_response.content.parts = [function_call_part]
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
logger.debug(
|
||||
f'Error attempting to parse JSON into function call. Leaving as text'
|
||||
f' response. %s',
|
||||
'Error attempting to parse JSON into function call. Leaving as text'
|
||||
' response. %s',
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning('Error processing Gemma function call response: %s', e)
|
||||
logger.warning(
|
||||
'Error processing Gemma function call response: %s',
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class GemmaFunctionCallModel(BaseModel):
|
||||
"""Flexible Pydantic model for parsing inline Gemma function call responses."""
|
||||
|
||||
name: str = Field(validation_alias=AliasChoices('name', 'function'))
|
||||
parameters: dict[str, Any] = Field(
|
||||
validation_alias=AliasChoices('parameters', 'args')
|
||||
)
|
||||
|
||||
|
||||
class Gemma(GemmaFunctionCallingMixin, Gemini):
|
||||
"""Integration for Gemma models exposed via the Gemini API.
|
||||
|
||||
Only Gemma 3 models are supported at this time. For agentic use cases,
|
||||
use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended.
|
||||
|
||||
For full documentation, see: https://ai.google.dev/gemma/docs/core/
|
||||
|
||||
NOTE: Gemma does **NOT** support system instructions. Any system instructions
|
||||
will be replaced with an initial *user* prompt in the LLM request. If system
|
||||
instructions change over the course of agent execution, the initial content
|
||||
**SHOULD** be replaced. Special care is warranted here.
|
||||
See:
|
||||
https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions
|
||||
|
||||
NOTE: Gemma's function calling support is limited. It does not have full
|
||||
access to the
|
||||
same built-in tools as Gemini. It also does not have special API support for
|
||||
tools and
|
||||
functions. Rather, tools must be passed in via a `user` prompt, and extracted
|
||||
from model
|
||||
responses based on approximate shape.
|
||||
|
||||
NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY**
|
||||
supports
|
||||
usage via the Gemini API.
|
||||
"""
|
||||
|
||||
model: str = (
|
||||
'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(model="{self.model}")'
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'gemma-3.*',
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def _api_backend(self) -> GoogleLLMVariant:
|
||||
return GoogleLLMVariant.GEMINI_API
|
||||
|
||||
@override
|
||||
async def _preprocess_request(self, llm_request: LlmRequest) -> None:
|
||||
@@ -329,3 +352,55 @@ def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]:
|
||||
if last_json_str:
|
||||
return True, last_json_str
|
||||
return False, None
|
||||
|
||||
|
||||
try:
|
||||
from google.adk.models.lite_llm import LiteLlm # noqa: F401
|
||||
except Exception:
|
||||
# LiteLLM not available, Gemma3Ollama will not be defined
|
||||
LiteLlm = None
|
||||
|
||||
if LiteLlm is not None:
|
||||
|
||||
class Gemma3Ollama(GemmaFunctionCallingMixin, LiteLlm):
|
||||
"""Integration for Gemma 3 models running locally via Ollama.
|
||||
|
||||
This enables fully local agent workflows using Gemma 3 models.
|
||||
Requires Ollama to be running with a Gemma 3 model pulled.
|
||||
|
||||
Example:
|
||||
ollama pull gemma3:12b
|
||||
model = Gemma3Ollama(model="ollama/gemma3:12b")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = 'ollama/gemma3:12b', **kwargs):
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(model="{self.model}")'
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def supported_models(cls) -> list[str]:
|
||||
return [
|
||||
r'ollama/gemma3.*',
|
||||
]
|
||||
|
||||
@override
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Sends a request to Gemma via Ollama/LiteLLM.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
self._move_function_calls_into_system_instruction(llm_request)
|
||||
|
||||
async for response in super().generate_content_async(llm_request, stream):
|
||||
self._extract_function_calls_from_response(response)
|
||||
yield response
|
||||
|
||||
@@ -504,3 +504,28 @@ def test_process_response_last_json_object():
|
||||
assert part.function_call.name == "second_call"
|
||||
assert part.function_call.args == {"b": 2}
|
||||
assert part.text is None
|
||||
|
||||
|
||||
# Tests for Gemma3Ollama (only run when LiteLLM is installed)
|
||||
try:
|
||||
from google.adk.models.gemma_llm import Gemma3Ollama
|
||||
|
||||
def test_gemma3_ollama_supported_models():
|
||||
assert Gemma3Ollama.supported_models() == [r"ollama/gemma3.*"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_arg,expected_model",
|
||||
[
|
||||
(None, "ollama/gemma3:12b"),
|
||||
("ollama/gemma3:27b", "ollama/gemma3:27b"),
|
||||
],
|
||||
)
|
||||
def test_gemma3_ollama_model(model_arg, expected_model):
|
||||
model = (
|
||||
Gemma3Ollama() if model_arg is None else Gemma3Ollama(model=model_arg)
|
||||
)
|
||||
assert model.model == expected_model
|
||||
|
||||
except ImportError:
|
||||
# LiteLLM not installed, skip Gemma3Ollama tests
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user