From e9182e5eb4a37fb5219fc607cd8f06d7e6982e83 Mon Sep 17 00:00:00 2001 From: Doug Reid <21148125+douglas-reid@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:25:31 -0800 Subject: [PATCH] feat: Add Gemma3Ollama model integration and a sample This change introduces `Gemma3Ollama`, a new LLM model class for running Gemma 3 models locally via Ollama, leveraging LiteLLM. The function calling logic previously in the `Gemma` class has been refactored into a `GemmaFunctionCallingMixin` and is now used by both `Gemma` and `Gemma3Ollama`. A new sample application, `hello_world_gemma3_ollama`, is added to demonstrate using `Gemma3Ollama` with an agent. Unit tests for `Gemma3Ollama` are also included. Merge: https://github.com/google/adk-python/pull/3120 Co-authored-by: George Weale PiperOrigin-RevId: 839996879 --- .../hello_world_gemma3_ollama/__init__.py | 16 ++ .../hello_world_gemma3_ollama/agent.py | 93 +++++++++ .../samples/hello_world_gemma3_ollama/main.py | 77 +++++++ src/google/adk/models/__init__.py | 10 + src/google/adk/models/gemma_llm.py | 195 ++++++++++++------ tests/unittests/models/test_gemma_llm.py | 25 +++ 6 files changed, 356 insertions(+), 60 deletions(-) create mode 100644 contributing/samples/hello_world_gemma3_ollama/__init__.py create mode 100644 contributing/samples/hello_world_gemma3_ollama/agent.py create mode 100644 contributing/samples/hello_world_gemma3_ollama/main.py diff --git a/contributing/samples/hello_world_gemma3_ollama/__init__.py b/contributing/samples/hello_world_gemma3_ollama/__init__.py new file mode 100644 index 00000000..7d5bb0b1 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from . import agent diff --git a/contributing/samples/hello_world_gemma3_ollama/agent.py b/contributing/samples/hello_world_gemma3_ollama/agent.py new file mode 100644 index 00000000..58294e56 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/agent.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models import Gemma3Ollama + +litellm_logger = logging.getLogger("LiteLLM") +litellm_logger.setLevel(logging.WARNING) + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model=Gemma3Ollama(), + name="data_processing_agent", + description=( + "hello world agent that can roll a dice of 8 sides and check prime" + " numbers." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel (in one request and in one round). + It is ok to discuss previous dice rolls, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_gemma3_ollama/main.py b/contributing/samples/hello_world_gemma3_ollama/main.py new file mode 100644 index 00000000..a383b4f2 --- /dev/null +++ b/contributing/samples/hello_world_gemma3_ollama/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_1 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_1, 'Hi, introduce yourself.') + await run_prompt( + session_1, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_1, 'Roll it again.') + await run_prompt(session_1, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 1be0cc69..d190dcf9 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -53,3 +53,13 @@ try: except Exception: # LiteLLM support requires: pip install google-adk[extensions] pass + +# Optionally register Gemma3Ollama if litellm package is installed +try: + from .gemma_llm import Gemma3Ollama + + LLMRegistry.register(Gemma3Ollama) + __all__.append('Gemma3Ollama') +except Exception: + # Gemma3Ollama requires LiteLLM: pip install google-adk[extensions] + pass diff --git a/src/google/adk/models/gemma_llm.py b/src/google/adk/models/gemma_llm.py index 3233d66f..e45987b9 100644 --- a/src/google/adk/models/gemma_llm.py +++ b/src/google/adk/models/gemma_llm.py @@ -38,67 +38,21 @@ from typing_extensions import override logger = logging.getLogger('google_adk.' + __name__) -class GemmaFunctionCallModel(BaseModel): - """Flexible Pydantic model for parsing inline Gemma function call responses.""" +class GemmaFunctionCallingMixin: + """Mixin providing function calling support for Gemma models. - name: str = Field(validation_alias=AliasChoices('name', 'function')) - parameters: dict[str, Any] = Field( - validation_alias=AliasChoices('parameters', 'args') - ) - - -class Gemma(Gemini): - """Integration for Gemma models exposed via the Gemini API. - - Only Gemma 3 models are supported at this time. For agentic use cases, - use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended. - - For full documentation, see: https://ai.google.dev/gemma/docs/core/ - - NOTE: Gemma does **NOT** support system instructions. Any system instructions - will be replaced with an initial *user* prompt in the LLM request. If system - instructions change over the course of agent execution, the initial content - **SHOULD** be replaced. Special care is warranted here. - See: https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions - - NOTE: Gemma's function calling support is limited. It does not have full access to the - same built-in tools as Gemini. It also does not have special API support for tools and - functions. Rather, tools must be passed in via a `user` prompt, and extracted from model - responses based on approximate shape. - - NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** supports - usage via the Gemini API. + Gemma models don't have native function calling support, so this mixin + provides the logic to: + 1. Convert function declarations to system instruction prompts + 2. Convert function call/response parts to text in the conversation + 3. Extract function calls from model text responses """ - model: str = ( - 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] - ) - - @classmethod - @override - def supported_models(cls) -> list[str]: - """Provides the list of supported models. - - Returns: - A list of supported models. - """ - - return [ - r'gemma-3.*', - ] - - @cached_property - def _api_backend(self) -> GoogleLLMVariant: - return GoogleLLMVariant.GEMINI_API - def _move_function_calls_into_system_instruction( self, llm_request: LlmRequest - ): - if llm_request.model is None or not llm_request.model.startswith('gemma-3'): - return - - # Iterate through the existing contents to find and convert function calls and responses - # from text parts, as Gemma models don't directly support function calling. + ) -> None: + """Converts function declarations to system instructions for Gemma.""" + # Convert function calls/responses in contents to text new_contents: list[Content] = [] for content_item in llm_request.contents: ( @@ -136,7 +90,10 @@ class Gemma(Gemini): llm_request.config.tools = [] - def _extract_function_calls_from_response(self, llm_response: LlmResponse): + def _extract_function_calls_from_response( + self, llm_response: LlmResponse + ) -> None: + """Extracts function calls from Gemma text responses.""" if llm_response.partial or (llm_response.turn_complete is True): return @@ -182,12 +139,78 @@ class Gemma(Gemini): llm_response.content.parts = [function_call_part] except (json.JSONDecodeError, ValidationError) as e: logger.debug( - f'Error attempting to parse JSON into function call. Leaving as text' - f' response. %s', + 'Error attempting to parse JSON into function call. Leaving as text' + ' response. %s', e, ) except Exception as e: - logger.warning('Error processing Gemma function call response: %s', e) + logger.warning( + 'Error processing Gemma function call response: %s', + e, + exc_info=True, + ) + + +class GemmaFunctionCallModel(BaseModel): + """Flexible Pydantic model for parsing inline Gemma function call responses.""" + + name: str = Field(validation_alias=AliasChoices('name', 'function')) + parameters: dict[str, Any] = Field( + validation_alias=AliasChoices('parameters', 'args') + ) + + +class Gemma(GemmaFunctionCallingMixin, Gemini): + """Integration for Gemma models exposed via the Gemini API. + + Only Gemma 3 models are supported at this time. For agentic use cases, + use of gemma-3-27b-it and gemma-3-12b-it are strongly recommended. + + For full documentation, see: https://ai.google.dev/gemma/docs/core/ + + NOTE: Gemma does **NOT** support system instructions. Any system instructions + will be replaced with an initial *user* prompt in the LLM request. If system + instructions change over the course of agent execution, the initial content + **SHOULD** be replaced. Special care is warranted here. + See: + https://ai.google.dev/gemma/docs/core/prompt-structure#system-instructions + + NOTE: Gemma's function calling support is limited. It does not have full + access to the + same built-in tools as Gemini. It also does not have special API support for + tools and + functions. Rather, tools must be passed in via a `user` prompt, and extracted + from model + responses based on approximate shape. + + NOTE: Vertex AI API support for Gemma is not currently included. This **ONLY** + supports + usage via the Gemini API. + """ + + model: str = ( + 'gemma-3-27b-it' # Others: [gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it] + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(model="{self.model}")' + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Provides the list of supported models. + + Returns: + A list of supported models. + """ + + return [ + r'gemma-3.*', + ] + + @cached_property + def _api_backend(self) -> GoogleLLMVariant: + return GoogleLLMVariant.GEMINI_API @override async def _preprocess_request(self, llm_request: LlmRequest) -> None: @@ -329,3 +352,55 @@ def _get_last_valid_json_substring(text: str) -> tuple[bool, str | None]: if last_json_str: return True, last_json_str return False, None + + +try: + from google.adk.models.lite_llm import LiteLlm # noqa: F401 +except Exception: + # LiteLLM not available, Gemma3Ollama will not be defined + LiteLlm = None + +if LiteLlm is not None: + + class Gemma3Ollama(GemmaFunctionCallingMixin, LiteLlm): + """Integration for Gemma 3 models running locally via Ollama. + + This enables fully local agent workflows using Gemma 3 models. + Requires Ollama to be running with a Gemma 3 model pulled. + + Example: + ollama pull gemma3:12b + model = Gemma3Ollama(model="ollama/gemma3:12b") + """ + + def __init__(self, model: str = 'ollama/gemma3:12b', **kwargs): + super().__init__(model=model, **kwargs) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(model="{self.model}")' + + @classmethod + @override + def supported_models(cls) -> list[str]: + return [ + r'ollama/gemma3.*', + ] + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + """Sends a request to Gemma via Ollama/LiteLLM. + + Args: + llm_request: LlmRequest, the request to send. + stream: bool = False, whether to do streaming call. + + Yields: + LlmResponse: The model response. + """ + self._move_function_calls_into_system_instruction(llm_request) + + async for response in super().generate_content_async(llm_request, stream): + self._extract_function_calls_from_response(response) + yield response diff --git a/tests/unittests/models/test_gemma_llm.py b/tests/unittests/models/test_gemma_llm.py index 2cf98306..82e19b1a 100644 --- a/tests/unittests/models/test_gemma_llm.py +++ b/tests/unittests/models/test_gemma_llm.py @@ -504,3 +504,28 @@ def test_process_response_last_json_object(): assert part.function_call.name == "second_call" assert part.function_call.args == {"b": 2} assert part.text is None + + +# Tests for Gemma3Ollama (only run when LiteLLM is installed) +try: + from google.adk.models.gemma_llm import Gemma3Ollama + + def test_gemma3_ollama_supported_models(): + assert Gemma3Ollama.supported_models() == [r"ollama/gemma3.*"] + + @pytest.mark.parametrize( + "model_arg,expected_model", + [ + (None, "ollama/gemma3:12b"), + ("ollama/gemma3:27b", "ollama/gemma3:27b"), + ], + ) + def test_gemma3_ollama_model(model_arg, expected_model): + model = ( + Gemma3Ollama() if model_arg is None else Gemma3Ollama(model=model_arg) + ) + assert model.model == expected_model + +except ImportError: + # LiteLLM not installed, skip Gemma3Ollama tests + pass