diff --git a/.github/.release-please-manifest.json b/.github/.release-please-manifest.json index 661ffa45..f97891a6 100644 --- a/.github/.release-please-manifest.json +++ b/.github/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.25.1" + ".": "1.26.0" } diff --git a/.github/release-please-config.json b/.github/release-please-config.json index e7ecc230..053aab23 100644 --- a/.github/release-please-config.json +++ b/.github/release-please-config.json @@ -1,24 +1,60 @@ { "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", - "last-release-sha": "9f7d5b3f1476234e552b783415527cc4bac55b39", + "last-release-sha": "8f5428150d18ed732b66379c0acb806a9121c3cb", "packages": { ".": { "release-type": "python", + "versioning": "always-bump-minor", "package-name": "google-adk", "include-component-in-tag": false, "skip-github-release": true, "changelog-path": "CHANGELOG.md", "changelog-sections": [ - {"type": "feat", "section": "Features"}, - {"type": "fix", "section": "Bug Fixes"}, - {"type": "perf", "section": "Performance Improvements"}, - {"type": "refactor", "section": "Code Refactoring"}, - {"type": "docs", "section": "Documentation"}, - {"type": "test", "section": "Tests", "hidden": true}, - {"type": "build", "section": "Build System", "hidden": true}, - {"type": "ci", "section": "CI/CD", "hidden": true}, - {"type": "style", "section": "Styles", "hidden": true}, - {"type": "chore", "section": "Miscellaneous Chores", "hidden": true} + { + "type": "feat", + "section": "Features" + }, + { + "type": "fix", + "section": "Bug Fixes" + }, + { + "type": "perf", + "section": "Performance Improvements" + }, + { + "type": "refactor", + "section": "Code Refactoring" + }, + { + "type": "docs", + "section": "Documentation" + }, + { + "type": "test", + "section": "Tests", + "hidden": true + }, + { + "type": "build", + "section": "Build System", + "hidden": true + }, + { + "type": "ci", + "section": "CI/CD", + "hidden": true + }, + { + "type": "style", + "section": "Styles", + "hidden": true + }, + { + "type": "chore", + "section": "Miscellaneous Chores", + "hidden": true + } ] } } diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 7670733e..8e506d92 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -30,8 +30,8 @@ jobs: - name: Check for logger pattern in all changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -61,8 +61,8 @@ jobs: - name: Check for import pattern in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E '__init__.py$|version.py$|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -88,8 +88,8 @@ jobs: - name: Check for import from cli package in certain changed Python files run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' | grep -v -E 'cli/.*|src/google/adk/tools/apihub_tool/apihub_toolset.py|tests/.*|contributing/samples/' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files to check:" echo "$CHANGED_FILES" @@ -110,4 +110,4 @@ jobs: fi else echo "✅ No relevant Python files found." - fi \ No newline at end of file + fi diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 49536911..840d4ea8 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -42,8 +42,8 @@ jobs: - name: Run isort on changed files id: run_isort run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index f2626209..e893ce9e 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,10 +1,7 @@ name: Mypy Type Check on: - push: - branches: [ main ] - pull_request: - branches: [ main ] + workflow_dispatch: jobs: mypy: diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index d2eac1da..a2d9e6d7 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -42,8 +42,8 @@ jobs: - name: Run pyink on changed files id: run_pyink run: | - git fetch origin ${{ github.base_ref }} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true) + git fetch origin ${GITHUB_BASE_REF} + CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) if [ -n "$CHANGED_FILES" ]; then echo "Changed Python files:" echo "$CHANGED_FILES" diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 8689f1a1..866ba8b3 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -43,7 +43,7 @@ jobs: run: | uv venv .venv source .venv/bin/activate - uv sync --extra test --extra eval --extra a2a + uv sync --extra test - name: Run unit tests with pytest run: | diff --git a/.github/workflows/release-cherry-pick.yml b/.github/workflows/release-cherry-pick.yml index bf2247c4..ad324a08 100644 --- a/.github/workflows/release-cherry-pick.yml +++ b/.github/workflows/release-cherry-pick.yml @@ -1,5 +1,6 @@ # Step 3 (optional): Cherry-picks a commit from main to the release/candidate branch. # Use between step 1 and step 4 to include bug fixes in an in-progress release. +# Note: Does NOT auto-trigger release-please to preserve manual changelog edits. name: "Release: Cherry-pick" on: @@ -12,7 +13,6 @@ on: permissions: contents: write - actions: write jobs: cherry-pick: @@ -30,17 +30,14 @@ jobs: - name: Cherry-pick commit run: | - echo "Cherry-picking ${{ inputs.commit_sha }} to release/candidate" - git cherry-pick ${{ inputs.commit_sha }} + echo "Cherry-picking ${INPUTS_COMMIT_SHA} to release/candidate" + git cherry-pick ${INPUTS_COMMIT_SHA} + env: + INPUTS_COMMIT_SHA: ${{ inputs.commit_sha }} - name: Push changes run: | git push origin release/candidate echo "Successfully cherry-picked commit to release/candidate" - - - name: Trigger Release Please - env: - GH_TOKEN: ${{ github.token }} - run: | - gh workflow run release-please.yml --repo ${{ github.repository }} --ref release/candidate - echo "Triggered Release Please workflow" + echo "Note: Release Please is NOT auto-triggered to preserve manual changelog edits." + echo "Run release-please.yml manually if you want to regenerate the changelog." diff --git a/.github/workflows/release-finalize.yml b/.github/workflows/release-finalize.yml index ade58ec2..b9d6203f 100644 --- a/.github/workflows/release-finalize.yml +++ b/.github/workflows/release-finalize.yml @@ -68,9 +68,11 @@ jobs: - name: Rename release/candidate to release/v{version} if: steps.check.outputs.is_release_pr == 'true' run: | - VERSION="v${{ steps.version.outputs.version }}" + VERSION="v${STEPS_VERSION_OUTPUTS_VERSION}" git push origin "release/candidate:refs/heads/release/$VERSION" ":release/candidate" echo "Renamed release/candidate to release/$VERSION" + env: + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} - name: Update PR label to tagged if: steps.check.outputs.is_release_pr == 'true' diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 791d84a5..41d8d864 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -1,11 +1,10 @@ # Runs release-please to create/update a PR with version bump and changelog. -# Triggered automatically by step 1 (cut) or step 3 (cherry-pick). +# Triggered only by workflow_dispatch (from release-cut.yml). +# Does NOT auto-run on push to preserve manual changelog edits after cherry-picks. name: "Release: Please" on: - push: - branches: - - release/candidate + # Only run via workflow_dispatch (triggered by release-cut.yml) workflow_dispatch: permissions: @@ -14,8 +13,6 @@ permissions: jobs: release-please: - # Skip if this is a release-please PR merge (handled by Release: Finalize) - if: "!startsWith(github.event.head_commit.message, 'chore(release')" runs-on: ubuntu-latest steps: - name: Check if release/candidate still exists diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml index 5979cd9c..95ee326a 100644 --- a/.github/workflows/release-publish.yml +++ b/.github/workflows/release-publish.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Validate branch run: | - if [[ ! "${{ github.ref_name }}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + if [[ ! "${GITHUB_REF_NAME}" =~ ^release/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "Error: Must run from a release/v* branch (e.g., release/v0.3.0)" exit 1 fi @@ -23,7 +23,7 @@ jobs: - name: Extract version id: version run: | - VERSION="${{ github.ref_name }}" + VERSION="${GITHUB_REF_NAME}" VERSION="${VERSION#release/v}" echo "version=$VERSION" >> $GITHUB_OUTPUT echo "Publishing version: $VERSION" @@ -51,9 +51,10 @@ jobs: - name: Create merge-back PR env: GH_TOKEN: ${{ secrets.RELEASE_PAT }} + STEPS_VERSION_OUTPUTS_VERSION: ${{ steps.version.outputs.version }} run: | gh pr create \ --base main \ - --head "${{ github.ref_name }}" \ - --title "chore: merge release v${{ steps.version.outputs.version }} to main" \ - --body "Syncs version bump and CHANGELOG from release v${{ steps.version.outputs.version }} to main." + --head "${GITHUB_REF_NAME}" \ + --title "chore: merge release v${STEPS_VERSION_OUTPUTS_VERSION} to main" \ + --body "Syncs version bump and CHANGELOG from release v${STEPS_VERSION_OUTPUTS_VERSION} to main." diff --git a/CHANGELOG.md b/CHANGELOG.md index c4867801..92a8197b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,97 @@ # Changelog +## [1.26.0](https://github.com/google/adk-python/compare/v1.25.1...v1.26.0) (2026-02-26) + + +### Features + +* **[Core]** + * Add intra-invocation compaction and token compaction pre-request ([485fcb8](https://github.com/google/adk-python/commit/485fcb84e3ca351f83416c012edcafcec479c1db)) + * Use `--memory_service_uri` in ADK CLI run command ([a7b5097](https://github.com/google/adk-python/commit/a7b509763c1732f0363e90952bb4c2672572d542)) + +* **[Models]** + * Add `/chat/completions` integration to `ApigeeLlm` ([9c4c445](https://github.com/google/adk-python/commit/9c4c44536904f5cf3301a5abb910a5666344a8c5)) + * Add `/chat/completions` streaming support to Apigee LLM ([121d277](https://github.com/google/adk-python/commit/121d27741684685c564e484704ae949c5f0807b1)) + * Expand LiteLlm supported models and add registry tests ([d5332f4](https://github.com/google/adk-python/commit/d5332f44347f44d60360e14205a2342a0c990d66)) + +* **[Tools]** + * Add `load_skill_from_dir()` method ([9f7d5b3](https://github.com/google/adk-python/commit/9f7d5b3f1476234e552b783415527cc4bac55b39)) + * Agent Skills spec compliance — validation, aliases, scripts, and auto-injection ([223d9a7](https://github.com/google/adk-python/commit/223d9a7ff52d8da702f1f436bd22e94ad78bd5da)) + * BigQuery ADK support for search catalog tool ([bef3f11](https://github.com/google/adk-python/commit/bef3f117b4842ce62760328304484cd26a1ec30a)) + * Make skill instruction optimizable and can adapt to user tasks ([21be6ad](https://github.com/google/adk-python/commit/21be6adcb86722a585b26f600c45c85e593b4ee0)) + * Pass trace context in MCP tool call's `_meta` field with OpenTelemetry propagator ([bcbfeba](https://github.com/google/adk-python/commit/bcbfeba953d46fca731b11542a00103cef374e57)) + +* **[Evals]** + * Introduce User Personas to the ADK evaluation framework ([6a808c6](https://github.com/google/adk-python/commit/6a808c60b38ad7140ddeb222887c6accc63edce9)) + +* **[Services]** + * Add generate/create modes for Vertex AI Memory Bank writes ([811e50a](https://github.com/google/adk-python/commit/811e50a0cbb181d502b9837711431ef78fca3f34)) + * Add support for memory consolidation via Vertex AI Memory Bank ([4a88804](https://github.com/google/adk-python/commit/4a88804ec7d17fb4031b238c362f27d240df0a13)) + +* **[A2A]** + * Add interceptor framework to `A2aAgentExecutor` ([87fcd77](https://github.com/google/adk-python/commit/87fcd77caa9672f219c12e5a0e2ff65cbbaaf6f3)) + +* **[Auth]** + * Add native support for `id_token` in OAuth2 credentials ([33f7d11](https://github.com/google/adk-python/commit/33f7d118b377b60f998c92944d2673679fddbc6e)) + * Support ID token exchange in `ServiceAccountCredentialExchanger` ([7be90db](https://github.com/google/adk-python/commit/7be90db24b41f1830e39ca3d7e15bf4dbfa5a304)), closes [#4458](https://github.com/google/adk-python/issues/4458) + +* **[Integrations]** + * Agent Registry in ADK ([abaa929](https://github.com/google/adk-python/commit/abaa92944c4cd43d206e2986d405d4ee07d45afe)) + * Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin ([4260ef0](https://github.com/google/adk-python/commit/4260ef0c7c37ecdfea295fb0e1a933bb0df78bea)) + * Change default BigQuery table ID and update docstring ([7557a92](https://github.com/google/adk-python/commit/7557a929398ec2a1f946500d906cef5a4f86b5d1)) + * Update Agent Registry to create AgentCard from info in get agents endpoint ([c33d614](https://github.com/google/adk-python/commit/c33d614004a47d1a74951dd13628fd2300aeb9ef)) + +* **[Web]** + * Enable dependency injection for agent loader in FastAPI app gen ([34da2d5](https://github.com/google/adk-python/commit/34da2d5b26e82f96f1951334fe974a0444843720)) + + +### Bug Fixes + +* Add OpenAI strict JSON schema enforcement in LiteLLM ([2dbd1f2](https://github.com/google/adk-python/commit/2dbd1f25bdb1d88a6873d824b81b3dd5243332a4)), closes [#4573](https://github.com/google/adk-python/issues/4573) +* Add push notification config store to agent_to_a2a ([4ca904f](https://github.com/google/adk-python/commit/4ca904f11113c4faa3e17bb4a9662dca1f936e2e)), closes [#4126](https://github.com/google/adk-python/issues/4126) +* Add support for injecting a custom google.genai.Client into Gemini models ([48105b4](https://github.com/google/adk-python/commit/48105b49c5ab8e4719a66e7219f731b2cd293b00)), closes [#2560](https://github.com/google/adk-python/issues/2560) +* Add support for injecting a custom google.genai.Client into Gemini models ([c615757](https://github.com/google/adk-python/commit/c615757ba12093ba4a2ba19bee3f498fef91584c)), closes [#2560](https://github.com/google/adk-python/issues/2560) +* Check both `input_stream` parameter name and its annotation to decide whether it's a streaming tool that accept input stream ([d56cb41](https://github.com/google/adk-python/commit/d56cb4142c5040b6e7d13beb09123b8a59341384)) +* **deps:** Increase pydantic lower version to 2.7.0 ([dbd6420](https://github.com/google/adk-python/commit/dbd64207aebea8c5af19830a9a02d4c05d1d9469)) +* edit copybara and BUILD config for new adk/integrations folder (added with Agent Registry) ([37d52b4](https://github.com/google/adk-python/commit/37d52b4caf6738437e62fe804103efe4bde363a1)) +* Expand add_memory to accept MemoryEntry ([f27a9cf](https://github.com/google/adk-python/commit/f27a9cfb87caecb8d52967c50637ed5ad541cd07)) +* Fix pickling lock errors in McpSessionManager ([4e2d615](https://github.com/google/adk-python/commit/4e2d6159ae3552954aaae295fef3e09118502898)) +* fix typo in PlanReActPlanner instruction ([6d53d80](https://github.com/google/adk-python/commit/6d53d800d5f6dc5d4a3a75300e34d5a9b0f006f5)) +* handle UnicodeDecodeError when loading skills in ADK ([3fbc27f](https://github.com/google/adk-python/commit/3fbc27fa4ddb58b2b69ee1bea1e3a7b2514bd725)) +* Improve BigQuery Agent Analytics plugin reliability and code quality ([ea03487](https://github.com/google/adk-python/commit/ea034877ec15eef1be8f9a4be9fcd95446a3dc21)) +* Include list of skills in every message and remove list_skills tool from system instruction ([4285f85](https://github.com/google/adk-python/commit/4285f852d54670390b19302ed38306bccc0a7cee)) +* Invoke on_tool_error_callback for missing tools in live mode ([e6b601a](https://github.com/google/adk-python/commit/e6b601a2ab71b7e2df0240fd55550dca1eba8397)) +* Keep query params embedded in OpenAPI paths when using httpx ([ffbcc0a](https://github.com/google/adk-python/commit/ffbcc0a626deb24fe38eab402b3d6ace484115df)), closes [#4555](https://github.com/google/adk-python/issues/4555) +* Only relay the LiveRequest after tools is invoked ([b53bc55](https://github.com/google/adk-python/commit/b53bc555cceaa11dc53b42c9ca1d650592fb4365)) +* Parallelize tool resolution in LlmAgent.canonical_tools() ([7478bda](https://github.com/google/adk-python/commit/7478bdaa9817b0285b4119e8c739d7520373f719)) +* race condition in table creation for `DatabaseSessionService` ([fbe9ecc](https://github.com/google/adk-python/commit/fbe9eccd05e628daa67059ba2e6a0d03966b240d)) +* Re-export DEFAULT_SKILL_SYSTEM_INSTRUCTION to skills and skill/prompt.py to avoid breaking current users ([40ec134](https://github.com/google/adk-python/commit/40ec1343c2708e1cf0d39cd8b8a96f3729f843de)) +* Refactor LiteLLM streaming response parsing for compatibility with LiteLLM 1.81+ ([e8019b1](https://github.com/google/adk-python/commit/e8019b1b1b0b43dcc5fa23075942b31db502ffdd)), closes [#4225](https://github.com/google/adk-python/issues/4225) +* remove duplicate session GET when using API server, unbreak auto_session_create when using API server ([445dc18](https://github.com/google/adk-python/commit/445dc189e915ce5198e822ad7fadd6bb0880a95e)) +* Remove experimental decorators from user persona data models ([eccdf6d](https://github.com/google/adk-python/commit/eccdf6d01e70c37a1e5aa47c40d74469580365d2)) +* Replace the global DEFAULT_USER_PERSONA_REGISTRY with a function call to get_default_persona_registry ([2703613](https://github.com/google/adk-python/commit/2703613572a38bf4f9e25569be2ee678dc91b5b5)) +* **skill:** coloate default skill SI with skilltoolset ([fc1f1db](https://github.com/google/adk-python/commit/fc1f1db00562a79cd6c742cfd00f6267295c29a8)) +* Update agent_engine_sandbox_code_executor in ADK ([ee8d956](https://github.com/google/adk-python/commit/ee8d956413473d1bbbb025a470ad882c1487d8b8)) +* Update agent_engine_sandbox_code_executor in ADK ([dab80e4](https://github.com/google/adk-python/commit/dab80e4a8f3c5476f731335724bff5df3e6f3650)) +* Update sample skills agent to use weather-skill instead of weather_skill ([8f54281](https://github.com/google/adk-python/commit/8f5428150d18ed732b66379c0acb806a9121c3cb)) +* update Spanner query tools to async functions ([1dbcecc](https://github.com/google/adk-python/commit/1dbceccf36c28d693b0982b531a99877a3e75169)) +* use correct msg_out/msg_err keys for Agent Engine sandbox output ([b1e33a9](https://github.com/google/adk-python/commit/b1e33a90b4ba716d717e0488b84892b8a7f42aac)) +* Validate session before streaming instead of eagerly advancing the runner generator ([ab32f33](https://github.com/google/adk-python/commit/ab32f33e7418d452e65cf6f5b6cbfe1371600323)) +* **web:** allow session resume without new message ([30b2ed3](https://github.com/google/adk-python/commit/30b2ed3ef8ee6d3633743c0db00533683d3342d8)) + + +### Code Refactoring + +* Extract reusable function for building agent transfer instructions ([e1e0d63](https://github.com/google/adk-python/commit/e1e0d6361675e7b9a2c9b2523e3a72e2e5e7ce05)) +* Extract reusable private methods ([976a238](https://github.com/google/adk-python/commit/976a238544330528b4f9f4bea6c4e75ec13b33e1)) +* Extract reusable private methods ([42eeaef](https://github.com/google/adk-python/commit/42eeaef2b34c860f126c79c552435458614255ad)) +* Extract reusable private methods ([706f9fe](https://github.com/google/adk-python/commit/706f9fe74db0197e19790ca542d372ce46d0ae87)) + + +### Documentation + +* add `thinking_config` in `generate_content_config` in example agent ([c6b1c74](https://github.com/google/adk-python/commit/c6b1c74321faf62cc52d2518eb9ea0dcef050cde)) + ## [1.25.1](https://github.com/google/adk-python/compare/v1.25.0...v1.25.1) (2026-02-18) ### Bug Fixes diff --git a/contributing/samples/agent_engine_code_execution/README b/contributing/samples/agent_engine_code_execution/README index 8d5a4442..b0443ae2 100644 --- a/contributing/samples/agent_engine_code_execution/README +++ b/contributing/samples/agent_engine_code_execution/README @@ -7,9 +7,9 @@ This sample data science agent uses Agent Engine Code Execution Sandbox to execu ## How to use -* 1. Follow https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/code-execution/overview to create a code execution sandbox environment. +* 1. Follow https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create-an-agent-engine-instance to create an agent engine instance. Replace the AGENT_ENGINE_RESOURCE_NAME with the one you just created. A new sandbox environment under this agent engine instance will be created for each session with TTL of 1 year. But sandbox can only main its state for up to 14 days. This is the recommended usage for production environments. -* 2. Replace the SANDBOX_RESOURCE_NAME with the one you just created. If you dont want to create a new sandbox environment directly, the Agent Engine Code Execution Sandbox will create one for you by default using the AGENT_ENGINE_RESOURCE_NAME you specified, however, please ensure to clean up sandboxes after use; otherwise, it will consume quotas. +* 2. For testing or protyping purposes, create a sandbox environment by following this guide: https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create_a_sandbox. Replace the SANDBOX_RESOURCE_NAME with the one you just created. This will be used as the default sandbox environment for all the code executions throughout the lifetime of the agent. As the sandbox is re-used across sessions, all sessions will share the same Python environment and variable values." ## Sample prompt diff --git a/contributing/samples/agent_engine_code_execution/agent.py b/contributing/samples/agent_engine_code_execution/agent.py index d85989eb..a32e4ca4 100644 --- a/contributing/samples/agent_engine_code_execution/agent.py +++ b/contributing/samples/agent_engine_code_execution/agent.py @@ -85,11 +85,10 @@ When plotting trends, you should make sure to sort and order the data by the x-a """, code_executor=AgentEngineSandboxCodeExecutor( - # Replace with your sandbox resource name if you already have one. - sandbox_resource_name="SANDBOX_RESOURCE_NAME", + # Replace with your sandbox resource name if you already have one. Only use it for testing or prototyping purposes, because this will use the same sandbox for all requests. # "projects/vertex-agent-loadtest/locations/us-central1/reasoningEngines/6842889780301135872/sandboxEnvironments/6545148628569161728", - # Replace with agent engine resource name used for creating sandbox if - # sandbox_resource_name is not set. + sandbox_resource_name=None, + # Replace with agent engine resource name used for creating sandbox environment. agent_engine_resource_name="AGENT_ENGINE_RESOURCE_NAME", ), ) diff --git a/contributing/samples/agent_registry_agent/README.md b/contributing/samples/agent_registry_agent/README.md new file mode 100644 index 00000000..b9370b64 --- /dev/null +++ b/contributing/samples/agent_registry_agent/README.md @@ -0,0 +1,49 @@ +# Agent Registry Sample + +This sample demonstrates how to use the `AgentRegistry` client to discover agents and MCP servers registered in Google Cloud. + +## Setup + +1. Ensure you have Google Cloud credentials configured (e.g., `gcloud auth application-default login`). +2. Set the following environment variables: + +```bash +export GOOGLE_CLOUD_PROJECT=your-project-id +export GOOGLE_CLOUD_LOCATION=global # or your specific region +``` + +3. Obtain the full resource names for the agents and MCP servers you want to use. You can do this by running the sample script once to list them: + + ```bash + python3 agent.py + ``` + + Alternatively, use `gcloud` to list them: + + ```bash + # For agents + gcloud alpha agent-registry agents list --project=$GOOGLE_CLOUD_PROJECT --location=$GOOGLE_CLOUD_LOCATION + + # For MCP servers + gcloud alpha agent-registry mcp-servers list --project=$GOOGLE_CLOUD_PROJECT --location=$GOOGLE_CLOUD_LOCATION + ``` + +4. Replace `AGENT_NAME` and `MCP_SERVER_NAME` in `agent.py` with the last part of the resource names (e.g., if the name is `projects/.../agents/my-agent`, use `my-agent`). + +## Running the Sample + +Run the sample script to list available agents and MCP servers: + +```bash +python3 agent.py +``` + +## How it Works + +The sample uses `AgentRegistry` to: +- List registered agents using `list_agents()`. +- List registered MCP servers using `list_mcp_servers()`. + +It also shows (in comments) how to: +- Get a `RemoteA2aAgent` instance using `get_remote_a2a_agent(name)`. +- Get an `McpToolset` instance using `get_mcp_toolset(name)`. diff --git a/contributing/samples/agent_registry_agent/__init__.py b/contributing/samples/agent_registry_agent/__init__.py new file mode 100644 index 00000000..4015e47d --- /dev/null +++ b/contributing/samples/agent_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/agent_registry_agent/agent.py b/contributing/samples/agent_registry_agent/agent.py new file mode 100644 index 00000000..38036dea --- /dev/null +++ b/contributing/samples/agent_registry_agent/agent.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample agent demonstrating Agent Registry discovery.""" + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.integrations.agent_registry import AgentRegistry + +# Project and location can be set via environment variables: +# GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION +project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") +location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global") + +# Initialize Agent Registry client +registry = AgentRegistry(project_id=project_id, location=location) + +print(f"Listing agents in {project_id}/{location}...") +agents = registry.list_agents() +for agent in agents.get("agents", []): + print(f"- Agent: {agent.get('displayName')} ({agent.get('name')})") + +print(f"\nListing MCP servers in {project_id}/{location}...") +mcp_servers = registry.list_mcp_servers() +for server in mcp_servers.get("mcpServers", []): + print(f"- MCP Server: {server.get('displayName')} ({server.get('name')})") + +# Example of using a specific agent or MCP server from the registry: +# (Note: These names should be full resource names as returned by list methods) + +# 1. Using a Remote A2A Agent as a sub-agent +# TODO: Replace AGENT_NAME with your agent name +remote_agent = registry.get_remote_a2a_agent( + f"projects/{project_id}/locations/{location}/agents/AGENT_NAME" +) + +# 2. Using an MCP Server in a toolset +# TODO: Replace MCP_SERVER_NAME with your MCP server name +mcp_toolset = registry.get_mcp_toolset( + f"projects/{project_id}/locations/{location}/mcpServers/MCP_SERVER_NAME" +) + +root_agent = LlmAgent( + model="gemini-2.5-flash", + name="discovery_agent", + instruction=( + "You have access to tools and sub-agents discovered via Registry." + ), + tools=[mcp_toolset], + sub_agents=[remote_agent], +) diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py index 9f55ef80..87faea31 100644 --- a/contributing/samples/api_registry_agent/agent.py +++ b/contributing/samples/api_registry_agent/agent.py @@ -15,7 +15,7 @@ import os from google.adk.agents.llm_agent import LlmAgent -from google.adk.tools.api_registry import ApiRegistry +from google.adk.integrations.api_registry import ApiRegistry # TODO: Fill in with your GCloud project id and MCP server name PROJECT_ID = "your-google-cloud-project-id" diff --git a/contributing/samples/authn-adk-all-in-one/requirements.txt b/contributing/samples/authn-adk-all-in-one/requirements.txt index 6cd3c4bb..777d8d52 100644 --- a/contributing/samples/authn-adk-all-in-one/requirements.txt +++ b/contributing/samples/authn-adk-all-in-one/requirements.txt @@ -1,5 +1,5 @@ google-adk==1.12 -Flask==3.1.1 +Flask==3.1.3 flask-cors==6.0.1 python-dotenv==1.1.1 PyJWT[crypto]==2.10.1 diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 3ed97432..99481390 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. +11. `search_catalog` + Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. + ## How to use Set up environment variables in your `.env` file for using @@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf. * which tables exist in the ml_datasets dataset? * show more details about the penguins table * compute penguins population per island. +* are there any tables related to animals in project ? \ No newline at end of file diff --git a/contributing/samples/bigtable/agent.py b/contributing/samples/bigtable/agent.py index d35f51c1..1d52e1fe 100644 --- a/contributing/samples/bigtable/agent.py +++ b/contributing/samples/bigtable/agent.py @@ -16,14 +16,17 @@ import os from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.tools.bigtable import query_tool as bigtable_query_tool from google.adk.tools.bigtable.bigtable_credentials import BigtableCredentialsConfig from google.adk.tools.bigtable.bigtable_toolset import BigtableToolset from google.adk.tools.bigtable.settings import BigtableToolSettings +from google.adk.tools.google_tool import GoogleTool import google.auth +from google.cloud.bigtable.data.execute_query.metadata import SqlType -# Define an appropriate credential type -CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 - +# Define an appropriate credential type. +# None for Application Default Credentials +CREDENTIALS_TYPE = None # Define Bigtable tool config with read capability set to allowed. tool_settings = BigtableToolSettings() @@ -59,6 +62,53 @@ bigtable_toolset = BigtableToolset( credentials_config=credentials_config, bigtable_tool_settings=tool_settings ) +_BIGTABLE_PROJECT_ID = "" +_BIGTABLE_INSTANCE_ID = "" + + +def search_hotels_by_location( + location_name: str, + credentials: google.auth.credentials.Credentials, + settings: BigtableToolSettings, + tool_context: google.adk.tools.tool_context.ToolContext, +): + """Search hotels by location name. + + This function takes a location name and returns a list of hotels + in that area. + + Args: + location_name (str): The geographical location (e.g., city or town) for the + hotel search. + Example: { "location_name": "Basel" } + + Returns: + The hotels name, price tier. + """ + + sql_template = """ + SELECT + TO_INT64(cf['id']) as id, + CAST(cf['name'] AS STRING) AS name, + CAST(cf['location'] AS STRING) AS location, + CAST(cf['price_tier'] AS STRING) AS price_tier, + CAST(cf['checkin_date'] AS STRING) AS checkin_date, + CAST(cf['checkout_date'] AS STRING) AS checkout_date + FROM hotels + WHERE LOWER(CAST(cf['location'] AS STRING)) LIKE LOWER(CONCAT('%', @location_name, '%')) + """ + return bigtable_query_tool.execute_sql( + project_id=_BIGTABLE_PROJECT_ID, + instance_id=_BIGTABLE_INSTANCE_ID, + query=sql_template, + credentials=credentials, + settings=settings, + tool_context=tool_context, + parameters={"location": location_name}, + parameter_types={"location": SqlType.String()}, + ) + + # The variable name `root_agent` determines what your root agent is for the # debug CLI root_agent = LlmAgent( @@ -72,5 +122,13 @@ root_agent = LlmAgent( You are a data agent with access to several Bigtable tools. Make use of those tools to answer the user's questions. """, - tools=[bigtable_toolset], + tools=[ + bigtable_toolset, + # Or, uncomment to use customized Bigtable tools. + # GoogleTool( + # func=search_hotels_by_location, + # credentials_config=credentials_config, + # tool_settings=tool_settings, + # ), + ], ) diff --git a/contributing/samples/fields_output_schema/agent.py b/contributing/samples/fields_output_schema/agent.py index de40774d..f948668a 100644 --- a/contributing/samples/fields_output_schema/agent.py +++ b/contributing/samples/fields_output_schema/agent.py @@ -22,9 +22,20 @@ class WeatherData(BaseModel): wind_speed: str +def get_current_year() -> str: + """Get the current year. + + Returns: + The current year as a string + """ + from datetime import datetime + + return str(datetime.now().year) + + root_agent = Agent( name='root_agent', - model='gemini-2.0-flash', + model='gemini-2.5-flash', instruction="""\ Answer user's questions based on the data you have. @@ -43,6 +54,7 @@ Here are the data you have for Cupertino * wind_speed: 13 mph """, - output_schema=WeatherData, + output_schema=list[WeatherData], output_key='weather_data', + tools=[get_current_year], ) diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 39eec53c..6d5355db 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -17,9 +17,47 @@ import pathlib from google.adk import Agent +from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor from google.adk.skills import load_skill_from_dir from google.adk.skills import models -from google.adk.tools import skill_toolset +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.skill_toolset import SkillToolset +from google.genai import types + + +class GetTimezoneTool(BaseTool): + """A tool to get the timezone for a given location.""" + + def __init__(self): + super().__init__( + name="get_timezone", + description="Returns the timezone for a given location.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the timezone for.", + }, + }, + "required": ["location"], + }, + ) + + async def run_async(self, *, args: dict, tool_context) -> str: + return f"The timezone for {args['location']} is UTC+00:00." + + +def get_current_humidity(location: str) -> str: + """Returns the current humidity for a given location.""" + return f"The humidity in {location} is 45%." + greeting_skill = models.Skill( frontmatter=models.Frontmatter( @@ -27,6 +65,7 @@ greeting_skill = models.Skill( description=( "A friendly greeting skill that can say hello to a specific person." ), + metadata={"adk_additional_tools": ["get_timezone"]}, ), instructions=( "Step 1: Read the 'references/hello_world.txt' file to understand how" @@ -41,18 +80,21 @@ greeting_skill = models.Skill( ) weather_skill = load_skill_from_dir( - pathlib.Path(__file__).parent / "skills" / "weather_skill" + pathlib.Path(__file__).parent / "skills" / "weather-skill" ) -my_skill_toolset = skill_toolset.SkillToolset( - skills=[greeting_skill, weather_skill] +# WARNING: UnsafeLocalCodeExecutor has security concerns and should NOT +# be used in production environments. +my_skill_toolset = SkillToolset( + skills=[greeting_skill, weather_skill], + additional_tools=[GetTimezoneTool(), get_current_humidity], + code_executor=UnsafeLocalCodeExecutor(), ) root_agent = Agent( model="gemini-2.5-flash", name="skill_user_agent", description="An agent that can use specialized skills.", - instruction=skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION, tools=[ my_skill_toolset, ], diff --git a/contributing/samples/skills_agent/skills/weather-skill/SKILL.md b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md new file mode 100644 index 00000000..6893ef67 --- /dev/null +++ b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md @@ -0,0 +1,8 @@ +--- +name: weather-skill +description: A skill that provides weather information based on reference data. +--- + +Step 1: Check 'references/weather_info.md' for the current weather. +Step 2: If humidity is requested, use run 'scripts/get_humidity.py' with the `location` argument. +Step 3: Provide the update to the user. diff --git a/contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md b/contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md similarity index 100% rename from contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md rename to contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md diff --git a/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py b/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py new file mode 100644 index 00000000..a2e1dc47 --- /dev/null +++ b/contributing/samples/skills_agent/skills/weather-skill/scripts/get_humidity.py @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def get_humidity(location: str) -> str: + """Fetch live humidity for a given location. (Simulated)""" + print(f"Fetching live humidity for {location}...") + return "45% (Simulated)" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--location", type=str, default="Mountain View") + args = parser.parse_args() + + print(get_humidity(args.location)) diff --git a/contributing/samples/skills_agent/skills/weather_skill/SKILL.md b/contributing/samples/skills_agent/skills/weather_skill/SKILL.md index 67d87105..ea79220a 100644 --- a/contributing/samples/skills_agent/skills/weather_skill/SKILL.md +++ b/contributing/samples/skills_agent/skills/weather_skill/SKILL.md @@ -1,6 +1,9 @@ --- name: weather-skill description: A skill that provides weather information based on reference data. +metadata: + adk_additional_tools: + - get_current_humidity --- Step 1: Check 'references/weather_info.md' for the current weather. diff --git a/pyproject.toml b/pyproject.toml index 9bec96cb..83b1a3f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery-storage>=2.0.0", "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database + "google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool @@ -48,15 +49,15 @@ dependencies = [ "httpx>=0.27.0, <1.0.0", # HTTP client library "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation "mcp>=1.23.0, <2.0.0", # For MCP Toolset - "opentelemetry-api>=1.36.0, <1.40.0", # OpenTelemetry - keep below 1.40.0 to reduce risk of breaking changes around log-signal APIs. + "opentelemetry-api>=1.36.0, <1.39.0", # OpenTelemetry - keep below 1.39.0 due to current agent_engines exporter constraints. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-trace>=1.9.0, <2.0.0", "opentelemetry-exporter-otlp-proto-http>=1.36.0", "opentelemetry-resourcedetector-gcp>=1.9.0a0, <2.0.0", - "opentelemetry-sdk>=1.36.0, <1.40.0", + "opentelemetry-sdk>=1.36.0, <1.39.0", "pyarrow>=14.0.0", - "pydantic>=2.7.0, <3.0.0", # For data validation/models + "pydantic>=2.12.0, <3.0.0", # For data validation/models "python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service "python-dotenv>=1.0.0, <2.0.0", # To manage environment variables "requests>=2.32.4, <3.0.0", @@ -109,6 +110,7 @@ community = [ eval = [ # go/keep-sorted start "Jinja2>=3.1.4,<4.0.0", # For eval template rendering + "gepa>=0.1.0", "google-cloud-aiplatform[evaluation]>=1.100.0", "pandas>=2.2.3", "rouge-score>=0.1.2", @@ -124,7 +126,7 @@ test = [ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.80.17", # For LiteLLM tests + "litellm>=1.75.5, <2.0.0", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM "opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0", @@ -155,8 +157,9 @@ extensions = [ "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor + "k8s-agent-sandbox>=0.1.1.post2", # For GkeCodeExecutor sandbox mode "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. "lxml>=5.3.0", # For load_web_page tool. diff --git a/src/google/adk/a2a/agent/__init__.py b/src/google/adk/a2a/agent/__init__.py new file mode 100644 index 00000000..8026986e --- /dev/null +++ b/src/google/adk/a2a/agent/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A2A agents package.""" + +from .config import A2aRemoteAgentConfig +from .config import ParametersConfig +from .config import RequestInterceptor + +__all__ = [ + "A2aRemoteAgentConfig", + "ParametersConfig", + "RequestInterceptor", +] diff --git a/src/google/adk/a2a/agent/config.py b/src/google/adk/a2a/agent/config.py new file mode 100644 index 00000000..98984362 --- /dev/null +++ b/src/google/adk/a2a/agent/config.py @@ -0,0 +1,110 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for A2A agents.""" + +from __future__ import annotations + +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.client.middleware import ClientCallContext +from a2a.server.events import Event as A2AEvent +from a2a.types import Message as A2AMessage +from pydantic import BaseModel + +from ...a2a.converters.part_converter import A2APartToGenAIPartConverter +from ...a2a.converters.part_converter import convert_a2a_part_to_genai_part +from ...a2a.converters.to_adk_event import A2AArtifactUpdateToEventConverter +from ...a2a.converters.to_adk_event import A2AMessageToEventConverter +from ...a2a.converters.to_adk_event import A2AStatusUpdateToEventConverter +from ...a2a.converters.to_adk_event import A2ATaskToEventConverter +from ...a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from ...a2a.converters.to_adk_event import convert_a2a_message_to_event +from ...a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from ...a2a.converters.to_adk_event import convert_a2a_task_to_event +from ...agents.invocation_context import InvocationContext +from ...events.event import Event + + +class ParametersConfig(BaseModel): + """Configuration for the parameters passed to the A2A send_message request.""" + + request_metadata: Optional[dict[str, Any]] = None + client_call_context: Optional[ClientCallContext] = None + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. + # + # requested_extension: Optional[list[str]] = None + # message_send_configuration: Optional[MessageSendConfiguration] = None + + +class RequestInterceptor(BaseModel): + """Interceptor for A2A requests.""" + + before_request: Optional[ + Callable[ + [InvocationContext, A2AMessage, ParametersConfig], + Awaitable[tuple[Union[A2AMessage, Event], ParametersConfig]], + ] + ] = None + """Hook executed before the agent starts processing the request. + + Returns an Event if the request should be aborted and the Event + returned to the caller. + """ + + after_request: Optional[ + Callable[ + [InvocationContext, A2AEvent, Event], Awaitable[Union[Event, None]] + ] + ] = None + """Hook executed after the agent has processed the request. + + Returns None if the event should not be sent to the caller. + """ + + +class A2aRemoteAgentConfig(BaseModel): + """Configuration for A2A remote agents.""" + + # Converts standard A2A Messages into ADK Event. + a2a_message_converter: A2AMessageToEventConverter = ( + convert_a2a_message_to_event + ) + + # Converts an A2A Task into an ADK Event. + a2a_task_converter: A2ATaskToEventConverter = convert_a2a_task_to_event + + # Converts A2A TaskStatusUpdateEvents into ADK Event. + a2a_status_update_converter: A2AStatusUpdateToEventConverter = ( + convert_a2a_status_update_to_event + ) + + # Converts A2A TaskArtifactUpdateEvents into ADK Event. + a2a_artifact_update_converter: A2AArtifactUpdateToEventConverter = ( + convert_a2a_artifact_update_to_event + ) + + # A low-level hook that converts individual A2A Message Parts + # into native ADK/GenAI Part objects. + # This is utilized internally by the other converters. + a2a_part_converter: A2APartToGenAIPartConverter = ( + convert_a2a_part_to_genai_part + ) + + request_interceptors: Optional[list[RequestInterceptor]] = None diff --git a/src/google/adk/a2a/agent/utils.py b/src/google/adk/a2a/agent/utils.py new file mode 100644 index 00000000..7cbb25eb --- /dev/null +++ b/src/google/adk/a2a/agent/utils.py @@ -0,0 +1,70 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for A2A agents.""" + +from __future__ import annotations + +from typing import Optional +from typing import Union + +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.middleware import ClientCallContext +from a2a.types import Message as A2AMessage + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from .config import ParametersConfig +from .config import RequestInterceptor + + +async def execute_before_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_request: A2AMessage, +) -> tuple[Union[A2AMessage, Event], ParametersConfig]: + """Executes registered before_request interceptors.""" + + params = ParametersConfig( + client_call_context=ClientCallContext(state=ctx.session.state) + ) + if request_interceptors: + for interceptor in request_interceptors: + if not interceptor.before_request: + continue + + result, params = await interceptor.before_request( + ctx, a2a_request, params + ) + if isinstance(result, Event): + return result, params + a2a_request = result + + return a2a_request, params + + +async def execute_after_request_interceptors( + request_interceptors: Optional[list[RequestInterceptor]], + ctx: InvocationContext, + a2a_response: A2AMessage | A2AClientEvent, + event: Event, +) -> Optional[Event]: + """Executes registered after_request interceptors.""" + if request_interceptors: + for interceptor in reversed(request_interceptors): + if interceptor.after_request: + event = await interceptor.after_request(ctx, a2a_response, event) + if not event: + return None + return event diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 59bbefa1..a2a0ee75 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -370,7 +370,7 @@ def convert_a2a_message_to_event( @a2a_experimental def convert_event_to_a2a_message( event: Event, - invocation_context: InvocationContext, + invocation_context: InvocationContext | None = None, role: Role = Role.agent, part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, ) -> Optional[Message]: @@ -390,8 +390,6 @@ def convert_event_to_a2a_message( """ if not event: raise ValueError("Event cannot be None") - if not invocation_context: - raise ValueError("Invocation context cannot be None") if not event.content or not event.content.parts: return None diff --git a/src/google/adk/a2a/converters/from_adk_event.py b/src/google/adk/a2a/converters/from_adk_event.py new file mode 100644 index 00000000..05bf16d1 --- /dev/null +++ b/src/google/adk/a2a/converters/from_adk_event.py @@ -0,0 +1,288 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime +from datetime import timezone +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +import uuid + +from a2a.server.events import Event as A2AEvent +from a2a.types import Artifact +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from ..experimental import a2a_experimental +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import convert_genai_part_to_a2a_part +from .part_converter import GenAIPartToA2APartConverter +from .utils import _get_adk_metadata_key + +# Constants +DEFAULT_ERROR_MESSAGE = "An error occurred during processing" + +# Logger +logger = logging.getLogger("google_adk." + __name__) + +A2AUpdateEvent = Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent] + +AdkEventToA2AEventsConverter = Callable[ + [ + Event, + Optional[Dict[str, str]], + Optional[str], + Optional[str], + GenAIPartToA2APartConverter, + ], + List[A2AUpdateEvent], +] +"""A callable that converts an ADK Event into a list of A2A events. + +This interface allows for custom logic to map ADK's event structure to the +event structure expected by the A2A server. + +Args: + event: The source ADK Event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: The ID of the A2A task being processed. + context_id: The context ID from the A2A request. + part_converter: A function to convert GenAI content parts to A2A + parts. + +Returns: + A list of A2A events. +""" + + +def _convert_adk_parts_to_a2a_parts( + event: Event, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> Optional[List[A2APart]]: + """Converts an ADK event to an A2A parts list. + + Args: + event: The ADK event to convert. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A parts representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + + if not event.content or not event.content.parts: + return [] + + try: + output_parts = [] + for part in event.content.parts: + a2a_parts = part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + output_parts.append(a2a_part) + + return output_parts + + except Exception as e: + logger.error("Failed to convert event to status message: %s", e) + raise + + +def create_error_status_event( + event: Event, + task_id: Optional[str] = None, + context_id: Optional[str] = None, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent for error scenarios. + + Args: + event: The ADK event containing error information. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + + Returns: + A TaskStatusUpdateEvent with FAILED state. + """ + error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + + error_event = TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.failed, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[A2APart(root=TextPart(text=error_message))], + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + return _add_event_metadata(event, [error_event])[0] + + +@a2a_experimental +def convert_event_to_a2a_events( + event: Event, + agents_artifacts: Dict[str, str], + task_id: Optional[str] = None, + context_id: Optional[str] = None, + part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, +) -> List[A2AUpdateEvent]: + """Converts a GenAI event to a list of A2A StatusUpdate and ArtifactUpdate events. + + Args: + event: The ADK event to convert. + agents_artifacts: State map for tracking active artifact IDs across chunks. + task_id: Optional task ID to use for generated events. + context_id: Optional Context ID to use for generated events. + part_converter: The function to convert GenAI part to A2A part. + + Returns: + A list of A2A update events representing the converted ADK event. + + Raises: + ValueError: If required parameters are invalid. + """ + if not event: + raise ValueError("Event cannot be None") + if agents_artifacts is None: + raise ValueError("Agents artifacts cannot be None") + + a2a_events = [] + try: + a2a_parts = _convert_adk_parts_to_a2a_parts( + event, part_converter=part_converter + ) + # Handle artifact updates for normal parts + if a2a_parts: + agent_name = event.author + partial = event.partial or False + + artifact_id = agents_artifacts.get(agent_name) + if artifact_id: + append = partial + if not partial: + del agents_artifacts[agent_name] + else: + artifact_id = str(uuid.uuid4()) + # TODO: Clarify if new artifact id must have append=False + append = False + if partial: + agents_artifacts[agent_name] = artifact_id + + a2a_events.append( + TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + last_chunk=not partial, + append=append, + artifact=Artifact( + artifact_id=artifact_id, + parts=a2a_parts, + ), + ) + ) + + a2a_events = _add_event_metadata(event, a2a_events) + return a2a_events + + except Exception as e: + logger.error("Failed to convert event to A2A events: %s", e) + raise + + +def _serialize_value(value: Any) -> Optional[Any]: + """Serializes a value and returns it if it contains meaningful content. + + Returns None if the value is empty or missing. + """ + if value is None: + return None + + # Handle Pydantic models + if hasattr(value, "model_dump"): + try: + dumped = value.model_dump( + exclude_none=True, + exclude_unset=True, + exclude_defaults=True, + by_alias=True, + ) + return dumped if dumped else None + except Exception as e: + logger.warning("Failed to serialize Pydantic model, falling back: %s", e) + return str(value) + + return str(value) + + +# TODO: Clarify if this metadata needs to be translated back into the ADK event +def _add_event_metadata( + event: Event, a2a_events: List[A2AEvent] +) -> List[A2AEvent]: + """Gets the context metadata for the event and applies it to A2A events.""" + if not event: + raise ValueError("Event cannot be None") + + metadata_values = { + "invocation_id": event.invocation_id, + "author": event.author, + "event_id": event.id, + "branch": event.branch, + "citation_metadata": event.citation_metadata, + "grounding_metadata": event.grounding_metadata, + "custom_metadata": event.custom_metadata, + "usage_metadata": event.usage_metadata, + "error_code": event.error_code, + "actions": event.actions, + } + + metadata = {} + for field_name, field_value in metadata_values.items(): + value = _serialize_value(field_value) + if value is not None: + metadata[_get_adk_metadata_key(field_name)] = value + + for a2a_event in a2a_events: + if isinstance(a2a_event, TaskStatusUpdateEvent): + a2a_event.status.message.metadata = metadata.copy() + elif isinstance(a2a_event, TaskArtifactUpdateEvent): + a2a_event.artifact.metadata = metadata.copy() + + return a2a_events diff --git a/src/google/adk/a2a/converters/long_running_functions.py b/src/google/adk/a2a/converters/long_running_functions.py new file mode 100644 index 00000000..0bbb46da --- /dev/null +++ b/src/google/adk/a2a/converters/long_running_functions.py @@ -0,0 +1,215 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +from typing import List +from typing import Set +import uuid + +from a2a.server.agent_execution.context import RequestContext +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.genai import types as genai_types + +from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from .part_converter import A2APartToGenAIPartConverter +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _get_adk_metadata_key + + +class LongRunningFunctions: + """Keeps track of long running function calls and related responses.""" + + def __init__( + self, part_converter: A2APartToGenAIPartConverter | None = None + ) -> None: + self._parts: List[genai_types.Part] = [] + self._long_running_tool_ids: Set[str] = set() + self._part_converter = part_converter or convert_a2a_part_to_genai_part + self._task_state: TaskState = TaskState.input_required + + def has_long_running_function_calls(self) -> bool: + """Returns True if there are long running function calls.""" + return bool(self._long_running_tool_ids) + + def process_event(self, event: Event) -> Event: + """Processes parts to extract long running calls and responses. + + Returns a copy of the input event with processed parts removed from + event.content.parts. + + Args: + event: The ADK event containing long running tool IDs and content parts. + """ + event = event.model_copy(deep=True) + if not event.content or not event.content.parts: + return event + + kept_parts = [] + for part in event.content.parts: + should_remove = False + if part.function_call: + if part.function_call.id in event.long_running_tool_ids: + if not event.partial: + self._parts.append(part) + self._long_running_tool_ids.add(part.function_call.id) + should_remove = True + + elif part.function_response: + if part.function_response.id in self._long_running_tool_ids: + if not event.partial: + self._parts.append(part) + should_remove = True + + if not should_remove: + kept_parts.append(part) + + event.content.parts = kept_parts + return event + + def create_long_running_function_call_event( + self, + task_id: str, + context_id: str, + ) -> TaskStatusUpdateEvent: + """Creates a task status update event for the long running function calls.""" + if not self._long_running_tool_ids: + return None + + a2a_parts = self._return_long_running_parts() + if not a2a_parts: + return None + + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=self._task_state, + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=a2a_parts, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + ) + + def _return_long_running_parts(self) -> List[A2APart]: + """Converts long-running parts to A2A parts.""" + if not self._long_running_tool_ids: + return [] + + output_parts = [] + for part in self._parts: + a2a_parts = self._part_converter(part) + if not isinstance(a2a_parts, list): + a2a_parts = [a2a_parts] if a2a_parts else [] + for a2a_part in a2a_parts: + self._mark_long_running_function_call(a2a_part) + output_parts.append(a2a_part) + + return output_parts + + def _mark_long_running_function_call(self, a2a_part: A2APart) -> None: + """Processes long-running tool metadata for an A2A part. + + Args: + a2a_part: The A2A part to potentially mark as long-running. + """ + + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ): + a2a_part.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ] = True + # If the function is a request for EUC, set the task state to + # auth_required. Otherwise, set it to input_required. Save the state of + # the last function call, as it will be the state of the task. + if a2a_part.root.metadata.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME: + self._task_state = TaskState.auth_required + else: + self._task_state = TaskState.input_required + + +def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: + """Processes user input events, validating function responses.""" + + if ( + not context.current_task + or not context.current_task.status + or ( + context.current_task.status.state != TaskState.input_required + and context.current_task.status.state != TaskState.auth_required + ) + ): + return None + + # If the task is in input_required or auth_required state, we expect the user + # to provide a response for the function call. Check if the user input + # contains a function response. + for a2a_part in context.message.parts: + if ( + isinstance(a2a_part.root, DataPart) + and a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + ): + return None + + return TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=context.current_task.status.state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[ + A2APart( + root=TextPart( + text=( + "It was not provided a function response for the" + " function call." + ) + ) + ) + ], + ), + ), + final=True, + ) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index 72dbcb21..ef4a94fd 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -61,13 +61,18 @@ def convert_a2a_part_to_genai_part( """Convert an A2A Part to a Google GenAI Part.""" part = a2a_part.root if isinstance(part, a2a_types.TextPart): - return genai_types.Part(text=part.text) + thought = None + if part.metadata: + thought = part.metadata.get(_get_adk_metadata_key('thought')) + return genai_types.Part(text=part.text, thought=thought) if isinstance(part, a2a_types.FilePart): if isinstance(part.file, a2a_types.FileWithUri): return genai_types.Part( file_data=genai_types.FileData( - file_uri=part.file.uri, mime_type=part.file.mime_type + file_uri=part.file.uri, + mime_type=part.file.mime_type, + display_name=part.file.name, ) ) @@ -76,6 +81,7 @@ def convert_a2a_part_to_genai_part( inline_data=genai_types.Blob( data=base64.b64decode(part.file.bytes), mime_type=part.file.mime_type, + display_name=part.file.name, ) ) else: @@ -101,10 +107,25 @@ def convert_a2a_part_to_genai_part( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ): + # Restore thought_signature if present + thought_signature = None + thought_sig_key = _get_adk_metadata_key('thought_signature') + if thought_sig_key in part.metadata: + sig_value = part.metadata[thought_sig_key] + if isinstance(sig_value, bytes): + thought_signature = sig_value + elif isinstance(sig_value, str): + try: + thought_signature = base64.b64decode(sig_value) + except Exception: + logger.warning( + 'Failed to decode thought_signature: %s', sig_value + ) return genai_types.Part( function_call=genai_types.FunctionCall.model_validate( part.data, by_alias=True - ) + ), + thought_signature=thought_signature, ) if ( part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] @@ -170,6 +191,7 @@ def convert_genai_part_to_a2a_part( file=a2a_types.FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, + name=part.file_data.display_name, ) ) ) @@ -193,6 +215,7 @@ def convert_genai_part_to_a2a_part( file=a2a_types.FileWithBytes( bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), mime_type=part.inline_data.mime_type, + name=part.inline_data.display_name, ) ) @@ -211,16 +234,22 @@ def convert_genai_part_to_a2a_part( # TODO once A2A defined how to service such information, migrate below # logic accordingly if part.function_call: + fc_metadata = { + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + } + # Preserve thought_signature if present + if part.thought_signature is not None: + fc_metadata[_get_adk_metadata_key('thought_signature')] = ( + base64.b64encode(part.thought_signature).decode('utf-8') + ) return a2a_types.Part( root=a2a_types.DataPart( data=part.function_call.model_dump( by_alias=True, exclude_none=True ), - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - }, + metadata=fc_metadata, ) ) diff --git a/src/google/adk/a2a/converters/to_adk_event.py b/src/google/adk/a2a/converters/to_adk_event.py new file mode 100644 index 00000000..66d7768e --- /dev/null +++ b/src/google/adk/a2a/converters/to_adk_event.py @@ -0,0 +1,374 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +import logging +from typing import Any +from typing import List +from typing import Optional +import uuid + +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Task +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from google.genai import types as genai_types + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from ..experimental import a2a_experimental +from .part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from .part_converter import A2APartToGenAIPartConverter +from .part_converter import convert_a2a_part_to_genai_part +from .utils import _get_adk_metadata_key + +# Logger +logger = logging.getLogger("google_adk." + __name__) + +A2AMessageToEventConverter = Callable[ + [ + Message, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A Message to an ADK Event. + +Args: + Message: The A2A message to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2ATaskToEventConverter = Callable[ + [ + Task, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A Task to an ADK Event. + +Args: + Task: The A2A task to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2AStatusUpdateToEventConverter = Callable[ + [ + TaskStatusUpdateEvent, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A TaskStatusUpdateEvent to an ADK Event. + +Args: + TaskStatusUpdateEvent: The A2A status update event to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + +A2AArtifactUpdateToEventConverter = Callable[ + [ + TaskArtifactUpdateEvent, + Optional[str], + Optional[InvocationContext], + A2APartToGenAIPartConverter, + ], + Optional[Event], +] +"""A Callable that converts an A2A TaskArtifactUpdateEvent to an ADK Event. + +Args: + TaskArtifactUpdateEvent: The A2A artifact update event to convert. + Optional[str]: The author of the event. + Optional[InvocationContext]: The invocation context. + A2APartToGenAIPartConverter: The part converter function. + +Returns: + Optional[Event]: The converted ADK Event. +""" + + +def _convert_a2a_parts_to_adk_parts( + a2a_parts: List[A2APart], + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> tuple[List[genai_types.Part], set[str]]: + """Converts a list of A2A parts to a list of ADK parts.""" + output_parts = [] + long_running_function_ids = set() + + for a2a_part in a2a_parts: + try: + parts = part_converter(a2a_part) + if not isinstance(parts, list): + parts = [parts] if parts else [] + if not parts: + logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) + continue + + # Check for long-running functions + if ( + a2a_part.root.metadata + and a2a_part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + ) + is True + ): + for part in parts: + if part.function_call: + long_running_function_ids.add(part.function_call.id) + + output_parts.extend(parts) + + except Exception as e: + logger.error("Failed to convert A2A part: %s, error: %s", a2a_part, e) + # Continue processing other parts instead of failing completely + continue + + if not output_parts: + logger.warning("No parts could be converted from A2A message") + + return output_parts, long_running_function_ids + + +def _create_event( + output_parts: List[genai_types.Part], + invocation_context: Optional[InvocationContext], + author: Optional[str], + long_running_function_ids: Optional[set[str]] = None, + partial: bool = False, +) -> Optional[Event]: + """Creates an ADK event from parts and metadata.""" + if not output_parts: + return None + + event = Event( + invocation_id=( + invocation_context.invocation_id + if invocation_context + else str(uuid.uuid4()) + ), + author=author or "a2a agent", + branch=invocation_context.branch if invocation_context else None, + long_running_tool_ids=( + long_running_function_ids if long_running_function_ids else None + ), + content=genai_types.Content( + role="model", + parts=output_parts, + ), + partial=partial, + ) + + return event + + +@a2a_experimental +def convert_a2a_task_to_event( + a2a_task: Task, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task to an ADK event. + + Args: + a2a_task: The A2A task to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted task. + + Raises: + ValueError: If a2a_task is None. + RuntimeError: If conversion of the underlying message fails. + """ + if a2a_task is None: + raise ValueError("A2A task cannot be None") + + try: + output_parts = [] + long_running_function_ids = set() + if a2a_task.artifacts: + artifact_parts = [ + part for artifact in a2a_task.artifacts for part in artifact.parts + ] + output_parts, _ = _convert_a2a_parts_to_adk_parts( + artifact_parts, part_converter + ) + if ( + a2a_task.status.message + and a2a_task.status.state == TaskState.input_required + ): + parts, ids = _convert_a2a_parts_to_adk_parts( + a2a_task.status.message.parts, part_converter + ) + output_parts.extend(parts) + long_running_function_ids.update(ids) + + return _create_event( + output_parts, + invocation_context, + author, + long_running_function_ids, + ) + + except Exception as e: + logger.error("Failed to convert A2A task to event: %s", e) + raise + + +@a2a_experimental +def convert_a2a_message_to_event( + a2a_message: Message, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A message to an ADK event. + + Args: + a2a_message: The A2A message to convert. Must not be None. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + If provided, the branch will be set from the context. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object with converted content and long-running function + metadata. + + Raises: + ValueError: If a2a_message is None. + RuntimeError: If conversion of message parts fails. + """ + if a2a_message is None: + raise ValueError("A2A message cannot be None") + + try: + output_parts, _ = _convert_a2a_parts_to_adk_parts( + a2a_message.parts, part_converter + ) + return _create_event(output_parts, invocation_context, author) + + except Exception as e: + logger.error("Failed to convert A2A message to event: %s", e) + raise RuntimeError(f"Failed to convert message: {e}") from e + + +@a2a_experimental +def convert_a2a_status_update_to_event( + a2a_status_update: TaskStatusUpdateEvent, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task status update to an ADK event. + + Args: + a2a_status_update: The A2A task status update to convert. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted status update. + """ + if a2a_status_update is None: + raise ValueError("A2A status update cannot be None") + + try: + output_parts = [] + long_running_function_ids = set() + if a2a_status_update.status.message: + parts, ids = _convert_a2a_parts_to_adk_parts( + a2a_status_update.status.message.parts, part_converter + ) + output_parts.extend(parts) + long_running_function_ids.update(ids) + + return _create_event( + output_parts, + invocation_context, + author, + long_running_function_ids, + ) + except Exception as e: + logger.error("Failed to convert A2A status update to event: %s", e) + raise RuntimeError(f"Failed to convert status update: {e}") from e + + +# TODO: Add support for non-ADK Artifact Updates. +@a2a_experimental +def convert_a2a_artifact_update_to_event( + a2a_artifact_update: TaskArtifactUpdateEvent, + author: Optional[str] = None, + invocation_context: Optional[InvocationContext] = None, + part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, +) -> Optional[Event]: + """Converts an A2A task artifact update to an ADK event. + + Args: + a2a_artifact_update: The A2A task artifact update to convert. + author: The author of the event. Defaults to "a2a agent" if not provided. + invocation_context: The invocation context containing session information. + part_converter: The function to convert A2A part to GenAI part. + + Returns: + An ADK Event object representing the converted artifact update. + """ + if a2a_artifact_update is None: + raise ValueError("A2A artifact update cannot be None") + + try: + output_parts, _ = _convert_a2a_parts_to_adk_parts( + a2a_artifact_update.artifact.parts, part_converter + ) + return _create_event( + output_parts, + invocation_context, + author, + partial=not a2a_artifact_update.last_chunk, + ) + except Exception as e: + logger.error("Failed to convert A2A artifact update to event: %s", e) + raise RuntimeError(f"Failed to convert artifact update: {e}") from e diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index cca728db..da28955a 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -35,47 +35,33 @@ from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart from google.adk.runners import Runner -from pydantic import BaseModel from typing_extensions import override from ...utils.context_utils import Aclosing -from ..converters.event_converter import AdkEventToA2AEventsConverter -from ..converters.event_converter import convert_event_to_a2a_events -from ..converters.part_converter import A2APartToGenAIPartConverter -from ..converters.part_converter import convert_a2a_part_to_genai_part -from ..converters.part_converter import convert_genai_part_to_a2a_part -from ..converters.part_converter import GenAIPartToA2APartConverter -from ..converters.request_converter import A2ARequestToAgentRunRequestConverter from ..converters.request_converter import AgentRunRequest -from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key from ..experimental import a2a_experimental +from .a2a_agent_executor_impl import _A2aAgentExecutor as ExecutorImpl +from .config import A2aAgentExecutorConfig +from .executor_context import ExecutorContext from .task_result_aggregator import TaskResultAggregator +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors logger = logging.getLogger('google_adk.' + __name__) -@a2a_experimental -class A2aAgentExecutorConfig(BaseModel): - """Configuration for the A2aAgentExecutor.""" - - a2a_part_converter: A2APartToGenAIPartConverter = ( - convert_a2a_part_to_genai_part - ) - gen_ai_part_converter: GenAIPartToA2APartConverter = ( - convert_genai_part_to_a2a_part - ) - request_converter: A2ARequestToAgentRunRequestConverter = ( - convert_a2a_request_to_agent_run_request - ) - event_converter: AdkEventToA2AEventsConverter = convert_event_to_a2a_events - - @a2a_experimental class A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and publishes updates to an event queue. + + Args: + runner: The runner to use for the agent. + config: The config to use for the executor. + use_legacy: Whether to use the legacy executor implementation. """ def __init__( @@ -83,10 +69,15 @@ class A2aAgentExecutor(AgentExecutor): *, runner: Runner | Callable[..., Runner | Awaitable[Runner]], config: Optional[A2aAgentExecutorConfig] = None, + use_legacy: bool = True, ): super().__init__() - self._runner = runner - self._config = config or A2aAgentExecutorConfig() + if not use_legacy: + self._executor_impl = ExecutorImpl(runner=runner, config=config) + else: + self._executor_impl = None + self._runner = runner + self._config = config or A2aAgentExecutorConfig() async def _resolve_runner(self) -> Runner: """Resolve the runner, handling cases where it's a callable that returns a Runner.""" @@ -115,6 +106,10 @@ class A2aAgentExecutor(AgentExecutor): @override async def cancel(self, context: RequestContext, event_queue: EventQueue): """Cancel the execution.""" + if self._executor_impl: + await self._executor_impl.cancel(context, event_queue) + return + # TODO: Implement proper cancellation logic if needed raise NotImplementedError('Cancellation is not supported') @@ -125,6 +120,7 @@ class A2aAgentExecutor(AgentExecutor): event_queue: EventQueue, ): """Executes an A2A request and publishes updates to the event queue + specified. It runs as following: * Takes the input from the A2A request * Convert the input to ADK input content, and runs the ADK agent @@ -132,9 +128,17 @@ class A2aAgentExecutor(AgentExecutor): * Converts the ADK output events into A2A task updates * Publishes the updates back to A2A server via event queue """ + if self._executor_impl: + await self._executor_impl.execute(context, event_queue) + return + if not context.message: raise ValueError('A2A request must have a message') + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + # for new task, create a task submitted event if not context.current_task: await event_queue.enqueue_event( @@ -202,6 +206,13 @@ class A2aAgentExecutor(AgentExecutor): run_config=run_request.run_config, ) + executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + # publish the task working event await event_queue.enqueue_event( TaskStatusUpdateEvent( @@ -230,6 +241,15 @@ class A2aAgentExecutor(AgentExecutor): context.context_id, self._config.gen_ai_part_converter, ): + a2a_event = await execute_after_event_interceptors( + a2a_event, + executor_context, + adk_event, + self._config.execute_interceptors, + ) + if a2a_event is None: + continue + task_result_aggregator.process_event(a2a_event) await event_queue.enqueue_event(a2a_event) @@ -253,31 +273,34 @@ class A2aAgentExecutor(AgentExecutor): ) ) # public the final status update event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, ) else: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.now(timezone.utc).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, - ) + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=context.context_id, + final=True, ) + final_event = await execute_after_agent_interceptors( + executor_context, + final_event, + self._config.execute_interceptors, + ) + await event_queue.enqueue_event(final_event) + async def _prepare_session( self, context: RequestContext, diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py new file mode 100644 index 00000000..cec68f36 --- /dev/null +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -0,0 +1,310 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import inspect +import logging +from typing import Awaitable +from typing import Callable +from typing import Optional +import uuid + +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from typing_extensions import override + +from ...runners import Runner +from ...utils.context_utils import Aclosing +from ..converters.from_adk_event import create_error_status_event +from ..converters.long_running_functions import handle_user_input +from ..converters.long_running_functions import LongRunningFunctions +from ..converters.request_converter import AgentRunRequest +from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental +from .config import A2aAgentExecutorConfig +from .executor_context import ExecutorContext +from .utils import execute_after_agent_interceptors +from .utils import execute_after_event_interceptors +from .utils import execute_before_agent_interceptors + +logger = logging.getLogger('google_adk.' + __name__) + + +@a2a_experimental +class _A2aAgentExecutor(AgentExecutor): + """An AgentExecutor that runs an ADK Agent against an A2A request and + + publishes updates to an event queue. + """ + + def __init__( + self, + *, + runner: Runner | Callable[..., Runner | Awaitable[Runner]], + config: Optional[A2aAgentExecutorConfig] = None, + ): + super().__init__() + self._runner = runner + self._config = config or A2aAgentExecutorConfig() + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cancel the execution.""" + # TODO: Implement proper cancellation logic if needed + raise NotImplementedError('Cancellation is not supported') + + @override + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ): + """Executes an A2A request and publishes updates to the event queue + + specified. It runs as following: + * Takes the input from the A2A request + * Convert the input to ADK input content, and runs the ADK agent + * Collects output events of the underlying ADK Agent + * Converts the ADK output events into A2A task updates + * Publishes the updates back to A2A server via event queue + """ + if not context.message: + raise ValueError('A2A request must have a message') + + context = await execute_before_agent_interceptors( + context, self._config.execute_interceptors + ) + + runner = await self._resolve_runner() + try: + run_request = self._config.request_converter( + context, + self._config.a2a_part_converter, + ) + await self._resolve_session(run_request, runner) + + executor_context = ExecutorContext( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + runner=runner, + ) + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + Task( + id=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + history=[context.message], + metadata=self._get_invocation_metadata(executor_context), + ) + ) + else: + # Check if the user input is responding to the agent's + # request for input. + missing_user_input_event = handle_user_input(context) + if missing_user_input_event: + missing_user_input_event.metadata = self._get_invocation_metadata( + executor_context + ) + await event_queue.enqueue_event(missing_user_input_event) + return + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + metadata=self._get_invocation_metadata(executor_context), + ) + ) + + # Handle the request and publish updates to the event queue + await self._handle_request( + context, + executor_context, + event_queue, + runner, + run_request, + ) + except Exception as e: + logger.error('Error handling A2A request: %s', e, exc_info=True) + # Publish failure event + try: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=str(e))], + ), + ), + context_id=context.context_id, + final=True, + ) + ) + except Exception as enqueue_error: + logger.error( + 'Failed to publish failure event: %s', enqueue_error, exc_info=True + ) + + async def _handle_request( + self, + context: RequestContext, + executor_context: ExecutorContext, + event_queue: EventQueue, + runner: Runner, + run_request: AgentRunRequest, + ): + agents_artifact: dict[str, str] = {} + error_event = None + long_running_functions = LongRunningFunctions( + self._config.gen_ai_part_converter + ) + async with Aclosing(runner.run_async(**vars(run_request))) as agen: + async for adk_event in agen: + # Handle error scenarios + if adk_event and (adk_event.error_code or adk_event.error_message): + error_event = create_error_status_event( + adk_event, + context.task_id, + context.context_id, + ) + + # Handle long running function calls + adk_event = long_running_functions.process_event(adk_event) + + for a2a_event in self._config.adk_event_converter( + adk_event, + agents_artifact, + context.task_id, + context.context_id, + self._config.gen_ai_part_converter, + ): + a2a_event.metadata = self._get_invocation_metadata(executor_context) + a2a_event = await execute_after_event_interceptors( + a2a_event, + executor_context, + adk_event, + self._config.execute_interceptors, + ) + if not a2a_event: + continue + await event_queue.enqueue_event(a2a_event) + + if error_event: + final_event = error_event + elif long_running_functions.has_long_running_function_calls(): + final_event = ( + long_running_functions.create_long_running_function_call_event( + context.task_id, context.context_id + ) + ) + else: + final_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, + ) + + final_event.metadata = self._get_invocation_metadata(executor_context) + final_event = await execute_after_agent_interceptors( + executor_context, final_event, self._config.execute_interceptors + ) + await event_queue.enqueue_event(final_event) + + async def _resolve_runner(self) -> Runner: + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + if isinstance(self._runner, Runner): + return self._runner + if callable(self._runner): + result = self._runner() + + if inspect.iscoroutine(result): + resolved_runner = await result + else: + resolved_runner = result + + self._runner = resolved_runner + return resolved_runner + + raise TypeError( + 'Runner must be a Runner instance or a callable that returns a' + f' Runner, got {type(self._runner)}' + ) + + async def _resolve_session( + self, + run_request: AgentRunRequest, + runner: Runner, + ): + session_id = run_request.session_id + # create a new session if not exists + user_id = run_request.user_id + session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=user_id, + session_id=session_id, + ) + if session is None: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + session_id=session_id, + ) + # Update run_request with the new session_id + run_request.session_id = session.id + + def _get_invocation_metadata( + self, executor_context: ExecutorContext + ) -> dict[str, str]: + return { + _get_adk_metadata_key('app_name'): executor_context.app_name, + _get_adk_metadata_key('user_id'): executor_context.user_id, + _get_adk_metadata_key('session_id'): executor_context.session_id, + # TODO: Remove this metadata once the new agent executor + # is fully adopted. + _get_adk_metadata_key('agent_executor_v2'): True, + } diff --git a/src/google/adk/a2a/executor/config.py b/src/google/adk/a2a/executor/config.py new file mode 100644 index 00000000..c083affd --- /dev/null +++ b/src/google/adk/a2a/executor/config.py @@ -0,0 +1,107 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from typing import Awaitable +from typing import Callable +from typing import Optional +from typing import Union + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent +from pydantic import BaseModel + +from ...events.event import Event +from ..converters.event_converter import AdkEventToA2AEventsConverter +from ..converters.event_converter import convert_event_to_a2a_events as legacy_convert_event_to_a2a_events +from ..converters.from_adk_event import AdkEventToA2AEventsConverter as AdkEventToA2AEventsConverterImpl +from ..converters.from_adk_event import convert_event_to_a2a_events as convert_event_to_a2a_events_impl +from ..converters.part_converter import A2APartToGenAIPartConverter +from ..converters.part_converter import convert_a2a_part_to_genai_part +from ..converters.part_converter import convert_genai_part_to_a2a_part +from ..converters.part_converter import GenAIPartToA2APartConverter +from ..converters.request_converter import A2ARequestToAgentRunRequestConverter +from ..converters.request_converter import convert_a2a_request_to_agent_run_request +from ..converters.utils import _get_adk_metadata_key +from ..experimental import a2a_experimental +from .executor_context import ExecutorContext + + +@dataclasses.dataclass +class ExecuteInterceptor: + """Interceptor for the A2aAgentExecutor.""" + + before_agent: Optional[ + Callable[[RequestContext], Awaitable[RequestContext]] + ] = None + """Hook executed before the agent starts processing the request. + + Allows inspection or modification of the incoming request context. + Must return a valid `RequestContext` to continue execution. + """ + + after_event: Optional[ + Callable[ + [ExecutorContext, A2AEvent, Event], + Awaitable[Union[A2AEvent, None]], + ] + ] = None + """Hook executed after an ADK event is converted to an A2A event. + + Allows mutating the outgoing event before it is enqueued. + Return `None` to filter out and drop the event entirely, + which also halts any subsequent interceptors in the chain. + """ + + after_agent: Optional[ + Callable[ + [ExecutorContext, TaskStatusUpdateEvent], + Awaitable[TaskStatusUpdateEvent], + ] + ] = None + """Hook executed after the agent finishes and the final event is prepared. + + Allows inspection or modification of the terminal status event (e.g., + completed or failed) before it is enqueued. Must return a valid + `TaskStatusUpdateEvent`. + """ + + +@a2a_experimental +class A2aAgentExecutorConfig(BaseModel): + """Configuration for the A2aAgentExecutor.""" + + a2a_part_converter: A2APartToGenAIPartConverter = ( + convert_a2a_part_to_genai_part + ) + gen_ai_part_converter: GenAIPartToA2APartConverter = ( + convert_genai_part_to_a2a_part + ) + request_converter: A2ARequestToAgentRunRequestConverter = ( + convert_a2a_request_to_agent_run_request + ) + event_converter: AdkEventToA2AEventsConverter = ( + legacy_convert_event_to_a2a_events + ) + """Set up the default event converter implementation to be used by the legacy agent executor implementation.""" + + adk_event_converter: AdkEventToA2AEventsConverterImpl = ( + convert_event_to_a2a_events_impl + ) + """Set up the imlp event converter implementation to be used by the new agent executor implementation.""" + + execute_interceptors: Optional[list[ExecuteInterceptor]] = None diff --git a/src/google/adk/a2a/executor/executor_context.py b/src/google/adk/a2a/executor/executor_context.py new file mode 100644 index 00000000..313afee6 --- /dev/null +++ b/src/google/adk/a2a/executor/executor_context.py @@ -0,0 +1,49 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.runners import Runner + + +class ExecutorContext: + """Context for the executor.""" + + def __init__( + self, + app_name: str, + user_id: str, + session_id: str, + runner: Runner, + ): + self._app_name = app_name + self._user_id = user_id + self._session_id = session_id + self._runner = runner + + @property + def app_name(self) -> str: + return self._app_name + + @property + def user_id(self) -> str: + return self._user_id + + @property + def session_id(self) -> str: + return self._session_id + + @property + def runner(self) -> Runner: + return self._runner diff --git a/src/google/adk/a2a/executor/utils.py b/src/google/adk/a2a/executor/utils.py new file mode 100644 index 00000000..d01066ea --- /dev/null +++ b/src/google/adk/a2a/executor/utils.py @@ -0,0 +1,67 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Optional + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent +from a2a.types import TaskStatusUpdateEvent + +from ...events.event import Event +from ..converters.utils import _get_adk_metadata_key +from .config import ExecuteInterceptor +from .executor_context import ExecutorContext + + +async def execute_before_agent_interceptors( + context: RequestContext, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> RequestContext: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.before_agent: + context = await interceptor.before_agent(context) + return context + + +async def execute_after_event_interceptors( + a2a_event: A2AEvent, + executor_context: ExecutorContext, + adk_event: Event, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> Optional[A2AEvent]: + if execute_interceptors: + for interceptor in execute_interceptors: + if interceptor.after_event: + a2a_event = await interceptor.after_event( + executor_context, a2a_event, adk_event + ) + if a2a_event is None: + return None + return a2a_event + + +async def execute_after_agent_interceptors( + executor_context: ExecutorContext, + final_event: TaskStatusUpdateEvent, + execute_interceptors: Optional[list[ExecuteInterceptor]], +) -> TaskStatusUpdateEvent: + if execute_interceptors: + for interceptor in reversed(execute_interceptors): + if interceptor.after_agent: + final_event = await interceptor.after_agent( + executor_context, final_event + ) + return final_event diff --git a/src/google/adk/a2a/experimental.py b/src/google/adk/a2a/experimental.py index 77c31fde..7f331eb7 100644 --- a/src/google/adk/a2a/experimental.py +++ b/src/google/adk/a2a/experimental.py @@ -23,7 +23,7 @@ a2a_experimental = _make_feature_decorator( default_message=( "ADK Implementation for A2A support (A2aAgentExecutor, RemoteA2aAgent " "and corresponding supporting components etc.) is in experimental mode " - "and is subjected to breaking changes. A2A protocol and SDK are" + "and is subject to breaking changes. A2A protocol and SDK are " "themselves not experimental. Once it's stable enough the experimental " "mode will be removed. Your feedback is welcome." ), diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 155888bc..d6a07080 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -20,7 +20,9 @@ from typing import Union from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import PushNotificationConfigStore from a2a.types import AgentCard from starlette.applications import Starlette @@ -78,6 +80,7 @@ def to_a2a( port: int = 8000, protocol: str = "http", agent_card: Optional[Union[AgentCard, str]] = None, + push_config_store: Optional[PushNotificationConfigStore] = None, runner: Optional[Runner] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -90,6 +93,9 @@ def to_a2a( agent_card: Optional pre-built AgentCard object or path to agent card JSON. If not provided, will be built automatically from the agent. + push_config_store: Optional A2A push notification config store. If not + provided, an in-memory store will be created so push-notification + config RPC methods are supported. runner: Optional pre-built Runner object. If not provided, a default runner will be created using in-memory services. @@ -127,8 +133,13 @@ def to_a2a( runner=runner or create_runner, ) + if push_config_store is None: + push_config_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=task_store + agent_executor=agent_executor, + task_store=task_store, + push_config_store=push_config_store, ) # Use provided agent card or build one from the agent diff --git a/src/google/adk/agents/agent_config.py b/src/google/adk/agents/agent_config.py index add31f4b..2d3c6270 100644 --- a/src/google/adk/agents/agent_config.py +++ b/src/google/adk/agents/agent_config.py @@ -16,14 +16,14 @@ from __future__ import annotations from typing import Annotated from typing import Any -from typing import get_args from typing import Union from pydantic import Discriminator from pydantic import RootModel from pydantic import Tag -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig from .llm_agent_config import LlmAgentConfig from .loop_agent_config import LoopAgentConfig @@ -68,6 +68,6 @@ ConfigsUnion = Annotated[ # Use a RootModel to represent the agent directly at the top level. # The `discriminator` is applied to the union within the RootModel. -@experimental +@experimental(FeatureName.AGENT_CONFIG) class AgentConfig(RootModel[ConfigsUnion]): """The config for the YAML schema to create an agent.""" diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 7e46436a..dec85690 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -40,10 +40,11 @@ from typing_extensions import TypeAlias from ..events.event import Event from ..events.event_actions import EventActions +from ..features import experimental +from ..features import FeatureName from ..telemetry import tracing from ..telemetry.tracing import tracer from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext @@ -70,7 +71,7 @@ AfterAgentCallback: TypeAlias = Union[ SelfAgent = TypeVar('SelfAgent', bound='BaseAgent') -@experimental +@experimental(FeatureName.AGENT_STATE) class BaseAgentState(BaseModel): """Base class for all agent states.""" @@ -121,7 +122,9 @@ class BaseAgent(BaseModel): One-line description is enough and preferred. """ - parent_agent: Optional[BaseAgent] = Field(default=None, init=False) + parent_agent: Optional[BaseAgent] = Field( + default=None, init=False, exclude=True + ) """The parent agent of this agent. Note that an agent can ONLY be added as sub-agent once. @@ -618,7 +621,7 @@ class BaseAgent(BaseModel): @final @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def from_config( cls: Type[SelfAgent], config: BaseAgentConfig, @@ -642,7 +645,7 @@ class BaseAgent(BaseModel): return cls(**kwargs) @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: Type[SelfAgent], config: BaseAgentConfig, diff --git a/src/google/adk/agents/base_agent_config.py b/src/google/adk/agents/base_agent_config.py index 9f1f5566..3859cb35 100644 --- a/src/google/adk/agents/base_agent_config.py +++ b/src/google/adk/agents/base_agent_config.py @@ -26,14 +26,15 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .common_configs import AgentRefConfig from .common_configs import CodeConfig TBaseAgentConfig = TypeVar('TBaseAgentConfig', bound='BaseAgentConfig') -@experimental +@experimental(FeatureName.AGENT_CONFIG) class BaseAgentConfig(BaseModel): """The config for the YAML schema of a BaseAgent. diff --git a/src/google/adk/agents/common_configs.py b/src/google/adk/agents/common_configs.py index 4e4c49f3..49baa8a4 100644 --- a/src/google/adk/agents/common_configs.py +++ b/src/google/adk/agents/common_configs.py @@ -24,10 +24,11 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import model_validator -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ArgumentConfig(BaseModel): """An argument passed to a function or a class's constructor.""" @@ -43,7 +44,7 @@ class ArgumentConfig(BaseModel): """The argument value.""" -@experimental +@experimental(FeatureName.AGENT_CONFIG) class CodeConfig(BaseModel): """Code reference config for a variable, a function, or a class. @@ -81,7 +82,7 @@ class CodeConfig(BaseModel): """ -@experimental +@experimental(FeatureName.AGENT_CONFIG) class AgentRefConfig(BaseModel): """The config for the reference to another agent.""" diff --git a/src/google/adk/agents/config_agent_utils.py b/src/google/adk/agents/config_agent_utils.py index 446eac88..2c1c9bd9 100644 --- a/src/google/adk/agents/config_agent_utils.py +++ b/src/google/adk/agents/config_agent_utils.py @@ -22,7 +22,8 @@ from typing import List import yaml -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .agent_config import AgentConfig from .base_agent import BaseAgent from .base_agent_config import BaseAgentConfig @@ -30,7 +31,7 @@ from .common_configs import AgentRefConfig from .common_configs import CodeConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) def from_config(config_path: str) -> BaseAgent: """Build agent from a configfile path. @@ -102,7 +103,7 @@ def _load_config_from_path(config_path: str) -> AgentConfig: return AgentConfig.model_validate(config_data) -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_fully_qualified_name(name: str) -> Any: try: module_path, obj_name = name.rsplit(".", 1) @@ -112,7 +113,7 @@ def resolve_fully_qualified_name(name: str) -> Any: raise ValueError(f"Invalid fully qualified name: {name}") from e -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_agent_reference( ref_config: AgentRefConfig, referencing_agent_config_abs_path: str ) -> BaseAgent: @@ -170,7 +171,7 @@ def _resolve_agent_code_reference(code: str) -> Any: return obj -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_code_reference(code_config: CodeConfig) -> Any: """Resolve a code reference to actual Python object. @@ -199,7 +200,7 @@ def resolve_code_reference(code_config: CodeConfig) -> Any: return obj -@experimental +@experimental(FeatureName.AGENT_CONFIG) def resolve_callbacks(callbacks_config: List[CodeConfig]) -> Any: """Resolve callbacks from configuration. diff --git a/src/google/adk/agents/context_cache_config.py b/src/google/adk/agents/context_cache_config.py index 855e28c3..9e6d19ca 100644 --- a/src/google/adk/agents/context_cache_config.py +++ b/src/google/adk/agents/context_cache_config.py @@ -18,10 +18,11 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ContextCacheConfig(BaseModel): """Configuration for context caching across all agents in an app. diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 7a23a6cc..35b8dc97 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -24,6 +24,7 @@ from pydantic import ConfigDict from pydantic import Field from pydantic import PrivateAttr +from ..apps.app import EventsCompactionConfig from ..apps.app import ResumabilityConfig from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService @@ -200,6 +201,12 @@ class InvocationContext(BaseModel): resumability_config: Optional[ResumabilityConfig] = None """The resumability config that applies to all agents under this invocation.""" + events_compaction_config: Optional[EventsCompactionConfig] = None + """The compaction config for this invocation.""" + + token_compaction_checked: bool = False + """Whether token-threshold compaction ran during this invocation.""" + plugin_manager: PluginManager = Field(default_factory=PluginManager) """The manager for keeping track of plugins in this invocation.""" @@ -389,23 +396,20 @@ class InvocationContext(BaseModel): return False # TODO: Move this method from invocation_context to a dedicated module. - # TODO: Converge this method with find_matching_function_call in llm_flows. def _find_matching_function_call( self, function_response_event: Event ) -> Optional[Event]: """Finds the function call event in the current invocation that matches the function response id.""" + from ..flows.llm_flows.functions import find_event_by_function_call_id + function_responses = function_response_event.get_function_responses() if not function_responses: return None - function_call_id = function_responses[0].id - events = self._get_events(current_invocation=True) - # The last event is function_response_event, so we search backwards from the - # one before it. - for event in reversed(events[:-1]): - if any(fc.id == function_call_id for fc in event.get_function_calls()): - return event - return None + # Search backwards from the event before the current response event. + return find_event_by_function_call_id( + self._get_events(current_invocation=True)[:-1], function_responses[0].id + ) def new_invocation_context_id() -> str: diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 5294e056..0f7cc2b7 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import importlib import inspect import logging @@ -40,6 +41,8 @@ from typing_extensions import TypeAlias from ..code_executors.base_code_executor import BaseCodeExecutor from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..flows.llm_flows.auto_flow import AutoFlow from ..flows.llm_flows.base_llm_flow import BaseLlmFlow from ..flows.llm_flows.single_flow import SingleFlow @@ -53,8 +56,9 @@ from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool from ..tools.tool_configs import ToolConfig from ..tools.tool_context import ToolContext +from ..utils._schema_utils import SchemaType +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -316,9 +320,16 @@ class LlmAgent(BaseAgent): # Controlled input/output configurations - Start input_schema: Optional[type[BaseModel]] = None """The input schema when agent is used as a tool.""" - output_schema: Optional[type[BaseModel]] = None + output_schema: Optional[SchemaType] = None """The output schema when agent replies. + Supports all schema types that the underlying Google GenAI API supports: + - type[BaseModel]: e.g., MySchema + - list[type[BaseModel]]: e.g., list[MySchema] + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type + NOTE: When this is set, agent can ONLY reply and CANNOT use any tools, such as function tools, RAGs, agent transfer, etc. @@ -589,24 +600,27 @@ class LlmAgent(BaseAgent): return global_instruction, True async def canonical_tools( - self, ctx: ReadonlyContext = None + self, ctx: Optional[ReadonlyContext] = None ) -> list[BaseTool]: """The resolved self.tools field as a list of BaseTool based on the context. This method is only for use by Agent Development Kit. """ - resolved_tools = [] # We may need to wrap some built-in tools if there are other tools # because the built-in tools cannot be used together with other tools. # TODO(b/448114567): Remove once the workaround is no longer needed. multiple_tools = len(self.tools) > 1 model = self.canonical_model - for tool_union in self.tools: - resolved_tools.extend( - await _convert_tool_union_to_tools( - tool_union, ctx, model, multiple_tools - ) - ) + + results = await asyncio.gather(*( + _convert_tool_union_to_tools(tool_union, ctx, model, multiple_tools) + for tool_union in self.tools + )) + + resolved_tools = [] + for tools in results: + resolved_tools.extend(tools) + return resolved_tools @property @@ -815,12 +829,12 @@ class LlmAgent(BaseAgent): event.author, ) return - if ( - self.output_key - and event.is_final_response() - and event.content - and event.content.parts - ): + + if not self.output_key: + return + + # Handle text responses + if event.is_final_response() and event.content and event.content.parts: result = ''.join( part.text @@ -833,9 +847,7 @@ class LlmAgent(BaseAgent): # Do not attempt to parse it as JSON. if not result.strip(): return - result = self.output_schema.model_validate_json(result).model_dump( - exclude_none=True - ) + result = validate_schema(self.output_schema, result) event.actions.state_delta[self.output_key] = result @model_validator(mode='after') @@ -879,7 +891,7 @@ class LlmAgent(BaseAgent): ) @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _resolve_tools( cls, tool_configs: list[ToolConfig], config_abs_path: str ) -> list[Any]: @@ -938,7 +950,7 @@ class LlmAgent(BaseAgent): @override @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: Type[LlmAgent], config: LlmAgentConfig, diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 9296714f..2980f68a 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -26,8 +26,9 @@ from typing import Optional from typing_extensions import override from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -37,7 +38,7 @@ from .loop_agent_config import LoopAgentConfig logger = logging.getLogger('google_adk.' + __name__) -@experimental +@experimental(FeatureName.AGENT_STATE) class LoopAgentState(BaseAgentState): """State for LoopAgent.""" @@ -153,7 +154,7 @@ class LoopAgent(BaseAgent): @override @classmethod - @experimental + @experimental(FeatureName.AGENT_CONFIG) def _parse_config( cls: type[LoopAgent], config: LoopAgentConfig, diff --git a/src/google/adk/agents/loop_agent_config.py b/src/google/adk/agents/loop_agent_config.py index 1aaa0ef9..78fc790b 100644 --- a/src/google/adk/agents/loop_agent_config.py +++ b/src/google/adk/agents/loop_agent_config.py @@ -21,11 +21,12 @@ from typing import Optional from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) class LoopAgentConfig(BaseAgentConfig): """The config for the YAML schema of a LoopAgent.""" diff --git a/src/google/adk/agents/parallel_agent_config.py b/src/google/adk/agents/parallel_agent_config.py index 77eb1a68..96a75b65 100644 --- a/src/google/adk/agents/parallel_agent_config.py +++ b/src/google/adk/agents/parallel_agent_config.py @@ -19,11 +19,12 @@ from __future__ import annotations from pydantic import ConfigDict from pydantic import Field -from ..utils.feature_decorator import experimental +from ..features import experimental +from ..features import FeatureName from .base_agent_config import BaseAgentConfig -@experimental +@experimental(FeatureName.AGENT_CONFIG) class ParallelAgentConfig(BaseAgentConfig): """The config for the YAML schema of a ParallelAgent.""" diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 2da7a4fa..9b3a7b22 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -35,14 +35,17 @@ from a2a.client.errors import A2AClientHTTPError from a2a.client.middleware import ClientCallContext from a2a.types import AgentCard from a2a.types import Message as A2AMessage +from a2a.types import MessageSendConfiguration from a2a.types import Part as A2APart from a2a.types import Role +from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent from a2a.types import TransportProtocol as A2ATransport from google.genai import types as genai_types import httpx +from pydantic import BaseModel try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -50,6 +53,9 @@ except ImportError: # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" +from ..a2a.agent.config import A2aRemoteAgentConfig +from ..a2a.agent.utils import execute_after_request_interceptors +from ..a2a.agent.utils import execute_before_request_interceptors from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message @@ -57,6 +63,7 @@ from ..a2a.converters.part_converter import A2APartToGenAIPartConverter from ..a2a.converters.part_converter import convert_a2a_part_to_genai_part from ..a2a.converters.part_converter import convert_genai_part_to_a2a_part from ..a2a.converters.part_converter import GenAIPartToA2APartConverter +from ..a2a.converters.utils import _get_adk_metadata_key from ..a2a.experimental import a2a_experimental from ..a2a.logs.log_utils import build_a2a_request_log from ..a2a.logs.log_utils import build_a2a_response_log @@ -127,6 +134,7 @@ class RemoteA2aAgent(BaseAgent): Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, full_history_when_stateless: bool = False, + config: Optional[A2aRemoteAgentConfig] = None, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -147,6 +155,7 @@ class RemoteA2aAgent(BaseAgent): return Tasks or context IDs) will receive all session events on every request. If False, the default behavior of sending only events since the last reply from the agent will be used. + config: Optional configuration object. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -174,6 +183,7 @@ class RemoteA2aAgent(BaseAgent): self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider self._full_history_when_stateless = full_history_when_stateless + self._config = config or A2aRemoteAgentConfig() # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -514,6 +524,76 @@ class RemoteA2aAgent(BaseAgent): branch=ctx.branch, ) + async def _handle_a2a_response_v2( + self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext + ) -> Optional[Event]: + """Handle A2A response and convert to Event. + + Args: + a2a_response: The A2A response object + ctx: The invocation context + + Returns: + Event object representing the response, or None if no event should be + emitted. + """ + try: + if isinstance(a2a_response, tuple): + task, update = a2a_response + event = None + if update is None: + # This is the initial response for a streaming task or the complete + # response for a non-streaming task. + event = self._config.a2a_task_converter( + task, self.name, ctx, self._config.a2a_part_converter + ) + elif isinstance(update, A2ATaskStatusUpdateEvent): + # This is a streaming task status update. + event = self._config.a2a_status_update_converter( + update, self.name, ctx, self._config.a2a_part_converter + ) + elif isinstance(update, A2ATaskArtifactUpdateEvent): + # This is a streaming task artifact update. + event = self._config.a2a_artifact_update_converter( + update, self.name, ctx, self._config.a2a_part_converter + ) + if not event: + return None + event.custom_metadata = event.custom_metadata or {} + event.custom_metadata[A2A_METADATA_PREFIX + "task_id"] = task.id + if task.context_id: + event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = ( + task.context_id + ) + + # Otherwise, it's a regular A2AMessage. + elif isinstance(a2a_response, A2AMessage): + event = self._config.a2a_message_converter( + a2a_response, self.name, ctx, self._config.a2a_part_converter + ) + event.custom_metadata = event.custom_metadata or {} + + if a2a_response.context_id: + event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = ( + a2a_response.context_id + ) + else: + event = Event( + author=self.name, + error_message="Unknown A2A response type", + invocation_id=ctx.invocation_id, + branch=ctx.branch, + ) + return event + except A2AClientError as e: + logger.error("Failed to handle A2A response: %s", e) + return Event( + author=self.name, + error_message=f"Failed to process A2A response: {e}", + invocation_id=ctx.invocation_id, + branch=ctx.branch, + ) + async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: @@ -558,18 +638,49 @@ class RemoteA2aAgent(BaseAgent): logger.debug(build_a2a_request_log(a2a_request)) try: - request_metadata = None - if self._a2a_request_meta_provider: - request_metadata = self._a2a_request_meta_provider(ctx, a2a_request) + a2a_request, parameters = await execute_before_request_interceptors( + self._config.request_interceptors, ctx, a2a_request + ) + if isinstance(a2a_request, Event): + yield a2a_request + return + + # Backward compatibility + if self._a2a_request_meta_provider: + parameters.request_metadata = self._a2a_request_meta_provider( + ctx, a2a_request + ) + + # TODO: Add support for requested_extension and + # message_send_configuration once they are supported by the A2A client. async for a2a_response in self._a2a_client.send_message( request=a2a_request, - request_metadata=request_metadata, - context=ClientCallContext(state=ctx.session.state), + request_metadata=parameters.request_metadata, + context=parameters.client_call_context, ): logger.debug(build_a2a_response_log(a2a_response)) - event = await self._handle_a2a_response(a2a_response, ctx) + metadata = None + if isinstance(a2a_response, tuple): + task = a2a_response[0] + if task: + metadata = task.metadata + else: + metadata = a2a_response.metadata + + if metadata and metadata.get( + _get_adk_metadata_key("agent_executor_v2") + ): + event = await self._handle_a2a_response_v2(a2a_response, ctx) + else: + event = await self._handle_a2a_response(a2a_response, ctx) + if not event: + continue + + event = await execute_after_request_interceptors( + self._config.request_interceptors, ctx, a2a_response, event + ) if not event: continue diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index eec1dea9..06a2377b 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -24,8 +24,9 @@ from typing import Type from typing_extensions import override from ..events.event import Event +from ..features import experimental +from ..features import FeatureName from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig @@ -36,7 +37,7 @@ from .sequential_agent_config import SequentialAgentConfig logger = logging.getLogger('google_adk.' + __name__) -@experimental +@experimental(FeatureName.AGENT_STATE) class SequentialAgentState(BaseAgentState): """State for SequentialAgent.""" diff --git a/src/google/adk/agents/sequential_agent_config.py b/src/google/adk/agents/sequential_agent_config.py index 763527e9..44551c42 100644 --- a/src/google/adk/agents/sequential_agent_config.py +++ b/src/google/adk/agents/sequential_agent_config.py @@ -19,11 +19,12 @@ from __future__ import annotations from pydantic import ConfigDict from pydantic import Field -from ..agents.base_agent import experimental from ..agents.base_agent_config import BaseAgentConfig +from ..features import experimental +from ..features import FeatureName -@experimental +@experimental(FeatureName.AGENT_CONFIG) class SequentialAgentConfig(BaseAgentConfig): """The config for the YAML schema of a SequentialAgent.""" diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index 4af7b512..61941bff 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -16,25 +16,53 @@ from __future__ import annotations import logging +from google.genai import types + +from ..agents.base_agent import BaseAgent from ..events.event import Event from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .app import App +from .app import EventsCompactionConfig from .llm_event_summarizer import LlmEventSummarizer logger = logging.getLogger('google_adk.' + __name__) -def _count_text_chars_in_event(event: Event) -> int: - """Returns the number of text characters in an event's content.""" +def _count_text_chars_in_content(content: types.Content | None) -> int: + """Returns the number of text characters in a content object.""" total_chars = 0 - if event.content and event.content.parts: - for part in event.content.parts: + if content and content.parts: + for part in content.parts: if part.text: total_chars += len(part.text) return total_chars +def _valid_compactions( + events: list[Event], +) -> list[tuple[int, float, float, Event]]: + """Returns compaction events with fully-defined compaction ranges.""" + compactions: list[tuple[int, float, float, Event]] = [] + for i, event in enumerate(events): + if not (event.actions and event.actions.compaction): + continue + compaction = event.actions.compaction + if ( + compaction.start_timestamp is None + or compaction.end_timestamp is None + or compaction.compacted_content is None + ): + continue + compactions.append(( + i, + compaction.start_timestamp, + compaction.end_timestamp, + event, + )) + return compactions + + def _is_compaction_subsumed( *, start_timestamp: float, @@ -60,67 +88,29 @@ def _is_compaction_subsumed( return False -def _estimate_prompt_token_count(events: list[Event]) -> int | None: +def _estimate_prompt_token_count( + *, + events: list[Event], + current_branch: str | None, + agent_name: str, +) -> int | None: """Returns an approximate prompt token count from session events. - This estimate is compaction-aware: it counts compaction summaries and only - counts raw events that would remain visible after applying compaction ranges. + This estimate mirrors the effective content-building path used by the + contents request processor. """ - compactions: list[tuple[int, float, float, Event]] = [] - for i, event in enumerate(events): - if not (event.actions and event.actions.compaction): - continue - compaction = event.actions.compaction - if ( - compaction.start_timestamp is None - or compaction.end_timestamp is None - or compaction.compacted_content is None - ): - continue - compactions.append(( - i, - compaction.start_timestamp, - compaction.end_timestamp, - Event( - timestamp=compaction.end_timestamp, - author='model', - content=compaction.compacted_content, - branch=event.branch, - invocation_id=event.invocation_id, - actions=event.actions, - ), - )) - - effective_compactions = [ - (i, start, end, summary_event) - for i, start, end, summary_event in compactions - if not _is_compaction_subsumed( - start_timestamp=start, - end_timestamp=end, - event_index=i, - compactions=compactions, - ) - ] - compaction_ranges = [ - (start, end) for _, start, end, _ in effective_compactions - ] - - def _is_timestamp_compacted(ts: float) -> bool: - for start_ts, end_ts in compaction_ranges: - if start_ts <= ts <= end_ts: - return True - return False + # Deferred import: contents depends on agents.invocation_context which + # imports from apps, so a top-level import would create a circular dependency. + from ..flows.llm_flows import contents + effective_contents = contents._get_contents( + current_branch=current_branch, + events=events, + agent_name=agent_name, + ) total_chars = 0 - for _, _, _, summary_event in effective_compactions: - total_chars += _count_text_chars_in_event(summary_event) - - for event in events: - if event.actions and event.actions.compaction: - continue - if _is_timestamp_compacted(event.timestamp): - continue - total_chars += _count_text_chars_in_event(event) + for content in effective_contents: + total_chars += _count_text_chars_in_content(content) if total_chars <= 0: return None @@ -129,7 +119,12 @@ def _estimate_prompt_token_count(events: list[Event]) -> int | None: return total_chars // 4 -def _latest_prompt_token_count(events: list[Event]) -> int | None: +def _latest_prompt_token_count( + events: list[Event], + *, + current_branch: str | None = None, + agent_name: str = '', +) -> int | None: """Returns the most recently observed prompt token count, if available.""" for event in reversed(events): if ( @@ -137,23 +132,29 @@ def _latest_prompt_token_count(events: list[Event]) -> int | None: and event.usage_metadata.prompt_token_count is not None ): return event.usage_metadata.prompt_token_count - return _estimate_prompt_token_count(events) + return _estimate_prompt_token_count( + events=events, + current_branch=current_branch, + agent_name=agent_name, + ) def _latest_compaction_event(events: list[Event]) -> Event | None: - """Returns the compaction event with the greatest covered end timestamp.""" + """Returns the latest non-subsumed compaction event by stream order.""" + compactions = _valid_compactions(events) latest_event = None - latest_end = 0.0 - for event in events: - if ( - event.actions - and event.actions.compaction - and event.actions.compaction.end_timestamp is not None + latest_index = -1 + for event_index, start_ts, end_ts, event in compactions: + if _is_compaction_subsumed( + start_timestamp=start_ts, + end_timestamp=end_ts, + event_index=event_index, + compactions=compactions, ): - end_ts = event.actions.compaction.end_timestamp - if end_ts is not None and end_ts >= latest_end: - latest_end = end_ts - latest_event = event + continue + if event_index > latest_index: + latest_index = event_index + latest_event = event return latest_event @@ -167,55 +168,73 @@ def _latest_compaction_end_timestamp(events: list[Event]) -> float: return latest_event.actions.compaction.end_timestamp -async def _run_compaction_for_token_threshold( - app: App, session: Session, session_service: BaseSessionService -): - """Runs post-invocation compaction based on a token threshold. +def _has_token_threshold_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether token-threshold compaction is fully configured.""" + return bool( + config + and config.token_threshold is not None + and config.event_retention_size is not None + ) - If triggered, this compacts older raw events and keeps the last - `event_retention_size` raw events un-compacted. - """ - config = app.events_compaction_config - if not config: - return False - if config.token_threshold is None or config.event_retention_size is None: - return False - prompt_token_count = _latest_prompt_token_count(session.events) - if prompt_token_count is None or prompt_token_count < config.token_threshold: - return False +def _has_sliding_window_config(config: EventsCompactionConfig | None) -> bool: + """Returns whether sliding-window compaction is fully configured.""" + return bool( + config + and config.compaction_interval is not None + and config.overlap_size is not None + ) - latest_compaction_event = _latest_compaction_event(session.events) - last_compacted_end_timestamp = 0.0 - if ( - latest_compaction_event - and latest_compaction_event.actions - and latest_compaction_event.actions.compaction - and latest_compaction_event.actions.compaction.end_timestamp is not None - ): - last_compacted_end_timestamp = ( - latest_compaction_event.actions.compaction.end_timestamp + +def _ensure_compaction_summarizer( + *, config: EventsCompactionConfig, agent: BaseAgent +) -> None: + """Ensures compaction config has a summarizer initialized.""" + if config.summarizer is not None: + return + + from ..agents.llm_agent import LlmAgent + + if not isinstance(agent, LlmAgent): + raise ValueError( + 'No LlmAgent model available for event compaction summarizer.' ) + config.summarizer = LlmEventSummarizer(llm=agent.canonical_model) + + +def _events_to_compact_for_token_threshold( + *, + events: list[Event], + event_retention_size: int, +) -> list[Event]: + """Collects token-threshold compaction candidates with rolling-summary seed. + + If a previous compaction exists, include its summary as the first event so + the next summary can supersede it. + """ + latest_compaction_event = _latest_compaction_event(events) + last_compacted_end_timestamp = _latest_compaction_end_timestamp(events) + candidate_events = [ - e - for e in session.events - if not (e.actions and e.actions.compaction) - and e.timestamp > last_compacted_end_timestamp + event + for event in events + if not (event.actions and event.actions.compaction) + and event.timestamp > last_compacted_end_timestamp ] + if len(candidate_events) <= event_retention_size: + return [] - if len(candidate_events) <= config.event_retention_size: - return False - - if config.event_retention_size == 0: + if event_retention_size == 0: events_to_compact = candidate_events else: - events_to_compact = candidate_events[: -config.event_retention_size] + split_index = _safe_token_compaction_split_index( + candidate_events=candidate_events, + event_retention_size=event_retention_size, + ) + events_to_compact = candidate_events[:split_index] if not events_to_compact: - return False + return [] - # Rolling summary: if a previous compaction exists, seed the next summary with - # the previous compaction summary content so new compactions can subsume older - # ones while still keeping `event_retention_size` raw events visible. if ( latest_compaction_event and latest_compaction_event.actions @@ -231,10 +250,101 @@ async def _run_compaction_for_token_threshold( branch=latest_compaction_event.branch, invocation_id=Event.new_id(), ) - events_to_compact = [seed_event] + events_to_compact + return [seed_event] + events_to_compact - if not config.summarizer: - config.summarizer = LlmEventSummarizer(llm=app.root_agent.canonical_model) + return events_to_compact + + +def _event_function_call_ids(event: Event) -> set[str]: + """Returns function call ids found in an event.""" + function_call_ids: set[str] = set() + for function_call in event.get_function_calls(): + if function_call.id: + function_call_ids.add(function_call.id) + return function_call_ids + + +def _event_function_response_ids(event: Event) -> set[str]: + """Returns function response ids found in an event.""" + function_response_ids: set[str] = set() + for function_response in event.get_function_responses(): + if function_response.id: + function_response_ids.add(function_response.id) + return function_response_ids + + +def _safe_token_compaction_split_index( + *, + candidate_events: list[Event], + event_retention_size: int, +) -> int: + """Returns a split index that avoids orphaning retained tool responses. + + Retained events (tail of candidate events) may contain function responses. + If their matching function call events are in the compacted prefix, contents + assembly can fail. This method shifts the split earlier so matching function + call events are retained together with their responses. + + Iterates backwards through candidate_events once, maintaining a running set + of unmatched response IDs. The latest valid split point where no unmatched + responses remain is returned. + """ + initial_split = len(candidate_events) - event_retention_size + if initial_split <= 0: + return 0 + + unmatched_response_ids: set[str] = set() + best_split = 0 + + for i in range(len(candidate_events) - 1, -1, -1): + event = candidate_events[i] + unmatched_response_ids.update(_event_function_response_ids(event)) + call_ids = _event_function_call_ids(event) + unmatched_response_ids -= call_ids + + if not unmatched_response_ids and i <= initial_split: + best_split = i + break + + return best_split + + +async def _run_compaction_for_token_threshold_config( + *, + config: EventsCompactionConfig | None, + session: Session, + session_service: BaseSessionService, + agent: BaseAgent, + agent_name: str = '', + current_branch: str | None = None, +) -> bool: + """Runs token-threshold compaction for a provided compaction config.""" + if not _has_token_threshold_config(config): + return False + if config is None: + return False + + if config.token_threshold is None or config.event_retention_size is None: + return False + + prompt_token_count = _latest_prompt_token_count( + session.events, + current_branch=current_branch, + agent_name=agent_name, + ) + if prompt_token_count is None or prompt_token_count < config.token_threshold: + return False + + events_to_compact = _events_to_compact_for_token_threshold( + events=session.events, + event_retention_size=config.event_retention_size, + ) + if not events_to_compact: + return False + + _ensure_compaction_summarizer(config=config, agent=agent) + if config.summarizer is None: + return False compaction_event = await config.summarizer.maybe_summarize_events( events=events_to_compact @@ -246,8 +356,30 @@ async def _run_compaction_for_token_threshold( return False -async def _run_compaction_for_sliding_window( +async def _run_compaction_for_token_threshold( app: App, session: Session, session_service: BaseSessionService +): + """Runs post-invocation compaction based on a token threshold. + + If triggered, this compacts older raw events and keeps the last + `event_retention_size` raw events un-compacted. + """ + return await _run_compaction_for_token_threshold_config( + config=app.events_compaction_config, + session=session, + session_service=session_service, + agent=app.root_agent, + agent_name='', + current_branch=None, + ) + + +async def _run_compaction_for_sliding_window( + app: App, + session: Session, + session_service: BaseSessionService, + *, + skip_token_compaction: bool = False, ): """Runs compaction for SlidingWindowCompactor. @@ -327,22 +459,30 @@ async def _run_compaction_for_sliding_window( app: The application instance. session: The session containing events to compact. session_service: The session service for appending events. + skip_token_compaction: Whether to skip token-threshold compaction. """ events = session.events if not events: return None + config = app.events_compaction_config + if config is None: + return None + # Prefer token-threshold compaction if configured and triggered. - if ( - app.events_compaction_config - and app.events_compaction_config.token_threshold is not None - ): + if not skip_token_compaction and _has_token_threshold_config(config): token_compacted = await _run_compaction_for_token_threshold( app, session, session_service ) if token_compacted: return None + if not _has_sliding_window_config(config): + return None + + if config.compaction_interval is None or config.overlap_size is None: + return None + # Find the last compaction event and its range. last_compacted_end_timestamp = 0.0 for event in reversed(events): @@ -373,7 +513,7 @@ async def _run_compaction_for_sliding_window( if invocation_latest_timestamps[inv_id] > last_compacted_end_timestamp ] - if len(new_invocation_ids) < app.events_compaction_config.compaction_interval: + if len(new_invocation_ids) < config.compaction_interval: return None # Not enough new invocations to trigger compaction. # Determine the range of invocations to compact. @@ -385,9 +525,7 @@ async def _run_compaction_for_sliding_window( first_new_inv_id = new_invocation_ids[0] first_new_inv_idx = unique_invocation_ids.index(first_new_inv_id) - start_idx = max( - 0, first_new_inv_idx - app.events_compaction_config.overlap_size - ) + start_idx = max(0, first_new_inv_idx - config.overlap_size) start_inv_id = unique_invocation_ids[start_idx] # Find the index of the last event with end_inv_id. @@ -419,15 +557,12 @@ async def _run_compaction_for_sliding_window( if not events_to_compact: return None - if not app.events_compaction_config.summarizer: - app.events_compaction_config.summarizer = LlmEventSummarizer( - llm=app.root_agent.canonical_model - ) + _ensure_compaction_summarizer(config=config, agent=app.root_agent) + if config.summarizer is None: + return None - compaction_event = ( - await app.events_compaction_config.summarizer.maybe_summarize_events( - events=events_to_compact - ) + compaction_event = await config.summarizer.maybe_summarize_events( + events=events_to_compact ) if compaction_event: await session_service.append_event(session=session, event=compaction_event) diff --git a/src/google/adk/artifacts/base_artifact_service.py b/src/google/adk/artifacts/base_artifact_service.py index 1a265f8a..23f5e44f 100644 --- a/src/google/adk/artifacts/base_artifact_service.py +++ b/src/google/adk/artifacts/base_artifact_service.py @@ -16,8 +16,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from datetime import datetime +import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from pydantic import alias_generators @@ -25,6 +27,8 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +logger = logging.getLogger("google_adk." + __name__) + class ArtifactVersion(BaseModel): """Metadata describing a specific version of an artifact.""" @@ -60,6 +64,26 @@ class ArtifactVersion(BaseModel): ) +def ensure_part(artifact: Union[types.Part, dict[str, Any]]) -> types.Part: + """Normalizes an artifact to a ``types.Part`` instance. + + External callers may provide artifacts as + plain dictionaries with camelCase keys (``inlineData``) instead of properly + deserialized ``types.Part`` objects. ``model_validate`` handles both + camelCase and snake_case dictionaries transparently via Pydantic aliases. + + Args: + artifact: A ``types.Part`` instance or a dictionary representation. + + Returns: + A validated ``types.Part`` instance. + """ + if isinstance(artifact, dict): + logger.debug("Normalizing artifact dict to types.Part: %s", list(artifact)) + return types.Part.model_validate(artifact) + return artifact + + class BaseArtifactService(ABC): """Abstract base class for artifact services.""" @@ -70,7 +94,7 @@ class BaseArtifactService(ABC): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -84,10 +108,12 @@ class BaseArtifactService(ABC): app_name: The app name. user_id: The user ID. filename: The filename of the artifact. - artifact: The artifact to save. If the artifact consists of `file_data`, - the artifact service assumes its content has been uploaded separately, - and this method will associate the `file_data` with the artifact if - necessary. + artifact: The artifact to save. Accepts a ``types.Part`` instance or a + plain dictionary (camelCase or snake_case keys) which will be + normalized via ``ensure_part``. If the artifact consists of + ``file_data``, the artifact service assumes its content has been + uploaded separately, and this method will associate the ``file_data`` + with the artifact if necessary. session_id: The session ID. If `None`, the artifact is user-scoped. custom_metadata: custom metadata to associate with the artifact. diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index be5adb48..b0078e27 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -22,6 +22,7 @@ from pathlib import PureWindowsPath import shutil from typing import Any from typing import Optional +from typing import Union from urllib.parse import unquote from urllib.parse import urlparse @@ -35,6 +36,7 @@ from typing_extensions import override from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -314,7 +316,7 @@ class FileArtifactService(BaseArtifactService): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -339,11 +341,12 @@ class FileArtifactService(BaseArtifactService): self, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str], custom_metadata: Optional[dict[str, Any]], ) -> int: """Saves an artifact to disk and returns its version.""" + artifact = ensure_part(artifact) artifact_dir = self._artifact_dir( user_id=user_id, session_id=session_id, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index 4108cfb0..f8706ded 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -27,6 +27,7 @@ import asyncio import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from typing_extensions import override @@ -34,6 +35,7 @@ from typing_extensions import override from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -61,7 +63,7 @@ class GcsArtifactService(BaseArtifactService): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -198,9 +200,10 @@ class GcsArtifactService(BaseArtifactService): user_id: str, session_id: Optional[str], filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + artifact = ensure_part(artifact) versions = self._list_versions( app_name=app_name, user_id=user_id, diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 45552b14..48e7afca 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -17,6 +17,7 @@ import dataclasses import logging from typing import Any from typing import Optional +from typing import Union from google.genai import types from pydantic import BaseModel @@ -27,6 +28,7 @@ from . import artifact_util from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService +from .base_artifact_service import ensure_part logger = logging.getLogger("google_adk." + __name__) @@ -99,10 +101,11 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: Union[types.Part, dict[str, Any]], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + artifact = ensure_part(artifact) path = self._artifact_path(app_name, user_id, filename, session_id) if path not in self.artifacts: self.artifacts[path] = [] diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index e205d9be..6160edcc 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -25,6 +25,7 @@ from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator class BaseModelWithConfig(BaseModel): @@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccount(BaseModelWithConfig): - """Represents Google Service Account configuration.""" + """Represents Google Service Account configuration. + + Attributes: + service_account_credential: The service account credential (JSON key). + scopes: The OAuth2 scopes to request. Optional; when omitted with + ``use_default_credential=True``, defaults to the cloud-platform scope. + use_default_credential: Whether to use Application Default Credentials. + use_id_token: Whether to exchange for an ID token instead of an access + token. Required for service-to-service authentication with Cloud Run, + Cloud Functions, and other Google Cloud services that require identity + verification. When True, ``audience`` must also be set. + audience: The target audience for the ID token, typically the URL of the + receiving service (e.g. ``https://my-service-xyz.run.app``). Required + when ``use_id_token`` is True. + """ service_account_credential: Optional[ServiceAccountCredential] = None - scopes: List[str] + scopes: Optional[List[str]] = None use_default_credential: Optional[bool] = False + use_id_token: Optional[bool] = False + audience: Optional[str] = None + + @model_validator(mode="after") + def _validate_config(self) -> ServiceAccount: + if ( + not self.use_default_credential + and self.service_account_credential is None + ): + raise ValueError( + "service_account_credential is required when" + " use_default_credential is False." + ) + if self.use_id_token and not self.audience: + raise ValueError( + "audience is required when use_id_token is True. Set it to the" + " URL of the target service" + " (e.g. 'https://my-service.run.app')." + ) + return self class AuthCredentialTypes(str, Enum): diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 37ad6745..76dd2dda 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing_extensions import override @@ -25,6 +26,7 @@ from ..flows.llm_flows import functions from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ..models.llm_request import LlmRequest +from ..sessions.state import State from .auth_handler import AuthHandler from .auth_tool import AuthConfig from .auth_tool import AuthToolArguments @@ -35,6 +37,93 @@ from .auth_tool import AuthToolArguments TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' +async def _store_auth_and_collect_resume_targets( + events: list[Event], + auth_fc_ids: set[str], + auth_responses: dict[str, Any], + state: State, +) -> set[str]: + """Store auth credentials and return original function call IDs to resume. + + Scans session events for ``adk_request_credential`` function calls whose + IDs are in *auth_fc_ids*, extracts ``credential_key`` from their + ``AuthToolArguments`` args, merges ``credential_key`` into the + corresponding auth response, stores credentials via ``AuthHandler``, + and returns the set of original function call IDs that should be + re-executed (excluding toolset auth). + + Args: + events: Session events to scan. + auth_fc_ids: IDs of ``adk_request_credential`` function calls to match. + auth_responses: Mapping of FC ID -> auth config response dict from the + client. + state: Session state for temporary credential storage. + + Returns: + Set of original function call IDs to resume. + """ + # Step 1: Scan events for matching adk_request_credential function calls + # to extract AuthToolArguments (contains credential_key). + requested_auth_config_by_id: dict[str, AuthConfig] = {} + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + try: + for function_call in event_function_calls: + if ( + function_call.id in auth_fc_ids + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + requested_auth_config_by_id[function_call.id] = args.auth_config + except TypeError: + continue + + # Step 2: Store credentials. Merge credential_key from the original + # request into the client's auth response before storing. + for fc_id in auth_fc_ids: + if fc_id not in auth_responses: + continue + auth_config = AuthConfig.model_validate(auth_responses[fc_id]) + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if ( + requested_auth_config + and requested_auth_config.credential_key is not None + ): + auth_config.credential_key = requested_auth_config.credential_key + await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( + state=state + ) + + # Step 3: Collect original function call IDs to resume, skipping + # toolset auth entries which don't map to a resumable function call. + tools_to_resume: set[str] = set() + for fc_id in auth_fc_ids: + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if not requested_auth_config: + continue + # Re-parse to get function_call_id (AuthConfig doesn't carry it; + # AuthToolArguments does). + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + for function_call in event_function_calls: + if ( + function_call.id == fc_id + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + if args.function_call_id.startswith( + TOOLSET_AUTH_CREDENTIAL_ID_PREFIX + ): + continue + tools_to_resume.add(args.function_call_id) + + return tools_to_resume + + class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): """Handles auth information to build the LLM request.""" @@ -49,8 +138,8 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_euc_function_call_ids = set() - # find the last event with non-None content + # Find the last user-authored event with function responses to + # identify adk_request_credential responses. last_event_with_content = None for i in range(len(events) - 1, -1, -1): event = events[i] @@ -58,7 +147,6 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): last_event_with_content = event break - # check if the last event with content is authored by user if not last_event_with_content or last_event_with_content.author != 'user': return @@ -66,104 +154,55 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not responses: return - requested_auth_config_by_request_id = {} - # look for auth response + # Collect adk_request_credential function response IDs and their + # response dicts. + auth_fc_ids: set[str] = set() + auth_responses: dict[str, Any] = {} for function_call_response in responses: if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: continue - # found the function call response for the system long running request euc - # function call - request_euc_function_call_ids.add(function_call_response.id) - - if request_euc_function_call_ids: - for event in events: - function_calls = event.get_function_calls() - if not function_calls: - continue - try: - for function_call in function_calls: - if ( - function_call.id in request_euc_function_call_ids - and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): - args = AuthToolArguments.model_validate(function_call.args) - requested_auth_config_by_request_id[function_call.id] = ( - args.auth_config - ) - except TypeError: - continue - - for function_call_response in responses: - if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: - continue - - auth_config = AuthConfig.model_validate(function_call_response.response) - requested_auth_config = requested_auth_config_by_request_id.get( - function_call_response.id - ) - if ( - requested_auth_config - and requested_auth_config.credential_key is not None - ): - auth_config.credential_key = requested_auth_config.credential_key - await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state + auth_fc_ids.add(function_call_response.id) + auth_responses[function_call_response.id] = ( + function_call_response.response ) - if not request_euc_function_call_ids: + if not auth_fc_ids: return + # Store credentials and collect tools to resume. + tools_to_resume = await _store_auth_and_collect_resume_targets( + events, auth_fc_ids, auth_responses, invocation_context.session.state + ) + + if not tools_to_resume: + return + + # Find the original function call event and re-execute the tools + # that needed auth. for i in range(len(events) - 2, -1, -1): event = events[i] - # looking for the system long running request euc function call function_calls = event.get_function_calls() if not function_calls: continue - tools_to_resume = set() - - for function_call in function_calls: - if function_call.id not in request_euc_function_call_ids: - continue - args = AuthToolArguments.model_validate(function_call.args) - - # Skip toolset auth - auth response is already stored in session state - # and we don't need to resume a function call for toolsets - if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX): - continue - - tools_to_resume.add(args.function_call_id) - if not tools_to_resume: - continue - - # found the system long running request euc function call - # looking for original function call that requests euc - for j in range(i - 1, -1, -1): - event = events[j] - function_calls = event.get_function_calls() - if not function_calls: - continue - - if any([ - function_call.id in tools_to_resume - for function_call in function_calls - ]): - if function_response_event := await functions.handle_function_calls_async( - invocation_context, - event, - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # there could be parallel function calls that require auth - # auth response would be a dict keyed by function call id - tools_to_resume, - ): - yield function_response_event - return - return + if any([ + function_call.id in tools_to_resume + for function_call in function_calls + ]): + if function_response_event := await functions.handle_function_calls_async( + invocation_context, + event, + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume, + ): + yield function_response_event + return + return request_processor = _AuthLlmRequestProcessor() diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index c61f855f..afedb738 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -68,6 +68,7 @@ from ..auth.credential_service.base_credential_service import BaseCredentialServ from ..errors.already_exists_error import AlreadyExistsError from ..errors.input_validation_error import InputValidationError from ..errors.not_found_error import NotFoundError +from ..errors.session_not_found_error import SessionNotFoundError from ..evaluation.base_eval_service import InferenceConfig from ..evaluation.base_eval_service import InferenceRequest from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE @@ -207,6 +208,8 @@ class RunAgentRequest(common.BaseModel): new_message: Optional[types.Content] = None streaming: bool = False state_delta: Optional[dict[str, Any]] = None + # for long-running function resume requests (e.g., OAuth callback) + function_call_event_id: Optional[str] = None # for resume long-running functions invocation_id: Optional[str] = None @@ -587,7 +590,8 @@ class AdkWebServer: """Import a plugin object (class or instance) from a fully qualified name. Args: - qualified_name: Fully qualified name (e.g., 'my_package.my_plugin.MyPlugin') + qualified_name: Fully qualified name (e.g., + 'my_package.my_plugin.MyPlugin') Returns: The imported object, which can be either a class or an instance. @@ -687,6 +691,7 @@ class AdkWebServer: ] = lambda o, s: None, register_processors: Callable[[TracerProvider], None] = lambda o: None, otel_to_cloud: bool = False, + with_ui: bool = False, ): """Creates a FastAPI app for the ADK web server. @@ -699,7 +704,8 @@ class AdkWebServer: lifespan: The lifespan of the FastAPI app. allow_origins: The origins that are allowed to make cross-origin requests. Entries can be literal origins (e.g., 'https://example.com') or regex - patterns prefixed with 'regex:' (e.g., 'regex:https://.*\\.example\\.com'). + patterns prefixed with 'regex:' (e.g., + 'regex:https://.*\\.example\\.com'). web_assets_dir: The directory containing the web assets to serve. setup_observer: Callback for setting up the file system observer. tear_down_observer: Callback for cleaning up the file system observer. @@ -794,10 +800,93 @@ class AdkWebServer: raise HTTPException(status_code=404, detail="Trace not found") return event_dict - @app.get("/apps/{app_name}") - async def get_app_info(app_name: str) -> Any: - runner = await self.get_runner_async(app_name) - return runner.app + if web_assets_dir: + + @app.get("/dev/build_graph/{app_name}") + async def get_app_info(app_name: str) -> Any: + runner = await self.get_runner_async(app_name) + + if not runner.app: + raise HTTPException( + status_code=404, detail=f"App not found: {app_name}" + ) + + def serialize_agent(agent: BaseAgent) -> dict[str, Any]: + """Recursively serialize an agent, excluding non-serializable fields.""" + agent_dict = {} + + for field_name, field_info in agent.__class__.model_fields.items(): + # Skip non-serializable fields + if field_name in [ + "parent_agent", + "before_agent_callback", + "after_agent_callback", + "before_model_callback", + "after_model_callback", + "on_model_error_callback", + "before_tool_callback", + "after_tool_callback", + "on_tool_error_callback", + ]: + continue + + value = getattr(agent, field_name, None) + + # Handle sub_agents recursively + if field_name == "sub_agents" and value: + agent_dict[field_name] = [ + serialize_agent(sub_agent) for sub_agent in value + ] + elif value is None or field_name == "tools": + continue + else: + try: + if isinstance(value, (str, int, float, bool, list, dict)): + agent_dict[field_name] = value + elif hasattr(value, "model_dump"): + agent_dict[field_name] = value.model_dump( + mode="python", exclude_none=True + ) + else: + agent_dict[field_name] = str(value) + except Exception: + pass + + return agent_dict + + app_info = { + "name": runner.app.name, + "root_agent": serialize_agent(runner.app.root_agent), + } + + # Add optional fields if present + if runner.app.plugins: + app_info["plugins"] = [ + {"name": getattr(plugin, "name", type(plugin).__name__)} + for plugin in runner.app.plugins + ] + + if runner.app.context_cache_config: + try: + app_info["context_cache_config"] = ( + runner.app.context_cache_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + if runner.app.resumability_config: + try: + app_info["resumability_config"] = ( + runner.app.resumability_config.model_dump( + mode="python", exclude_none=True + ) + ) + except Exception: + pass + + return app_info @app.get("/debug/trace/session/{session_id}", tags=[TAG_DEBUG]) async def get_session_trace(session_id: str) -> Any: @@ -1533,7 +1622,8 @@ class AdkWebServer: update_memory_request: The memory request for the update Raises: - HTTPException: If the memory service is not configured or the request is invalid. + HTTPException: If the memory service is not configured or the request + is invalid. """ if not self.memory_service: raise HTTPException( @@ -1558,52 +1648,60 @@ class AdkWebServer: @app.post("/run", response_model_exclude_none=True) async def run_agent(req: RunAgentRequest) -> list[Event]: - session = await self.session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") runner = await self.get_runner_async(req.app_name) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - invocation_id=req.invocation_id, - ) - ) as agen: - events = [event async for event in agen] + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + invocation_id=req.invocation_id, + ) + ) as agen: + events = [event async for event in agen] + except SessionNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e logger.info("Generated %s events in agent run", len(events)) logger.debug("Events generated: %s", events) return events @app.post("/run_sse") async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: - # SSE endpoint - session = await self.session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") + stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE + runner = await self.get_runner_async(req.app_name) + + # Validate session existence before starting the stream. + # We check directly here instead of eagerly advancing the + # runner's async generator with anext(), because splitting + # generator consumption across two asyncio Tasks (request + # handler vs StreamingResponse) breaks OpenTelemetry context + # detachment. + if not runner.auto_create_session: + session = await self.session_service.get_session( + app_name=req.app_name, + user_id=req.user_id, + session_id=req.session_id, + ) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session not found: {req.session_id}", + ) # Convert the events to properly formatted SSE async def event_generator(): - try: - stream_mode = ( - StreamingMode.SSE if req.streaming else StreamingMode.NONE - ) - runner = await self.get_runner_async(req.app_name) - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - invocation_id=req.invocation_id, - ) - ) as agen: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + invocation_id=req.invocation_id, + ) + ) as agen: + try: async for event in agen: # ADK Web renders artifacts from `actions.artifactDelta` # during part processing *and* during action processing @@ -1611,7 +1709,8 @@ class AdkWebServer: # 2) a content-less "action-only" event carrying `artifactDelta` events_to_stream = [event] if ( - event.actions.artifact_delta + not req.function_call_event_id + and event.actions.artifact_delta and event.content and event.content.parts ): @@ -1630,9 +1729,9 @@ class AdkWebServer: "Generated event in agent run streaming: %s", sse_event ) yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - yield f"data: {json.dumps({'error': str(e)})}\n\n" + except Exception as e: + logger.exception("Error in event_generator: %s", e) + yield f"data: {json.dumps({'error': str(e)})}\n\n" # Returns a streaming response with the proper media type for SSE return StreamingResponse( diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 16eba88b..1d49f50d 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -28,6 +28,7 @@ from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from ..memory.base_memory_service import BaseMemoryService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session @@ -37,6 +38,7 @@ from .service_registry import load_services_module from .utils import envs from .utils.agent_loader import AgentLoader from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_memory_service_from_options from .utils.service_factory import create_session_service_from_options @@ -53,6 +55,7 @@ async def run_input_file( session_service: BaseSessionService, credential_service: BaseCredentialService, input_path: str, + memory_service: Optional[BaseMemoryService] = None, ) -> Session: app = ( agent_or_app @@ -63,6 +66,7 @@ async def run_input_file( app=app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, ) with open(input_path, 'r', encoding='utf-8') as f: @@ -93,6 +97,7 @@ async def run_interactively( session: Session, session_service: BaseSessionService, credential_service: BaseCredentialService, + memory_service: Optional[BaseMemoryService] = None, ) -> None: app = ( root_agent_or_app @@ -103,6 +108,7 @@ async def run_interactively( app=app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, ) while True: @@ -137,6 +143,7 @@ async def run_cli( session_id: Optional[str] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, use_local_storage: bool = True, ) -> None: """Runs an interactive CLI for a certain agent. @@ -154,6 +161,7 @@ async def run_cli( session_id: Optional[str], the session ID to save the session to on exit. session_service_uri: Optional[str], custom session service URI. artifact_service_uri: Optional[str], custom artifact service URI. + memory_service_uri: Optional[str], custom memory service URI. use_local_storage: bool, whether to use local .adk storage by default. """ agent_parent_path = Path(agent_parent_dir).resolve() @@ -171,6 +179,9 @@ async def run_cli( if isinstance(agent_or_app, App) and agent_or_app.name != agent_folder_name: app_name_to_dir = {agent_or_app.name: agent_folder_name} + if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): + envs.load_dotenv_for_agent(agent_folder_name, agents_dir) + # Create session and artifact services using factory functions. # Sessions persist under //.adk/session.db when enabled. session_service = create_session_service_from_options( @@ -185,10 +196,12 @@ async def run_cli( artifact_service_uri=artifact_service_uri, use_local_storage=use_local_storage, ) + memory_service = create_memory_service_from_options( + base_dir=agent_parent_path, + memory_service_uri=memory_service_uri, + ) credential_service = InMemoryCredentialService() - if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): - envs.load_dotenv_for_agent(agent_folder_name, agents_dir) # Helper function for printing events def _print_event(event) -> None: @@ -208,6 +221,7 @@ async def run_cli( agent_or_app=agent_or_app, artifact_service=artifact_service, session_service=session_service, + memory_service=memory_service, credential_service=credential_service, input_path=input_file, ) @@ -235,6 +249,7 @@ async def run_cli( session, session_service, credential_service, + memory_service=memory_service, ) else: session = await session_service.create_session( @@ -247,6 +262,7 @@ async def run_cli( session, session_service, credential_service, + memory_service=memory_service, ) if save_session: diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 5b5d3e5c..b817d4b4 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -645,14 +645,6 @@ def cli_run( """ logs.log_to_tmp_folder() - # Validation warning for memory_service_uri (not supported for adk run) - if memory_service_uri: - click.secho( - "WARNING: --memory_service_uri is not supported for adk run.", - fg="yellow", - err=True, - ) - agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) @@ -666,6 +658,7 @@ def cli_run( session_id=session_id, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, + memory_service_uri=memory_service_uri, use_local_storage=use_local_storage, ) ) @@ -1974,9 +1967,11 @@ def cli_deploy_agent_engine( Example: + \b # With Express Mode API Key adk deploy agent_engine --api_key=[api_key] my_agent + \b # With Google Cloud Project and Region adk deploy agent_engine --project=[project] --region=[region] --display_name=[app_name] my_agent diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 553629f2..8f78c15f 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -525,6 +525,7 @@ def get_fast_api_app( if a2a: from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -563,8 +564,12 @@ def get_fast_api_app( runner=create_a2a_runner_loader(app_name), ) + push_config_store = InMemoryPushNotificationConfigStore() + request_handler = DefaultRequestHandler( - agent_executor=agent_executor, task_store=a2a_task_store + agent_executor=agent_executor, + task_store=a2a_task_store, + push_config_store=push_config_store, ) with (p / "agent.json").open("r", encoding="utf-8") as f: diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 2ea286ef..b1328958 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -301,6 +301,11 @@ def _register_builtin_services(registry: ServiceRegistry) -> None: registry.register_artifact_service("file", file_artifact_factory) # -- Memory Services -- + def memory_memory_factory(_uri: str, **_): + from ..memory.in_memory_memory_service import InMemoryMemoryService + + return InMemoryMemoryService() + def rag_memory_factory(uri: str, **kwargs): from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService @@ -324,6 +329,7 @@ def _register_builtin_services(registry: ServiceRegistry) -> None: ) return VertexAiMemoryBankService(**params) + registry.register_memory_service("memory", memory_memory_factory) registry.register_memory_service("rag", rag_memory_factory) registry.register_memory_service("agentengine", agentengine_memory_factory) diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index 8b5805c5..efd24648 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -335,13 +335,18 @@ class AgentLoader(BaseAgentLoader): def list_agents(self) -> list[str]: """Lists all agents available in the agent loader (sorted alphabetically).""" base_path = Path.cwd() / self.agents_dir - agent_names = [ - x - for x in os.listdir(base_path) - if os.path.isdir(os.path.join(base_path, x)) - and not x.startswith(".") - and x != "__pycache__" - ] + agent_names = [] + for x in os.listdir(base_path): + if ( + os.path.isdir(os.path.join(base_path, x)) + and not x.startswith(".") + and x != "__pycache__" + ): + try: + self._determine_agent_language(x) + agent_names.append(x) + except ValueError: + continue agent_names.sort() return agent_names diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index f601d045..071d59dc 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -38,10 +38,15 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): sandbox_resource_name: If set, load the existing resource name of the code interpreter extension instead of creating a new one. Format: projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789 + agent_engine_resource_name: The resource name of the agent engine to use + to create the code execution sandbox. Format: + projects/123/locations/us-central1/reasoningEngines/456 """ sandbox_resource_name: str = None + agent_engine_resource_name: str = None + def __init__( self, sandbox_resource_name: Optional[str] = None, @@ -67,30 +72,19 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): agent_engine_resource_name_pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' if sandbox_resource_name is not None: - self.sandbox_resource_name = sandbox_resource_name self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( sandbox_resource_name, sandbox_resource_name_pattern ) ) + self.sandbox_resource_name = sandbox_resource_name elif agent_engine_resource_name is not None: - from vertexai import types - self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( agent_engine_resource_name, agent_engine_resource_name_pattern ) ) - # @TODO - Add TTL for sandbox creation after it is available - # in SDK. - operation = self._get_api_client().agent_engines.sandboxes.create( - spec={'code_execution_environment': {}}, - name=agent_engine_resource_name, - config=types.CreateAgentEngineSandboxConfig( - display_name='default_sandbox' - ), - ) - self.sandbox_resource_name = operation.response.name + self.agent_engine_resource_name = agent_engine_resource_name else: raise ValueError( 'Either sandbox_resource_name or agent_engine_resource_name must be' @@ -103,6 +97,45 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: + # default to the sandbox resource name if set. + sandbox_name = self.sandbox_resource_name + if self.sandbox_resource_name is None: + from google.api_core import exceptions + from vertexai import types + + # use sandbox name stored in session if available. + sandbox_name = invocation_context.session.state.get('sandbox_name', None) + create_new_sandbox = False + if sandbox_name is None: + create_new_sandbox = True + else: + # Check if the sandbox is still running OR already expired due to ttl. + try: + sandbox = self._get_api_client().agent_engines.sandboxes.get( + name=sandbox_name + ) + if sandbox is None or sandbox.state != 'STATE_RUNNING': + create_new_sandbox = True + except exceptions.NotFound: + create_new_sandbox = True + + if create_new_sandbox: + # Create a new sandbox and assign it to sandbox_name. + operation = self._get_api_client().agent_engines.sandboxes.create( + spec={'code_execution_environment': {}}, + name=self.agent_engine_resource_name, + config=types.CreateAgentEngineSandboxConfig( + # VertexAiSessionService has a default TTL of 1 year, so we set + # the sandbox TTL to 1 year as well. For the current code + # execution sandbox, if it hasn't been used for 14 days, the + # state will be lost. + display_name='default_sandbox', + ttl='31536000s', + ), + ) + sandbox_name = operation.response.name + invocation_context.session.state['sandbox_name'] = sandbox_name + # Execute the code. input_data = { 'code': code_execution_input.code, @@ -119,7 +152,7 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): code_execution_response = ( self._get_api_client().agent_engines.sandboxes.execute_code( - name=self.sandbox_resource_name, + name=sandbox_name, input_data=input_data, ) ) @@ -134,8 +167,8 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): or 'file_name' not in output.metadata.attributes ): json_output_data = json.loads(output.data.decode('utf-8')) - stdout = json_output_data.get('stdout', '') - stderr = json_output_data.get('stderr', '') + stdout = json_output_data.get('msg_out', '') + stderr = json_output_data.get('msg_err', '') else: file_name = '' if ( diff --git a/src/google/adk/code_executors/built_in_code_executor.py b/src/google/adk/code_executors/built_in_code_executor.py index 50a0b9f4..a4e32034 100644 --- a/src/google/adk/code_executors/built_in_code_executor.py +++ b/src/google/adk/code_executors/built_in_code_executor.py @@ -20,6 +20,7 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext from ..models import LlmRequest from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -42,7 +43,8 @@ class BuiltInCodeExecutor(BaseCodeExecutor): def process_llm_request(self, llm_request: LlmRequest) -> None: """Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool.""" - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( diff --git a/src/google/adk/code_executors/gke_code_executor.py b/src/google/adk/code_executors/gke_code_executor.py index 1dc46878..b44aa193 100644 --- a/src/google/adk/code_executors/gke_code_executor.py +++ b/src/google/adk/code_executors/gke_code_executor.py @@ -19,12 +19,24 @@ import uuid import kubernetes as k8s from kubernetes.watch import Watch +from pydantic import field_validator +from typing_extensions import Literal +from typing_extensions import override +from typing_extensions import TYPE_CHECKING from ..agents.invocation_context import InvocationContext from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult +try: + from agentic_sandbox import SandboxClient +except ImportError: + SandboxClient = None + +if TYPE_CHECKING: + from agentic_sandbox import SandboxClient + # Expose these for tests to monkeypatch. client = k8s.client config = k8s.config @@ -36,9 +48,19 @@ logger = logging.getLogger("google_adk." + __name__) class GkeCodeExecutor(BaseCodeExecutor): """Executes Python code in a secure gVisor-sandboxed Pod on GKE. - This executor securely runs code by dynamically creating a Kubernetes Job for - each execution request. The user's code is mounted via a ConfigMap, and the - Pod is hardened with a strict security context and resource limits. + This executor supports two modes of execution: 'job' and 'sandbox'. + + Job Mode (default): + Securely runs code by dynamically creating a Kubernetes Job for each execution + request. The user's code is mounted via a ConfigMap, and the Pod is hardened + with a strict security context and resource limits. + + Sandbox Mode: + Executes code using the Agent Sandbox Client. This mode requires additional + infrastructure to be deployed in the cluster, specifically: + - Agent-sandbox controller + - Sandbox templates (e.g., python-sandbox-template) + - Sandbox router and gateway Key Features: - Sandboxed execution using the gVisor runtime. @@ -70,6 +92,7 @@ class GkeCodeExecutor(BaseCodeExecutor): namespace: str = "default" image: str = "python:3.11-slim" timeout_seconds: int = 300 + executor_type: Literal["job", "sandbox"] = "job" cpu_requested: str = "200m" mem_requested: str = "256Mi" # The maximum CPU the container can use, in "millicores". 1000m is 1 full CPU core. @@ -79,6 +102,10 @@ class GkeCodeExecutor(BaseCodeExecutor): kubeconfig_path: str | None = None kubeconfig_context: str | None = None + # Sandbox constants + sandbox_gateway_name: str | None = None + sandbox_template: str | None = "python-sandbox-template" + _batch_v1: k8s.client.BatchV1Api _core_v1: k8s.client.CoreV1Api @@ -136,10 +163,46 @@ class GkeCodeExecutor(BaseCodeExecutor): self._batch_v1 = client.BatchV1Api() self._core_v1 = client.CoreV1Api() - def execute_code( - self, - invocation_context: InvocationContext, - code_execution_input: CodeExecutionInput, + @field_validator("executor_type") + @classmethod + def _check_sandbox_dependency(cls, v: str) -> str: + if v == "sandbox" and SandboxClient is None: + raise ImportError( + "k8s-agent-sandbox not found. To use Agent Sandbox, please install" + " google-adk with the extensions extra: pip install" + " google-adk[extensions]" + ) + return v + + def _execute_in_sandbox(self, code: str) -> CodeExecutionResult: + """Executes code using Agent Sandbox Client.""" + try: + with SandboxClient( + template_name=self.sandbox_template, + gateway_name=self.sandbox_gateway_name, + namespace=self.namespace, + ) as sandbox: + # Execute the code as a python script + sandbox.write("script.py", code) + result = sandbox.run("python3 script.py") + + return CodeExecutionResult(stdout=result.stdout, stderr=result.stderr) + except RuntimeError as e: + logger.error( + "SandboxClient failed to initialize or find gateway", exc_info=True + ) + raise RuntimeError(f"Sandbox infrastructure error: {e}") from e + except TimeoutError as e: + logger.error("Sandbox timed out", exc_info=True) + # Returning a result instead of raising allows the Agent to process + # the error gracefully. + return CodeExecutionResult(stderr=f"Sandbox timed out: {e}") + except Exception as e: + logger.error("Sandbox execution failed: %s", e, exc_info=True) + raise + + def _execute_as_job( + self, code: str, invocation_context: InvocationContext ) -> CodeExecutionResult: """Orchestrates the secure execution of a code snippet on GKE.""" job_name = f"adk-exec-{uuid.uuid4().hex[:10]}" @@ -150,7 +213,7 @@ class GkeCodeExecutor(BaseCodeExecutor): # 1. Create a ConfigMap to mount LLM-generated code into the Pod. # 2. Create a Job that runs the code from the ConfigMap. # 3. Set the Job as the ConfigMap's owner for automatic cleanup. - self._create_code_configmap(configmap_name, code_execution_input.code) + self._create_code_configmap(configmap_name, code) job_manifest = self._create_job_manifest( job_name, configmap_name, invocation_context ) @@ -162,7 +225,6 @@ class GkeCodeExecutor(BaseCodeExecutor): logger.info( f"Submitted Job '{job_name}' to namespace '{self.namespace}'." ) - logger.debug("Executing code:\n```\n%s\n```", code_execution_input.code) return self._watch_job_completion(job_name) except ApiException as e: @@ -186,6 +248,20 @@ class GkeCodeExecutor(BaseCodeExecutor): stderr=f"An unexpected executor error occurred: {e}" ) + @override + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Overrides the base method to route execution based on executor_type.""" + code = code_execution_input.code + if self.executor_type == "sandbox": + return self._execute_in_sandbox(code) + else: + # Fallback to existing GKE Job logic + return self._execute_as_job(code, invocation_context) + def _create_job_manifest( self, job_name: str, diff --git a/src/google/adk/errors/session_not_found_error.py b/src/google/adk/errors/session_not_found_error.py new file mode 100644 index 00000000..4fc3258e --- /dev/null +++ b/src/google/adk/errors/session_not_found_error.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +class SessionNotFoundError(ValueError): + """Raised when a session cannot be found. + + Inherits from ValueError (for backward compatibility). + """ + + def __init__(self, message="Session not found."): + super().__init__(message) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 1d9662bd..725bddc1 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -280,7 +280,7 @@ class EvaluationGenerator: invocations = [] for invocation_id, events in events_by_invocation_id.items(): final_response = None - user_content = "" + user_content = Content(parts=[]) invocation_timestamp = 0 app_details = None if ( diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 81089162..9b633c2d 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -26,6 +26,8 @@ from ..utils.env_utils import is_env_enabled class FeatureName(str, Enum): """Feature names.""" + AGENT_CONFIG = "AGENT_CONFIG" + AGENT_STATE = "AGENT_STATE" AUTHENTICATED_FUNCTION_TOOL = "AUTHENTICATED_FUNCTION_TOOL" BASE_AUTHENTICATED_TOOL = "BASE_AUTHENTICATED_TOOL" BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET" @@ -79,6 +81,12 @@ class FeatureConfig: # Central registry: FeatureName -> FeatureConfig _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = { + FeatureName.AGENT_CONFIG: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.AGENT_STATE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.AUTHENTICATED_FUNCTION_TOOL: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/flows/llm_flows/_output_schema_processor.py b/src/google/adk/flows/llm_flows/_output_schema_processor.py index 36fa8d56..284cc213 100644 --- a/src/google/adk/flows/llm_flows/_output_schema_processor.py +++ b/src/google/adk/flows/llm_flows/_output_schema_processor.py @@ -110,8 +110,12 @@ def get_structured_model_response(function_response_event: Event) -> str | None: for func_response in function_response_event.get_function_responses(): if func_response.name == 'set_model_response': - # Convert dict to JSON string - return json.dumps(func_response.response, ensure_ascii=False) + # Extract the actual result from the wrapped response. + # Tool results are wrapped as {'result': ...} when not already a dict. + response = func_response.response + if isinstance(response, dict) and 'result' in response: + response = response['result'] + return json.dumps(response, ensure_ascii=False) return None diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 424bb580..5368ca93 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -368,11 +368,17 @@ async def _run_and_handle_error( try: async with Aclosing(response_generator) as agen: - with tracing.use_generate_content_span( - llm_request, invocation_context, model_response_event - ) as span: + async with tracing.use_inference_span( + llm_request, + invocation_context, + model_response_event, + ) as gc_span: async for llm_response in agen: - tracing.trace_generate_content_result(span, llm_response) + if gc_span: + tracing.trace_inference_result( + gc_span, + llm_response, + ) yield llm_response except Exception as model_error: callback_context = CallbackContext( diff --git a/src/google/adk/flows/llm_flows/compaction.py b/src/google/adk/flows/llm_flows/compaction.py new file mode 100644 index 00000000..f4b60ba9 --- /dev/null +++ b/src/google/adk/flows/llm_flows/compaction.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Request processor that runs token-threshold event compaction.""" + +from __future__ import annotations + +from typing import AsyncGenerator +from typing import TYPE_CHECKING + +from ...apps.compaction import _has_token_threshold_config +from ...apps.compaction import _run_compaction_for_token_threshold_config +from ...events.event import Event +from ._base_llm_processor import BaseLlmRequestProcessor + +if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext + from ...models.llm_request import LlmRequest + + +class CompactionRequestProcessor(BaseLlmRequestProcessor): + """Compacts session events before contents are prepared for model calls.""" + + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + del llm_request + config = invocation_context.events_compaction_config + if not _has_token_threshold_config(config): + return + yield # Required for AsyncGenerator. + + token_compacted = await _run_compaction_for_token_threshold_config( + config=config, + session=invocation_context.session, + session_service=invocation_context.session_service, + agent=invocation_context.agent, + agent_name=invocation_context.agent.name, + current_branch=invocation_context.branch, + ) + if token_compacted: + invocation_context.token_compaction_checked = True + return + yield # Required for AsyncGenerator. + + +request_processor = CompactionRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 6f34e8fe..24057c37 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -17,6 +17,8 @@ from __future__ import annotations import asyncio +import base64 +import binascii from concurrent.futures import ThreadPoolExecutor import copy import functools @@ -31,10 +33,10 @@ from typing import Optional from typing import TYPE_CHECKING import uuid +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.genai import types from ...agents.active_streaming_tool import ActiveStreamingTool -from ...agents.invocation_context import InvocationContext from ...agents.live_request_queue import LiveRequestQueue from ...auth.auth_tool import AuthConfig from ...auth.auth_tool import AuthToolArguments @@ -49,6 +51,7 @@ from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing if TYPE_CHECKING: + from ...agents.invocation_context import InvocationContext from ...agents.llm_agent import LlmAgent AF_FUNCTION_CALL_ID_PREFIX = 'adk-' @@ -147,8 +150,8 @@ async def _call_tool_in_thread_pool( args_to_call = tool._preprocess_args(args) signature = inspect.signature(tool.func) valid_params = {param for param in signature.parameters} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if tool._context_param_name in valid_params: + args_to_call[tool._context_param_name] = tool_context args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } @@ -660,14 +663,65 @@ async def _execute_single_function_call_live( streaming_lock: asyncio.Lock, ) -> Optional[Event]: """Execute a single function call for live mode with thread safety.""" - tool, tool_context = _get_tool_and_context( - invocation_context, function_call, tools_dict - ) + async def _run_on_tool_error_callbacks( + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ) -> Optional[dict[str, Any]]: + """Runs the on_tool_error_callbacks for the given tool.""" + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=tool_args, + tool_context=tool_context, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_tool_error_callbacks: + error_response = callback( + tool=tool, + args=tool_args, + tool_context=tool_context, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. function_args = ( copy.deepcopy(function_call.args) if function_call.args else {} ) + tool_context = _create_tool_context(invocation_context, function_call) + + try: + tool = _get_tool(function_call, tools_dict) + except ValueError as tool_error: + tool = BaseTool(name=function_call.name, description='Tool not found') + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + return __build_response_event( + tool, error_response, tool_context, invocation_context + ) + raise tool_error + async def _run_with_trace(): nonlocal function_args @@ -676,41 +730,77 @@ async def _execute_single_function_call_live( # Make a deep copy to avoid being modified. function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) + except Exception as tool_error: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) + + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. if altered_function_response is not None: function_response = altered_function_response @@ -940,6 +1030,50 @@ def _get_tool_and_context( return (tool, tool_context) +def _try_decode_computer_use_image( + tool: BaseTool, + function_result: dict[str, object], +) -> Optional[list[types.FunctionResponsePart]]: + """Decodes the image from the function result for a computer use tool. + + Args: + tool: The tool that produced the function result. + function_result: The dictionary containing the function's result. This + dictionary may be modified in-place to remove the 'image' key if an image + is successfully decoded. + + Returns: + A list containing a `types.FunctionResponsePart` with the decoded image + data, or None if no image was found or decoding failed. + """ + + if not isinstance(tool, ComputerUseTool) or not isinstance( + function_result, dict + ): + return None + + if ( + 'image' not in function_result + or 'data' not in function_result['image'] + or 'mimetype' not in function_result['image'] + ): + return None + + try: + image_data = base64.b64decode(function_result['image']['data']) + mime_type = function_result['image']['mimetype'] + + part = types.FunctionResponsePart.from_bytes( + data=image_data, mime_type=mime_type + ) + + del function_result['image'] + return [part] + except (binascii.Error, ValueError): + logger.exception('Failed to decode image from computer use tool') + return None + + async def __call_tool_live( tool: BaseTool, args: dict[str, object], @@ -977,8 +1111,16 @@ def __build_response_event( if not isinstance(function_result, dict): function_result = {'result': function_result} + function_response_parts = None + if isinstance(tool, ComputerUseTool): + function_response_parts = _try_decode_computer_use_image( + tool, function_result + ) + part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result + name=tool.name, + response=function_result, + parts=function_response_parts, ) part_function_response.function_response.id = tool_context.function_call_id @@ -1051,6 +1193,18 @@ def merge_parallel_function_response_events( return merged_event +def find_event_by_function_call_id( + events: list[Event], + function_call_id: str, +) -> Optional[Event]: + """Finds the function call event that matches the function call id.""" + for event in reversed(events): + for function_call in event.get_function_calls(): + if function_call.id == function_call_id: + return event + return None + + def find_matching_function_call( events: list[Event], ) -> Optional[Event]: @@ -1059,25 +1213,8 @@ def find_matching_function_call( return None last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): + function_responses = last_event.get_function_responses() + if not function_responses: + return None - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long-running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None + return find_event_by_function_call_id(events[:-1], function_responses[0].id) diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index f7b7f7f6..d066db79 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -15,6 +15,7 @@ from __future__ import annotations import json import logging +from typing import Any from typing import AsyncGenerator from typing import TYPE_CHECKING @@ -37,6 +38,65 @@ if TYPE_CHECKING: logger = logging.getLogger('google_adk.' + __name__) +def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation: + """Parse ToolConfirmation from a function response dict. + + Handles both the direct dict format and the ADK client's + ``{'response': json_string}`` wrapper format. + + """ + if response and len(response.values()) == 1 and 'response' in response.keys(): + return ToolConfirmation.model_validate(json.loads(response['response'])) + return ToolConfirmation.model_validate(response) + + +def _resolve_confirmation_targets( + events: list[Event], + confirmation_fc_ids: set[str], + confirmations_by_fc_id: dict[str, ToolConfirmation], +) -> tuple[dict[str, ToolConfirmation], dict[str, types.FunctionCall]]: + """Find original function calls for confirmed tools. + + Scans events for ``adk_request_confirmation`` function calls whose IDs + are in *confirmation_fc_ids*, extracts the ``originalFunctionCall`` from + their args, and maps each confirmation to the original FC ID. + + Args: + events: Session events to scan. + confirmation_fc_ids: IDs of ``adk_request_confirmation`` function calls. + confirmations_by_fc_id: Mapping of confirmation FC ID -> + ``ToolConfirmation``. + + Returns: + Tuple of ``(tool_confirmation_dict, original_fcs_dict)`` where both + are keyed by the ORIGINAL function call IDs. + """ + tool_confirmation_dict: dict[str, ToolConfirmation] = {} + original_fcs_dict: dict[str, types.FunctionCall] = {} + + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + + for function_call in event_function_calls: + if function_call.id not in confirmation_fc_ids: + continue + + args = function_call.args + if 'originalFunctionCall' not in args: + continue + original_function_call = types.FunctionCall( + **args['originalFunctionCall'] + ) + tool_confirmation_dict[original_function_call.id] = ( + confirmations_by_fc_id[function_call.id] + ) + original_fcs_dict[original_function_call.id] = original_function_call + + return tool_confirmation_dict, original_fcs_dict + + class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): """Handles tool confirmation information to build the LLM request.""" @@ -53,14 +113,12 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_confirmation_function_responses = ( - dict() - ) # {function call id, tool confirmation} - + # Step 1: Find the last user-authored event and parse confirmation + # responses from it. + confirmations_by_fc_id: dict[str, ToolConfirmation] = {} confirmation_event_index = -1 for k in range(len(events) - 1, -1, -1): event = events[k] - # Find the first event authored by user if not event.author or event.author != 'user': continue responses = event.get_function_responses() @@ -70,101 +128,58 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): for function_response in responses: if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: continue - - # Find the FunctionResponse event that contains the user provided tool - # confirmation - if ( + confirmations_by_fc_id[function_response.id] = _parse_tool_confirmation( function_response.response - and len(function_response.response.values()) == 1 - and 'response' in function_response.response.keys() - ): - # ADK client must send a resuming run request with a function response - # that always encapsulate the confirmation result with a 'response' - # key - tool_confirmation = ToolConfirmation.model_validate( - json.loads(function_response.response['response']) - ) - else: - tool_confirmation = ToolConfirmation.model_validate( - function_response.response - ) - request_confirmation_function_responses[function_response.id] = ( - tool_confirmation ) confirmation_event_index = k break - if not request_confirmation_function_responses: + if not confirmations_by_fc_id: return - for i in range(len(events) - 2, -1, -1): + # Step 2: Resolve confirmation targets using extracted helper. + confirmation_fc_ids = set(confirmations_by_fc_id.keys()) + tools_to_resume_with_confirmation, tools_to_resume_with_args = ( + _resolve_confirmation_targets( + events, confirmation_fc_ids, confirmations_by_fc_id + ) + ) + + if not tools_to_resume_with_confirmation: + return + + # Step 3: Remove tools that have already been confirmed (dedup). + for i in range(len(events) - 1, confirmation_event_index, -1): event = events[i] - # Find the system generated FunctionCall event requesting the tool - # confirmation - function_calls = event.get_function_calls() - if not function_calls: + fr_list = event.get_function_responses() + if not fr_list: continue - tools_to_resume_with_confirmation = ( - dict() - ) # {Function call id, tool confirmation} - tools_to_resume_with_args = dict() # {Function call id, function calls} - - for function_call in function_calls: - if ( - function_call.id - not in request_confirmation_function_responses.keys() - ): - continue - - args = function_call.args - if 'originalFunctionCall' not in args: - continue - original_function_call = types.FunctionCall( - **args['originalFunctionCall'] - ) - tools_to_resume_with_confirmation[original_function_call.id] = ( - request_confirmation_function_responses[function_call.id] - ) - tools_to_resume_with_args[original_function_call.id] = ( - original_function_call - ) + for function_response in fr_list: + if function_response.id in tools_to_resume_with_confirmation: + tools_to_resume_with_confirmation.pop(function_response.id) + tools_to_resume_with_args.pop(function_response.id) if not tools_to_resume_with_confirmation: - continue + break - # Remove the tools that have already been confirmed. - for i in range(len(events) - 1, confirmation_event_index, -1): - event = events[i] - function_response = event.get_function_responses() - if not function_response: - continue - - for function_response in event.get_function_responses(): - if function_response.id in tools_to_resume_with_confirmation: - tools_to_resume_with_confirmation.pop(function_response.id) - tools_to_resume_with_args.pop(function_response.id) - if not tools_to_resume_with_confirmation: - break - - if not tools_to_resume_with_confirmation: - continue - - if function_response_event := await functions.handle_function_call_list_async( - invocation_context, - tools_to_resume_with_args.values(), - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # There could be parallel function calls that require input - # response would be a dict keyed by function call id - tools_to_resume_with_confirmation.keys(), - tools_to_resume_with_confirmation, - ): - yield function_response_event + if not tools_to_resume_with_confirmation: return + # Step 4: Re-execute the confirmed tools. + if function_response_event := await functions.handle_function_call_list_async( + invocation_context, + tools_to_resume_with_args.values(), + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume_with_confirmation.keys(), + tools_to_resume_with_confirmation, + ): + yield function_response_event + return + request_processor = _RequestConfirmationLlmRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 0a26cdce..e0bd00ff 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -22,6 +22,7 @@ from . import _code_execution from . import _nl_planning from . import _output_schema_processor from . import basic +from . import compaction from . import contents from . import context_cache_processor from . import identity @@ -42,6 +43,9 @@ def _create_request_processors(): request_confirmation.request_processor, instructions.request_processor, identity.request_processor, + # Compaction should run before contents so compacted events are reflected + # in the model request context. + compaction.request_processor, contents.request_processor, # Context cache processor sets up cache config and finds # existing cache metadata. diff --git a/src/google/adk/integrations/README.md b/src/google/adk/integrations/README.md new file mode 100644 index 00000000..56ab2b33 --- /dev/null +++ b/src/google/adk/integrations/README.md @@ -0,0 +1,35 @@ +# ADK Integrations + +This directory houses modules that integrate ADK with external tools and +services. The goal is to provide an organized and scalable way to extend ADK's +capabilities. + +Integrations with external systems, such as the Agent Registry, BigQuery, +ApiHub, etc., should be developed within sub-packages in this folder. This +centralization makes it easier for developers to find, use, and contribute to +various integrations. + +## What Belongs Here? + +* Code that connects ADK to other services, APIs, or tools. +* Modules that depend on third-party libraries not included in the core ADK + dependencies. + +## Guidelines for Contributions + +1. **Self-Contained Packages:** Each integration should reside in its own + sub-directory (e.g., `integrations/my_service/`). +2. **Internal Structure:** Integration sub-packages are free to manage their + own internal code structure and design patterns. They do not need to + strictly follow the core ADK framework's structure. +3. **Dependencies:** To keep the core ADK lightweight, dependencies required + for a specific integration must be optional. These should be defined as + "extras" in the `pyproject.toml`. Users will install them using commands + like `pip install "google-adk[my_service]"`. The extra name should match the + integration directory name. +4. **Lazy Importing:** Implement lazy importing within the integration code. If + a user tries to use an integration without installing the necessary extras, + catch the `ModuleNotFoundError` and raise a descriptive error message + guiding the user to the correct installation command. +5. **Documentation:** Ensure clear documentation is provided for each + integration, including setup, configuration, and usage examples. diff --git a/src/google/adk/integrations/agent_registry/__init__.py b/src/google/adk/integrations/agent_registry/__init__.py new file mode 100644 index 00000000..995ad046 --- /dev/null +++ b/src/google/adk/integrations/agent_registry/__init__.py @@ -0,0 +1,18 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent_registry import _ProtocolType +from .agent_registry import AgentRegistry + +__all__ = [ + 'AgentRegistry', +] diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py new file mode 100644 index 00000000..93a91df4 --- /dev/null +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -0,0 +1,281 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client library for interacting with the Google Cloud Agent Registry within ADK.""" + +from __future__ import annotations + +from enum import Enum +import logging +import os +import re +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Union +from urllib.parse import parse_qs +from urllib.parse import urlparse + +from a2a.client.client_factory import minimal_agent_card +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentSkill +from a2a.types import TransportProtocol as A2ATransport +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth +import google.auth.transport.requests +import httpx + +logger = logging.getLogger("google_adk." + __name__) + +AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha" + + +class _ProtocolType(str, Enum): + """Supported agent protocol types.""" + + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + A2A_AGENT = "A2A_AGENT" + CUSTOM = "CUSTOM" + + +class AgentRegistry: + """Client for interacting with the Google Cloud Agent Registry service. + + Unlike a standard REST client library, this class provides higher-level + abstractions for ADK integration. It surfaces the agent registry service + methods along with helper methods like `get_mcp_toolset` and + `get_remote_a2a_agent` that automatically resolve connection details and + handle authentication to produce ready-to-use ADK components. + """ + + def __init__( + self, + project_id: Optional[str] = None, + location: Optional[str] = None, + header_provider: Optional[ + Callable[[ReadonlyContext], Dict[str, str]] + ] = None, + ): + """Initializes the AgentRegistry client. + + Args: + project_id: The Google Cloud project ID. + location: The Google Cloud location (region). + header_provider: Optional provider for custom headers. + """ + self.project_id = project_id + self.location = location + + if not self.project_id or not self.location: + raise ValueError("project_id and location must be provided") + + self._base_path = f"projects/{self.project_id}/locations/{self.location}" + self._header_provider = header_provider + try: + self._credentials, _ = google.auth.default() + except google.auth.exceptions.DefaultCredentialsError as e: + raise RuntimeError( + f"Failed to get default Google Cloud credentials: {e}" + ) from e + + def _get_auth_headers(self) -> Dict[str, str]: + """Refreshes credentials and returns authorization headers.""" + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + quota_project_id = getattr(self._credentials, "quota_project_id", None) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id + return headers + except google.auth.exceptions.RefreshError as e: + raise RuntimeError( + f"Failed to refresh Google Cloud credentials: {e}" + ) from e + + def _make_request( + self, path: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Helper function to make GET requests to the Agent Registry API.""" + if path.startswith("projects/"): + url = f"{AGENT_REGISTRY_BASE_URL}/{path}" + else: + url = f"{AGENT_REGISTRY_BASE_URL}/{self._base_path}/{path}" + + try: + headers = self._get_auth_headers() + with httpx.Client() as client: + response = client.get(url, headers=headers, params=params) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"API request failed with status {e.response.status_code}:" + f" {e.response.text}" + ) from e + except httpx.RequestError as e: + raise RuntimeError(f"API request failed (network error): {e}") from e + except Exception as e: + raise RuntimeError(f"API request failed: {e}") from e + + def _get_connection_uri( + self, + resource_details: Dict[str, Any], + protocol_type: Optional[_ProtocolType] = None, + protocol_binding: Optional[A2ATransport] = None, + ) -> Optional[str]: + """Extracts the first matching URI based on type and binding filters.""" + protocols = list(resource_details.get("protocols", [])) + if "interfaces" in resource_details: + protocols.append({"interfaces": resource_details["interfaces"]}) + + for p in protocols: + if protocol_type and p.get("type") != protocol_type: + continue + for i in p.get("interfaces", []): + if protocol_binding and i.get("protocolBinding") != protocol_binding: + continue + if url := i.get("url"): + return url + + return None + + def _clean_name(self, name: str) -> str: + """Cleans a string to be a valid Python identifier for agent names.""" + clean = re.sub(r"[^a-zA-Z0-9_]", "_", name) + clean = re.sub(r"_+", "_", clean) + clean = clean.strip("_") + if clean and not clean[0].isalpha() and clean[0] != "_": + clean = "_" + clean + return clean + + # --- MCP Server Methods --- + + def list_mcp_servers( + self, + filter_str: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Fetches a list of MCP Servers.""" + params = {} + if filter_str: + params["filter"] = filter_str + if page_size: + params["pageSize"] = str(page_size) + if page_token: + params["pageToken"] = page_token + return self._make_request("mcpServers", params=params) + + def get_mcp_server(self, name: str) -> Dict[str, Any]: + """Retrieves details of a specific MCP Server.""" + return self._make_request(name) + + def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset: + """Constructs an McpToolset instance from a registered MCP Server.""" + server_details = self.get_mcp_server(mcp_server_name) + name = self._clean_name(server_details.get("displayName", mcp_server_name)) + + endpoint_uri = self._get_connection_uri( + server_details, protocol_binding=A2ATransport.jsonrpc + ) or self._get_connection_uri( + server_details, protocol_binding=A2ATransport.http_json + ) + if not endpoint_uri: + raise ValueError( + f"MCP Server endpoint URI not found for: {mcp_server_name}" + ) + + connection_params = StreamableHTTPConnectionParams( + url=endpoint_uri, headers=self._get_auth_headers() + ) + return McpToolset( + connection_params=connection_params, + tool_name_prefix=name, + header_provider=self._header_provider, + ) + + # --- Agent Methods --- + + def list_agents( + self, + filter_str: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Fetches a list of registered A2A Agents.""" + params = {} + if filter_str: + params["filter"] = filter_str + if page_size: + params["pageSize"] = str(page_size) + if page_token: + params["pageToken"] = page_token + return self._make_request("agents", params=params) + + def get_agent_info(self, name: str) -> Dict[str, Any]: + """Retrieves detailed metadata of a specific A2A Agent.""" + return self._make_request(name) + + def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent: + """Creates a RemoteA2aAgent instance for a registered A2A Agent.""" + agent_info = self.get_agent_info(agent_name) + name = self._clean_name(agent_info.get("displayName", agent_name)) + description = agent_info.get("description", "") + version = agent_info.get("version", "") + + url = self._get_connection_uri( + agent_info, protocol_type=_ProtocolType.A2A_AGENT + ) + if not url: + raise ValueError(f"A2A connection URI not found for Agent: {agent_name}") + + skills = [] + for s in agent_info.get("skills", []): + skills.append( + AgentSkill( + id=s.get("id"), + name=s.get("name"), + description=s.get("description", ""), + tags=s.get("tags", []), + examples=s.get("examples", []), + ) + ) + + agent_card = AgentCard( + name=name, + description=description, + version=version, + url=url, + skills=skills, + capabilities=AgentCapabilities(streaming=False, polling=False), + defaultInputModes=["text"], + defaultOutputModes=["text"], + ) + + return RemoteA2aAgent( + name=name, + agent_card=agent_card, + description=description, + ) diff --git a/src/google/adk/integrations/api_registry/__init__.py b/src/google/adk/integrations/api_registry/__init__.py new file mode 100644 index 00000000..1179bc86 --- /dev/null +++ b/src/google/adk/integrations/api_registry/__init__.py @@ -0,0 +1,17 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .api_registry import ApiRegistry + +__all__ = [ + 'ApiRegistry', +] diff --git a/src/google/adk/integrations/api_registry/api_registry.py b/src/google/adk/integrations/api_registry/api_registry.py new file mode 100644 index 00000000..966ad68b --- /dev/null +++ b/src/google/adk/integrations/api_registry/api_registry.py @@ -0,0 +1,140 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import Callable + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.base_toolset import ToolPredicate +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import google.auth +import google.auth.transport.requests +import httpx + +API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" + + +class ApiRegistry: + """Registry that provides McpToolsets for MCP servers registered in API Registry.""" + + def __init__( + self, + api_registry_project_id: str, + location: str = "global", + header_provider: ( + Callable[[ReadonlyContext], dict[str, str]] | None + ) = None, + ): + """Initialize the API Registry. + + Args: + api_registry_project_id: The project ID for the Google Cloud API Registry. + location: The location of the API Registry resources. + header_provider: Optional function to provide additional headers for MCP + server calls. + """ + self.api_registry_project_id = api_registry_project_id + self.location = location + self._credentials, _ = google.auth.default() + self._mcp_servers: dict[str, dict[str, Any]] = {} + self._header_provider = header_provider + + url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + + try: + headers = self._get_auth_headers() + headers["Content-Type"] = "application/json" + page_token = None + with httpx.Client() as client: + while True: + params = {} + if page_token: + params["pageToken"] = page_token + + response = client.get(url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + mcp_servers_list = data.get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + + page_token = data.get("nextPageToken") + if not page_token: + break + except (httpx.HTTPError, ValueError) as e: + # Handle error in fetching or parsing tool definitions + raise RuntimeError( + f"Error fetching MCP servers from API Registry: {e}" + ) from e + + def get_toolset( + self, + mcp_server_name: str, + tool_filter: ToolPredicate | list[str] | None = None, + tool_name_prefix: str | None = None, + ) -> McpToolset: + """Return the MCP Toolset based on the params. + + Args: + mcp_server_name: Filter to select the MCP server name to get tools from. + tool_filter: Optional filter to select specific tools. Can be a list of + tool names or a ToolPredicate function. + tool_name_prefix: Optional prefix to prepend to the names of the tools + returned by the toolset. + + Returns: + McpToolset: A toolset for the MCP server specified. + """ + server = self._mcp_servers.get(mcp_server_name) + if not server: + raise ValueError( + f"MCP server {mcp_server_name} not found in API Registry." + ) + if not server.get("urls"): + raise ValueError(f"MCP server {mcp_server_name} has no URLs.") + + mcp_server_url = server["urls"][0] + headers = self._get_auth_headers() + + # Only prepend "https://" if the URL doesn't already have a scheme + if not mcp_server_url.startswith(("http://", "https://")): + mcp_server_url = "https://" + mcp_server_url + + return McpToolset( + connection_params=StreamableHTTPConnectionParams( + url=mcp_server_url, + headers=headers, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=self._header_provider, + ) + + def _get_auth_headers(self) -> dict[str, str]: + """Refreshes credentials and returns authorization headers.""" + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + # Add quota project header if available in ADC + quota_project_id = getattr(self._credentials, "quota_project_id", None) + if quota_project_id: + headers["x-goog-user-project"] = quota_project_id + return headers diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 7bb18efa..2218c874 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -65,6 +65,11 @@ _CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({ 'wait_for_completion', }) +_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation' +# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow +# at most 5 direct_memories per request. +_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5 + def _supports_generate_memories_metadata() -> bool: """Returns whether installed Vertex SDK supports config.metadata.""" @@ -160,6 +165,11 @@ class VertexAiMemoryBankService(BaseMemoryService): not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ + if not agent_engine_id: + raise ValueError( + 'agent_engine_id is required for VertexAiMemoryBankService.' + ) + self._project = project self._location = location self._agent_engine_id = agent_engine_id @@ -219,7 +229,22 @@ class VertexAiMemoryBankService(BaseMemoryService): memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: - """Adds explicit memory items via Vertex memories.create.""" + """Adds explicit memory items using Vertex Memory Bank. + + By default, this writes directly via `memories.create`. + If `custom_metadata["enable_consolidation"]` is set to True, this uses + `memories.generate` with `direct_memories_source` so provided memories are + consolidated server-side. + """ + if _is_consolidation_enabled(custom_metadata): + await self._add_memories_via_generate_direct_memories_source( + app_name=app_name, + user_id=user_id, + memories=memories, + custom_metadata=custom_metadata, + ) + return + await self._add_memories_via_create( app_name=app_name, user_id=user_id, @@ -235,9 +260,6 @@ class VertexAiMemoryBankService(BaseMemoryService): events_to_process: Sequence[Event], custom_metadata: Mapping[str, object] | None = None, ) -> None: - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - direct_events = [] for event in events_to_process: if _should_filter_out_event(event.content): @@ -272,9 +294,6 @@ class VertexAiMemoryBankService(BaseMemoryService): custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds direct memory items without server-side extraction.""" - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - normalized_memories = _normalize_memories_for_create(memories) api_client = self._get_api_client() for index, memory in enumerate(normalized_memories): @@ -300,11 +319,41 @@ class VertexAiMemoryBankService(BaseMemoryService): logger.info('Create memory response received.') logger.debug('Create memory response: %s', operation) + async def _add_memories_via_generate_direct_memories_source( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds memories via generate API with direct_memories_source.""" + normalized_memories = _normalize_memories_for_create(memories) + memory_texts = [ + _memory_entry_to_fact(m, index=i) + for i, m in enumerate(normalized_memories) + ] + api_client = self._get_api_client() + config = _build_generate_memories_config(custom_metadata) + for memory_batch in _iter_memory_batches(memory_texts): + operation = await api_client.agent_engines.memories.generate( + name='reasoningEngines/' + self._agent_engine_id, + direct_memories_source={ + 'direct_memories': [ + {'fact': memory_text} for memory_text in memory_batch + ] + }, + scope={ + 'app_name': app_name, + 'user_id': user_id, + }, + config=config, + ) + logger.info('Generate direct memory response received.') + logger.debug('Generate direct memory response: %s', operation) + @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - api_client = self._get_api_client() retrieved_memories_iterator = ( await api_client.agent_engines.memories.retrieve( @@ -379,6 +428,8 @@ def _build_generate_memories_config( metadata_by_key: dict[str, object] = {} for key, value in custom_metadata.items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'ttl': if value is None: continue @@ -456,6 +507,8 @@ def _build_create_memory_config( metadata_by_key: dict[str, object] = {} custom_revision_labels: dict[str, str] = {} for key, value in (custom_metadata or {}).items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'metadata': if value is None: continue @@ -641,6 +694,32 @@ def _extract_revision_labels( return revision_labels +def _is_consolidation_enabled( + custom_metadata: Mapping[str, object] | None, +) -> bool: + """Returns whether direct memories should be consolidated via generate API.""" + if not custom_metadata: + return False + enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY) + if enable_consolidation is None: + return False + if not isinstance(enable_consolidation, bool): + raise TypeError( + f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.' + ) + return enable_consolidation + + +def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]: + """Returns memory slices that comply with direct_memories limits.""" + memory_batches: list[Sequence[str]] = [] + for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL): + memory_batches.append( + memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL] + ) + return memory_batches + + def _build_vertex_metadata( metadata_by_key: Mapping[str, object], ) -> dict[str, object]: diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 9811733f..42ad3e93 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -17,7 +17,9 @@ from __future__ import annotations import base64 +import dataclasses from functools import cached_property +import json import logging import os from typing import Any @@ -31,6 +33,7 @@ from typing import Union from anthropic import AsyncAnthropic from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN +from anthropic import NotGiven from anthropic import types as anthropic_types from google.genai import types from pydantic import BaseModel @@ -48,6 +51,15 @@ __all__ = ["AnthropicLlm", "Claude"] logger = logging.getLogger("google_adk." + __name__) +@dataclasses.dataclass +class _ToolUseAccumulator: + """Accumulates streamed tool_use content block data.""" + + id: str + name: str + args_json: str + + class ClaudeRequest(BaseModel): system_instruction: str messages: Iterable[anthropic_types.MessageParam] @@ -115,12 +127,15 @@ def part_to_message_block( else: content_items.append(str(item)) content = "\n".join(content_items) if content_items else "" - # Handle traditional result format - elif "result" in response_data and response_data["result"]: - # Transformation is required because the content is a list of dict. - # ToolResultBlockParam content doesn't support list of dict. Converting - # to str to prevent anthropic.BadRequestError from being thrown. - content = str(response_data["result"]) + # We serialize to str here + # SDK ref: anthropic.types.tool_result_block_param + # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py + elif "result" in response_data and response_data["result"] is not None: + result = response_data["result"] + if isinstance(result, (dict, list)): + content = json.dumps(result) + else: + content = str(result) return anthropic_types.ToolResultBlockParam( tool_use_id=part.function_response.id or "", @@ -302,16 +317,111 @@ class AnthropicLlm(BaseLlm): if llm_request.tools_dict else NOT_GIVEN ) - # TODO(b/421255973): Enable streaming for anthropic models. - message = await self._anthropic_client.messages.create( + + if not stream: + message = await self._anthropic_client.messages.create( + model=llm_request.model, + system=llm_request.config.system_instruction, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=self.max_tokens, + ) + yield message_to_generate_content_response(message) + else: + async for response in self._generate_content_streaming( + llm_request, messages, tools, tool_choice + ): + yield response + + async def _generate_content_streaming( + self, + llm_request: LlmRequest, + messages: list[anthropic_types.MessageParam], + tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven], + tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven], + ) -> AsyncGenerator[LlmResponse, None]: + """Handles streaming responses from Anthropic models. + + Yields partial LlmResponse objects as content arrives, followed by + a final aggregated LlmResponse with all content. + """ + raw_stream = await self._anthropic_client.messages.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, tools=tools, tool_choice=tool_choice, max_tokens=self.max_tokens, + stream=True, + ) + + # Track content blocks being built during streaming. + # Each entry maps a block index to its accumulated state. + text_blocks: dict[int, str] = {} + tool_use_blocks: dict[int, _ToolUseAccumulator] = {} + input_tokens = 0 + output_tokens = 0 + + async for event in raw_stream: + if event.type == "message_start": + input_tokens = event.message.usage.input_tokens + output_tokens = event.message.usage.output_tokens + + elif event.type == "content_block_start": + block = event.content_block + if isinstance(block, anthropic_types.TextBlock): + text_blocks[event.index] = block.text + elif isinstance(block, anthropic_types.ToolUseBlock): + tool_use_blocks[event.index] = _ToolUseAccumulator( + id=block.id, + name=block.name, + args_json="", + ) + + elif event.type == "content_block_delta": + delta = event.delta + if isinstance(delta, anthropic_types.TextDelta): + text_blocks.setdefault(event.index, "") + text_blocks[event.index] += delta.text + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text=delta.text)], + ), + partial=True, + ) + elif isinstance(delta, anthropic_types.InputJSONDelta): + if event.index in tool_use_blocks: + tool_use_blocks[event.index].args_json += delta.partial_json + + elif event.type == "message_delta": + output_tokens = event.usage.output_tokens + + # Build the final aggregated response with all content. + all_parts: list[types.Part] = [] + all_indices = sorted( + set(list(text_blocks.keys()) + list(tool_use_blocks.keys())) + ) + for idx in all_indices: + if idx in text_blocks: + all_parts.append(types.Part.from_text(text=text_blocks[idx])) + if idx in tool_use_blocks: + acc = tool_use_blocks[idx] + args = json.loads(acc.args_json) if acc.args_json else {} + part = types.Part.from_function_call(name=acc.name, args=args) + part.function_call.id = acc.id + all_parts.append(part) + + yield LlmResponse( + content=types.Content(role="model", parts=all_parts), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=input_tokens, + candidates_token_count=output_tokens, + total_token_count=input_tokens + output_tokens, + ), + partial=False, ) - yield message_to_generate_content_response(message) @cached_property def _anthropic_client(self) -> AsyncAnthropic: diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 92f94c75..fc4928cb 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -12,21 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations +import asyncio +import atexit +import base64 +import collections.abc +import enum from functools import cached_property +import json import logging import os +from typing import Any +from typing import AsyncGenerator +from typing import Generator from typing import Optional from typing import TYPE_CHECKING from google.adk import version as adk_version from google.genai import types +import httpx +import tenacity from typing_extensions import override from ..utils.env_utils import is_env_enabled from .google_llm import Gemini +from .llm_response import LlmResponse if TYPE_CHECKING: from google.genai import Client @@ -41,6 +52,14 @@ _GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_USE_VERTEXAI' _PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT' _LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION' +_CUSTOM_METADATA_FIELDS = ( + 'id', + 'created', + 'model', + 'service_tier', + 'object', +) + class ApigeeLlm(Gemini): """A BaseLlm implementation for calling Apigee proxy. @@ -49,6 +68,20 @@ class ApigeeLlm(Gemini): model: The name of the Gemini model. """ + class ApiType(str, enum.Enum): + """The supported API types for Apigee LLM.""" + + UNKNOWN = 'unknown' + CHAT_COMPLETIONS = 'chat_completions' + GENAI = 'genai' + + @classmethod + def _missing_(cls, value): + # Empty string or None should return UNKNOWN. + if not value: + return cls.UNKNOWN + return super()._missing_(value) + def __init__( self, *, @@ -56,6 +89,7 @@ class ApigeeLlm(Gemini): proxy_url: str | None = None, custom_headers: dict[str, str] | None = None, retry_options: Optional[types.HttpRetryOptions] = None, + api_type: ApiType | str = ApiType.UNKNOWN, ): """Initializes the Apigee LLM backend. @@ -80,19 +114,31 @@ class ApigeeLlm(Gemini): - `apigee/vertex_ai/gemini-2.5-flash` - `apigee/gemini/v1/gemini-2.5-flash` - `apigee/vertex_ai/v1beta/gemini-2.5-flash` - proxy_url: The URL of the Apigee proxy. custom_headers: A dictionary of headers to be sent with the request. + If needed, you can add authorization headers here, for example: + {'Authorization': f'Bearer {API_KEY}'}. ApigeeLlm already handles + authorization headers in Vertex AI and Gemini API calls. retry_options: Allow google-genai to retry failed responses. - """ + api_type: The type of API to use. One of `ApiType` or string. + """ # fmt: skip super().__init__(model=model, retry_options=retry_options) # Validate the model string. Create a helper method to validate the model # string. if not _validate_model_string(model): raise ValueError(f'Invalid model string: {model}') - - self._isvertexai = _identify_vertexai(model) + if isinstance(api_type, str): + api_type = ApigeeLlm.ApiType(api_type) + if api_type and api_type != ApigeeLlm.ApiType.UNKNOWN: + self._api_type = api_type + elif model.startswith(('apigee/gemini/', 'apigee/vertex_ai/')): + self._api_type = ApigeeLlm.ApiType.GENAI + elif model.startswith('apigee/openai/'): + self._api_type = ApigeeLlm.ApiType.CHAT_COMPLETIONS + else: + self._api_type = ApigeeLlm.ApiType.GENAI + self._isvertexai = _identify_vertexai(model, self._api_type) # Set the project and location for Vertex AI. if self._isvertexai: @@ -131,6 +177,42 @@ class ApigeeLlm(Gemini): r'apigee\/.*', ] + @cached_property + def _completions_http_client(self) -> CompletionsHTTPClient: + """Provides the completions HTTP client.""" + return CompletionsHTTPClient( + base_url=self._proxy_url, + headers=self._merge_tracking_headers(self._custom_headers), + retry_options=self.retry_options, + ) + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + if self._api_type == ApigeeLlm.ApiType.CHAT_COMPLETIONS: + await self._preprocess_other_requests(llm_request) + async for ( + response + ) in self._completions_http_client.generate_content_async( + llm_request, stream + ): + yield response + else: + async for response in super().generate_content_async(llm_request, stream): + yield response + + async def _preprocess_other_requests(self, llm_request: LlmRequest) -> None: + """Preprocesses the request for non-Gemini/Vertex AI models.""" + llm_request.model = _get_model_id(llm_request.model) + if llm_request.config and llm_request.config.tools: + # Check if computer use is configured + for tool in llm_request.config.tools: + if isinstance(tool, types.Tool) and tool.computer_use: + llm_request.config.system_instruction = None + await self._adapt_computer_use_tool(llm_request) + self._maybe_append_user_content(llm_request) + @cached_property def api_client(self) -> Client: """Provides the api client. @@ -167,11 +249,25 @@ class ApigeeLlm(Gemini): await super()._preprocess_request(llm_request) -def _identify_vertexai(model: str) -> bool: - """Returns True if the model spec starts with apigee/vertex_ai.""" - return not model.startswith('apigee/gemini/') and ( - model.startswith('apigee/vertex_ai/') - or is_env_enabled(_GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME) +def _identify_vertexai(model: str, api_type: ApigeeLlm.ApiType) -> bool: + """Returns if a model is Vertex AI. + + 1. The api_type is GENAI or UNKNOWN. + 2. The model is provider is Vertex AI model or the + GOOGLE_GENAI_USE_VERTEXAI environment variable is set to TRUE or 1. + + Args: + model: The model string. + api_type: The type of API to use. + """ + if api_type not in (ApigeeLlm.ApiType.GENAI, ApigeeLlm.ApiType.UNKNOWN): + return False + if model.startswith('apigee/gemini/'): + return False + if model.startswith('apigee/openai/'): + return False + return model.startswith('apigee/vertex_ai/') or is_env_enabled( + _GOOGLE_GENAI_USE_VERTEXAI_ENV_VARIABLE_NAME ) @@ -203,6 +299,45 @@ def _get_model_id(model: str) -> str: return components[-1] +def _parse_logprobs( + logprobs_data: dict[str, Any] | None, +) -> types.LogprobsResult | None: + """Parses OpenAI logprobs data into LogprobsResult.""" + if not logprobs_data or 'content' not in logprobs_data: + return None + + chosen_candidates = [] + top_candidates = [] + + for item in logprobs_data['content']: + chosen_candidates.append( + types.LogprobsResultCandidate( + token=item.get('token'), + log_probability=item.get('logprob'), + # OpenAI text format usually doesn't expose ID easily here + token_id=None, + ) + ) + + if 'top_logprobs' in item: + current_top_candidates = [] + for top_item in item['top_logprobs']: + current_top_candidates.append( + types.LogprobsResultCandidate( + token=top_item.get('token'), + log_probability=top_item.get('logprob'), + token_id=None, + ) + ) + top_candidates.append( + types.LogprobsResultTopCandidates(candidates=current_top_candidates) + ) + + return types.LogprobsResult( + chosen_candidates=chosen_candidates, top_candidates=top_candidates + ) + + def _validate_model_string(model: str) -> bool: """Validates the model string for Apigee LLM. @@ -240,7 +375,7 @@ def _validate_model_string(model: str) -> bool: # and model_id are present. This is a valid format. if len(components) == 3: # Format: // - if components[0] not in ('vertex_ai', 'gemini'): + if components[0] not in ('vertex_ai', 'gemini', 'openai'): return False if not components[1].startswith('v'): return False @@ -249,10 +384,762 @@ def _validate_model_string(model: str) -> bool: # If the model string has 2 components, it means either the provider or the # version (but not both), and model_id are present. if len(components) == 2: - if components[0] in ['vertex_ai', 'gemini']: + if components[0] in ['vertex_ai', 'gemini', 'openai']: return True if components[0].startswith('v'): return True return False return False + + +class CompletionsHTTPClient: + """A generic HTTP client for completions, compatible with OpenAI API.""" + + def __init__( + self, + base_url: str, + headers: dict[str, str] | None = None, + retry_options: Optional[types.HttpRetryOptions] = None, + ): + self._base_url = base_url + self._headers = headers or {} + self.retry_options = retry_options + + def __del__(self) -> None: + self.close() + + @cached_property + def _client(self) -> httpx.AsyncClient: + """Provides the httpx client.""" + client = httpx.AsyncClient( + base_url=self._base_url, + headers=self._headers, + timeout=None, + follow_redirects=True, + ) + atexit.register(self._cleanup_client, client) + return client + + @staticmethod + def _cleanup_client(client: httpx.AsyncClient) -> None: + """Cleans up the httpx client.""" + if client.is_closed: + return + try: + loop = asyncio.get_running_loop() + loop.create_task(client.aclose()) + except RuntimeError: + try: + # This fails if asyncio.run is already called in main and is closing. + asyncio.run(client.aclose()) + except RuntimeError: + pass + + def close(self) -> None: + if '_client' not in self.__dict__: + return + self._cleanup_client(self._client) + + async def aclose(self) -> None: + if '_client' not in self.__dict__: + return + if self._client.is_closed: + return + await self._client.aclose() + + def _get_retry_kwargs(self) -> dict[str, Any]: + """Returns the retry kwargs for tenacity.""" + if not self.retry_options: + return {'stop': tenacity.stop_after_attempt(1), 'reraise': True} + + default_attempts = 5 + default_initial_delay = 1.0 + default_max_delay = 60.0 + default_exp_base = 2 + default_jitter = 1 + default_status_codes = (408, 429, 500, 502, 503, 504) + + opts = self.retry_options + stop = tenacity.stop_after_attempt( + opts.attempts if opts.attempts is not None else default_attempts + ) + + retriable_codes = ( + opts.http_status_codes + if opts.http_status_codes is not None + else default_status_codes + ) + + retry_network = tenacity.retry_if_exception_type(httpx.NetworkError) + + def is_retriable(e: Exception) -> bool: + if isinstance(e, httpx.HTTPStatusError): + return e.response.status_code in retriable_codes + return False + + retry_status = tenacity.retry_if_exception(is_retriable) + + wait = tenacity.wait_exponential_jitter( + initial=( + opts.initial_delay + if opts.initial_delay is not None + else default_initial_delay + ), + max=( + opts.max_delay if opts.max_delay is not None else default_max_delay + ), + exp_base=( + opts.exp_base if opts.exp_base is not None else default_exp_base + ), + jitter=opts.jitter if opts.jitter is not None else default_jitter, + ) + + return { + 'stop': stop, + 'retry': tenacity.retry_any(retry_network, retry_status), + 'reraise': True, + 'wait': wait, + } + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool + ) -> AsyncGenerator[LlmResponse, None]: + """Generates content using the OpenAI-compatible HTTP API.""" + payload = self._construct_payload(llm_request, stream) + headers = self._headers.copy() + headers['Content-Type'] = 'application/json' + + url = self._base_url + if not url: + raise ValueError('Base URL is not set.') + + if not url.endswith('/chat/completions'): + url = f"{url.rstrip('/')}/chat/completions" + + if stream: + async for stream_res in self._handle_streaming(url, payload, headers): + yield stream_res + else: + response = await self._httpx_post_with_retry(url, payload, headers) + data = response.json() + yield self._parse_response(data) + + async def _httpx_post_with_retry( + self, url: str, payload: dict[str, Any], headers: dict[str, str] + ) -> httpx.Response: + """Sends a POST request and handles retries.""" + retry_kwargs = self._get_retry_kwargs() + async for attempt in tenacity.AsyncRetrying(**retry_kwargs): + with attempt: + response = await self._client.post(url, json=payload, headers=headers) + response.raise_for_status() + return response + + async def _handle_streaming( + self, url: str, payload: dict[str, Any], headers: dict[str, str] + ) -> AsyncGenerator[LlmResponse, None]: + """Handles streaming response from OpenAI-compatible API.""" + accumulator = ChatCompletionsResponseHandler() + async with self._client.stream( + 'POST', + url, + json=payload, + headers=headers, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + line = line.strip() + if line.startswith('data:'): + line = line.removeprefix('data:') + line = line.lstrip() + if line == '[DONE]': + break + try: + for res in self._parse_streaming_line(line, accumulator): + yield res + except json.JSONDecodeError: + logger.warning('Failed to parse JSON chunk: %s', line) + continue + + def _construct_payload( + self, llm_request: LlmRequest, stream: bool + ) -> dict[str, Any]: + """Constructs the payload from the LlmRequest.""" + messages = [] + if llm_request.config and llm_request.config.system_instruction: + content = self._serialize_system_instruction( + llm_request.config.system_instruction + ) + if content: + messages.append({ + 'role': 'system', + 'content': content, + }) + + for content in llm_request.contents: + messages += self._content_to_messages(content) + + payload = { + 'model': _get_model_id(llm_request.model), + 'messages': messages, + 'stream': stream, + } + + if llm_request.config: + self._map_config_parameters(llm_request.config, payload) + self._map_tools(llm_request.config, payload) + + return payload + + def _map_config_parameters( + self, config: types.GenerateContentConfig, payload: dict[str, Any] + ) -> None: + """Maps configuration parameters to the payload.""" + if config.temperature is not None: + payload['temperature'] = config.temperature + if config.top_p is not None: + payload['top_p'] = config.top_p + if config.max_output_tokens is not None: + payload['max_tokens'] = config.max_output_tokens + if config.stop_sequences: + payload['stop'] = config.stop_sequences + if config.frequency_penalty is not None: + payload['frequency_penalty'] = config.frequency_penalty + if config.presence_penalty is not None: + payload['presence_penalty'] = config.presence_penalty + if config.seed is not None: + payload['seed'] = config.seed + if config.candidate_count is not None: + payload['n'] = config.candidate_count + if config.response_logprobs: + payload['logprobs'] = True + if config.logprobs is not None: + payload['top_logprobs'] = config.logprobs + + if config.response_json_schema: + payload['response_format'] = { + 'type': 'json_schema', + 'json_schema': config.response_json_schema, + } + elif config.response_mime_type == 'application/json': + payload['response_format'] = {'type': 'json_object'} + + def _map_tools( + self, config: types.GenerateContentConfig, payload: dict[str, Any] + ) -> None: + """Maps tools and tool configuration to the payload.""" + if config.tools: + tools = [] + for tool in config.tools: + if tool.function_declarations: + for func in tool.function_declarations: + tools.append(self._function_declaration_to_tool(func)) + if tools: + payload['tools'] = tools + if config.tool_config and config.tool_config.function_calling_config: + mode = config.tool_config.function_calling_config.mode + if mode == types.FunctionCallingConfigMode.ANY: + payload['tool_choice'] = 'required' + elif mode == types.FunctionCallingConfigMode.NONE: + payload['tool_choice'] = 'none' + elif mode == types.FunctionCallingConfigMode.AUTO: + payload['tool_choice'] = 'auto' + + def _content_to_messages( + self, content: types.Content + ) -> list[dict[str, Any]]: + """Converts a Content object to /chat/completions messages.""" + role = content.role + if role == 'model': + role = 'assistant' + + tool_calls = [] + content_parts = [] + + function_responses = [] + + for part in content.parts or []: + self._process_content_part(content, part, tool_calls, content_parts) + if part.function_response: + function_responses.append({ + 'role': 'tool', + 'tool_call_id': part.function_response.id, + 'content': json.dumps(part.function_response.response), + }) + if function_responses: + return function_responses + + message = {'role': role} + if tool_calls: + message['tool_calls'] = tool_calls + if not content_parts: + message['content'] = None + + if content_parts: + if len(content_parts) == 1 and content_parts[0]['type'] == 'text': + message['content'] = content_parts[0]['text'] + else: + message['content'] = content_parts + return [message] + + def _process_content_part( + self, + content: types.Content, + part: types.Part, + tool_calls: list[dict[str, Any]], + content_parts: list[dict[str, Any]], + ) -> None: + """Processes a single Part and updates tool_calls or content_parts.""" + if content.role != 'user' and ( + part.inline_data + or ( + part.file_data + and part.file_data.mime_type + and part.file_data.mime_type.startswith('image') + ) + ): + logger.warning('Image data is not supported for assistant turns.') + return + + if part.function_call: + tool_call = { + 'id': part.function_call.id or 'call_' + part.function_call.name, + 'type': 'function', + 'function': { + 'name': part.function_call.name, + 'arguments': ( + json.dumps(part.function_call.args) + if part.function_call.args + else '{}' + ), + }, + } + if part.thought_signature: + sig = part.thought_signature + if isinstance(sig, bytes): + sig = base64.b64encode(sig).decode('utf-8') + tool_call['extra_content'] = { + 'google': { + 'thought_signature': sig, + }, + } + tool_calls.append(tool_call) + elif part.function_response: + # Handled in the loop to return immediately + pass + elif part.text: + content_parts.append({'type': 'text', 'text': part.text}) + elif part.inline_data: + mime_type = part.inline_data.mime_type + data = base64.b64encode(part.inline_data.data).decode('utf-8') + url = f'data:{mime_type};base64,{data}' + content_parts.append({'type': 'image_url', 'image_url': {'url': url}}) + elif part.file_data: + if part.file_data.file_uri: + content_parts.append({ + 'type': 'image_url', + 'image_url': {'url': part.file_data.file_uri}, + }) + elif part.executable_code: + logger.warning( + 'Executable code is not supported in the standard Chat Completions' + ' API.' + ) + elif part.code_execution_result: + logger.warning( + 'Code execution result is not supported in the standard Chat' + ' Completions API.' + ) + + def _function_declaration_to_tool( + self, func: types.FunctionDeclaration + ) -> dict[str, Any]: + """Converts a FunctionDeclaration to an OpenAI tool dictionary.""" + parameters = {} + if func.parameters_json_schema: + parameters = func.parameters_json_schema + elif func.parameters: + parameters = func.parameters.model_dump(exclude_none=True) + + return { + 'type': 'function', + 'function': { + 'name': func.name, + 'description': func.description, + 'parameters': parameters, + }, + } + + def _serialize_system_instruction( + self, system_instruction: Optional[types.ContentUnion] + ) -> str | None: + """Serializes system instruction to a string from ContentUnion type.""" + if not system_instruction: + return None + if isinstance(system_instruction, str): + return system_instruction + if isinstance(system_instruction, types.Part): + return system_instruction.text + if isinstance(system_instruction, types.Content): + return ''.join( + part.text for part in system_instruction.parts if part.text + ) + if isinstance(system_instruction, dict): + part = types.Part(**system_instruction) + return part.text + if isinstance(system_instruction, collections.abc.Iterable): + parts = [] + for item in system_instruction: + if isinstance(item, str): + parts.append(types.Part(text=item)) + elif isinstance(item, types.Part): + parts.append(item) + elif isinstance(item, dict): + parts.append(types.Part(**item)) + return ''.join(part.text for part in parts if part.text) + return None + + def _parse_response(self, response: dict[str, Any]) -> LlmResponse: + """Parses an OpenAI response dictionary into an LlmResponse.""" + handler = ChatCompletionsResponseHandler() + return handler.process_response(response) + + def _parse_streaming_line( + self, + line: str, + accumulator: ChatCompletionsResponseHandler, + ) -> Generator[LlmResponse]: + """Parses a single line from the streaming response. + + Args: + line: A single line from the streaming response, expected to be a JSON + string. + accumulator: An accumulator to manage partial chat completion choices + across multiple chunks. + + Yields: + An LlmResponse object parsed from the streaming line. + """ + chunk = json.loads(line) + for response in accumulator.process_chunk(chunk): + yield response + + +class ChatCompletionsResponseHandler: + """Accumulates responses from the /chat/completions endpoint. + + Useful for both streaming and non-streaming responses. + """ + + def __init__(self): + self.content_parts = '' + self.tool_call_parts = {} + self.role = '' + self.streaming_complete = False + self.model = '' + self.usage = {} + self.logprobs = {} + self.custom_metadata = {} + + def process_response(self, response: dict[str, Any]) -> LlmResponse: + """Processes a complete non-streaming response.""" + choices = response.get('choices', []) + if not choices: + raise ValueError('No choices found in response.') + if len(choices) > 1: + logging.error( + 'Multiple choices found in response but only the first one will be' + ' used.' + ) + choice = choices[0] + message = choice.get('message', {}) + _, role = self._add_chat_completion_message(message) + parts = self._get_content_parts() + + usage = response.get('usage', {}) + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get('prompt_tokens', 0), + candidates_token_count=usage.get('completion_tokens', 0), + total_token_count=usage.get('total_tokens', 0), + ) + logprobs_result = _parse_logprobs(choice.get('logprobs')) + + custom_metadata = {} + for k in _CUSTOM_METADATA_FIELDS: + v = response.get(k) + if v is not None: + custom_metadata[k] = v + + return LlmResponse( + content=types.Content(role=role, parts=parts), + usage_metadata=usage_metadata, + finish_reason=self._map_finish_reason(choice.get('finish_reason')), + logprobs_result=logprobs_result, + model_version=response.get('model'), + custom_metadata=custom_metadata, + ) + + def process_chunk( + self, chunk: dict[str, Any] + ) -> Generator[LlmResponse, None, None]: + """Processes a chunk and yields responses.""" + if 'model' in chunk: + self.model = chunk['model'] + if 'usage' in chunk and chunk['usage']: + self.usage.update(chunk['usage']) + + for k in _CUSTOM_METADATA_FIELDS: + v = chunk.get(k) + if v is not None: + self.custom_metadata[k] = v + + usage_metadata = None + if self.usage: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=self.usage.get('prompt_tokens', 0), + candidates_token_count=self.usage.get('completion_tokens', 0), + total_token_count=self.usage.get('total_tokens', 0), + ) + + choices = chunk.get('choices') + if not choices: + # If no choices, but we have usage or other metadata updates, yield them. + if usage_metadata or self.custom_metadata: + yield LlmResponse( + partial=True, + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + ) + return + + if len(choices) > 1: + logging.error( + 'Multiple choices found in streaming response but only the first one' + ' will be used.' + ) + choice = choices[0] + + # Accumulate logprobs if present + if 'logprobs' in choice and choice['logprobs']: + self._accumulate_logprobs(choice['logprobs']) + + logprobs_result = None + if self.logprobs: + logprobs_result = _parse_logprobs(self.logprobs) + + delta = choice.get('delta', {}) + partial_parts, role = self._add_chat_completion_chunk_delta(delta) + + yield LlmResponse( + partial=True, + content=types.Content(role=role, parts=partial_parts), + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + logprobs_result=logprobs_result, + ) + + finish_reason = choice.get('finish_reason') + if finish_reason: + yield LlmResponse( + content=types.Content( + role=role, + parts=self._get_content_parts(), + ), + finish_reason=self._map_finish_reason(finish_reason), + custom_metadata=self.custom_metadata, + model_version=self.model, + usage_metadata=usage_metadata, + logprobs_result=logprobs_result, + ) + # Exit because the 'finish_reason' chunk is the final chunk. + return + + def _map_finish_reason(self, reason: str | None) -> types.FinishReason: + if reason == 'stop': + return types.FinishReason.STOP + if reason == 'length': + return types.FinishReason.MAX_TOKENS + if reason == 'tool_calls': + return types.FinishReason.STOP + if reason == 'content_filter': + return types.FinishReason.SAFETY + return types.FinishReason.FINISH_REASON_UNSPECIFIED + + def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None: + """Accumulates logprobs from a chunk.""" + if not self.logprobs: + self.logprobs = {'content': [], 'refusal': []} + + if 'content' in logprobs_chunk and logprobs_chunk['content']: + if 'content' not in self.logprobs: + self.logprobs['content'] = [] + self.logprobs['content'].extend(logprobs_chunk['content']) + + if 'refusal' in logprobs_chunk and logprobs_chunk['refusal']: + if 'refusal' not in self.logprobs: + self.logprobs['refusal'] = [] + self.logprobs['refusal'].extend(logprobs_chunk['refusal']) + + def _append_content(self, content: str, refusal: str) -> str: + if content and refusal: + content += '\n' + content += refusal + elif refusal: + content = refusal + if content: + self.content_parts += content + return content + + def _add_chat_completion_chunk_delta( + self, delta: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a chunk delta from a streaming chat completions response. + + This method processes a single delta chunk from a streaming chat completions + response, accumulating partial content and tool calls. + + Args: + delta: A dictionary representing a single delta from the streaming chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this chunk. + - The role associated with the message. + """ + parts = [] + for tool_call in delta.get('tool_calls', []): + chunk_part = self._upsert_tool_call(tool_call) + parts.append(chunk_part) + content = delta.get('content') + refusal = delta.get('refusal') + merged_content = self._append_content(content, refusal) + if merged_content: + parts.append(types.Part.from_text(text=merged_content)) + + self._get_or_create_role(delta.get('role', 'model')) + return parts, self.role + + def _add_chat_completion_message( + self, message: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a complete chat completion message to the accumulator. + + This method processes a single message from a non-streaming chat completions + response, extracting and accumulating content and tool calls. + + Args: + message: A dictionary representing a single message from the chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this message. + - The role associated with the message. + """ + for tool_call in message.get('tool_calls', []): + self._upsert_tool_call(tool_call) + function_call = message.get('function_call') + if function_call: + # function_call is a single tool call and does not have an id. + self._upsert_tool_call({ + 'type': 'function', + 'function': function_call, + }) + content = message.get('content') + refusal = message.get('refusal') + self._append_content(content, refusal) + + self._get_or_create_role(message.get('role', 'model')) + return self._get_content_parts(), self.role + + def _get_content_parts(self) -> list[types.Part]: + """Returns the content parts from the accumulated response.""" + parts = [] + if self.content_parts: + parts.append(types.Part.from_text(text=self.content_parts)) + sorted_indices = sorted(self.tool_call_parts.keys()) + for index in sorted_indices: + parts.append(self.tool_call_parts[index]) + return parts + + def _upsert_tool_call(self, tool_call: dict[str, Any]) -> types.Part: + """Upserts a tool call into the accumulated tool call parts. + + This method handles partial tool call chunks in streaming responses by + updating existing tool call parts or creating new ones. + + Args: + tool_call: A dictionary representing a tool call or a delta of a tool call + from the chat completions API. + + Returns: + A `types.Part` object representing the updated or newly created tool call. + """ + index = tool_call.get('index') + if index is None: + # If index is not provided, we might be in a non-streaming response. + # We just append it as a new tool call. + index = len(self.tool_call_parts) + + if index not in self.tool_call_parts: + self.tool_call_parts[index] = types.Part( + function_call=types.FunctionCall() + ) + part = self.tool_call_parts[index] + chunk_part = types.Part(function_call=types.FunctionCall()) + call_type = tool_call.get('type') + # TODO: Add support for 'custom' type. + if call_type is not None and call_type != 'function': + raise ValueError( + f'Unsupported tool_call type: {call_type} in call {tool_call}' + ) + func = tool_call.get('function', {}) + args_delta = func.get('arguments', '') + if args_delta: + try: + args = json.loads(args_delta) + chunk_part.function_call.args = args + if not part.function_call.args: + part.function_call.args = dict(args) + else: + part.function_call.args.update(args) + except json.JSONDecodeError as e: + raise ValueError(f'Failed to parse arguments: {args_delta}') from e + + func_name = func.get('name') + if func_name: + part.function_call.name = func_name + chunk_part.function_call.name = func_name + tool_call_id = tool_call.get('id') + if tool_call_id: + part.function_call.id = tool_call_id + chunk_part.function_call.id = tool_call_id + + # Add support for gemini's thought_signature. + thought_signature = ( + tool_call.get('extra_content', {}) + .get('google', {}) + .get('thought_signature', '') + ) + if thought_signature: + if isinstance(thought_signature, str): + thought_signature = base64.b64decode(thought_signature) + part.thought_signature = thought_signature + chunk_part.thought_signature = thought_signature + return chunk_part + + def _get_or_create_role(self, role: str = '') -> str: + if self.role: + return self.role + if role == 'assistant': + role = 'model' + self.role = role + return self.role diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index b954d8a0..8c1568cc 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -70,7 +70,9 @@ if TYPE_CHECKING: from litellm import Function from litellm import Message from litellm import ModelResponse + from litellm import ModelResponseStream from litellm import OpenAIMessageContent + from litellm.types.utils import Delta else: litellm = None acompletion = None @@ -85,7 +87,9 @@ else: Function = None Message = None ModelResponse = None + Delta = None OpenAIMessageContent = None + ModelResponseStream = None logger = logging.getLogger("google_adk." + __name__) @@ -151,6 +155,7 @@ _LITELLM_GLOBAL_SYMBOLS = ( "Function", "Message", "ModelResponse", + "ModelResponseStream", "OpenAIMessageContent", "acompletion", "completion", @@ -382,15 +387,19 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: ] -def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: - """Fetches the reasoning payload from a LiteLLM message or dict.""" +def _extract_reasoning_value(message: Message | Delta | None) -> Any: + """Fetches the reasoning payload from a LiteLLM message. + + Checks for both 'reasoning_content' (LiteLLM standard, used by Azure/Foundry, + Ollama via LiteLLM) and 'reasoning' (used by LM Studio, vLLM). + Prioritizes 'reasoning_content' when both are present. + """ if message is None: return None - if hasattr(message, "reasoning_content"): - return getattr(message, "reasoning_content") - if isinstance(message, dict): - return message.get("reasoning_content") - return None + reasoning_content = message.get("reasoning_content") + if reasoning_content is not None: + return reasoning_content + return message.get("reasoning") class ChatCompletionFileUrlObject(TypedDict, total=False): @@ -1264,7 +1273,7 @@ def _function_declaration_to_tool_param( def _model_response_to_chunk( - response: ModelResponse, + response: ModelResponse | ModelResponseStream, ) -> Generator[ Tuple[ Optional[ @@ -1282,6 +1291,9 @@ def _model_response_to_chunk( ]: """Converts a litellm message to text, function or usage metadata chunk. + LiteLLM streaming chunks carry `delta`, while non-streaming chunks carry + `message`. + Args: response: The response from the model. @@ -1290,18 +1302,46 @@ def _model_response_to_chunk( """ _ensure_litellm_imported() - message = None - if response.get("choices", None): - message = response["choices"][0].get("message", None) - finish_reason = response["choices"][0].get("finish_reason", None) - # check streaming delta - if message is None and response["choices"][0].get("delta", None): - message = response["choices"][0]["delta"] + def _has_meaningful_signal(message: Message | Delta | None) -> bool: + if message is None: + return False + return bool( + message.get("content") + or message.get("tool_calls") + or message.get("function_call") + or message.get("reasoning_content") + or message.get("reasoning") + ) + + if isinstance(response, ModelResponseStream): + message_field = "delta" + elif isinstance(response, ModelResponse): + message_field = "message" + else: + raise TypeError( + "Unexpected response type from LiteLLM: %r" % (type(response),) + ) + + choices = response.get("choices") + if not choices: + yield None, None + else: + choice = choices[0] + finish_reason = choice.get("finish_reason") + if message_field == "delta": + message = choice.get("delta") + else: + message = choice.get("message") + + if message is not None and not _has_meaningful_signal(message): + message = None message_content: Optional[OpenAIMessageContent] = None tool_calls: list[ChatCompletionMessageToolCall] = [] reasoning_parts: List[types.Part] = [] + if message is not None: + # Both Delta and Message support dict-like .get() access ( message_content, tool_calls, @@ -1318,39 +1358,46 @@ def _model_response_to_chunk( if tool_calls: for idx, tool_call in enumerate(tool_calls): - # aggregate tool_call - if tool_call.type == "function": - func_name = tool_call.function.name - func_args = tool_call.function.arguments - func_index = getattr(tool_call, "index", idx) + # LiteLLM tool call objects support dict-like .get() access + if tool_call.get("type") == "function": + function_obj = tool_call.get("function") + if not function_obj: + continue + func_name = function_obj.get("name") + func_args = function_obj.get("arguments") + func_index = tool_call.get("index", idx) + tool_call_id = tool_call.get("id") # Ignore empty chunks that don't carry any information. if not func_name and not func_args: continue yield FunctionChunk( - id=tool_call.id, + id=tool_call_id, name=func_name, args=func_args, index=func_index, ), finish_reason - if finish_reason and not (message_content or tool_calls): + if finish_reason and not (message_content or tool_calls or reasoning_parts): yield None, finish_reason - if not message: - yield None, None - # Ideally usage would be expected with the last ModelResponseStream with a # finish_reason set. But this is not the case we are observing from litellm. # So we are sending it as a separate chunk to be set on the llm_response. - if response.get("usage", None): - yield UsageMetadataChunk( - prompt_tokens=response["usage"].get("prompt_tokens", 0), - completion_tokens=response["usage"].get("completion_tokens", 0), - total_tokens=response["usage"].get("total_tokens", 0), - cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]), - ), None + usage = response.get("usage") + if usage: + try: + yield UsageMetadataChunk( + prompt_tokens=usage.get("prompt_tokens", 0) or 0, + completion_tokens=usage.get("completion_tokens", 0) or 0, + total_tokens=usage.get("total_tokens", 0) or 0, + cached_prompt_tokens=_extract_cached_prompt_tokens(usage), + ), None + except AttributeError as e: + raise TypeError( + "Unexpected LiteLLM usage type: %r" % (type(usage),) + ) from e def _model_response_to_generate_content_response( @@ -1453,6 +1500,54 @@ def _message_to_generate_content_response( ) +def _enforce_strict_openai_schema(schema: dict[str, Any]) -> None: + """Recursively transforms a JSON schema for OpenAI strict structured outputs. + + OpenAI strict mode requires: + 1. additionalProperties: false on all object schemas (including nested/$defs). + 2. All properties listed in 'required' (no optional omissions). + 3. $ref nodes must have no sibling keywords (e.g., no 'description' next to + '$ref'). + + This function mutates the schema dict in place. + + Args: + schema: A JSON schema dictionary to transform. + """ + if not isinstance(schema, dict): + return + + # Strip sibling keywords from $ref nodes (OpenAI rejects them). + if "$ref" in schema: + for key in list(schema.keys()): + if key != "$ref": + del schema[key] + return + + # Ensure all object schemas have additionalProperties: false and list every + # property as required. + if schema.get("type") == "object" and "properties" in schema: + schema["additionalProperties"] = False + schema["required"] = sorted(schema["properties"].keys()) + + # Recurse into $defs (Pydantic's nested model definitions). + for defn in schema.get("$defs", {}).values(): + _enforce_strict_openai_schema(defn) + + # Recurse into property schemas. + for prop in schema.get("properties", {}).values(): + _enforce_strict_openai_schema(prop) + + # Recurse into combinators. + for key in ("anyOf", "oneOf", "allOf"): + for item in schema.get(key, []): + _enforce_strict_openai_schema(item) + + # Recurse into array item schemas. + if "items" in schema and isinstance(schema["items"], dict): + _enforce_strict_openai_schema(schema["items"]) + + def _to_litellm_response_format( response_schema: types.SchemaUnion, model: str, @@ -1477,7 +1572,7 @@ def _to_litellm_response_format( and schema_type.lower() in _LITELLM_STRUCTURED_TYPES ): return response_schema - schema_dict = dict(response_schema) + schema_dict = copy.deepcopy(response_schema) if "title" in schema_dict: schema_name = str(schema_dict["title"]) elif isinstance(response_schema, type) and issubclass( @@ -1488,14 +1583,18 @@ def _to_litellm_response_format( elif isinstance(response_schema, BaseModel): if isinstance(response_schema, types.Schema): # GenAI Schema instances already represent JSON schema definitions. - schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + schema_dict = copy.deepcopy( + response_schema.model_dump(exclude_none=True, mode="json") + ) if "title" in schema_dict: schema_name = str(schema_dict["title"]) else: schema_dict = response_schema.__class__.model_json_schema() schema_name = response_schema.__class__.__name__ elif hasattr(response_schema, "model_dump"): - schema_dict = response_schema.model_dump(exclude_none=True, mode="json") + schema_dict = copy.deepcopy( + response_schema.model_dump(exclude_none=True, mode="json") + ) schema_name = response_schema.__class__.__name__ else: logger.warning( @@ -1513,14 +1612,8 @@ def _to_litellm_response_format( # OpenAI-compatible format (default) per LiteLLM docs: # https://docs.litellm.ai/docs/completion/json_mode - if ( - isinstance(schema_dict, dict) - and schema_dict.get("type") == "object" - and "additionalProperties" not in schema_dict - ): - # OpenAI structured outputs require explicit additionalProperties: false. - schema_dict = dict(schema_dict) - schema_dict["additionalProperties"] = False + if isinstance(schema_dict, dict): + _enforce_strict_openai_schema(schema_dict) return { "type": "json_schema", @@ -1902,6 +1995,57 @@ class LiteLlm(BaseLlm): aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 + + def _finalize_tool_call_response( + *, model_version: str, finish_reason: str + ) -> LlmResponse: + tool_calls = [] + for index, func_data in function_calls.items(): + if func_data["id"]: + tool_calls.append( + ChatCompletionMessageToolCall( + type="function", + id=func_data["id"], + function=Function( + name=func_data["name"], + arguments=func_data["args"], + index=index, + ), + ) + ) + llm_response = _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + content=text, + tool_calls=tool_calls, + ), + model_version=model_version, + thought_parts=list(reasoning_parts) if reasoning_parts else None, + ) + llm_response.finish_reason = _map_finish_reason(finish_reason) + return llm_response + + def _finalize_text_response( + *, model_version: str, finish_reason: str + ) -> LlmResponse: + message_content = text if text else None + llm_response = _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + content=message_content, + ), + model_version=model_version, + thought_parts=list(reasoning_parts) if reasoning_parts else None, + ) + llm_response.finish_reason = _map_finish_reason(finish_reason) + return llm_response + + def _reset_stream_buffers() -> None: + nonlocal text, reasoning_parts + text = "" + reasoning_parts = [] + function_calls.clear() + async for part in await self.llm_client.acompletion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): @@ -1951,58 +2095,49 @@ class LiteLlm(BaseLlm): cached_content_token_count=chunk.cached_prompt_tokens, ) - if ( - finish_reason == "tool_calls" or finish_reason == "stop" - ) and function_calls: - tool_calls = [] - for index, func_data in function_calls.items(): - if func_data["id"]: - tool_calls.append( - ChatCompletionMessageToolCall( - type="function", - id=func_data["id"], - function=Function( - name=func_data["name"], - arguments=func_data["args"], - index=index, - ), - ) - ) + # LiteLLM 1.81+ can set finish_reason="stop" on partial chunks. Only + # finalize tool calls on an explicit tool_calls finish_reason, or on a + # stop-only chunk (no content/tool deltas). + if function_calls and ( + finish_reason == "tool_calls" + or (finish_reason == "stop" and chunk is None) + ): aggregated_llm_response_with_tool_call = ( - _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content=text, - tool_calls=tool_calls, - ), + _finalize_tool_call_response( model_version=part.model, - thought_parts=list(reasoning_parts) - if reasoning_parts - else None, + finish_reason=finish_reason, ) ) - aggregated_llm_response_with_tool_call.finish_reason = ( - _map_finish_reason(finish_reason) - ) - text = "" - reasoning_parts = [] - function_calls.clear() - elif finish_reason == "stop" and (text or reasoning_parts): - message_content = text if text else None - aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", content=message_content - ), + _reset_stream_buffers() + elif ( + finish_reason == "stop" + and (text or reasoning_parts) + and chunk is None + and not function_calls + ): + # Only aggregate text response when we have a true stop signal + # chunk is None means no content in this chunk, just finish signal. + # LiteLLM 1.81+ sets finish_reason="stop" on partial chunks with + # content. + aggregated_llm_response = _finalize_text_response( model_version=part.model, - thought_parts=list(reasoning_parts) - if reasoning_parts - else None, + finish_reason=finish_reason, ) - aggregated_llm_response.finish_reason = _map_finish_reason( - finish_reason - ) - text = "" - reasoning_parts = [] + _reset_stream_buffers() + + if function_calls and not aggregated_llm_response_with_tool_call: + aggregated_llm_response_with_tool_call = _finalize_tool_call_response( + model_version=part.model, + finish_reason="tool_calls", + ) + _reset_stream_buffers() + + if (text or reasoning_parts) and not aggregated_llm_response: + aggregated_llm_response = _finalize_text_response( + model_version=part.model, + finish_reason="stop", + ) + _reset_stream_buffers() # waiting until streaming ends to yield the llm_response as litellm tends # to send chunk that contains usage_metadata after the chunk with diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py index 08d6b861..37f1852b 100644 --- a/src/google/adk/models/llm_request.py +++ b/src/google/adk/models/llm_request.py @@ -25,6 +25,7 @@ from pydantic import Field from ..agents.context_cache_config import ContextCacheConfig from ..tools.base_tool import BaseTool +from ..utils._schema_utils import SchemaType from .cache_metadata import CacheMetadata @@ -273,12 +274,27 @@ class LlmRequest(BaseModel): # No existing tool with function_declarations, create new one self.config.tools.append(types.Tool(function_declarations=declarations)) - def set_output_schema(self, base_model: type[BaseModel]) -> None: + def set_output_schema( + self, + output_schema: Optional[SchemaType] = None, + *, + base_model: Optional[SchemaType] = None, + ) -> None: """Sets the output schema for the request. Args: - base_model: The pydantic base model to set the output schema to. + output_schema: The output schema to set. Supports all types from + SchemaUnion: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - list[type[BaseModel]]: A generic list type (e.g., list[MySchema]) + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type + base_model: Deprecated alias for output_schema. Use output_schema instead. """ + schema = output_schema or base_model + if schema is None: + raise ValueError("Either output_schema or base_model must be provided.") - self.config.response_schema = base_model + self.config.response_schema = schema self.config.response_mime_type = "application/json" diff --git a/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py b/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py new file mode 100644 index 00000000..0627aced --- /dev/null +++ b/src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py @@ -0,0 +1,323 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Optional + +from google.genai import types as genai_types +from pydantic import BaseModel +from pydantic import Field + +from ..agents.llm_agent import Agent +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from ..utils.context_utils import Aclosing +from ..utils.feature_decorator import experimental +from .agent_optimizer import AgentOptimizer +from .data_types import BaseAgentWithScores +from .data_types import OptimizerResult +from .data_types import UnstructuredSamplingResult +from .sampler import Sampler + +_logger = logging.getLogger("google_adk." + __name__) + +_AGENT_PROMPT_NAME = "agent_prompt" + + +class GEPARootAgentPromptOptimizerConfig(BaseModel): + """Contains configuration options required by the GEPARootAgentPromptOptimizer.""" + + optimizer_model: str = Field( + default="gemini-2.5-flash", + description=( + "The model used to analyze the eval results and optimize the agent." + ), + ) + + model_configuration: genai_types.GenerateContentConfig = Field( + default_factory=lambda: genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=10240, + ) + ), + description="The configuration for the optimizer model.", + ) + + max_metric_calls: int = Field( + default=100, + description="The maximum number of metric calls (evaluations) to make.", + ) + + reflection_minibatch_size: int = Field( + default=3, + description="The number of examples to use for reflection.", + ) + + run_dir: Optional[str] = Field( + default=None, + description=( + "The directory to save the intermediate/final optimization results." + ), + ) + + +class GEPARootAgentPromptOptimizerResult(OptimizerResult[BaseAgentWithScores]): + """The final result of the GEPARootAgentPromptOptimizer.""" + + gepa_result: Optional[dict[str, Any]] = Field( + default=None, + description="The raw result dictionary from the GEPA optimizer.", + ) + + +def _create_agent_gepa_adapter_class(): + """Creates the _AgentGEPAAdapter class dynamically to avoid top-level gepa imports.""" + from gepa.core.adapter import EvaluationBatch + from gepa.core.adapter import GEPAAdapter + + class _AgentGEPAAdapter(GEPAAdapter[str, dict[str, Any], dict[str, Any]]): + """A GEPA adapter for ADK agents.""" + + def __init__( + self, + initial_agent: Agent, + sampler: Sampler[UnstructuredSamplingResult], + main_loop: asyncio.AbstractEventLoop, + ): + self._initial_agent = initial_agent + self._sampler = sampler + self._main_loop = main_loop + + self._train_example_ids = set(sampler.get_train_example_ids()) + self._validation_example_ids = set(sampler.get_validation_example_ids()) + + def evaluate( + self, + batch: list[str], + candidate: dict[str, str], + capture_traces: bool = False, + ) -> EvaluationBatch[dict[str, Any], dict[str, Any]]: + prompt = candidate[_AGENT_PROMPT_NAME] + _logger.info( + "Evaluating agent on batch:\n%s\nwith prompt:\n%s", batch, prompt + ) + # Clone the agent and update the instruction + new_agent = self._initial_agent.clone(update={"instruction": prompt}) + + if set(batch) <= self._train_example_ids: + example_set = "train" + elif set(batch) <= self._validation_example_ids: + example_set = "validation" + else: + raise ValueError(f"Invalid batch composition: {batch}") + + # Run the evaluation in the main loop + future = asyncio.run_coroutine_threadsafe( + self._sampler.sample_and_score( + new_agent, + example_set=example_set, + batch=batch, + capture_full_eval_data=capture_traces, + ), + self._main_loop, + ) + result: UnstructuredSamplingResult = future.result() + + scores = [] + outputs = [] + trajectories = [] + + for example_id in batch: + score = result.scores[example_id] + scores.append(score) + + eval_data = result.data.get(example_id, {}) if result.data else {} + outputs.append(eval_data) + trajectories.append(eval_data) + + return EvaluationBatch( + outputs=outputs, scores=scores, trajectories=trajectories + ) + + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[dict[str, Any], dict[str, Any]], + components_to_update: list[str], + ) -> dict[str, list[dict[str, Any]]]: + dataset: list[dict[str, Any]] = [] + trace_instances: list[tuple[float, dict[str, Any]]] = list( + zip( + eval_batch.scores, + eval_batch.trajectories, + strict=True, + ) + ) + for trace_instance in trace_instances: + score, eval_data = trace_instance + + dataset.append({ + _AGENT_PROMPT_NAME: candidate[_AGENT_PROMPT_NAME], + "score": score, + "eval_data": eval_data, + }) + + # same data for all components (should be only one) + result = {comp: dataset for comp in components_to_update} + + return result + + return _AgentGEPAAdapter + + +@experimental +class GEPARootAgentPromptOptimizer( + AgentOptimizer[UnstructuredSamplingResult, BaseAgentWithScores] +): + """An optimizer that improves the root agent prompt using the GEPA framework.""" + + def __init__( + self, + config: GEPARootAgentPromptOptimizerConfig, + ): + self._config = config + llm_registry = LLMRegistry() + self._llm_class = llm_registry.resolve(self._config.optimizer_model) + + async def optimize( + self, + initial_agent: Agent, + sampler: Sampler[UnstructuredSamplingResult], + ) -> GEPARootAgentPromptOptimizerResult: + """Runs the GEPARootAgentPromptOptimizer. + + Args: + initial_agent: The initial agent whose prompt is to be optimized. Only the + root agent prompt will be optimized. + sampler: The interface used to get training and validation example UIDs, + request agent evaluations, and get useful data for optimizing the agent. + + Returns: + The final result of the optimization process, containing the optimized + agent instance, its scores on the validation examples, and other metrics. + """ + if initial_agent.sub_agents: + _logger.warning( + "The GEPARootAgentPromptOptimizer will not optimize prompts for" + " sub-agents." + ) + + _logger.info("Setting up the GEPA optimizer...") + + try: + import gepa # lazy import as gepa is not in core ADK package + + _AgentGEPAAdapter = _create_agent_gepa_adapter_class() + except ImportError as e: + raise ImportError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e + + loop = asyncio.get_running_loop() + + adapter = _AgentGEPAAdapter( + initial_agent=initial_agent, + sampler=sampler, + main_loop=loop, + ) + + llm = self._llm_class(model=self._config.optimizer_model) + + def reflection_lm(prompt: str) -> str: + llm_request = LlmRequest( + model=self._config.optimizer_model, + config=self._config.model_configuration, + contents=[ + genai_types.Content( + parts=[genai_types.Part(text=prompt)], + role="user", + ) + ], + ) + + async def _generate(): + response_text = "" + async with Aclosing(llm.generate_content_async(llm_request)) as agen: + async for llm_response in agen: + llm_response: LlmResponse + generated_content: genai_types.Content = llm_response.content + if not generated_content.parts: + continue + response_text = "".join( + part.text + for part in generated_content.parts + if part.text and not part.thought + ) + return response_text + + future = asyncio.run_coroutine_threadsafe(_generate(), loop) + return future.result() + + train_ids = sampler.get_train_example_ids() + val_ids = sampler.get_validation_example_ids() + + if set(train_ids).intersection(val_ids): + _logger.warning( + "The training and validation example UIDs overlap. This WILL cause" + " aliasing issues unless each common UID refers to the same example" + " in both sets." + ) + + def run_gepa(): + return gepa.optimize( + seed_candidate={_AGENT_PROMPT_NAME: initial_agent.instruction}, + trainset=train_ids, + valset=val_ids, + adapter=adapter, + max_metric_calls=self._config.max_metric_calls, + reflection_lm=reflection_lm, + reflection_minibatch_size=self._config.reflection_minibatch_size, + run_dir=self._config.run_dir, + ) + + _logger.info("Running the GEPA optimizer...") + + gepa_results = await loop.run_in_executor(None, run_gepa) + + _logger.info("GEPA optimization finished. Preparing final results...") + + optimized_prompts = [ + candidate[_AGENT_PROMPT_NAME] for candidate in gepa_results.candidates + ] + scores = gepa_results.val_aggregate_scores + + optimized_agents = [ + BaseAgentWithScores( + optimized_agent=initial_agent.clone( + update={"instruction": optimized_prompt}, + ), + overall_score=score, + ) + for optimized_prompt, score in zip(optimized_prompts, scores) + ] + + return GEPARootAgentPromptOptimizerResult( + optimized_agents=optimized_agents, + gepa_result=gepa_results.to_dict(), + ) diff --git a/src/google/adk/optimization/local_eval_sampler.py b/src/google/adk/optimization/local_eval_sampler.py new file mode 100644 index 00000000..b00c3428 --- /dev/null +++ b/src/google/adk/optimization/local_eval_sampler.py @@ -0,0 +1,367 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any +from typing import Literal +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + +from ..agents.llm_agent import Agent +from ..evaluation.base_eval_service import EvaluateConfig +from ..evaluation.base_eval_service import EvaluateRequest +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.base_eval_service import InferenceResult +from ..evaluation.eval_case import get_all_tool_calls_with_responses +from ..evaluation.eval_case import IntermediateData +from ..evaluation.eval_case import Invocation +from ..evaluation.eval_case import InvocationEvents +from ..evaluation.eval_config import EvalConfig +from ..evaluation.eval_config import get_eval_metrics_from_config +from ..evaluation.eval_metrics import EvalStatus +from ..evaluation.eval_result import EvalCaseResult +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..evaluation.local_eval_service import LocalEvalService +from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider +from ..utils.context_utils import Aclosing +from .data_types import UnstructuredSamplingResult +from .sampler import Sampler + +logger = logging.getLogger("google_adk." + __name__) + + +def _log_eval_summary(eval_results: list[EvalCaseResult]): + """Logs a summary of eval results.""" + num_pass, num_fail, num_other = 0, 0, 0 + for eval_result in eval_results: + eval_result: EvalCaseResult + if eval_result.final_eval_status == EvalStatus.PASSED: + num_pass += 1 + elif eval_result.final_eval_status == EvalStatus.FAILED: + num_fail += 1 + else: + num_other += 1 + log_str = f"Evaluation summary: {num_pass} PASSED, {num_fail} FAILED" + if num_other: + log_str += f", {num_other} OTHER" + logger.info(log_str) + + +def extract_tool_call_data( + intermediate_data: IntermediateData | InvocationEvents, +) -> list[dict[str, Any]]: + """Extracts tool calls and their responses from intermediate data.""" + call_response_pairs = get_all_tool_calls_with_responses(intermediate_data) + result = [] + for tool_call, tool_response in call_response_pairs: + result.append({ + "name": tool_call.name, + "args": tool_call.args, + "response": tool_response.response if tool_response else None, + }) + return result + + +def extract_single_invocation_info( + invocation: Invocation, +) -> dict[str, Any]: + """Extracts useful information from a single invocation.""" + user_prompt = "" + for part in invocation.user_content.parts: + if part.text and not part.thought: + user_prompt += part.text + agent_response = "" + if invocation.final_response: + for part in invocation.final_response.parts: + if part.text and not part.thought: + agent_response += part.text + result = {"user_prompt": user_prompt, "agent_response": agent_response} + if invocation.intermediate_data: + tool_call_data = extract_tool_call_data(invocation.intermediate_data) + result["tool_calls"] = tool_call_data + return result + + +class LocalEvalSamplerConfig(BaseModel): + """Contains configuration options required by the LocalEvalServiceInterface.""" + + eval_config: EvalConfig = Field( + required=True, + description="The configuration for the evaluation.", + ) + + app_name: str = Field( + required=True, + description="The app name to use for evaluation.", + ) + + train_eval_set: str = Field( + required=True, + description="The name of the eval set to use for optimization.", + ) + + train_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for optimization. If not provided," + " all eval cases in the train_eval_set will be used." + ), + ) + + validation_eval_set: Optional[str] = Field( + default=None, + description=( + "The name of the eval set to use for validating the optimized agent." + " If not provided, the train_eval_set will also be used for" + " validation." + ), + ) + + validation_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for validating the optimized agent." + " If not provided, all eval cases in the validation_eval_set will be" + " used. If validation_eval_set is also not provided, all train eval" + " cases will be used." + ), + ) + + +class LocalEvalSampler(Sampler[UnstructuredSamplingResult]): + """Evaluates candidate agents with the ADK's LocalEvalService.""" + + def __init__( + self, + config: LocalEvalSamplerConfig, + eval_sets_manager: EvalSetsManager, + ): + self._config = config + self._eval_sets_manager = eval_sets_manager + + self._train_eval_set = self._config.train_eval_set + self._train_eval_case_ids = ( + self._config.train_eval_case_ids + or self._get_eval_case_ids(self._train_eval_set) + ) + + self._validation_eval_set = ( + self._config.validation_eval_set or self._train_eval_set + ) + if self._config.validation_eval_case_ids: + self._validation_eval_case_ids = self._config.validation_eval_case_ids + elif self._config.validation_eval_set: + self._validation_eval_case_ids = self._get_eval_case_ids( + self._validation_eval_set + ) + else: + self._validation_eval_case_ids = self._train_eval_case_ids + + def _get_selected_example_set_id( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> str: + """Returns the ID of the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_set, + Sampler.VALIDATION_SET: self._validation_eval_set, + }[example_set] + + def _get_all_example_ids( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> list[str]: + """Returns the IDs of all examples in the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_case_ids, + Sampler.VALIDATION_SET: self._validation_eval_case_ids, + }[example_set] + + def _get_eval_case_ids(self, eval_set_id: str) -> list[str]: + """Returns the ids of eval cases in the given eval set.""" + eval_set = self._eval_sets_manager.get_eval_set( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + ) + if eval_set: + return [eval_case.eval_id for eval_case in eval_set.eval_cases] + else: + raise ValueError( + f"Eval set `{eval_set_id}` does not exist for app" + f" `{self._config.app_name}`." + ) + + async def _evaluate_agent( + self, + agent: Agent, + eval_set_id: str, + eval_case_ids: list[str], + ) -> list[EvalCaseResult]: + """Evaluates the agent on the requested eval cases and returns the results. + + Args: + agent: The agent to evaluate. + eval_set_id: The id of the eval set to use for evaluation. + eval_case_ids: The ids of the eval cases to use for evaluation. + + Returns: + A list of EvalCaseResult, one per eval case. + """ + # create the inference request + inference_request = InferenceRequest( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) + + # create the LocalEvalService + user_simulator_provider = UserSimulatorProvider( + self._config.eval_config.user_simulator_config + ) + eval_service = LocalEvalService( + root_agent=agent, + eval_sets_manager=self._eval_sets_manager, + user_simulator_provider=user_simulator_provider, + ) + + # inference/sampling + async with Aclosing( + eval_service.perform_inference(inference_request=inference_request) + ) as agen: + inference_results: list[InferenceResult] = [ + inference_result async for inference_result in agen + ] + + # evaluation + eval_metrics = get_eval_metrics_from_config(self._config.eval_config) + evaluate_request = EvaluateRequest( + inference_results=inference_results, + evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), + ) + async with Aclosing( + eval_service.evaluate(evaluate_request=evaluate_request) + ) as agen: + eval_results: list[EvalCaseResult] = [ + eval_result async for eval_result in agen + ] + + return eval_results + + def _extract_eval_data( + self, + eval_set_id: str, + eval_results: list[EvalCaseResult], + ) -> dict[str, dict[str, Any]]: + """Extracts evaluation data from the eval results.""" + eval_data = {} + for eval_result in eval_results: + eval_result_dict = {} + eval_case = self._eval_sets_manager.get_eval_case( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_id=eval_result.eval_id, + ) + if eval_case and eval_case.conversation_scenario: + eval_result_dict["conversation_scenario"] = ( + eval_case.conversation_scenario + ) + + per_invocation_results = [] + for ( + per_invocation_result + ) in eval_result.eval_metric_result_per_invocation: + eval_metric_results = [] + for eval_metric_result in per_invocation_result.eval_metric_results: + eval_metric_results.append({ + "metric_name": eval_metric_result.metric_name, + "score": round(eval_metric_result.score, 2), # accurate enough + "eval_status": eval_metric_result.eval_status.name, + }) + per_invocation_result_dict = { + "actual_invocation": extract_single_invocation_info( + per_invocation_result.actual_invocation + ), + "eval_metric_results": eval_metric_results, + } + if per_invocation_result.expected_invocation: + per_invocation_result_dict["expected_invocation"] = ( + extract_single_invocation_info( + per_invocation_result.expected_invocation + ) + ) + per_invocation_results.append(per_invocation_result_dict) + eval_result_dict["invocations"] = per_invocation_results + eval_data[eval_result.eval_id] = eval_result_dict + + return eval_data + + def get_train_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for training the agent.""" + return self._train_eval_case_ids + + def get_validation_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for validating the optimized agent.""" + return self._validation_eval_case_ids + + async def sample_and_score( + self, + candidate: Agent, + example_set: Literal[ + Sampler.TRAIN_SET, Sampler.VALIDATION_SET + ] = Sampler.VALIDATION_SET, + batch: Optional[list[str]] = None, + capture_full_eval_data: bool = False, + ) -> UnstructuredSamplingResult: + """Evaluates the candidate agent on the batch of examples using the ADK LocalEvalService. + + Args: + candidate: The candidate agent to be evaluated. + example_set: The set of examples to evaluate the candidate agent on. + Possible values are "train" and "validation". + batch: UIDs of examples to evaluate the candidate agent on. If not + provided, all examples from the chosen set will be used. + capture_full_eval_data: If false, it is enough to only calculate the + scores for each example. If true, this method should also capture all + other data required for optimizing the agent (e.g., outputs, + trajectories, and tool calls). + + Returns: + The evaluation results, containing the scores for each example and (if + requested) other data required for optimization. + """ + eval_set_id = self._get_selected_example_set_id(example_set) + if batch is None: + batch = self._get_all_example_ids(example_set) + + eval_results = await self._evaluate_agent(candidate, eval_set_id, batch) + _log_eval_summary(eval_results) + + scores = { + eval_result.eval_id: ( + 1.0 if eval_result.final_eval_status == EvalStatus.PASSED else 0.0 + ) + for eval_result in eval_results + } + + eval_data = ( + self._extract_eval_data(eval_set_id, eval_results) + if capture_full_eval_data + else None + ) + + return UnstructuredSamplingResult(scores=scores, data=eval_data) diff --git a/src/google/adk/optimization/sampler.py b/src/google/adk/optimization/sampler.py index 0a0ff45d..632e5d3d 100644 --- a/src/google/adk/optimization/sampler.py +++ b/src/google/adk/optimization/sampler.py @@ -32,6 +32,9 @@ class Sampler(ABC, Generic[SamplingResult]): to get evaluation results for the candidate agent on the batch of examples. """ + TRAIN_SET = "train" + VALIDATION_SET = "validation" + @abstractmethod def get_train_example_ids(self) -> list[str]: """Returns the UIDs of examples to use for training the agent.""" @@ -46,7 +49,7 @@ class Sampler(ABC, Generic[SamplingResult]): async def sample_and_score( self, candidate: Agent, - example_set: Literal["train", "validation"] = "validation", + example_set: Literal[TRAIN_SET, VALIDATION_SET] = VALIDATION_SET, batch: Optional[list[str]] = None, capture_full_eval_data: bool = False, ) -> SamplingResult: diff --git a/src/google/adk/planners/plan_re_act_planner.py b/src/google/adk/planners/plan_re_act_planner.py index f7930b14..dab3a1fe 100644 --- a/src/google/adk/planners/plan_re_act_planner.py +++ b/src/google/adk/planners/plan_re_act_planner.py @@ -168,7 +168,7 @@ Follow this format when answering the question: (1) The planning part should be planning_preamble = f""" Below are the requirements for the planning: The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take. -If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under {REPLANNING_TAG}. Then use tools to follow the new plan. +If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be under {REPLANNING_TAG}. Then use tools to follow the new plan. """ reasoning_preamble = """ diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 5b0fcf55..ce028cf4 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -27,6 +27,14 @@ import functools import json import logging import mimetypes +import os + +# Enable gRPC fork support so child processes created via os.fork() +# can safely create new gRPC channels. Must be set before grpc's +# C-core is loaded (which happens through the google.api_core +# imports below). setdefault respects any explicit user override. +os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1") + import random import time from types import MappingProxyType @@ -51,7 +59,6 @@ from google.cloud.bigquery import schema as bq_schema from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types -from opentelemetry import context from opentelemetry import trace import pyarrow as pa @@ -71,6 +78,34 @@ tracer = trace.get_tracer( "google.adk.plugins.bigquery_agent_analytics", __version__ ) +# Bumped when the schema changes (1 → 2 → 3 …). Used as a table +# label for governance and to decide whether auto-upgrade should run. +_SCHEMA_VERSION = "1" +_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version" + +_HITL_EVENT_MAP = MappingProxyType({ + "adk_request_credential": "HITL_CREDENTIAL_REQUEST", + "adk_request_confirmation": "HITL_CONFIRMATION_REQUEST", + "adk_request_input": "HITL_INPUT_REQUEST", +}) + +# Track all living plugin instances so the fork handler can reset +# them proactively in the child, before _ensure_started runs. +_LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet() + + +def _after_fork_in_child() -> None: + """Reset every living plugin instance after os.fork().""" + for plugin in list(_LIVE_PLUGINS): + try: + plugin._reset_runtime_state() + except Exception: + pass + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=_after_fork_in_child) + def _safe_callback(func): """Decorator that catches and logs exceptions in plugin callbacks. @@ -132,6 +167,47 @@ def _format_content( return " | ".join(parts), truncated +def _get_tool_origin(tool: "BaseTool") -> str: + """Returns the provenance category of a tool. + + Uses lazy imports to avoid circular dependencies. + + Args: + tool: The tool instance. + + Returns: + One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN. + """ + # Import lazily to avoid circular dependencies. + # pylint: disable=g-import-not-at-top + from ..tools.agent_tool import AgentTool # pytype: disable=import-error + from ..tools.function_tool import FunctionTool # pytype: disable=import-error + from ..tools.transfer_to_agent_tool import TransferToAgentTool # pytype: disable=import-error + + try: + from ..tools.mcp_tool.mcp_tool import McpTool # pytype: disable=import-error + except ImportError: + McpTool = None + + try: + from ..agents.remote_a2a_agent import RemoteA2aAgent # pytype: disable=import-error + except ImportError: + RemoteA2aAgent = None + + # Order matters: TransferToAgentTool is a subclass of FunctionTool. + if McpTool is not None and isinstance(tool, McpTool): + return "MCP" + if isinstance(tool, TransferToAgentTool): + return "TRANSFER_AGENT" + if isinstance(tool, AgentTool): + if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent): + return "A2A" + return "SUB_AGENT" + if isinstance(tool, FunctionTool): + return "LOCAL" + return "UNKNOWN" + + def _recursive_smart_truncate( obj: Any, max_len: int, seen: Optional[set[int]] = None ) -> tuple[Any, bool]: @@ -412,7 +488,7 @@ class BigQueryLoggerConfig: event_allowlist: list[str] | None = None event_denylist: list[str] | None = None max_content_length: int = 500 * 1024 # Defaults to 500KB per text block - table_id: str = "agent_events_v2" + table_id: str = "agent_events" # V2 Configuration clustering_fields: list[str] = field( @@ -435,27 +511,56 @@ class BigQueryLoggerConfig: log_session_metadata: bool = True # Static custom tags (e.g. {"agent_role": "sales"}) custom_tags: dict[str, Any] = field(default_factory=dict) + # Automatically add new columns to existing tables when the plugin + # schema evolves. Only additive changes are made (columns are never + # dropped or altered). Safe to leave enabled; a version label on the + # table ensures the diff runs at most once per schema version. + auto_schema_upgrade: bool = True + # Automatically create per-event-type BigQuery views that unnest + # JSON columns into typed, queryable columns. + create_views: bool = True # ============================================================================== # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== +# NOTE: These contextvars are module-global, not plugin-instance-scoped. +# This is safe in practice for two reasons: +# 1. PluginManager enforces name-uniqueness, preventing two BQ plugin +# instances on the same Runner. +# 2. Concurrent asyncio tasks (e.g. two Runners in asyncio.gather) each +# get an isolated contextvar copy, so they don't interfere. +# The only problematic case would be two plugin instances interleaved +# within the *same* asyncio task without task boundaries — which the +# framework's PluginManager already prevents. _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) +# Tracks the invocation_id that owns the current span stack so that +# ensure_invocation_span() can distinguish "same invocation re-entry" +# (idempotent) from "stale records from a previous invocation" (clear). +_active_invocation_id_ctx: contextvars.ContextVar[Optional[str]] = ( + contextvars.ContextVar("_bq_analytics_active_invocation_id", default=None) +) + @dataclass class _SpanRecord: """A single record on the unified span stack. - Consolidates span, token, id, ownership, and timing into one object + Consolidates span, id, ownership, and timing into one object so all stacks stay in sync by construction. + + Note: The plugin intentionally does NOT attach its spans to the + ambient OTel context (no ``context.attach``). This prevents the + plugin from corrupting the framework's span hierarchy when an + external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``) + is active. See https://github.com/google/adk-python/issues/4561. """ span: trace.Span - token: Any # opentelemetry context token span_id: str owns_span: bool start_time_ns: int @@ -485,12 +590,13 @@ class TraceManager: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _root_agent_name_ctx.get() is None: - try: - root_agent = callback_context._invocation_context.agent.root_agent - _root_agent_name_ctx.set(root_agent.name) - except (AttributeError, ValueError): - pass + # Always refresh root_agent_name — it can change between + # invocations (e.g. different root agents in the same task). + try: + root_agent = callback_context._invocation_context.agent.root_agent + _root_agent_name_ctx.set(root_agent.name) + except (AttributeError, ValueError): + pass # Ensure records stack is initialized TraceManager._get_records() @@ -513,17 +619,35 @@ class TraceManager: @staticmethod def push_span( - callback_context: CallbackContext, span_name: Optional[str] = "adk-span" + callback_context: CallbackContext, + span_name: Optional[str] = "adk-span", ) -> str: """Starts a new span and pushes it onto the stack. - If OTel is not configured (returning non-recording spans), a UUID fallback - is generated to ensure span_id and parent_span_id are populated in logs. + The span is created but NOT attached to the ambient OTel context, + so it cannot corrupt the framework's own span hierarchy. The + plugin tracks span_id / parent_span_id internally via its own + contextvar stack. + + If OTel is not configured (returning non-recording spans), a UUID + fallback is generated to ensure span_id and parent_span_id are + populated in BigQuery logs. """ TraceManager.init_trace(callback_context) - span = tracer.start_span(span_name) - token = context.attach(trace.set_span_in_context(span)) + # Create the span without attaching it to the ambient context. + # This avoids re-parenting framework spans like ``call_llm`` + # or ``execute_tool``. See #4561. + # + # If the internal stack already has a span, create the new span + # as a child so it shares the same trace_id. Without this, each + # ``start_span`` would be an independent root with its own + # trace_id — causing trace_id fracture (see #4645). + records = TraceManager._get_records() + parent_ctx = None + if records and records[-1].span.get_span_context().is_valid: + parent_ctx = trace.set_span_in_context(records[-1].span) + span = tracer.start_span(span_name, context=parent_ctx) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -532,13 +656,11 @@ class TraceManager: record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=True, start_time_ns=time.time_ns(), ) - records = TraceManager._get_records() new_records = list(records) + [record] _span_records_ctx.set(new_records) @@ -548,11 +670,14 @@ class TraceManager: def attach_current_span( callback_context: CallbackContext, ) -> str: - """Attaches the current OTEL span to the stack without owning it.""" + """Records the current OTel span on the stack without owning it. + + The span is NOT re-attached to the ambient context; it is only + tracked internally for span_id / parent_span_id resolution. + """ TraceManager.init_trace(callback_context) span = trace.get_current_span() - token = context.attach(trace.set_span_in_context(span)) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -561,7 +686,6 @@ class TraceManager: record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=False, start_time_ns=time.time_ns(), @@ -573,9 +697,56 @@ class TraceManager: return span_id_str + @staticmethod + def ensure_invocation_span( + callback_context: CallbackContext, + ) -> None: + """Ensures a root span exists on the plugin stack for this invocation. + + Must be called before any events are logged so that every event in + the invocation shares the same trace_id. + + * If the stack has entries for the *current* invocation → no-op + (idempotent within the same invocation). + * If the stack has entries from a *different* invocation → clear + stale records and re-initialise (safety net for abnormal exit). + * If the ambient OTel span is valid → ``attach_current_span`` + (reuse the runner's span without owning it). + * Otherwise → ``push_span("invocation")`` (create a new root + span that will be popped in ``after_run_callback``). + """ + current_inv = callback_context.invocation_id + active_inv = _active_invocation_id_ctx.get() + + records = _span_records_ctx.get() + if records: + if active_inv == current_inv: + return # Already initialised for this invocation. + # Stale records from a previous invocation that wasn't cleaned + # up (e.g. exception skipped after_run_callback). Clear and + # re-init. + logger.debug( + "Clearing %d stale span records from previous invocation.", + len(records), + ) + TraceManager.clear_stack() + + _active_invocation_id_ctx.set(current_inv) + + # Check for a valid ambient span (e.g. the Runner's invocation span). + ambient = trace.get_current_span() + if ambient.get_span_context().is_valid: + TraceManager.attach_current_span(callback_context) + else: + TraceManager.push_span(callback_context, "invocation") + @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - """Ends the current span and pops it from the stack.""" + """Ends the current span and pops it from the stack. + + No ambient OTel context is detached because we never attached + one in the first place (see ``push_span``). + """ records = _span_records_ctx.get() if not records: return None, None @@ -595,10 +766,19 @@ class TraceManager: if record.owns_span: record.span.end() - context.detach(record.token) - return record.span_id, duration_ms + @staticmethod + def clear_stack() -> None: + """Clears all span records. Safety net for cross-invocation cleanup.""" + records = _span_records_ctx.get() + if records: + # End any owned spans to avoid OTel resource leaks. + for record in reversed(records): + if record.owns_span: + record.span.end() + _span_records_ctx.set([]) + @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: """Gets current span_id and parent span_id.""" @@ -1244,7 +1424,10 @@ class HybridContentParser: if content.config and getattr(content.config, "system_instruction", None): si = content.config.system_instruction if isinstance(si, str): - json_payload["system_prompt"] = si + truncated_si, trunc = process_text(si) + if trunc: + is_truncated = True + json_payload["system_prompt"] = truncated_si else: summary, parts, trunc = await self._parse_content_object(si) if trunc: @@ -1501,6 +1684,115 @@ def _get_events_schema() -> list[bigquery.SchemaField]: ] +# ============================================================================== +# ANALYTICS VIEW DEFINITIONS +# ============================================================================== + +# Columns included in every per-event-type view. +_VIEW_COMMON_COLUMNS = ( + "timestamp", + "event_type", + "agent", + "session_id", + "invocation_id", + "user_id", + "trace_id", + "span_id", + "parent_span_id", + "status", + "error_message", + "is_truncated", +) + +# Per-event-type column extractions. Each value is a list of +# ``"SQL_EXPR AS alias"`` strings that will be appended after the +# common columns in the view SELECT. +_EVENT_VIEW_DEFS: dict[str, list[str]] = { + "USER_MESSAGE_RECEIVED": [], + "LLM_REQUEST": [ + "JSON_VALUE(attributes, '$.model') AS model", + "content AS request_content", + "JSON_QUERY(attributes, '$.llm_config') AS llm_config", + "JSON_QUERY(attributes, '$.tools') AS tools", + ], + "LLM_RESPONSE": [ + "JSON_QUERY(content, '$.response') AS response", + ( + "CAST(JSON_VALUE(content, '$.usage.prompt')" + " AS INT64) AS usage_prompt_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.completion')" + " AS INT64) AS usage_completion_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.total')" + " AS INT64) AS usage_total_tokens" + ), + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ( + "CAST(JSON_VALUE(latency_ms," + " '$.time_to_first_token_ms') AS INT64) AS ttft_ms" + ), + "JSON_VALUE(attributes, '$.model_version') AS model_version", + "JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata", + ], + "LLM_ERROR": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_STARTING": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + ], + "TOOL_COMPLETED": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.result') AS tool_result", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_ERROR": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "AGENT_STARTING": [ + "JSON_VALUE(content, '$.text_summary') AS agent_instruction", + ], + "AGENT_COMPLETED": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "INVOCATION_STARTING": [], + "INVOCATION_COMPLETED": [], + "STATE_DELTA": [ + "JSON_QUERY(attributes, '$.state_delta') AS state_delta", + ], + "HITL_CREDENTIAL_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_CONFIRMATION_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_INPUT_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], +} + +_VIEW_SQL_TEMPLATE = """\ +CREATE OR REPLACE VIEW `{project}.{dataset}.{view_name}` AS +SELECT + {columns} +FROM + `{project}.{dataset}.{table}` +WHERE + event_type = '{event_type}' +""" + + # ============================================================================== # MAIN PLUGIN # ============================================================================== @@ -1512,7 +1804,7 @@ class _LoopState: batch_processor: BatchProcessor -@dataclass +@dataclass(kw_only=True) class EventData: """Typed container for structured fields passed to _log_event.""" @@ -1526,10 +1818,11 @@ class EventData: status: str = "OK" error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) + trace_id_override: Optional[str] = None class BigQueryAgentAnalyticsPlugin(BasePlugin): - """BigQuery Agent Analytics Plugin (v2.0 using Write API). + """BigQuery Agent Analytics Plugin using Write API. Logs agent events (LLM requests, tool calls, etc.) to BigQuery for analytics. Uses the BigQuery Write API for efficient, asynchronous, and reliable logging. @@ -1570,6 +1863,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self.location = location self._started = False + self._startup_error: Optional[Exception] = None self._is_shutting_down = False self._setup_lock = None self.client = None @@ -1580,6 +1874,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self.parser: Optional[HybridContentParser] = None self._schema = None self.arrow_schema = None + self._init_pid = os.getpid() + _LIVE_PLUGINS.add(self) def _cleanup_stale_loop_states(self) -> None: """Removes entries for event loops that have been closed.""" @@ -1822,20 +2118,34 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) def _ensure_schema_exists(self) -> None: - """Ensures the BigQuery table exists with the correct schema.""" + """Ensures the BigQuery table exists with the correct schema. + + When ``config.auto_schema_upgrade`` is True and the table already + exists, missing columns are added automatically (additive only). + A ``adk_schema_version`` label is written for governance. + """ try: - self.client.get_table(self.full_table_id) + existing_table = self.client.get_table(self.full_table_id) + if self.config.auto_schema_upgrade: + self._maybe_upgrade_schema(existing_table) + if self.config.create_views: + self._create_analytics_views() except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) tbl.time_partitioning = bigquery.TimePartitioning( - type_=bigquery.TimePartitioningType.DAY, field="timestamp" + type_=bigquery.TimePartitioningType.DAY, + field="timestamp", ) tbl.clustering_fields = self.config.clustering_fields + tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} + table_ready = False try: self.client.create_table(tbl) + table_ready = True except cloud_exceptions.Conflict: - pass + # Another process created it concurrently — still usable. + table_ready = True except Exception as e: logger.error( "Could not create table %s: %s", @@ -1843,6 +2153,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): e, exc_info=True, ) + if table_ready and self.config.create_views: + self._create_analytics_views() except Exception as e: logger.error( "Error checking for table %s: %s", @@ -1851,6 +2163,173 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): exc_info=True, ) + @staticmethod + def _schema_fields_match( + existing: list[bq_schema.SchemaField], + desired: list[bq_schema.SchemaField], + ) -> tuple[ + list[bq_schema.SchemaField], + list[bq_schema.SchemaField], + ]: + """Compares existing vs desired schema fields recursively. + + Returns: + A tuple of (new_top_level_fields, updated_record_fields). + ``new_top_level_fields`` are fields in *desired* that are + entirely absent from *existing*. + ``updated_record_fields`` are RECORD fields that exist in + both but have new sub-fields in *desired*; each entry is a + copy of the existing field with the missing sub-fields + appended. + """ + existing_by_name = {f.name: f for f in existing} + new_fields: list[bq_schema.SchemaField] = [] + updated_records: list[bq_schema.SchemaField] = [] + + for desired_field in desired: + existing_field = existing_by_name.get(desired_field.name) + if existing_field is None: + new_fields.append(desired_field) + elif ( + desired_field.field_type == "RECORD" + and existing_field.field_type == "RECORD" + and desired_field.fields + ): + # Recurse into nested RECORD fields. + sub_new, sub_updated = ( + BigQueryAgentAnalyticsPlugin._schema_fields_match( + list(existing_field.fields), + list(desired_field.fields), + ) + ) + if sub_new or sub_updated: + # Build a merged sub-field list. + merged_sub = list(existing_field.fields) + # Replace updated nested records in-place. + updated_names = {f.name for f in sub_updated} + merged_sub = [ + next(u for u in sub_updated if u.name == f.name) + if f.name in updated_names + else f + for f in merged_sub + ] + # Append entirely new sub-fields. + merged_sub.extend(sub_new) + # Rebuild via API representation to preserve all + # existing field attributes (policy_tags, etc.). + api_repr = existing_field.to_api_repr() + api_repr["fields"] = [sf.to_api_repr() for sf in merged_sub] + updated_records.append(bq_schema.SchemaField.from_api_repr(api_repr)) + + return new_fields, updated_records + + def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: + """Adds missing columns to an existing table (additive only). + + Handles nested RECORD fields by recursing into sub-fields. + The version label is only stamped after a successful update + so that a failed attempt is retried on the next run. + + Args: + existing_table: The current BigQuery table object. + """ + stored_version = (existing_table.labels or {}).get( + _SCHEMA_VERSION_LABEL_KEY + ) + if stored_version == _SCHEMA_VERSION: + return + + new_fields, updated_records = self._schema_fields_match( + list(existing_table.schema), list(self._schema) + ) + + if new_fields or updated_records: + # Build merged top-level schema. + updated_names = {f.name for f in updated_records} + merged = [ + next(u for u in updated_records if u.name == f.name) + if f.name in updated_names + else f + for f in existing_table.schema + ] + merged.extend(new_fields) + existing_table.schema = merged + + change_desc = [] + if new_fields: + change_desc.append(f"new columns {[f.name for f in new_fields]}") + if updated_records: + change_desc.append( + f"updated RECORD fields {[f.name for f in updated_records]}" + ) + logger.info( + "Auto-upgrading table %s: %s", + self.full_table_id, + ", ".join(change_desc), + ) + + try: + # Stamp the version label inside the try block so that + # on failure the label is NOT persisted and the next run + # retries the upgrade. + labels = dict(existing_table.labels or {}) + labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION + existing_table.labels = labels + + update_fields = ["schema", "labels"] + self.client.update_table(existing_table, update_fields) + except Exception as e: + logger.error( + "Schema auto-upgrade failed for %s: %s", + self.full_table_id, + e, + exc_info=True, + ) + + def _create_analytics_views(self) -> None: + """Creates per-event-type BigQuery views (idempotent). + + Each view filters the events table by ``event_type`` and + extracts JSON columns into typed, queryable columns. Uses + ``CREATE OR REPLACE VIEW`` so it is safe to call repeatedly. + Errors are logged but never raised. + """ + for event_type, extra_cols in _EVENT_VIEW_DEFS.items(): + view_name = "v_" + event_type.lower() + columns = ",\n ".join(list(_VIEW_COMMON_COLUMNS) + extra_cols) + sql = _VIEW_SQL_TEMPLATE.format( + project=self.project_id, + dataset=self.dataset_id, + view_name=view_name, + columns=columns, + table=self.table_id, + event_type=event_type, + ) + try: + self.client.query(sql).result() + except Exception as e: + logger.error( + "Failed to create view %s: %s", + view_name, + e, + exc_info=True, + ) + + async def create_analytics_views(self) -> None: + """Public async helper to (re-)create all analytics views. + + Useful when views need to be refreshed explicitly, for example + after a schema upgrade. Ensures the plugin is initialized + before attempting view creation. + """ + await self._ensure_started() + if not self._started: + raise RuntimeError( + "Plugin initialization failed; cannot create analytics views." + ) from self._startup_error + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._executor, self._create_analytics_views) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -1868,6 +2347,22 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): if loop in self._loop_state_by_loop: await self._loop_state_by_loop[loop].batch_processor.shutdown(timeout=t) + # 1b. Drain batch processors on other (non-current) loops. + for other_loop, state in self._loop_state_by_loop.items(): + if other_loop is loop or other_loop.is_closed(): + continue + try: + future = asyncio.run_coroutine_threadsafe( + state.batch_processor.shutdown(timeout=t), + other_loop, + ) + future.result(timeout=t) + except Exception: + logger.warning( + "Could not drain batch processor on loop %s", + other_loop, + ) + # 2. Close clients for all states for state in self._loop_state_by_loop.values(): if state.write_client and getattr( @@ -1902,13 +2397,71 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): state["offloader"] = None state["parser"] = None state["_started"] = False + state["_startup_error"] = None state["_is_shutting_down"] = False + state["_init_pid"] = 0 return state def __setstate__(self, state): """Custom unpickling to restore state.""" + # Backfill keys that may be absent in pickled state from older + # code versions so _ensure_started does not raise AttributeError. + state.setdefault("_init_pid", 0) self.__dict__.update(state) + def _reset_runtime_state(self) -> None: + """Resets all runtime state after a fork. + + gRPC channels and asyncio locks are not safe to use after + ``os.fork()``. This method clears them so the next call to + ``_ensure_started()`` re-initializes everything in the child + process. Pure-data fields like ``_schema`` and + ``arrow_schema`` are kept because they are safe across fork. + """ + logger.warning( + "Fork detected (parent PID %s, child PID %s). Resetting" + " gRPC state for BigQuery analytics plugin. Note: gRPC" + " bidirectional streaming (used by the BigQuery Storage" + " Write API) is not fork-safe. If writes hang or time" + " out, configure the 'spawn' start method at your program" + " entry-point before creating child processes:" + " multiprocessing.set_start_method('spawn')", + self._init_pid, + os.getpid(), + ) + # Best-effort: close inherited gRPC channels so broken + # finalizers don't interfere with newly created channels. + # For grpc.aio channels, close() is a coroutine. We cannot + # await here (called from sync context / fork handler), so + # we skip async channels and only close sync ones. + for loop_state in self._loop_state_by_loop.values(): + wc = getattr(loop_state, "write_client", None) + transport = getattr(wc, "transport", None) + if transport is not None: + try: + channel = getattr(transport, "_grpc_channel", None) + if channel is not None and hasattr(channel, "close"): + result = channel.close() + # If close() returned a coroutine (grpc.aio channel), + # discard it to avoid unawaited-coroutine warnings. + if asyncio.iscoroutine(result): + result.close() + except Exception: + pass + + # Clear all runtime state. + self._setup_lock = None + self.client = None + self._loop_state_by_loop = {} + self._write_stream_name = None + self._executor = None + self.offloader = None + self.parser = None + self._started = False + self._startup_error = None + self._is_shutting_down = False + self._init_pid = os.getpid() + async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin: await self._ensure_started() return self @@ -1918,6 +2471,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def _ensure_started(self, **kwargs) -> None: """Ensures that the plugin is started and initialized.""" + if os.getpid() != self._init_pid: + self._reset_runtime_state() if not self._started: # Kept original lock name as it was not explicitly changed. if self._setup_lock is None: @@ -1927,31 +2482,59 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): try: await self._lazy_setup(**kwargs) self._started = True + self._startup_error = None except Exception as e: + self._startup_error = e logger.error("Failed to initialize BigQuery Plugin: %s", e) @staticmethod - def _resolve_span_ids( + def _resolve_ids( event_data: EventData, - ) -> tuple[str, str]: - """Reads span/parent overrides from EventData, falling back to TraceManager. + callback_context: CallbackContext, + ) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolves trace_id, span_id, and parent_span_id for a log row. + + Priority order (highest first): + 1. Explicit ``EventData`` overrides (needed for post-pop callbacks). + 2. Ambient OTel span (the framework's ``start_as_current_span``). + When present this aligns BQ rows with Cloud Trace / o11y. + 3. Plugin's internal span stack (``TraceManager``). + 4. ``invocation_id`` fallback for trace_id. Returns: - (span_id, parent_span_id) + (trace_id, span_id, parent_span_id) """ - current_span_id, current_parent_span_id = ( + # --- Layer 3: plugin stack baseline --- + trace_id = TraceManager.get_trace_id(callback_context) + plugin_span_id, plugin_parent_span_id = ( TraceManager.get_current_span_and_parent() ) + span_id = plugin_span_id + parent_span_id = plugin_parent_span_id - span_id = current_span_id + # --- Layer 2: ambient OTel span --- + ambient = trace.get_current_span() + ambient_ctx = ambient.get_span_context() + if ambient_ctx.is_valid: + trace_id = format(ambient_ctx.trace_id, "032x") + span_id = format(ambient_ctx.span_id, "016x") + # Reset parent — stale plugin-stack parent must not leak through + # when the ambient span is a root (no parent). + parent_span_id = None + # SDK spans expose .parent; non-recording spans do not. + parent_ctx = getattr(ambient, "parent", None) + if parent_ctx is not None and parent_ctx.span_id: + parent_span_id = format(parent_ctx.span_id, "016x") + + # --- Layer 1: explicit EventData overrides --- + if event_data.trace_id_override is not None: + trace_id = event_data.trace_id_override if event_data.span_id_override is not None: span_id = event_data.span_id_override - - parent_span_id = current_parent_span_id if event_data.parent_span_id_override is not None: parent_span_id = event_data.parent_span_id_override - return span_id, parent_span_id + return trace_id, span_id, parent_span_id @staticmethod def _extract_latency( @@ -2011,7 +2594,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): # Include session state if non-empty (contains user-set metadata # like gchat thread-id, customer_id, etc.) if session.state: - session_meta["state"] = dict(session.state) + truncated_state, _ = _recursive_smart_truncate( + dict(session.state), + self.config.max_content_length, + ) + session_meta["state"] = truncated_state attrs["session_metadata"] = session_meta except Exception: pass @@ -2064,8 +2651,9 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): except Exception as e: logger.warning("Content formatter failed: %s", e) - trace_id = TraceManager.get_trace_id(callback_context) - span_id, parent_span_id = self._resolve_span_ids(event_data) + trace_id, span_id, parent_span_id = self._resolve_ids( + event_data, callback_context + ) if not self.parser: logger.warning("Parser not initialized; skipping event %s.", event_type) @@ -2123,16 +2711,43 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) -> None: """Parity with V1: Logs USER_MESSAGE_RECEIVED event. + Also detects HITL completion responses (user-sent + ``FunctionResponse`` parts with ``adk_request_*`` names) and emits + dedicated ``HITL_*_COMPLETED`` events. + Args: invocation_context: The context of the current invocation. user_message: The message content received from the user. """ + callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "USER_MESSAGE_RECEIVED", - CallbackContext(invocation_context), + callback_ctx, raw_content=user_message, ) + # Detect HITL completion responses in the user message. + if user_message and user_message.parts: + for part in user_message.parts: + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + @_safe_callback async def on_event_callback( self, @@ -2140,24 +2755,76 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): invocation_context: InvocationContext, event: "Event", ) -> None: - """Logs state changes from events to BigQuery. + """Logs state changes and HITL events from the event stream. - Checks each event for a non-empty state_delta and logs it as a - STATE_DELTA event. This captures state changes from all sources - (tools, agents, LLM, manual), not just tool callbacks. + - Checks each event for a non-empty state_delta and logs it as a + STATE_DELTA event. + - Detects synthetic ``adk_request_*`` function calls (HITL pause + events) and their corresponding function responses (HITL + completions) and emits dedicated HITL event types. + + The HITL detection must happen here (not in tool callbacks) because + ``adk_request_credential``, ``adk_request_confirmation``, and + ``adk_request_input`` are synthetic function calls injected by the + framework — they never go through ``before_tool_callback`` / + ``after_tool_callback``. Args: invocation_context: The context for the current invocation. event: The event raised by the runner. """ + callback_ctx = CallbackContext(invocation_context) + + # --- State delta logging --- if event.actions and event.actions.state_delta: await self._log_event( "STATE_DELTA", - CallbackContext(invocation_context), + callback_ctx, event_data=EventData( extra_attributes={"state_delta": dict(event.actions.state_delta)} ), ) + + # --- HITL event logging --- + if event.content and event.content.parts: + for part in event.content.parts: + # Detect HITL function calls (request events). + if part.function_call: + hitl_event = _HITL_EVENT_MAP.get(part.function_call.name) + if hitl_event: + args_truncated, is_truncated = _recursive_smart_truncate( + part.function_call.args or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_call.name, + "args": args_truncated, + } + await self._log_event( + hitl_event, + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + # Detect HITL function responses (completion events). + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + return None async def on_state_change_callback( @@ -2188,9 +2855,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): invocation_context: The context of the current invocation. """ await self._ensure_started() + callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "INVOCATION_STARTING", - CallbackContext(invocation_context), + callback_ctx, ) @_safe_callback @@ -2202,12 +2871,40 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): Args: invocation_context: The context of the current invocation. """ - await self._log_event( - "INVOCATION_COMPLETED", - CallbackContext(invocation_context), - ) - # Ensure all logs are flushed before the agent returns - await self.flush() + try: + # Capture trace_id BEFORE popping the invocation-root span so + # that INVOCATION_COMPLETED shares the same trace_id as all + # earlier events in this invocation (fixes #4645). + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) + + # Pop the invocation-root span pushed by ensure_invocation_span. + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() + + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + + await self._log_event( + "INVOCATION_COMPLETED", + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, + ), + ) + finally: + # Cleanup must run even if _log_event raises, otherwise + # stale invocation metadata leaks into the next invocation. + TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) + # Ensure all logs are flushed before the agent returns. + await self.flush() @_safe_callback async def before_agent_callback( @@ -2238,18 +2935,20 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): callback_context: The callback context. """ span_id, duration = TraceManager.pop_span() - # When popping, the current stack now points to parent. - # The event we are logging ("AGENT_COMPLETED") belongs to the span we just popped. - # So we must override span_id to be the popped span, and parent to be current top of stack. parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "AGENT_COMPLETED", callback_context, event_data=EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) @@ -2399,6 +3098,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping LLM_REQUEST/LLM_RESPONSE pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + use_override = is_popped and not has_ambient + await self._log_event( "LLM_RESPONSE", callback_context, @@ -2409,8 +3114,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): time_to_first_token_ms=tfft, model_version=llm_response.model_version, usage_metadata=llm_response.usage_metadata, - span_id_override=span_id if is_popped else None, - parent_span_id_override=(parent_span_id if is_popped else None), + span_id_override=span_id if use_override else None, + parent_span_id_override=parent_span_id if use_override else None, ), ) @@ -2431,14 +3136,19 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "LLM_ERROR", callback_context, event_data=EventData( + status="ERROR", error_message=str(error), latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) @@ -2460,7 +3170,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } TraceManager.push_span(tool_context, "tool") await self._log_event( "TOOL_STARTING", @@ -2489,20 +3204,29 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): resp_truncated, is_truncated = _recursive_smart_truncate( result, self.config.max_content_length ) - content_dict = {"tool": tool.name, "result": resp_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "result": resp_truncated, + "tool_origin": tool_origin, + } span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + + event_data = EventData( + latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, + ) await self._log_event( "TOOL_COMPLETED", tool_context, raw_content=content_dict, is_truncated=is_truncated, - event_data=EventData( - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, - ), + event_data=event_data, ) @_safe_callback @@ -2525,15 +3249,28 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} - _, duration = TraceManager.pop_span() + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } + span_id, duration = TraceManager.pop_span() + parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "TOOL_ERROR", tool_context, raw_content=content_dict, is_truncated=is_truncated, event_data=EventData( + status="ERROR", error_message=str(error), latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=None if has_ambient else parent_span_id, ), ) diff --git a/src/google/adk/plugins/logging_plugin.py b/src/google/adk/plugins/logging_plugin.py index df37ee7e..b95e178d 100644 --- a/src/google/adk/plugins/logging_plugin.py +++ b/src/google/adk/plugins/logging_plugin.py @@ -19,6 +19,7 @@ from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from typing_extensions import override from ..agents.base_agent import BaseAgent from ..agents.callback_context import CallbackContext @@ -66,6 +67,7 @@ class LoggingPlugin(BasePlugin): """ super().__init__(name) + @override async def on_user_message_callback( self, *, @@ -87,6 +89,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Branch: {invocation_context.branch}") return None + @override async def before_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[types.Content]: @@ -99,6 +102,7 @@ class LoggingPlugin(BasePlugin): ) return None + @override async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: @@ -122,6 +126,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def after_run_callback( self, *, invocation_context: InvocationContext ) -> Optional[None]: @@ -134,6 +139,7 @@ class LoggingPlugin(BasePlugin): ) return None + @override async def before_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: @@ -145,6 +151,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Branch: {callback_context._invocation_context.branch}") return None + @override async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: @@ -154,6 +161,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Invocation ID: {callback_context.invocation_id}") return None + @override async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: @@ -179,6 +187,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse ) -> Optional[LlmResponse]: @@ -206,6 +215,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def before_tool_callback( self, *, @@ -221,6 +231,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Arguments: {self._format_args(tool_args)}") return None + @override async def after_tool_callback( self, *, @@ -237,6 +248,7 @@ class LoggingPlugin(BasePlugin): self._log(f" Result: {self._format_args(result)}") return None + @override async def on_model_error_callback( self, *, @@ -251,6 +263,7 @@ class LoggingPlugin(BasePlugin): return None + @override async def on_tool_error_callback( self, *, diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a8..d6230752 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -45,9 +45,11 @@ from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor +from .errors.session_not_found_error import SessionNotFoundError from .events.event import Event from .events.event import EventActions from .flows.llm_flows import contents +from .flows.llm_flows.functions import find_event_by_function_call_id from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService @@ -69,6 +71,16 @@ def _is_tool_call_or_response(event: Event) -> bool: return bool(event.get_function_calls() or event.get_function_responses()) +def _get_function_responses_from_content( + content: types.Content, +) -> list[types.FunctionResponse]: + if not content: + return [] + return [ + part.function_response for part in content.parts if part.function_response + ] + + def _is_transcription(event: Event) -> bool: return ( event.input_transcription is not None @@ -340,6 +352,35 @@ class Runner: self._app_name_alignment_hint = f'{mismatch_details} {resolution}' logger.warning('App name mismatch detected. %s', mismatch_details) + def _resolve_invocation_id( + self, + session: Session, + new_message: Optional[types.Content], + invocation_id: Optional[str], + ) -> Optional[str]: + """Infers invocation_id from new_message if it is a function response.""" + function_responses = _get_function_responses_from_content(new_message) + if not function_responses: + return invocation_id + + fc_event = find_event_by_function_call_id( + session.events, function_responses[0].id + ) + if not fc_event: + raise ValueError( + 'Function call event not found for function response id:' + f' {function_responses[0].id}' + ) + + if invocation_id and invocation_id != fc_event.invocation_id: + logger.warning( + 'Provided invocation_id %s is ignored because new_message has a ' + 'function response with invocation_id %s.', + invocation_id, + fc_event.invocation_id, + ) + return fc_event.invocation_id + def _format_session_not_found_message(self, session_id: str) -> str: message = f'Session not found: {session_id}' if not self._app_name_alignment_hint: @@ -358,7 +399,7 @@ class Runner: This helper first attempts to retrieve the session. If not found and auto_create_session is True, it creates a new session with the provided - identifiers. Otherwise, it raises a ValueError with a helpful message. + identifiers. Otherwise, it raises a SessionNotFoundError. Args: user_id: The user ID of the session. @@ -368,7 +409,8 @@ class Runner: The existing or newly created `Session`. Raises: - ValueError: If the session is not found and auto_create_session is False. + SessionNotFoundError: If the session is not found and + auto_create_session is False. """ session = await self.session_service.get_session( app_name=self.app_name, user_id=user_id, session_id=session_id @@ -380,7 +422,7 @@ class Runner: ) else: message = self._format_session_not_found_message(session_id) - raise ValueError(message) + raise SessionNotFoundError(message) return session def run( @@ -495,6 +537,7 @@ class Runner: session = await self._get_or_create_session( user_id=user_id, session_id=session_id ) + if not invocation_id and not new_message: raise ValueError( 'Running an agent requires either a new_message or an ' @@ -502,35 +545,49 @@ class Runner: f'Session: {session_id}, User: {user_id}' ) - if invocation_id: - if ( - not self.resumability_config - or not self.resumability_config.is_resumable - ): - raise ValueError( - f'invocation_id: {invocation_id} is provided but the app is not' - ' resumable.' - ) - invocation_context = await self._setup_context_for_resumed_invocation( - session=session, - new_message=new_message, - invocation_id=invocation_id, - run_config=run_config, - state_delta=state_delta, + is_resumable = ( + self.resumability_config and self.resumability_config.is_resumable + ) + if not is_resumable and not new_message: + raise ValueError( + 'Running an agent requires a new_message or a resumable app. ' + f'Session: {session_id}, User: {user_id}' ) - if invocation_context.end_of_agents.get( - invocation_context.agent.name - ): - # Directly return if the current agent in invocation context is - # already final. - return - else: + + if not is_resumable: invocation_context = await self._setup_context_for_new_invocation( session=session, - new_message=new_message, # new_message is not None. + new_message=new_message, run_config=run_config, state_delta=state_delta, ) + else: + invocation_id = self._resolve_invocation_id( + session, new_message, invocation_id + ) + if not invocation_id: + invocation_context = await self._setup_context_for_new_invocation( + session=session, + new_message=new_message, + run_config=run_config, + state_delta=state_delta, + ) + else: + invocation_context = ( + await self._setup_context_for_resumed_invocation( + session=session, + new_message=new_message, + invocation_id=invocation_id, + run_config=run_config, + state_delta=state_delta, + ) + ) + if invocation_context.end_of_agents.get( + invocation_context.agent.name + ): + # Directly return if the current agent in invocation context is + # already final. + return async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(ctx.agent.run_async(ctx)) as agen: @@ -553,7 +610,10 @@ class Runner: if self.app and self.app.events_compaction_config: logger.debug('Running event compactor.') await _run_compaction_for_sliding_window( - self.app, session, self.session_service + self.app, + session, + self.session_service, + skip_token_compaction=invocation_context.token_compaction_checked, ) async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: @@ -1321,6 +1381,10 @@ class Runner: return event.content return None + def _create_invocation_context(self, **kwargs) -> InvocationContext: + """Creates an InvocationContext instance.""" + return InvocationContext(**kwargs) + def _new_invocation_context( self, session: Session, @@ -1355,13 +1419,16 @@ class Runner: if not isinstance(self.agent.code_executor, BuiltInCodeExecutor): self.agent.code_executor = BuiltInCodeExecutor() - return InvocationContext( + return self._create_invocation_context( artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, credential_service=self.credential_service, plugin_manager=self.plugin_manager, context_cache_config=self.context_cache_config, + events_compaction_config=( + self.app.events_compaction_config if self.app else None + ), invocation_id=invocation_id, agent=self.agent, session=session, diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83..eb22a83b 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -106,13 +106,35 @@ class BaseSessionService(abc.ABC): """Appends an event to a session object.""" if event.partial: return event + # Apply temp-scoped state to the in-memory session BEFORE trimming the + # event delta, so that subsequent agents within the same invocation can + # read temp values (e.g. output_key='temp:my_key' in SequentialAgent). + self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) self._update_session_state(session, event) session.events.append(event) return event + def _apply_temp_state(self, session: Session, event: Event) -> None: + """Applies temp-scoped state delta to the in-memory session state. + + Temp state is ephemeral: it lives in the session's in-memory state for + the duration of the current invocation but is NOT persisted to storage + (the event delta is trimmed separately by _trim_temp_delta_state). + """ + if not event.actions or not event.actions.state_delta: + return + for key, value in event.actions.state_delta.items(): + if key.startswith(State.TEMP_PREFIX): + session.state[key] = value + def _trim_temp_delta_state(self, event: Event) -> Event: - """Removes temporary state delta keys from the event.""" + """Removes temporary state delta keys from the event. + + This prevents temp-scoped state from being persisted, while the + in-memory session state (updated by _apply_temp_state) retains the + values for the duration of the current invocation. + """ if not event.actions or not event.actions.state_delta: return event @@ -128,6 +150,4 @@ class BaseSessionService(abc.ABC): if not event.actions or not event.actions.state_delta: return for key, value in event.actions.state_delta.items(): - if key.startswith(State.TEMP_PREFIX): - continue session.state.update({key: value}) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 24f525ba..321a5cc6 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -522,6 +522,9 @@ class DatabaseSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) @@ -531,6 +534,16 @@ class DatabaseSessionService(BaseSessionService): schema = self._get_schema_classes() is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT use_row_level_locking = self._supports_row_level_locking() + + state_delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + state_deltas = _session_util.extract_state_delta(state_delta) + has_app_delta = bool(state_deltas["app"]) + has_user_delta = bool(state_deltas["user"]) + async with self._with_session_lock( app_name=session.app_name, user_id=session.user_id, @@ -554,7 +567,7 @@ class DatabaseSessionService(BaseSessionService): sql_session=sql_session, state_model=schema.StorageAppState, predicates=(schema.StorageAppState.app_name == session.app_name,), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_app_delta, missing_message=( "App state missing for app_name=" f"{session.app_name!r}. Session state tables should be " @@ -568,7 +581,7 @@ class DatabaseSessionService(BaseSessionService): schema.StorageUserState.app_name == session.app_name, schema.StorageUserState.user_id == session.user_id, ), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_user_delta, missing_message=( "User state missing for app_name=" f"{session.app_name!r}, user_id={session.user_id!r}. " @@ -599,23 +612,19 @@ class DatabaseSessionService(BaseSessionService): storage_events = [e async for e in result] session.events = [e.to_event() for e in storage_events] - # Extract state delta - if event.actions and event.actions.state_delta: - state_deltas = _session_util.extract_state_delta( - event.actions.state_delta + # Merge pre-extracted state deltas into storage. + if has_app_delta: + storage_app_state.state = ( + storage_app_state.state | state_deltas["app"] + ) + if has_user_delta: + storage_user_state.state = ( + storage_user_state.state | state_deltas["user"] + ) + if state_deltas["session"]: + storage_session.state = ( + storage_session.state | state_deltas["session"] ) - app_state_delta = state_deltas["app"] - user_state_delta = state_deltas["user"] - session_state_delta = state_deltas["session"] - # Merge state and update storage - if app_state_delta: - storage_app_state.state = storage_app_state.state | app_state_delta - if user_state_delta: - storage_user_state.state = ( - storage_user_state.state | user_state_delta - ) - if session_state_delta: - storage_session.state = storage_session.state | session_state_delta if is_sqlite: update_time = datetime.fromtimestamp( diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d23c8278..600f89c4 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -361,6 +361,9 @@ class SqliteSessionService(BaseSessionService): if event.partial: return event + # Apply temp state to in-memory session before trimming, so that + # subsequent agents within the same invocation can read temp values. + self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) event_timestamp = event.timestamp diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 1837a907..8cb7109e 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: from . import _session_util from ..events.event import Event from ..events.event_actions import EventActions +from ..events.event_actions import EventCompaction from ..utils.vertex_ai_utils import get_express_mode_api_key from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig @@ -267,8 +268,9 @@ class VertexAiSessionService(BaseSessionService): k: json.loads(v.model_dump_json(exclude_none=True, by_alias=True)) for k, v in event.actions.requested_auth_configs.items() }, - # TODO: add requested_tool_confirmations, compaction, agent_state once + # TODO: add requested_tool_confirmations, agent_state once # they are available in the API. + # Note: compaction is stored via event_metadata.custom_metadata. } if event.error_code: config['error_code'] = event.error_code @@ -291,6 +293,19 @@ class VertexAiSessionService(BaseSessionService): metadata_dict['grounding_metadata'] = event.grounding_metadata.model_dump( exclude_none=True, mode='json' ) + # Store compaction data in custom_metadata since the Vertex AI service + # does not yet support the compaction field. + # TODO: Stop writing to custom_metadata once the Vertex AI service + # supports the compaction field natively in EventActions. + if event.actions and event.actions.compaction: + compaction_dict = event.actions.compaction.model_dump( + exclude_none=True, mode='json' + ) + existing_custom = metadata_dict.get('custom_metadata') or {} + metadata_dict['custom_metadata'] = { + **existing_custom, + '_compaction': compaction_dict, + } config['event_metadata'] = metadata_dict async with self._get_api_client() as api_client: @@ -347,16 +362,6 @@ class VertexAiSessionService(BaseSessionService): def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: """Converts an API event object to an Event object.""" actions = getattr(api_event_obj, 'actions', None) - if actions: - actions_dict = actions.model_dump(exclude_none=True, mode='python') - rename_map = {'transfer_agent': 'transfer_to_agent'} - renamed_actions_dict = { - rename_map.get(k, k): v for k, v in actions_dict.items() - } - event_actions = EventActions.model_validate(renamed_actions_dict) - else: - event_actions = EventActions() - event_metadata = getattr(api_event_obj, 'event_metadata', None) if event_metadata: long_running_tool_ids_list = getattr( @@ -370,6 +375,16 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: interrupted = getattr(event_metadata, 'interrupted', None) branch = getattr(event_metadata, 'branch', None) custom_metadata = getattr(event_metadata, 'custom_metadata', None) + # Extract compaction data stored in custom_metadata. + # NOTE: This read path must be kept permanently because sessions + # written before native compaction support store compaction data + # in custom_metadata under the '_compaction' key. + compaction_data = None + if custom_metadata and '_compaction' in custom_metadata: + custom_metadata = dict(custom_metadata) # avoid mutating the API response + compaction_data = custom_metadata.pop('_compaction') + if not custom_metadata: + custom_metadata = None grounding_metadata = _session_util.decode_model( getattr(event_metadata, 'grounding_metadata', None), types.GroundingMetadata, @@ -381,8 +396,26 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: interrupted = None branch = None custom_metadata = None + compaction_data = None grounding_metadata = None + if actions: + actions_dict = actions.model_dump(exclude_none=True, mode='python') + rename_map = {'transfer_agent': 'transfer_to_agent'} + renamed_actions_dict = { + rename_map.get(k, k): v for k, v in actions_dict.items() + } + if compaction_data: + renamed_actions_dict['compaction'] = compaction_data + event_actions = EventActions.model_validate(renamed_actions_dict) + else: + if compaction_data: + event_actions = EventActions( + compaction=EventCompaction.model_validate(compaction_data) + ) + else: + event_actions = EventActions() + return Event( id=api_event_obj.name.split('/')[-1], invocation_id=api_event_obj.invocation_id, diff --git a/src/google/adk/skills/__init__.py b/src/google/adk/skills/__init__.py index 73184b2b..86724bd0 100644 --- a/src/google/adk/skills/__init__.py +++ b/src/google/adk/skills/__init__.py @@ -14,16 +14,38 @@ """Agent Development Kit - Skills.""" +from typing import Any +import warnings + +from ._utils import _load_skill_from_dir as load_skill_from_dir from .models import Frontmatter from .models import Resources from .models import Script from .models import Skill -from .utils import load_skill_from_dir __all__ = [ + "DEFAULT_SKILL_SYSTEM_INSTRUCTION", "Frontmatter", "Resources", "Script", "Skill", "load_skill_from_dir", ] + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + + from ..tools import skill_toolset + + warnings.warn( + ( + "Importing DEFAULT_SKILL_SYSTEM_INSTRUCTION from" + " google.adk.skills is deprecated." + " Please import it from google.adk.tools.skill_toolset instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/google/adk/skills/_utils.py b/src/google/adk/skills/_utils.py new file mode 100644 index 00000000..0bfbf30e --- /dev/null +++ b/src/google/adk/skills/_utils.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Agent Skills.""" + +from __future__ import annotations + +import pathlib +from typing import Union + +import yaml + +from . import models + +_ALLOWED_FRONTMATTER_KEYS = frozenset({ + "name", + "description", + "license", + "allowed-tools", + "allowed_tools", + "metadata", + "compatibility", +}) + + +def _load_dir(directory: pathlib.Path) -> dict[str, str]: + """Recursively load files from a directory into a dictionary. + + Args: + directory: Path to the directory to load. + + Returns: + Dictionary mapping relative file paths to their string content. + """ + files = {} + if directory.exists() and directory.is_dir(): + for file_path in directory.rglob("*"): + if "__pycache__" in file_path.parts: + continue + if file_path.is_file(): + relative_path = file_path.relative_to(directory) + try: + files[str(relative_path)] = file_path.read_text(encoding="utf-8") + except UnicodeDecodeError: + # Binary files or non-UTF-8 files are skipped for text content. + continue + return files + + +def _parse_skill_md( + skill_dir: pathlib.Path, +) -> tuple[dict, str, pathlib.Path]: + """Parse SKILL.md from a skill directory. + + Args: + skill_dir: Resolved path to the skill directory. + + Returns: + Tuple of (parsed_frontmatter_dict, body_string, skill_md_path). + + Raises: + FileNotFoundError: If the directory or SKILL.md is not found. + ValueError: If SKILL.md is invalid. + """ + if not skill_dir.is_dir(): + raise FileNotFoundError(f"Skill directory '{skill_dir}' not found.") + + skill_md = None + for name in ("SKILL.md", "skill.md"): + path = skill_dir / name + if path.exists(): + skill_md = path + break + + if skill_md is None: + raise FileNotFoundError(f"SKILL.md not found in '{skill_dir}'.") + + content = skill_md.read_text(encoding="utf-8") + if not content.startswith("---"): + raise ValueError("SKILL.md must start with YAML frontmatter (---)") + + parts = content.split("---", 2) + if len(parts) < 3: + raise ValueError("SKILL.md frontmatter not properly closed with ---") + + frontmatter_str = parts[1] + body = parts[2].strip() + + try: + parsed = yaml.safe_load(frontmatter_str) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in frontmatter: {e}") from e + + if not isinstance(parsed, dict): + raise ValueError("SKILL.md frontmatter must be a YAML mapping") + + return parsed, body, skill_md + + +def _load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: + """Load a complete skill from a directory. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Skill object with all components loaded. + + Raises: + FileNotFoundError: If the skill directory or SKILL.md is not found. + ValueError: If SKILL.md is invalid or the skill name does not match + the directory name. + """ + skill_dir = pathlib.Path(skill_dir).resolve() + + parsed, body, skill_md = _parse_skill_md(skill_dir) + + # Use model_validate to handle aliases like allowed-tools + frontmatter = models.Frontmatter.model_validate(parsed) + + # Validate that skill name matches the directory name + if skill_dir.name != frontmatter.name: + raise ValueError( + f"Skill name '{frontmatter.name}' does not match directory" + f" name '{skill_dir.name}'." + ) + + references = _load_dir(skill_dir / "references") + assets = _load_dir(skill_dir / "assets") + raw_scripts = _load_dir(skill_dir / "scripts") + scripts = { + name: models.Script(src=content) for name, content in raw_scripts.items() + } + + resources = models.Resources( + references=references, + assets=assets, + scripts=scripts, + ) + + return models.Skill( + frontmatter=frontmatter, + instructions=body, + resources=resources, + ) + + +def _validate_skill_dir( + skill_dir: Union[str, pathlib.Path], +) -> list[str]: + """Validate a skill directory without fully loading it. + + Checks that the directory exists, contains a valid SKILL.md with correct + frontmatter, and that the skill name matches the directory name. + + Args: + skill_dir: Path to the skill directory. + + Returns: + List of problem strings. Empty list means the skill is valid. + """ + problems: list[str] = [] + skill_dir = pathlib.Path(skill_dir).resolve() + + if not skill_dir.exists(): + return [f"Directory '{skill_dir}' does not exist."] + if not skill_dir.is_dir(): + return [f"'{skill_dir}' is not a directory."] + + skill_md = None + for name in ("SKILL.md", "skill.md"): + path = skill_dir / name + if path.exists(): + skill_md = path + break + if skill_md is None: + return [f"SKILL.md not found in '{skill_dir}'."] + + try: + parsed, _, _ = _parse_skill_md(skill_dir) + except (FileNotFoundError, ValueError) as e: + return [str(e)] + + unknown = set(parsed.keys()) - _ALLOWED_FRONTMATTER_KEYS + if unknown: + problems.append(f"Unknown frontmatter fields: {sorted(unknown)}") + + try: + frontmatter = models.Frontmatter.model_validate(parsed) + except Exception as e: + problems.append(f"Frontmatter validation error: {e}") + return problems + + if skill_dir.name != frontmatter.name: + problems.append( + f"Skill name '{frontmatter.name}' does not match directory" + f" name '{skill_dir.name}'." + ) + + return problems + + +def _read_skill_properties( + skill_dir: Union[str, pathlib.Path], +) -> models.Frontmatter: + """Read only the frontmatter properties from a skill directory. + + This is a lightweight alternative to ``load_skill_from_dir`` when you + only need the skill metadata without loading instructions or resources. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Frontmatter object with the skill's metadata. + + Raises: + FileNotFoundError: If the directory or SKILL.md is not found. + ValueError: If the frontmatter is invalid. + """ + skill_dir = pathlib.Path(skill_dir).resolve() + parsed, _, _ = _parse_skill_md(skill_dir) + return models.Frontmatter.model_validate(parsed) diff --git a/src/google/adk/skills/models.py b/src/google/adk/skills/models.py index 7f5d75b4..f7674cd9 100644 --- a/src/google/adk/skills/models.py +++ b/src/google/adk/skills/models.py @@ -16,9 +16,17 @@ from __future__ import annotations +import re +from typing import Any from typing import Optional +import unicodedata from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_validator + +_NAME_PATTERN = re.compile(r"^[a-z0-9]+(-[a-z0-9]+)*$") class Frontmatter(BaseModel): @@ -30,17 +38,68 @@ class Frontmatter(BaseModel): (required). license: License for the skill (optional). compatibility: Compatibility information for the skill (optional). - allowed_tools: Tool patterns the skill requires (optional, experimental). + allowed_tools: A space-delimited list of tools that are pre-approved to + run (optional, experimental). Accepts both ``allowed_tools`` and the + YAML-friendly ``allowed-tools`` key. For more details, see + https://agentskills.io/specification#allowed-tools-field. metadata: Key-value pairs for client-specific properties (defaults to - empty dict). + empty dict). For example, to include additional tools, use the + ``adk_additional_tools`` key with a list of tools. """ + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + name: str description: str license: Optional[str] = None compatibility: Optional[str] = None - allowed_tools: Optional[str] = None - metadata: dict[str, str] = {} + allowed_tools: Optional[str] = Field( + default=None, + alias="allowed-tools", + serialization_alias="allowed-tools", + ) + metadata: dict[str, Any] = {} + + @field_validator("metadata") + @classmethod + def _validate_metadata(cls, v: dict[str, Any]) -> dict[str, Any]: + if "adk_additional_tools" in v: + tools = v["adk_additional_tools"] + if not isinstance(tools, list): + raise ValueError("adk_additional_tools must be a list of strings") + return v + + @field_validator("name") + @classmethod + def _validate_name(cls, v: str) -> str: + v = unicodedata.normalize("NFKC", v) + if len(v) > 64: + raise ValueError("name must be at most 64 characters") + if not _NAME_PATTERN.match(v): + raise ValueError( + "name must be lowercase kebab-case (a-z, 0-9, hyphens)," + " with no leading, trailing, or consecutive hyphens" + ) + return v + + @field_validator("description") + @classmethod + def _validate_description(cls, v: str) -> str: + if not v: + raise ValueError("description must not be empty") + if len(v) > 1024: + raise ValueError("description must be at most 1024 characters") + return v + + @field_validator("compatibility") + @classmethod + def _validate_compatibility(cls, v: Optional[str]) -> Optional[str]: + if v is not None and len(v) > 500: + raise ValueError("compatibility must be at most 500 characters") + return v class Script(BaseModel): diff --git a/src/google/adk/skills/prompt.py b/src/google/adk/skills/prompt.py index e9840ab2..3c352036 100644 --- a/src/google/adk/skills/prompt.py +++ b/src/google/adk/skills/prompt.py @@ -17,16 +17,21 @@ from __future__ import annotations import html +from typing import Any from typing import List +from typing import Union +import warnings from . import models -def format_skills_as_xml(skills: List[models.Frontmatter]) -> str: +def format_skills_as_xml( + skills: List[Union[models.Frontmatter, models.Skill]], +) -> str: """Formats available skills into a standard XML string. Args: - skills: A list of skill frontmatter objects. + skills: A list of skill frontmatter or full skill objects. Returns: XML string with block containing each skill's @@ -38,16 +43,34 @@ def format_skills_as_xml(skills: List[models.Frontmatter]) -> str: lines = [""] - for skill in skills: + for item in skills: lines.append("") lines.append("") - lines.append(html.escape(skill.name)) + lines.append(html.escape(item.name)) lines.append("") lines.append("") - lines.append(html.escape(skill.description)) + lines.append(html.escape(item.description)) lines.append("") lines.append("") lines.append("") return "\n".join(lines) + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + + from ..tools import skill_toolset + + warnings.warn( + ( + "Importing DEFAULT_SKILL_SYSTEM_INSTRUCTION from" + " google.adk.skills.prompt is deprecated." + " Please import it from google.adk.tools.skill_toolset instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/google/adk/skills/utils.py b/src/google/adk/skills/utils.py deleted file mode 100644 index deb10b2a..00000000 --- a/src/google/adk/skills/utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for Agent Skills.""" - -from __future__ import annotations - -import pathlib -from typing import Union - -import yaml - -from . import models - - -def _load_dir(directory: pathlib.Path) -> dict[str, str]: - """Recursively load files from a directory into a dictionary. - - Args: - directory: Path to the directory to load. - - Returns: - Dictionary mapping relative file paths to their string content. - """ - files = {} - if directory.exists() and directory.is_dir(): - for file_path in directory.rglob("*"): - if "__pycache__" in file_path.parts: - continue - if file_path.is_file(): - relative_path = file_path.relative_to(directory) - try: - files[str(relative_path)] = file_path.read_text(encoding="utf-8") - except UnicodeDecodeError: - # Binary files or non-UTF-8 files are skipped for text content. - continue - return files - - -def load_skill_from_dir(skill_dir: Union[str, pathlib.Path]) -> models.Skill: - """Load a complete skill from a directory. - - Args: - skill_dir: Path to the skill directory. - - Returns: - Skill object with all components loaded. - - Raises: - FileNotFoundError: If the skill directory or SKILL.md is not found. - ValueError: If SKILL.md is invalid. - """ - skill_dir = pathlib.Path(skill_dir).resolve() - - if not skill_dir.is_dir(): - raise FileNotFoundError(f"Skill directory '{skill_dir}' not found.") - - skill_md = None - for name in ("SKILL.md", "skill.md"): - path = skill_dir / name - if path.exists(): - skill_md = path - break - - if skill_md is None: - raise FileNotFoundError(f"SKILL.md not found in '{skill_dir}'.") - - content = skill_md.read_text(encoding="utf-8") - if not content.startswith("---"): - raise ValueError("SKILL.md must start with YAML frontmatter (---)") - - parts = content.split("---", 2) - if len(parts) < 3: - raise ValueError("SKILL.md frontmatter not properly closed with ---") - - frontmatter_str = parts[1] - body = parts[2].strip() - - try: - parsed = yaml.safe_load(frontmatter_str) - except yaml.YAMLError as e: - raise ValueError(f"Invalid YAML in frontmatter: {e}") from e - - if not isinstance(parsed, dict): - raise ValueError("SKILL.md frontmatter must be a YAML mapping") - - # Frontmatter class handles required field validation - frontmatter = models.Frontmatter(**parsed) - - references = _load_dir(skill_dir / "references") - assets = _load_dir(skill_dir / "assets") - raw_scripts = _load_dir(skill_dir / "scripts") - scripts = { - name: models.Script(src=content) for name, content in raw_scripts.items() - } - - resources = models.Resources( - references=references, - assets=assets, - scripts=scripts, - ) - - return models.Skill( - frontmatter=frontmatter, - instructions=body, - resources=resources, - ) diff --git a/src/google/adk/telemetry/_experimental_semconv.py b/src/google/adk/telemetry/_experimental_semconv.py new file mode 100644 index 00000000..dbfb3f14 --- /dev/null +++ b/src/google/adk/telemetry/_experimental_semconv.py @@ -0,0 +1,518 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Provides instrumentation for experimental semantic convention https://github.com/open-telemetry/semantic-conventions/blob/v1.39.0/docs/gen-ai/gen-ai-events.md.""" + +from __future__ import annotations + +from collections.abc import Mapping +from collections.abc import MutableMapping +import contextvars +import json +import os +from typing import Any +from typing import Literal +from typing import TypedDict + +from google.genai import types +from google.genai.models import t as transformers +from mcp import ClientSession as McpClientSession +from mcp import Tool as McpTool +from opentelemetry._logs import Logger +from opentelemetry._logs import LogRecord +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_INPUT_MESSAGES +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OUTPUT_MESSAGES +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_RESPONSE_FINISH_REASONS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM_INSTRUCTIONS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_INPUT_TOKENS +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_OUTPUT_TOKENS +from opentelemetry.trace import Span +from opentelemetry.util.types import AttributeValue + +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse + +try: + from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_TOOL_DEFINITIONS +except ImportError: + GEN_AI_TOOL_DEFINITIONS = 'gen_ai.tool_definitions' + +OTEL_SEMCONV_STABILITY_OPT_IN = 'OTEL_SEMCONV_STABILITY_OPT_IN' + +OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = ( + 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT' +) + +FUNCTION_TOOL_DEFINITION_TYPE = 'function' + + +class Text(TypedDict): + content: str + type: Literal['text'] + + +class Blob(TypedDict): + mime_type: str + data: bytes + type: Literal['blob'] + + +class FileData(TypedDict): + mime_type: str + uri: str + type: Literal['file_data'] + + +class ToolCall(TypedDict): + id: str | None + name: str + arguments: Any + type: Literal['tool_call'] + + +class ToolCallResponse(TypedDict): + id: str | None + response: Any + type: Literal['tool_call_response'] + + +Part = Text | Blob | FileData | ToolCall | ToolCallResponse + + +class InputMessage(TypedDict): + role: str + parts: list[Part] + + +class OutputMessage(TypedDict): + role: str + parts: list[Part] + finish_reason: str + + +class FunctionToolDefinition(TypedDict): + name: str + description: str | None + parameters: Any + type: Literal['function'] + + +class GenericToolDefinition(TypedDict): + name: str + type: str + + +ToolDefinition = FunctionToolDefinition | GenericToolDefinition + + +def _safe_json_serialize_no_whitespaces(obj) -> str: + """Convert any Python object to a JSON-serializable type or string. + + Args: + obj: The object to serialize. + + Returns: + The JSON-serialized object string or if the object cannot be serialized. + """ + + try: + # Try direct JSON serialization first + return json.dumps( + obj, + separators=(',', ':'), + ensure_ascii=False, + default=lambda o: '', + ) + except (TypeError, OverflowError): + return '' + + +def is_experimental_semconv() -> bool: + opt_ins = os.getenv(OTEL_SEMCONV_STABILITY_OPT_IN) + if not opt_ins: + return False + opt_ins_list = [s.strip() for s in opt_ins.split(',')] + return 'gen_ai_latest_experimental' in opt_ins_list + + +def get_content_capturing_mode() -> str: + return os.getenv( + OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, '' + ).upper() + + +def _model_dump_to_tool_definition(tool: Any) -> dict[str, Any]: + model_dump = tool.model_dump(exclude_none=True) + + name = ( + model_dump.get('name') + or getattr(tool, 'name', None) + or type(tool).__name__ + ) + description = model_dump.get('description') or getattr( + tool, 'description', None + ) + parameters = model_dump.get('parameters') or model_dump.get('inputSchema') + return FunctionToolDefinition( + name=name, + description=description, + parameters=parameters, + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +def _clean_parameters(params: Any) -> Any: + """Converts parameter objects into plain dicts.""" + if params is None: + return None + if isinstance(params, dict): + return params + if hasattr(params, 'to_dict'): + return params.to_dict() + if hasattr(params, 'model_dump'): + return params.model_dump(exclude_none=True) + + try: + # Check if it's already a standard JSON type. + json.dumps(params) + return params + + except (TypeError, ValueError): + return { + 'type': 'object', + 'properties': { + 'serialization_error': { + 'type': 'string', + 'description': ( + f'Failed to serialize parameters: {type(params).__name__}' + ), + } + }, + } + + +def _tool_to_tool_definition(tool: types.Tool) -> list[dict[str, Any]]: + definitions = [] + if tool.function_declarations: + for fd in tool.function_declarations: + definitions.append( + FunctionToolDefinition( + name=getattr(fd, 'name', type(fd).__name__), + description=getattr(fd, 'description', None), + parameters=_clean_parameters(getattr(fd, 'parameters', None)), + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + ) + + # Generic types + if hasattr(tool, 'model_dump'): + exclude_fields = {'function_declarations'} + fields = { + k: v + for k, v in tool.model_dump().items() + if v is not None and k not in exclude_fields + } + + for tool_type, _ in fields.items(): + definitions.append( + GenericToolDefinition( + name=tool_type, + type=tool_type, + ) + ) + + return definitions + + +def _tool_definition_from_callable_tool(tool: Any) -> dict[str, Any]: + doc = getattr(tool, '__doc__', '') or '' + return FunctionToolDefinition( + name=getattr(tool, '__name__', type(tool).__name__), + description=doc.strip(), + parameters=None, + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +def _tool_definition_from_mcp_tool(tool: McpTool) -> dict[str, Any]: + if hasattr(tool, 'model_dump'): + return _model_dump_to_tool_definition(tool) + + return FunctionToolDefinition( + name=getattr(tool, 'name', type(tool).__name__), + description=getattr(tool, 'description', None), + parameters=getattr(tool, 'input_schema', None), + type=FUNCTION_TOOL_DEFINITION_TYPE, + ) + + +async def _to_tool_definitions( + tool: types.ToolUnionDict, +) -> list[dict[str, Any]]: + + if isinstance(tool, types.Tool): + return _tool_to_tool_definition(tool) + + if callable(tool): + return [_tool_definition_from_callable_tool(tool)] + + if isinstance(tool, McpTool): + return [_tool_definition_from_mcp_tool(tool)] + + if isinstance(tool, McpClientSession): + result = await tool.list_tools() + return [_model_dump_to_tool_definition(t) for t in result.tools] + + return [ + GenericToolDefinition( + name='UnserializableTool', + type=type(tool).__name__, + ) + ] + + +def _operation_details_attributes_no_content( + operation_details_attributes: Mapping[str, AttributeValue], +) -> dict[str, AttributeValue]: + tool_def = operation_details_attributes.get(GEN_AI_TOOL_DEFINITIONS) + if not tool_def: + return {} + + return { + GEN_AI_TOOL_DEFINITIONS: [ + FunctionToolDefinition( + name=td['name'], + description=td['description'], + parameters=None, + type=td['type'], + ) + if 'parameters' in td + else td + for td in tool_def + ] + } + + +def _to_input_message( + content: types.Content, +) -> InputMessage: + parts = (_to_part(part, idx) for idx, part in enumerate(content.parts or [])) + return InputMessage( + role=_to_role(content.role), + parts=[part for part in parts if part is not None], + ) + + +def _to_output_message( + llm_response: LlmResponse, +) -> OutputMessage | None: + if not llm_response.content: + return None + + message = _to_input_message(llm_response.content) + return OutputMessage( + role=message['role'], + parts=message['parts'], + finish_reason=_to_finish_reason(llm_response.finish_reason), + ) + + +def _to_finish_reason( + finish_reason: types.FinishReason | None, +) -> str: + if finish_reason is None: + return '' + if ( + # Mapping unspecified and other to error, + # as JSON schema for finish_reason does not support them. + finish_reason is types.FinishReason.FINISH_REASON_UNSPECIFIED + or finish_reason is types.FinishReason.OTHER + ): + return 'error' + if finish_reason is types.FinishReason.STOP: + return 'stop' + if finish_reason is types.FinishReason.MAX_TOKENS: + return 'length' + + return finish_reason.name.lower() + + +def _to_part(part: types.Part, idx: int) -> Part | None: + def tool_call_id_fallback(name: str | None) -> str: + if name: + return f'{name}_{idx}' + return f'{idx}' + + if part is None: + return None + + if (text := part.text) is not None: + return Text(content=text, type='text') + + if data := part.inline_data: + return Blob( + mime_type=data.mime_type or '', data=data.data or b'', type='blob' + ) + + if data := part.file_data: + return FileData( + mime_type=data.mime_type or '', + uri=data.file_uri or '', + type='file_data', + ) + + if call := part.function_call: + return ToolCall( + id=call.id or tool_call_id_fallback(call.name), + name=call.name or '', + arguments=call.args, + type='tool_call', + ) + + if response := part.function_response: + return ToolCallResponse( + id=response.id or tool_call_id_fallback(response.name), + response=response.response, + type='tool_call_response', + ) + + return None + + +def _to_role(role: str | None) -> str: + if role == 'user': + return 'user' + if role == 'model': + return 'assistant' + return '' + + +def _to_input_messages(contents: list[types.Content]) -> list[InputMessage]: + return [_to_input_message(content) for content in contents] + + +def _to_system_instructions( + config: types.GenerateContentConfig, +) -> list[Part]: + + if not config.system_instruction: + return [] + + transformed_contents = transformers.t_contents(config.system_instruction) + if not transformed_contents: + return [] + + sys_instr = transformed_contents[0] + + parts = ( + _to_part(part, idx) for idx, part in enumerate(sys_instr.parts or []) + ) + return [part for part in parts if part is not None] + + +def set_operation_details_common_attributes( + operation_details_common_attributes: MutableMapping[str, AttributeValue], + attributes: Mapping[str, AttributeValue], +): + operation_details_common_attributes.update(attributes) + + +async def set_operation_details_attributes_from_request( + operation_details_attributes: MutableMapping[str, AttributeValue], + llm_request: LlmRequest, +): + + input_messages = _to_input_messages( + transformers.t_contents(llm_request.contents) + ) + + system_instructions = _to_system_instructions(llm_request.config) + + tool_definitions = [] + if tools := llm_request.config.tools: + for tool in tools: + definitions = await _to_tool_definitions(tool) + for de in definitions: + if de: + tool_definitions.append(de) + + operation_details_attributes[GEN_AI_INPUT_MESSAGES] = input_messages + operation_details_attributes[GEN_AI_SYSTEM_INSTRUCTIONS] = system_instructions + operation_details_attributes[GEN_AI_TOOL_DEFINITIONS] = tool_definitions + + +def set_operation_details_attributes_from_response( + llm_response: LlmResponse, + operation_details_attributes: MutableMapping[str, AttributeValue], + operation_details_common_attributes: MutableMapping[str, AttributeValue], +): + if finish_reason := llm_response.finish_reason: + operation_details_common_attributes[GEN_AI_RESPONSE_FINISH_REASONS] = [ + _to_finish_reason(finish_reason) + ] + if usage_metadata := llm_response.usage_metadata: + if usage_metadata.prompt_token_count is not None: + operation_details_common_attributes[GEN_AI_USAGE_INPUT_TOKENS] = ( + usage_metadata.prompt_token_count + ) + if usage_metadata.candidates_token_count is not None: + operation_details_common_attributes[GEN_AI_USAGE_OUTPUT_TOKENS] = ( + usage_metadata.candidates_token_count + ) + + output_message = _to_output_message(llm_response) + if output_message is not None: + operation_details_attributes[GEN_AI_OUTPUT_MESSAGES] = [output_message] + + +def maybe_log_completion_details( + span: Span | None, + otel_logger: Logger, + operation_details_attributes: Mapping[str, AttributeValue], + operation_details_common_attributes: Mapping[str, AttributeValue], +): + """Logs completion details based on the experimental semantic convention capturing mode.""" + if span is None: + return + + if not is_experimental_semconv(): + return + + capturing_mode = get_content_capturing_mode() + final_attributes = operation_details_common_attributes + + if capturing_mode in ['EVENT_ONLY', 'SPAN_AND_EVENT']: + final_attributes = final_attributes | operation_details_attributes + else: + final_attributes = ( + final_attributes + | _operation_details_attributes_no_content(operation_details_attributes) + ) + + otel_logger.emit( + LogRecord( + event_name='gen_ai.client.inference.operation.details', + attributes=final_attributes, + ) + ) + + if capturing_mode in ['SPAN_ONLY', 'SPAN_AND_EVENT']: + for key, value in operation_details_attributes.items(): + span.set_attribute(key, _safe_json_serialize_no_whitespaces(value)) + else: + for key, value in _operation_details_attributes_no_content( + operation_details_attributes + ).items(): + span.set_attribute(key, _safe_json_serialize_no_whitespaces(value)) diff --git a/src/google/adk/telemetry/sqlite_span_exporter.py b/src/google/adk/telemetry/sqlite_span_exporter.py new file mode 100644 index 00000000..1d535908 --- /dev/null +++ b/src/google/adk/telemetry/sqlite_span_exporter.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLite-backed OpenTelemetry span exporter for local development.""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import threading +from typing import Any +from typing import Iterable +from typing import Optional +from typing import Sequence + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter +from opentelemetry.sdk.trace.export import SpanExportResult +from opentelemetry.trace import SpanContext +from opentelemetry.trace import TraceFlags +from opentelemetry.trace import TraceState + +logger = logging.getLogger("google_adk." + __name__) + +_CREATE_SPANS_TABLE = """ +CREATE TABLE IF NOT EXISTS spans ( + span_id TEXT PRIMARY KEY, + trace_id TEXT NOT NULL, + parent_span_id TEXT, + name TEXT NOT NULL, + start_time_unix_nano INTEGER, + end_time_unix_nano INTEGER, + session_id TEXT, + invocation_id TEXT, + attributes_json TEXT +); +""" + +_CREATE_SESSION_INDEX = """ +CREATE INDEX IF NOT EXISTS spans_session_id_idx ON spans(session_id); +""" + +_CREATE_TRACE_INDEX = """ +CREATE INDEX IF NOT EXISTS spans_trace_id_idx ON spans(trace_id); +""" + +_INSERT_SPAN = """ +INSERT OR REPLACE INTO spans ( + span_id, + trace_id, + parent_span_id, + name, + start_time_unix_nano, + end_time_unix_nano, + session_id, + invocation_id, + attributes_json +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); +""" + +_DEFAULT_TIMEOUT_SECONDS = 30.0 + + +class SqliteSpanExporter(SpanExporter): + """Exports spans to a local SQLite database. + + This is intended for local development (e.g. `adk web`) to allow reloading + traces for older sessions after process restart. + """ + + def __init__(self, *, db_path: str): + self._db_path = db_path + self._lock = threading.Lock() + self._conn: Optional[sqlite3.Connection] = None + self._ensure_schema() + + def _get_connection(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect( + self._db_path, + timeout=_DEFAULT_TIMEOUT_SECONDS, + check_same_thread=False, + ) + self._conn.row_factory = sqlite3.Row + return self._conn + + def _ensure_schema(self) -> None: + with self._lock: + conn = self._get_connection() + conn.execute(_CREATE_SPANS_TABLE) + conn.execute(_CREATE_SESSION_INDEX) + conn.execute(_CREATE_TRACE_INDEX) + conn.commit() + + def _serialize_attributes(self, attributes: dict[str, Any]) -> str: + try: + return json.dumps( + attributes, + ensure_ascii=False, + default=lambda o: "", + ) + except (TypeError, ValueError) as e: + logger.debug("Failed to serialize span attributes: %r", e) + return "{}" + + def _deserialize_attributes(self, attributes_json: Any) -> dict[str, Any]: + if not attributes_json: + return {} + try: + attributes = json.loads(attributes_json) + except (json.JSONDecodeError, TypeError) as e: + logger.debug("Failed to deserialize span attributes: %r", e) + return {} + return attributes if isinstance(attributes, dict) else {} + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + try: + with self._lock: + conn = self._get_connection() + rows: list[tuple[Any, ...]] = [] + for span in spans: + attributes = dict(span.attributes) if span.attributes else {} + session_id = attributes.get( + "gcp.vertex.agent.session_id" + ) or attributes.get("gen_ai.conversation.id") + invocation_id = attributes.get("gcp.vertex.agent.invocation_id") + + parent_span_id = None + if span.parent is not None: + parent_span_id = format(span.parent.span_id, "016x") + + rows.append(( + format(span.context.span_id, "016x"), + format(span.context.trace_id, "032x"), + parent_span_id, + span.name, + span.start_time, + span.end_time, + session_id, + invocation_id, + self._serialize_attributes(attributes), + )) + conn.executemany(_INSERT_SPAN, rows) + conn.commit() + return SpanExportResult.SUCCESS + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to export spans to SQLite: %s", e) + return SpanExportResult.FAILURE + + def shutdown(self) -> None: + with self._lock: + if self._conn is not None: + self._conn.close() + self._conn = None + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def _query(self, sql: str, params: Iterable[Any]) -> list[sqlite3.Row]: + with self._lock: + conn = self._get_connection() + cur = conn.execute(sql, tuple(params)) + return list(cur.fetchall()) + + def _row_to_readable_span(self, row: sqlite3.Row) -> ReadableSpan: + trace_id_hex = row["trace_id"] + span_id_hex = row["span_id"] + trace_id = int(str(trace_id_hex), 16) + span_id = int(str(span_id_hex), 16) + trace_state = TraceState() + trace_flags = TraceFlags(TraceFlags.SAMPLED) + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=trace_flags, + trace_state=trace_state, + ) + + parent: SpanContext | None = None + parent_span_id_hex = row["parent_span_id"] + if parent_span_id_hex: + parent = SpanContext( + trace_id=trace_id, + span_id=int(str(parent_span_id_hex), 16), + is_remote=False, + trace_flags=trace_flags, + trace_state=trace_state, + ) + + attributes = self._deserialize_attributes(row["attributes_json"]) + return ReadableSpan( + name=row["name"] or "", + context=context, + parent=parent, + attributes=attributes, + start_time=row["start_time_unix_nano"], + end_time=row["end_time_unix_nano"], + ) + + def get_all_spans_for_session(self, session_id: str) -> list[ReadableSpan]: + """Returns all spans for a session (full trace trees). + + We first find trace_ids associated with the session, then return all spans + for those trace_ids. This works even if some spans are missing session_id + attributes (e.g. parent spans). + """ + trace_rows = self._query( + "SELECT DISTINCT trace_id FROM spans WHERE session_id = ?", + (session_id,), + ) + trace_ids = [r["trace_id"] for r in trace_rows if r["trace_id"]] + if not trace_ids: + return [] + + placeholders = ",".join("?" for _ in trace_ids) + rows = self._query( + f"SELECT * FROM spans WHERE trace_id IN ({placeholders}) " + "ORDER BY start_time_unix_nano", + trace_ids, + ) + return [self._row_to_readable_span(row) for row in rows] diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index fbb55ec9..707bc313 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -23,8 +23,11 @@ from __future__ import annotations +import asyncio +from collections.abc import AsyncIterator from collections.abc import Iterator from collections.abc import Mapping +from contextlib import asynccontextmanager from contextlib import contextmanager import json import logging @@ -58,9 +61,15 @@ from opentelemetry.trace import Span from opentelemetry.util.types import AnyValue from opentelemetry.util.types import AttributeValue from pydantic import BaseModel +from typing_extensions import deprecated from .. import version from ..utils.model_name_utils import is_gemini_model +from ._experimental_semconv import is_experimental_semconv +from ._experimental_semconv import maybe_log_completion_details +from ._experimental_semconv import set_operation_details_attributes_from_request +from ._experimental_semconv import set_operation_details_attributes_from_response +from ._experimental_semconv import set_operation_details_common_attributes # By default some ADK spans include attributes with potential PII data. # This env, when set to false, allows to disable populating those attributes. @@ -427,6 +436,7 @@ def _should_add_request_response_to_spans() -> bool: return not disabled_via_env_var +@deprecated('Replaced by use_inference_span to support experimental semconv.') @contextmanager def use_generate_content_span( llm_request: LlmRequest, @@ -453,11 +463,57 @@ def use_generate_content_span( with _use_extra_generate_content_attributes(common_attributes): yield else: - with _use_native_generate_content_span( + with _use_native_generate_content_span_stable_semconv( llm_request=llm_request, common_attributes=common_attributes, ) as span: - yield span + yield span.span + + +@asynccontextmanager +async def use_inference_span( + llm_request: LlmRequest, + invocation_context: InvocationContext, + model_response_event: Event, +) -> AsyncIterator[GenerateContentSpan | None]: + """Context manager encompassing `generate_content {model.name}` span. + + When an external library for inference instrumentation is installed (e.g. + opentelemetry-instrumentation-google-genai), + span creation is delegated to said library. + """ + + common_attributes = { + GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_CONVERSATION_ID: invocation_context.session.id, + USER_ID: invocation_context.session.user_id, + 'gcp.vertex.agent.event_id': model_response_event.id, + 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + } + if ( + _is_gemini_agent(invocation_context.agent) + and _instrumented_with_opentelemetry_instrumentation_google_genai() + ): + with _use_extra_generate_content_attributes(common_attributes): + yield + else: + async with _use_native_generate_content_span( + llm_request=llm_request, + common_attributes=common_attributes, + ) as gc_span: + if is_experimental_semconv(): + set_operation_details_common_attributes( + gc_span.operation_details_common_attributes, common_attributes + ) + try: + yield gc_span + finally: + maybe_log_completion_details( + gc_span.span, + otel_logger, + gc_span.operation_details_attributes, + gc_span.operation_details_common_attributes, + ) def _should_log_prompt_response_content() -> bool: @@ -467,6 +523,8 @@ def _should_log_prompt_response_content() -> bool: def _serialize_content(content: types.ContentUnion) -> AnyValue: + if content is None: + return None if isinstance(content, BaseModel): return content.model_dump() if isinstance(content, str): @@ -540,18 +598,29 @@ def _is_gemini_agent(agent: BaseAgent) -> bool: return isinstance(agent.model, Gemini) -@contextmanager -def _use_native_generate_content_span( +def _set_common_generate_content_attributes( + span: Span, llm_request: LlmRequest, common_attributes: Mapping[str, AttributeValue], -) -> Iterator[Span]: +): + span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') + span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') + span.set_attributes(common_attributes) + + +@contextmanager +def _use_native_generate_content_span_stable_semconv( + llm_request: LlmRequest, + common_attributes: Mapping[str, AttributeValue], +) -> Iterator[GenerateContentSpan]: with tracer.start_as_current_span( f"generate_content {llm_request.model or ''}" ) as span: span.set_attribute(GEN_AI_SYSTEM, _guess_gemini_system_name()) - span.set_attribute(GEN_AI_OPERATION_NAME, 'generate_content') - span.set_attribute(GEN_AI_REQUEST_MODEL, llm_request.model or '') - span.set_attributes(common_attributes) + _set_common_generate_content_attributes( + span, llm_request, common_attributes + ) + gc_span = GenerateContentSpan(span) otel_logger.emit( LogRecord( @@ -564,7 +633,6 @@ def _use_native_generate_content_span( attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, ) ) - for content in llm_request.contents: otel_logger.emit( LogRecord( @@ -574,9 +642,51 @@ def _use_native_generate_content_span( ) ) - yield span + yield gc_span +@asynccontextmanager +async def _use_native_generate_content_span( + llm_request: LlmRequest, + common_attributes: Mapping[str, AttributeValue], +) -> AsyncIterator[GenerateContentSpan]: + if not is_experimental_semconv(): + with _use_native_generate_content_span_stable_semconv( + llm_request, common_attributes + ) as gc_span: + yield gc_span + return + + with tracer.start_as_current_span( + f"generate_content {llm_request.model or ''}" + ) as span: + + _set_common_generate_content_attributes( + span, llm_request, common_attributes + ) + gc_span = GenerateContentSpan(span) + + await set_operation_details_attributes_from_request( + gc_span.operation_details_attributes, llm_request + ) + yield gc_span + + +class GenerateContentSpan: + """Manages tracing within a `generate_content` OpenTelemetry span. + + This class provides attributes for the experimental semantic convention. + """ + + def __init__(self, span: Span): + self.span = span + self.operation_details_attributes = {} + self.operation_details_common_attributes = {} + + +@deprecated( + 'Replaced by trace_inference_result to support experimental semconv.' +) def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): """Trace result of the inference in generate_content span.""" @@ -613,6 +723,61 @@ def trace_generate_content_result(span: Span | None, llm_response: LlmResponse): ) +def trace_inference_result( + span: Span | None | GenerateContentSpan, + llm_response: LlmResponse, +): + """Trace result of the inference in generate_content span.""" + gc_span = None + if isinstance(span, GenerateContentSpan): + gc_span = span + span = gc_span.span + + if span is None: + return + + if llm_response.partial: + return + + if finish_reason := llm_response.finish_reason: + span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, [finish_reason.lower()]) + if usage_metadata := llm_response.usage_metadata: + if usage_metadata.prompt_token_count is not None: + span.set_attribute( + GEN_AI_USAGE_INPUT_TOKENS, usage_metadata.prompt_token_count + ) + if usage_metadata.candidates_token_count is not None: + span.set_attribute( + GEN_AI_USAGE_OUTPUT_TOKENS, usage_metadata.candidates_token_count + ) + + if is_experimental_semconv() and isinstance(gc_span, GenerateContentSpan): + set_operation_details_attributes_from_response( + llm_response, + gc_span.operation_details_attributes, + gc_span.operation_details_common_attributes, + ) + + else: + otel_logger.emit( + LogRecord( + event_name='gen_ai.choice', + body={ + 'content': _serialize_content_with_elision( + llm_response.content + ), + 'index': 0, # ADK always returns a single candidate + } + | ( + {'finish_reason': llm_response.finish_reason.value} + if llm_response.finish_reason is not None + else {} + ), + attributes={GEN_AI_SYSTEM: _guess_gemini_system_name()}, + ) + ) + + def _guess_gemini_system_name() -> str: return ( GenAiSystemValues.VERTEX_AI.name.lower() diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index a3097ad4..392e256b 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -151,9 +151,13 @@ def _remove_title(schema: Dict): def _get_pydantic_schema(func: Callable) -> Dict: + from ..utils.context_utils import find_context_parameter + fields_dict = _get_fields_dict(func) - if 'tool_context' in fields_dict.keys(): - fields_dict.pop('tool_context') + # Remove context parameter (detected by type or fallback to 'tool_context' name) + context_param = find_context_parameter(func) or 'tool_context' + if context_param in fields_dict.keys(): + fields_dict.pop(context_param) return pydantic.create_model(func.__name__, **fields_dict).model_json_schema() diff --git a/src/google/adk/tools/_gemini_schema_util.py b/src/google/adk/tools/_gemini_schema_util.py index 6a05f6c6..595b41a0 100644 --- a/src/google/adk/tools/_gemini_schema_util.py +++ b/src/google/adk/tools/_gemini_schema_util.py @@ -152,6 +152,13 @@ def _sanitize_schema_formats_for_gemini( ) for item in schema ] + # JSON Schema allows boolean schemas: `true` (accept any value) and `false` + # (reject all values). Gemini has no equivalent for either. `true` is + # approximated as an unconstrained object schema; `false` has no meaningful + # Gemini representation and is also mapped to an object schema as a safe + # fallback so that schema conversion does not crash. + if isinstance(schema, bool): + return {"type": "object"} if not isinstance(schema, dict): return schema diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 91135dce..f53c18df 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -28,6 +28,8 @@ from ..agents.common_configs import AgentRefConfig from ..features import FeatureName from ..features import is_feature_enabled from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..utils._schema_utils import SchemaType +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool @@ -64,7 +66,7 @@ def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: return None -def _get_output_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: +def _get_output_schema(agent: BaseAgent) -> Optional[SchemaType]: """Extracts the output_schema from an agent. For LlmAgent, returns its output_schema directly. @@ -268,9 +270,7 @@ class AgentTool(BaseTool): ) output_schema = _get_output_schema(self.agent) if output_schema: - tool_result = output_schema.model_validate_json(merged_text).model_dump( - exclude_none=True - ) + tool_result = validate_schema(output_schema, merged_text) else: tool_result = merged_text return tool_result diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py index feaf1c7a..d3483fc2 100644 --- a/src/google/adk/tools/api_registry.py +++ b/src/google/adk/tools/api_registry.py @@ -14,128 +14,13 @@ from __future__ import annotations -from typing import Any -from typing import Callable +import warnings -from google.adk.agents.readonly_context import ReadonlyContext -import google.auth -import google.auth.transport.requests -import httpx +from google.adk.integrations.api_registry import ApiRegistry -from .base_toolset import ToolPredicate -from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -from .mcp_tool.mcp_toolset import McpToolset - -API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" - - -class ApiRegistry: - """Registry that provides McpToolsets for MCP servers registered in API Registry.""" - - def __init__( - self, - api_registry_project_id: str, - location: str = "global", - header_provider: ( - Callable[[ReadonlyContext], dict[str, str]] | None - ) = None, - ): - """Initialize the API Registry. - - Args: - api_registry_project_id: The project ID for the Google Cloud API Registry. - location: The location of the API Registry resources. - header_provider: Optional function to provide additional headers for MCP - server calls. - """ - self.api_registry_project_id = api_registry_project_id - self.location = location - self._credentials, _ = google.auth.default() - self._mcp_servers: dict[str, dict[str, Any]] = {} - self._header_provider = header_provider - - url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" - - try: - headers = self._get_auth_headers() - headers["Content-Type"] = "application/json" - page_token = None - with httpx.Client() as client: - while True: - params = {} - if page_token: - params["pageToken"] = page_token - - response = client.get(url, headers=headers, params=params) - response.raise_for_status() - data = response.json() - mcp_servers_list = data.get("mcpServers", []) - for server in mcp_servers_list: - server_name = server.get("name", "") - if server_name: - self._mcp_servers[server_name] = server - - page_token = data.get("nextPageToken") - if not page_token: - break - except (httpx.HTTPError, ValueError) as e: - # Handle error in fetching or parsing tool definitions - raise RuntimeError( - f"Error fetching MCP servers from API Registry: {e}" - ) from e - - def get_toolset( - self, - mcp_server_name: str, - tool_filter: ToolPredicate | list[str] | None = None, - tool_name_prefix: str | None = None, - ) -> McpToolset: - """Return the MCP Toolset based on the params. - - Args: - mcp_server_name: Filter to select the MCP server name to get tools from. - tool_filter: Optional filter to select specific tools. Can be a list of - tool names or a ToolPredicate function. - tool_name_prefix: Optional prefix to prepend to the names of the tools - returned by the toolset. - - Returns: - McpToolset: A toolset for the MCP server specified. - """ - server = self._mcp_servers.get(mcp_server_name) - if not server: - raise ValueError( - f"MCP server {mcp_server_name} not found in API Registry." - ) - if not server.get("urls"): - raise ValueError(f"MCP server {mcp_server_name} has no URLs.") - - mcp_server_url = server["urls"][0] - headers = self._get_auth_headers() - - # Only prepend "https://" if the URL doesn't already have a scheme - if not mcp_server_url.startswith(("http://", "https://")): - mcp_server_url = "https://" + mcp_server_url - - return McpToolset( - connection_params=StreamableHTTPConnectionParams( - url=mcp_server_url, - headers=headers, - ), - tool_filter=tool_filter, - tool_name_prefix=tool_name_prefix, - header_provider=self._header_provider, - ) - - def _get_auth_headers(self) -> dict[str, str]: - """Refreshes credentials and returns authorization headers.""" - request = google.auth.transport.requests.Request() - self._credentials.refresh(request) - headers = { - "Authorization": f"Bearer {self._credentials.token}", - } - # Add quota project header if available in ADC - quota_project_id = getattr(self._credentials, "quota_project_id", None) - if quota_project_id: - headers["x-goog-user-project"] = quota_project_id - return headers +warnings.warn( + "google.adk.tools.api_registry is moved to" + " google.adk.integrations.api_registry", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/google/adk/tools/bash_tool.py b/src/google/adk/tools/bash_tool.py new file mode 100644 index 00000000..38e99643 --- /dev/null +++ b/src/google/adk/tools/bash_tool.py @@ -0,0 +1,150 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool to execute bash commands.""" + +from __future__ import annotations + +import dataclasses +import pathlib +import shlex +import subprocess +from typing import Any +from typing import Optional + +from google.genai import types + +from .. import features +from .base_tool import BaseTool +from .tool_context import ToolContext + + +@dataclasses.dataclass(frozen=True) +class BashToolPolicy: + """Configuration for allowed bash commands based on prefix matching. + + Set allowed_command_prefixes to ("*",) to allow all commands (default), + or explicitly list allowed prefixes. + """ + + allowed_command_prefixes: tuple[str, ...] = ("*",) + + +def _validate_command(command: str, policy: BashToolPolicy) -> Optional[str]: + """Validates a bash command against the permitted prefixes.""" + stripped = command.strip() + if not stripped: + return "Command is required." + + if "*" in policy.allowed_command_prefixes: + return None + + for prefix in policy.allowed_command_prefixes: + if stripped.startswith(prefix): + return None + + allowed = ", ".join(policy.allowed_command_prefixes) + return f"Command blocked. Permitted prefixes are: {allowed}" + + +@features.experimental(features.FeatureName.SKILL_TOOLSET) +class ExecuteBashTool(BaseTool): + """Tool to execute a validated bash command within a workspace directory.""" + + def __init__( + self, + *, + workspace: pathlib.Path | None = None, + policy: Optional[BashToolPolicy] = None, + ): + if workspace is None: + workspace = pathlib.Path.cwd() + policy = policy or BashToolPolicy() + allowed_hint = ( + "any command" + if "*" in policy.allowed_command_prefixes + else ( + "commands matching prefixes:" + f" {', '.join(policy.allowed_command_prefixes)}" + ) + ) + super().__init__( + name="execute_bash", + description=( + "Executes a bash command with the working directory set to the" + f" workspace. Allowed: {allowed_hint}. All commands require user" + " confirmation." + ), + ) + self._workspace = workspace + self._policy = policy + + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute.", + }, + }, + "required": ["command"], + }, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + command = args.get("command") + if not command: + return {"error": "Command is required."} + + # Static validation. + error = _validate_command(command, self._policy) + if error: + return {"error": error} + + # Always request user confirmation. + if not tool_context.tool_confirmation: + tool_context.request_confirmation( + hint=f"Please approve or reject the bash command: {command}", + ) + tool_context.actions.skip_summarization = True + return { + "error": ( + "This tool call requires confirmation, please approve or reject." + ) + } + elif not tool_context.tool_confirmation.confirmed: + return {"error": "This tool call is rejected."} + + try: + result = subprocess.run( + shlex.split(command), + shell=False, + cwd=str(self._workspace), + capture_output=True, + text=True, + timeout=30, + ) + return { + "stdout": result.stdout, + "stderr": result.stderr, + "returncode": result.returncode, + } + except subprocess.TimeoutExpired: + return {"error": "Command timed out after 30 seconds."} diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index fa23c74c..958ce9d7 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,6 +19,10 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" +BIGQUERY_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", +] BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_DEFAULT_SCOPE - + self.scopes = BIGQUERY_SCOPES + # Set the token cache key self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY return self diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 1a748b71..dba5f8ee 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,6 +24,7 @@ from typing_extensions import override from . import data_insights_tool from . import metadata_tool from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset): query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, + search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d57c0c80..2cb4e67c 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,19 +14,22 @@ from __future__ import annotations +from typing import List from typing import Optional +from typing import Union import google.api_core.client_info +from google.api_core.gapic_v1 import client_info as gapic_client_info from google.auth.credentials import Credentials from google.cloud import bigquery +from google.cloud import dataplex_v1 from ... import version -USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" - - -from typing import List -from typing import Union +USER_AGENT_BASE = f"google-adk/{version.__version__}" +BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" +DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" +USER_AGENT = BQ_USER_AGENT def get_bigquery_client( @@ -48,7 +51,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [USER_AGENT] + user_agents = [BQ_USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -67,3 +70,33 @@ def get_bigquery_client( ) return bigquery_client + + +def get_dataplex_catalog_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> dataplex_v1.CatalogServiceClient: + """Get a Dataplex CatalogServiceClient with minimal necessary arguments. + + Args: + credentials: The credentials to use for the request. + user_agent: Additional user agent string(s) to append. + + Returns: + A Dataplex Client. + """ + + user_agents = [DP_USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents)) + + return dataplex_v1.CatalogServiceClient( + credentials=credentials, + client_info=client_info, + ) diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py new file mode 100644 index 00000000..0bf01d5a --- /dev/null +++ b/src/google/adk/tools/bigquery/search_tool.py @@ -0,0 +1,179 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any + +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from . import client +from .config import BigQueryToolConfig + + +def _construct_search_query_helper( + predicate: str, operator: str, items: list[str] +) -> str: + """Constructs a search query part for a specific predicate and items.""" + if not items: + return "" + + clauses = [f'{predicate}{operator}"{item}"' for item in items] + return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0] + + +def search_catalog( + prompt: str, + project_id: str, + *, + credentials: Credentials, + settings: BigQueryToolConfig, + location: str | None = None, + page_size: int = 10, + project_ids_filter: list[str] | None = None, + dataset_ids_filter: list[str] | None = None, + types_filter: list[str] | None = None, +) -> dict[str, Any]: + """Searches for BigQuery assets within Dataplex. + + Args: + prompt: The base search query (natural language or keywords). + project_id: The Google Cloud project ID to scope the search. + credentials: Credentials for the request. + settings: BigQuery tool settings. + location: The Dataplex location to use. + page_size: Maximum number of results. + project_ids_filter: Specific project IDs to include in the search results. + If None, defaults to the scoping project_id. + dataset_ids_filter: BigQuery dataset IDs to filter by. + types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE, + BigQueryEntryType.DATASET). + + Returns: + Search results or error. The "results" list contains items with: + - name: The Dataplex Entry name (e.g., + "projects/p/locations/l/entryGroups/g/entries/e"). + - linked_resource: The underlying BigQuery resource name (e.g., + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t"). + - display_name, entry_type, description, location, update_time. + + Examples: + Search for tables related to customer data: + + >>> search_catalog( + ... prompt="Search for tables related to customer data", + ... project_id="my-project", + ... credentials=creds, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "results": [ + { + "name": + "projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id", + "display_name": "customer_table", + "entry_type": + "projects/p/locations/l/entryTypes/bigquery-table", + "linked_resource": + "//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table", + "description": "Table containing customer details.", + "location": "us", + "update_time": "2024-01-01 12:00:00+00:00" + } + ] + } + """ + + try: + if not project_id: + return { + "status": "ERROR", + "error_details": "project_id must be provided.", + } + + with client.get_dataplex_catalog_client( + credentials=credentials, + user_agent=[settings.application_name, "search_catalog"], + ) as dataplex_client: + query_parts = [] + if prompt: + query_parts.append(f"({prompt})") + + # Filter by project IDs + projects_to_filter = ( + project_ids_filter if project_ids_filter else [project_id] + ) + if projects_to_filter: + query_parts.append( + _construct_search_query_helper("projectid", "=", projects_to_filter) + ) + + # Filter by dataset IDs + if dataset_ids_filter: + dataset_resource_filters = [] + for pid in projects_to_filter: + for did in dataset_ids_filter: + dataset_resource_filters.append( + f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' + ) + if dataset_resource_filters: + query_parts.append(f"({' OR '.join(dataset_resource_filters)})") + # Filter by entry types + if types_filter: + query_parts.append( + _construct_search_query_helper("type", "=", types_filter) + ) + + # Always scope to BigQuery system + query_parts.append("system=BIGQUERY") + + full_query = " AND ".join(filter(None, query_parts)) + + search_location = location or settings.location or "global" + search_scope = f"projects/{project_id}/locations/{search_location}" + + request = dataplex_v1.SearchEntriesRequest( + name=search_scope, + query=full_query, + page_size=page_size, + semantic_search=True, + ) + + response = dataplex_client.search_entries(request=request) + + results = [] + for result in response.results: + entry = result.dataplex_entry + source = entry.entry_source + results.append({ + "name": entry.name, + "display_name": source.display_name or "", + "entry_type": entry.entry_type, + "update_time": str(entry.update_time), + "linked_resource": source.resource or "", + "description": source.description or "", + "location": source.location or "", + }) + return {"status": "SUCCESS", "results": results} + + except api_exceptions.GoogleAPICallError as e: + logging.exception("search_catalog tool: API call failed") + return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"} + except Exception as e: + logging.exception("search_catalog tool: Unexpected error") + return {"status": "ERROR", "error_details": repr(e)} diff --git a/src/google/adk/tools/bigtable/bigtable_toolset.py b/src/google/adk/tools/bigtable/bigtable_toolset.py index 8e9f430f..97fc2eb0 100644 --- a/src/google/adk/tools/bigtable/bigtable_toolset.py +++ b/src/google/adk/tools/bigtable/bigtable_toolset.py @@ -44,6 +44,8 @@ class BigtableToolset(BaseToolset): - bigtable_get_instance_info - bigtable_list_tables - bigtable_get_table_info + - bigtable_list_clusters + - bigtable_get_cluster_info - bigtable_execute_sql """ @@ -95,6 +97,8 @@ class BigtableToolset(BaseToolset): metadata_tool.get_instance_info, metadata_tool.list_tables, metadata_tool.get_table_info, + metadata_tool.list_clusters, + metadata_tool.get_cluster_info, query_tool.execute_sql, ] ] diff --git a/src/google/adk/tools/bigtable/metadata_tool.py b/src/google/adk/tools/bigtable/metadata_tool.py index 703c3447..de4fea6a 100644 --- a/src/google/adk/tools/bigtable/metadata_tool.py +++ b/src/google/adk/tools/bigtable/metadata_tool.py @@ -14,12 +14,16 @@ from __future__ import annotations +import enum import logging from google.auth.credentials import Credentials +from google.cloud.bigtable import enums from . import client +logger = logging.getLogger(f"google_adk.{__name__}") + def list_instances(project_id: str, credentials: Credentials) -> dict: """List Bigtable instance ids in a Google Cloud project. @@ -29,7 +33,22 @@ def list_instances(project_id: str, credentials: Credentials) -> dict: credentials (Credentials): The credentials to use for the request. Returns: - dict: Dictionary with a list of the Bigtable instance ids present in the project. + dict: Dictionary with a list of dictionaries, each representing a Bigtable instance. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + ], + } """ try: bt_client = client.get_bigtable_admin_client( @@ -41,12 +60,27 @@ def list_instances(project_id: str, credentials: Credentials) -> dict: "Failed to list instances from the following locations: %s", failed_locations_list, ) - instance_ids = [instance.instance_id for instance in instances_list] - return {"status": "SUCCESS", "results": instance_ids} + result = [ + { + "project_id": project_id, + "instance_id": instance.instance_id, + "display_name": instance.display_name, + "state": _enum_name_from_value( + enums.Instance.State, instance.state, "UNKNOWN_STATE" + ), + "type": _enum_name_from_value( + enums.Instance.Type, instance.type_, "UNKNOWN_TYPE" + ), + "labels": instance.labels, + } + for instance in instances_list + ] + return {"status": "SUCCESS", "results": result} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } @@ -69,26 +103,33 @@ def get_instance_info( ) instance = bt_client.instance(instance_id) instance.reload() - instance_info = { - "project_id": project_id, - "instance_id": instance.instance_id, - "display_name": instance.display_name, - "state": instance.state, - "type": instance.type_, - "labels": instance.labels, + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance.instance_id, + "display_name": instance.display_name, + "state": _enum_name_from_value( + enums.Instance.State, instance.state, "UNKNOWN_STATE" + ), + "type": _enum_name_from_value( + enums.Instance.Type, instance.type_, "UNKNOWN_TYPE" + ), + "labels": instance.labels, + }, } - return {"status": "SUCCESS", "results": instance_info} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } def list_tables( project_id: str, instance_id: str, credentials: Credentials ) -> dict: - """List table ids in a Bigtable instance. + """List tables and their metadata in a Bigtable instance. Args: project_id (str): The Google Cloud project id containing the instance. @@ -96,7 +137,21 @@ def list_tables( credentials (Credentials): The credentials to use for the request. Returns: - dict: Dictionary with a list of the tables ids present in the instance. + dict: A dictionary with status and results, where results is a list of + table properties. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "table_name": "fake-table-name", + } + ], + } """ try: bt_client = client.get_bigtable_admin_client( @@ -104,17 +159,29 @@ def list_tables( ) instance = bt_client.instance(instance_id) tables = instance.list_tables() - table_ids = [table.table_id for table in tables] - return {"status": "SUCCESS", "results": table_ids} + result = [ + { + "project_id": project_id, + "instance_id": instance_id, + "table_id": table.table_id, + "table_name": table.name, + } + for table in tables + ] + return {"status": "SUCCESS", "results": result} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), } def get_table_info( - project_id: str, instance_id: str, table_id: str, credentials: Credentials + project_id: str, + instance_id: str, + table_id: str, + credentials: Credentials, ) -> dict: """Get metadata information about a Bigtable table. @@ -126,6 +193,17 @@ def get_table_info( Returns: dict: Dictionary representing the properties of the table. + + Example: + { + "status": "SUCCESS", + "results": { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "column_families": ["cf1", "cf2"], + }, + } """ try: bt_client = client.get_bigtable_admin_client( @@ -134,15 +212,170 @@ def get_table_info( instance = bt_client.instance(instance_id) table = instance.table(table_id) column_families = table.list_column_families() - table_info = { - "project_id": project_id, - "instance_id": instance.instance_id, - "table_id": table.table_id, - "column_families": list(column_families.keys()), + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance.instance_id, + "table_id": table.table_id, + "column_families": list(column_families.keys()), + }, } - return {"status": "SUCCESS", "results": table_info} except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), + } + + +def _enum_name_from_value( + enum_class: type[enum.Enum], value: int, prefix: str = "UNKNOWN" +) -> str: + for attr_name in dir(enum_class): + if not attr_name.startswith("_"): + if getattr(enum_class, attr_name) == value: + return attr_name + return f"{prefix}_{value}" + + +def list_clusters( + project_id: str, instance_id: str, credentials: Credentials +) -> dict: + """List clusters and their metadata in a Bigtable instance. + + Args: + project_id (str): The Google Cloud project id containing the instance. + instance_id (str): The Bigtable instance id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the cluster. + + Example: + { + "status": "SUCCESS", + "results": [ + { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "cluster_name": "fake-cluster-name", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + } + ], + } + """ + try: + bt_client = client.get_bigtable_admin_client( + project=project_id, credentials=credentials + ) + instance = bt_client.instance(instance_id) + instance.reload() + clusters_list, failed_locations = instance.list_clusters() + if failed_locations: + logging.warning( + "Failed to list clusters from the following locations: %s", + failed_locations, + ) + + result = [ + { + "project_id": project_id, + "instance_id": instance_id, + "cluster_id": cluster.cluster_id, + "cluster_name": cluster.name, + "state": _enum_name_from_value( + enums.Cluster.State, cluster.state, "UNKNOWN_STATE" + ), + "serve_nodes": cluster.serve_nodes, + "default_storage_type": _enum_name_from_value( + enums.StorageType, + cluster.default_storage_type, + "UNKNOWN_STORAGE_TYPE", + ), + "location_id": cluster.location_id, + } + for cluster in clusters_list + ] + return {"status": "SUCCESS", "results": result} + except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) + return { + "status": "ERROR", + "error_details": repr(ex), + } + + +def get_cluster_info( + project_id: str, + instance_id: str, + cluster_id: str, + credentials: Credentials, +) -> dict: + """Get detailed metadata information about a Bigtable cluster. + + Args: + project_id (str): The Google Cloud project id containing the instance. + instance_id (str): The Bigtable instance id containing the cluster. + cluster_id (str): The Bigtable cluster id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the cluster. + + Example: + { + "status": "SUCCESS", + "results": { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + "min_serve_nodes": 1, + "max_serve_nodes": 10, + "cpu_utilization_percent": 80, + }, + } + """ + try: + bt_client = client.get_bigtable_admin_client( + project=project_id, credentials=credentials + ) + instance = bt_client.instance(instance_id) + instance.reload() + cluster = instance.cluster(cluster_id) + cluster.reload() + return { + "status": "SUCCESS", + "results": { + "project_id": project_id, + "instance_id": instance_id, + "cluster_id": cluster.cluster_id, + "state": _enum_name_from_value( + enums.Cluster.State, cluster.state, "UNKNOWN_STATE" + ), + "serve_nodes": cluster.serve_nodes, + "default_storage_type": _enum_name_from_value( + enums.StorageType, + cluster.default_storage_type, + "UNKNOWN_STORAGE_TYPE", + ), + "location_id": cluster.location_id, + "min_serve_nodes": cluster.min_serve_nodes, + "max_serve_nodes": cluster.max_serve_nodes, + "cpu_utilization_percent": cluster.cpu_utilization_percent, + }, + } + except Exception as ex: + logger.exception("Bigtable metadata tool failed: %s", ex) + return { + "status": "ERROR", + "error_details": repr(ex), } diff --git a/src/google/adk/tools/bigtable/query_tool.py b/src/google/adk/tools/bigtable/query_tool.py index a7a785a2..bf64b282 100644 --- a/src/google/adk/tools/bigtable/query_tool.py +++ b/src/google/adk/tools/bigtable/query_tool.py @@ -15,6 +15,7 @@ from __future__ import annotations """Tool to execute SQL queries against Bigtable.""" +import asyncio import json import logging from typing import Any @@ -22,7 +23,6 @@ from typing import Dict from typing import List from google.auth.credentials import Credentials -from google.cloud import bigtable from . import client from ..tool_context import ToolContext @@ -33,13 +33,15 @@ logger = logging.getLogger("google_adk." + __name__) DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 -def execute_sql( +async def execute_sql( project_id: str, instance_id: str, query: str, credentials: Credentials, settings: BigtableToolSettings, tool_context: ToolContext, + parameters: Dict[str, Any] | None = None, + parameter_types: Dict[str, Any] | None = None, ) -> dict: """Execute a GoogleSQL query from a Bigtable table. @@ -51,6 +53,10 @@ def execute_sql( credentials (Credentials): The credentials to use for the request. settings (BigtableToolSettings): The configuration for the tool. tool_context (ToolContext): The context for the tool. + parameters (dict): properties for parameter replacement. Keys must match + the names used in ``query``. + parameter_types (dict): maps explicit types for one or more param values. + Returns: dict: Dictionary containing the status and the rows read. If the result contains the key "result_is_likely_truncated" with @@ -59,64 +65,70 @@ def execute_sql( Examples: Fetch data or insights from a table: - - >>> execute_sql("my_project", "my_instance", - ... "SELECT * from mytable", credentials, config, tool_context) - { - "status": "SUCCESS", - "rows": [ - { - "user_id": 1, - "user_name": "Alice" - } - ] - } + + >>> await execute_sql("my_project", "my_instance", + ... "SELECT * from mytable", credentials, config, tool_context) + { + "status": "SUCCESS", + "rows": [ + { + "user_id": 1, + "user_name": "Alice" + } + ] + } + """ del tool_context # Unused for now - try: - bt_client = client.get_bigtable_data_client( - project=project_id, credentials=credentials - ) - eqi = bt_client.execute_query( - query=query, - instance_id=instance_id, - ) - - rows: List[Dict[str, Any]] = [] - max_rows = ( - settings.max_query_result_rows - if settings and settings.max_query_result_rows > 0 - else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS - ) - counter = max_rows - truncated = False + def _execute_sql(): try: - for row in eqi: - if counter <= 0: - truncated = True - break - row_values = {} - for key, val in dict(row.fields).items(): - try: - # if the json serialization of the value succeeds, use it as is - json.dumps(val) - except (TypeError, ValueError, OverflowError): - val = str(val) - row_values[key] = val - rows.append(row_values) - counter -= 1 - finally: - eqi.close() + bt_client = client.get_bigtable_data_client( + project=project_id, credentials=credentials + ) + eqi = bt_client.execute_query( + query=query, + instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, + ) - result = {"status": "SUCCESS", "rows": rows} - if truncated: - result["result_is_likely_truncated"] = True - return result + rows: List[Dict[str, Any]] = [] + max_rows = ( + settings.max_query_result_rows + if settings and settings.max_query_result_rows > 0 + else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS + ) + counter = max_rows + truncated = False + try: + for row in eqi: + if counter <= 0: + truncated = True + break + row_values = {} + for key, val in dict(row.fields).items(): + try: + # if the json serialization of the value succeeds, use it as is + json.dumps(val) + except (TypeError, ValueError, OverflowError): + val = str(val) + row_values[key] = val + rows.append(row_values) + counter -= 1 + finally: + eqi.close() - except Exception as ex: - logger.error("Bigtable query failed: %s", ex) - return { - "status": "ERROR", - "error_details": str(ex), - } + result = {"status": "SUCCESS", "rows": rows} + if truncated: + result["result_is_likely_truncated"] = True + return result + + except Exception as ex: + logger.exception("Bigtable query failed") + return { + "status": "ERROR", + "error_details": str(ex), + } + + return await asyncio.to_thread(_execute_sql) diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index f8022e11..fca8ba9f 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -90,19 +90,19 @@ class CrewaiTool(FunctionTool): # remove arguments like `self` that are managed by the framework and not # intended to be passed through **kwargs. args_to_call.pop('self', None) - # We also remove `tool_context` that might have been passed in `args`, + # We also remove context param that might have been passed in `args`, # as it will be explicitly injected later if it's a valid parameter. - args_to_call.pop('tool_context', None) + args_to_call.pop(self._context_param_name, None) else: # For functions without **kwargs, use the original filtering. args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } - # Inject tool_context if it's an explicit parameter. This will add it + # Inject context if it's an explicit parameter. This will add it # or overwrite any value that might have been passed in `args`. - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if self._context_param_name in valid_params: + args_to_call[self._context_param_name] = tool_context # Check for missing mandatory arguments mandatory_args = self._get_mandatory_args() diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 4f7a0d7f..c114fdb4 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -54,14 +55,16 @@ class EnterpriseWebSearchTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Enterprise Web Search tool cannot be used with other tools in' ' Gemini 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( types.Tool(enterprise_web_search=types.EnterpriseWebSearch()) ) diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 6b8496dc..10e32a54 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -28,6 +28,7 @@ import pydantic from typing_extensions import override from ..utils.context_utils import Aclosing +from ..utils.context_utils import find_context_parameter from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool from .tool_context import ToolContext @@ -80,7 +81,9 @@ class FunctionTool(BaseTool): super().__init__(name=name, description=doc) self.func = func - self._ignore_params = ['tool_context', 'input_stream'] + # Detect context parameter by type annotation, fallback to 'tool_context' name + self._context_param_name = find_context_parameter(func) or 'tool_context' + self._ignore_params = [self._context_param_name, 'input_stream'] self._require_confirmation = require_confirmation @override @@ -162,8 +165,8 @@ class FunctionTool(BaseTool): signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + if self._context_param_name in valid_params: + args_to_call[self._context_param_name] = tool_context # Filter args_to_call to only include valid parameters for the function args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} @@ -195,8 +198,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th if require_confirmation: if not tool_context.tool_confirmation: args_to_show = args_to_call.copy() - if 'tool_context' in args_to_show: - args_to_show.pop('tool_context') + if self._context_param_name in args_to_show: + args_to_show.pop(self._context_param_name) tool_context.request_confirmation( hint=( @@ -254,8 +257,8 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th args_to_call['input_stream'] = invocation_context.active_streaming_tools[ self.name ].stream - if 'tool_context' in signature.parameters: - args_to_call['tool_context'] = tool_context + if self._context_param_name in signature.parameters: + args_to_call[self._context_param_name] = tool_context # TODO: support tool confirmation for live mode. async with Aclosing(self.func(**args_to_call)) as agen: diff --git a/src/google/adk/tools/google_maps_grounding_tool.py b/src/google/adk/tools/google_maps_grounding_tool.py index bade0a33..d4b105ec 100644 --- a/src/google/adk/tools/google_maps_grounding_tool.py +++ b/src/google/adk/tools/google_maps_grounding_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -49,13 +50,14 @@ class GoogleMapsGroundingTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError( 'Google Maps grounding tool cannot be used with Gemini 1.x models.' ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_maps=types.GoogleMaps()) ) diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index 56da204e..7ed09c79 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -23,6 +23,7 @@ from typing_extensions import override from ..agents.llm_agent import LlmAgent from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.base_llm import BaseLlm +from ..utils._schema_utils import validate_schema from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .agent_tool import AgentTool @@ -127,9 +128,7 @@ class GoogleSearchAgentTool(AgentTool): return '' merged_text = '\n'.join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - tool_result = self.agent.output_schema.model_validate_json( - merged_text - ).model_dump(exclude_none=True) + tool_result = validate_schema(self.agent.output_schema, merged_text) else: tool_result = merged_text diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 406ad218..1c11e091 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -67,6 +68,7 @@ class GoogleSearchTool(BaseTool): if self.model is not None: llm_request.model = self.model + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): @@ -77,7 +79,7 @@ class GoogleSearchTool(BaseTool): llm_request.config.tools.append( types.Tool(google_search_retrieval=types.GoogleSearchRetrieval()) ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_search=types.GoogleSearch()) ) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index f31768a0..bf279f52 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -41,6 +41,7 @@ from ...auth.auth_schemes import AuthScheme from ...auth.auth_tool import AuthConfig from ...features import FeatureName from ...features import is_feature_enabled +from ...utils.context_utils import find_context_parameter from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool from ..tool_context import ToolContext @@ -242,14 +243,18 @@ class McpTool(BaseAuthenticatedTool): for param in signature.parameters.values() ) - if "tool_context" in valid_params or has_kwargs: - args_to_call["tool_context"] = tool_context + # Detect context parameter by type or fallback to 'tool_context' name + context_param = ( + find_context_parameter(self._require_confirmation) or "tool_context" + ) + if context_param in valid_params or has_kwargs: + args_to_call[context_param] = tool_context # Filter args_to_call only if there's no **kwargs if not has_kwargs: - # Add tool_context to valid_params if it was added to args_to_call - if "tool_context" in args_to_call: - valid_params.add("tool_context") + # Add context param to valid_params if it was added to args_to_call + if context_param in args_to_call: + valid_params.add(context_param) args_to_call = { k: v for k, v in args_to_call.items() if k in valid_params } @@ -264,10 +269,6 @@ class McpTool(BaseAuthenticatedTool): if require_confirmation: if not tool_context.tool_confirmation: - args_to_show = args.copy() - if "tool_context" in args_to_show: - args_to_show.pop("tool_context") - tool_context.request_confirmation( hint=( f"Please approve or reject the tool call {self.name}() by" diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 1dbe0fe4..2b79edf9 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -19,6 +19,7 @@ from __future__ import annotations from typing import Optional import google.auth +from google.auth import exceptions as google_auth_exceptions from google.auth.transport.requests import Request from google.oauth2 import service_account import google.oauth2.credentials @@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredential from .....auth.auth_credential import AuthCredentialTypes from .....auth.auth_credential import HttpAuth from .....auth.auth_credential import HttpCredentials +from .....auth.auth_credential import ServiceAccount from .....auth.auth_schemes import AuthScheme from .base_credential_exchanger import AuthCredentialMissingError from .base_credential_exchanger import BaseAuthCredentialExchanger @@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): Uses the default service credential if `use_default_credential = True`. Otherwise, uses the service account credential provided in the auth credential. + + Supports exchanging for either an access token (default) or an ID token + when ``ServiceAccount.use_id_token`` is True. ID tokens are required for + service-to-service authentication with Cloud Run, Cloud Functions, and + other services that verify caller identity. """ def exchange_credential( @@ -45,52 +52,130 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): auth_scheme: AuthScheme, auth_credential: Optional[AuthCredential] = None, ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. + """Exchanges the service account auth credential for a token. If auth_credential contains a service account credential, it will be used - to fetch an access token. Otherwise, the default service credential will be - used for fetching an access token. + to fetch a token. Otherwise, the default service credential will be + used for fetching a token. + + When ``service_account.use_id_token`` is True, an ID token is fetched + using the configured ``audience``. This is required for authenticating + to Cloud Run, Cloud Functions, and similar services. Args: auth_scheme: The auth scheme. auth_credential: The auth credential. Returns: - An AuthCredential in HTTPBearer format, containing the access token. + An AuthCredential in HTTPBearer format, containing the token. """ - if ( - auth_credential is None - or auth_credential.service_account is None - or ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ) - ): + if auth_credential is None or auth_credential.service_account is None: raise AuthCredentialMissingError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" + "Service account credentials are missing. Please provide them, or" + " set `use_default_credential = True` to use application default" " credential in a hosted service like Cloud Run." ) + sa_config = auth_credential.service_account + + if sa_config.use_id_token: + return self._exchange_for_id_token(sa_config) + + return self._exchange_for_access_token(sa_config) + + def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential: + """Exchanges the service account credential for an ID token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the ID token. + + Raises: + AuthCredentialMissingError: If token exchange fails. + """ + # audience and credential presence are validated by the ServiceAccount + # model_validator at construction time. try: - if auth_credential.service_account.use_default_credential: - credentials, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - quota_project_id = ( - getattr(credentials, "quota_project_id", None) or project_id - ) + if sa_config.use_default_credential: + from google.oauth2 import id_token as oauth2_id_token + + request = Request() + token = oauth2_id_token.fetch_id_token(request, sa_config.audience) else: - config = auth_credential.service_account + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None + credentials = ( + service_account.IDTokenCredentials.from_service_account_info( + sa_config.service_account_credential.model_dump(), + target_audience=sa_config.audience, + ) + ) + credentials.refresh(Request()) + token = credentials.token + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=token), + ), + ) + + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key), or when + # fetch_id_token cannot determine credentials from the environment. + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: + raise AuthCredentialMissingError( + f"Failed to exchange service account for ID token: {e}" + ) from e + + def _exchange_for_access_token( + self, sa_config: ServiceAccount + ) -> AuthCredential: + """Exchanges the service account credential for an access token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the access token. + + Raises: + AuthCredentialMissingError: If scopes are missing for explicit + credentials or token exchange fails. + """ + if not sa_config.use_default_credential and not sa_config.scopes: + raise AuthCredentialMissingError( + "scopes are required when using explicit service account credentials" + " for access token exchange." + ) + + try: + if sa_config.use_default_credential: + scopes = ( + sa_config.scopes + if sa_config.scopes + else ["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials, project_id = google.auth.default( + scopes=scopes, + ) + quota_project_id = credentials.quota_project_id or project_id + else: + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes + sa_config.service_account_credential.model_dump(), + scopes=sa_config.scopes, ) quota_project_id = None credentials.refresh(Request()) - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), @@ -101,9 +186,10 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): else None, ), ) - return updated_credential - except Exception as e: + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key). + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: raise AuthCredentialMissingError( f"Failed to exchange service account token: {e}" ) from e diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 300c47e1..5f835489 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -24,6 +24,9 @@ from typing import Literal from typing import Optional from typing import Tuple from typing import Union +from urllib.parse import parse_qs +from urllib.parse import urlparse +from urllib.parse import urlunparse from fastapi.openapi.models import Operation from fastapi.openapi.models import Schema @@ -375,6 +378,14 @@ class RestApiTool(BaseTool): base_url = base_url[:-1] if base_url.endswith("/") else base_url url = f"{base_url}{self.endpoint.path.format(**path_params)}" + # Move query params embedded in the path into query_params, since httpx + # replaces (rather than merges) the URL query string when `params` is set. + parsed_url = urlparse(url) + if parsed_url.query or parsed_url.fragment: + for key, values in parse_qs(parsed_url.query).items(): + query_params.setdefault(key, values[0] if len(values) == 1 else values) + url = urlunparse(parsed_url._replace(query="", fragment="")) + # Construct body body_kwargs: Dict[str, Any] = {} request_body = self.operation.requestBody diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index 206819a9..4d564ca1 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -24,6 +24,7 @@ from google.genai import types from typing_extensions import override from ...utils.model_name_utils import is_gemini_2_or_above +from ...utils.model_name_utils import is_gemini_model_id_check_disabled from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -63,7 +64,8 @@ class VertexAiRagRetrieval(BaseRetrievalTool): llm_request: LlmRequest, ) -> None: # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = ( types.GenerateContentConfig() if not llm_request.config diff --git a/src/google/adk/tools/set_model_response_tool.py b/src/google/adk/tools/set_model_response_tool.py index 7a69ca1f..d1dc6ed5 100644 --- a/src/google/adk/tools/set_model_response_tool.py +++ b/src/google/adk/tools/set_model_response_tool.py @@ -16,19 +16,22 @@ from __future__ import annotations +import inspect from typing import Any from typing import Optional from google.genai import types -from pydantic import BaseModel +from pydantic import TypeAdapter from typing_extensions import override +from ..utils._schema_utils import get_list_inner_type +from ..utils._schema_utils import is_basemodel_schema +from ..utils._schema_utils import is_list_of_basemodel +from ..utils._schema_utils import SchemaType from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool from .tool_context import ToolContext -MODEL_JSON_RESPONSE_KEY = 'temp:__adk_model_response__' - class SetModelResponseTool(BaseTool): """Internal tool used for output schema workaround. @@ -38,14 +41,20 @@ class SetModelResponseTool(BaseTool): provide its final structured response instead of outputting text directly. """ - def __init__(self, output_schema: type[BaseModel]): + def __init__(self, output_schema: SchemaType): """Initialize the tool with the expected output schema. Args: - output_schema: The pydantic model class defining the expected output - structure. + output_schema: The output schema. Supports all types from SchemaUnion: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - list[type[BaseModel]]: A generic list type (e.g., list[MySchema]) + - list[primitive]: e.g., list[str], list[int] + - dict: Raw dict schemas + - Schema: Google's Schema type """ self.output_schema = output_schema + self._is_basemodel = is_basemodel_schema(output_schema) + self._is_list_of_basemodel = is_list_of_basemodel(output_schema) # Create a function that matches the output schema def set_model_response() -> str: @@ -57,17 +66,37 @@ class SetModelResponseTool(BaseTool): return 'Response set successfully.' # Add the schema fields as parameters to the function dynamically - import inspect - - schema_fields = output_schema.model_fields - params = [] - for field_name, field_info in schema_fields.items(): - param = inspect.Parameter( - field_name, - inspect.Parameter.KEYWORD_ONLY, - annotation=field_info.annotation, - ) - params.append(param) + if self._is_basemodel: + # For regular BaseModel, use the model's fields + schema_fields = output_schema.model_fields + params = [] + for field_name, field_info in schema_fields.items(): + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.annotation, + ) + params.append(param) + elif self._is_list_of_basemodel: + # For list[BaseModel], create a single 'items' parameter + inner_type = get_list_inner_type(output_schema) + params = [ + inspect.Parameter( + 'items', + inspect.Parameter.KEYWORD_ONLY, + annotation=list[inner_type], + ) + ] + else: + # For other schema types (list[str], dict, etc.), + # create a single parameter with the actual schema type + params = [ + inspect.Parameter( + 'response', + inspect.Parameter.KEYWORD_ONLY, + annotation=output_schema, + ) + ] # Create new signature with schema parameters new_sig = inspect.Signature(parameters=params) @@ -94,19 +123,31 @@ class SetModelResponseTool(BaseTool): @override async def run_async( - self, *, args: dict[str, Any], tool_context: ToolContext # pylint: disable=unused-argument - ) -> dict[str, Any]: - """Process the model's response and return the validated dict. + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Process the model's response and return the validated data. Args: args: The structured response data matching the output schema. tool_context: Tool execution context. Returns: - The validated response as dict. + The validated response. Type depends on the output_schema: + - dict for BaseModel + - list of dicts for list[BaseModel] + - raw value for other schema types (list[str], dict, etc.) """ - # Validate the input matches the expected schema - validated_response = self.output_schema.model_validate(args) - - # Return the validated dict directly - return validated_response.model_dump() + if self._is_basemodel: + # For regular BaseModel, validate directly + validated_response = self.output_schema.model_validate(args) + return validated_response.model_dump(exclude_none=True) + elif self._is_list_of_basemodel: + # For list[BaseModel], extract and validate the 'items' field + items = args.get('items', []) + type_adapter = TypeAdapter(self.output_schema) + validated_response = type_adapter.validate_python(items) + return [item.model_dump(exclude_none=True) for item in validated_response] + else: + # For other schema types (list[str], dict, etc.), + # return the value directly without pydantic validation + return args.get('response') diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index 34cad5c5..81ce0c45 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -12,39 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-import-not-at-top,protected-access + """Toolset for discovering, viewing, and executing agent skills.""" from __future__ import annotations +import asyncio +import json +import logging from typing import Any +from typing import Optional from typing import TYPE_CHECKING +import warnings from google.genai import types from ..agents.readonly_context import ReadonlyContext +from ..code_executors.base_code_executor import BaseCodeExecutor +from ..code_executors.code_execution_utils import CodeExecutionInput from ..features import experimental from ..features import FeatureName from ..skills import models from ..skills import prompt from .base_tool import BaseTool from .base_toolset import BaseToolset +from .function_tool import FunctionTool from .tool_context import ToolContext if TYPE_CHECKING: + from ..agents.llm_agent import ToolUnion from ..models.llm_request import LlmRequest -DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. +logger = logging.getLogger("google_adk." + __name__) + +_DEFAULT_SCRIPT_TIMEOUT = 300 +_MAX_SKILL_PAYLOAD_BYTES = 16 * 1024 * 1024 # 16 MB + +_DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. - **references/** (Optional): Additional documentation or examples for skill usage. - **assets/** (Optional): Templates, scripts or other resources used by the skill. +- **scripts/** (Optional): Executable scripts that can be run via bash. This is very important: 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `name=""` to read its full instructions before proceeding. 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. -3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`). Do NOT use other tools to access these files. +3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. +4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. """ @@ -74,8 +92,8 @@ class ListSkillsTool(BaseTool): async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext ) -> Any: - skill_frontmatters = self._toolset._list_skills() - return prompt.format_skills_as_xml(skill_frontmatters) + skills = self._toolset._list_skills() + return prompt.format_skills_as_xml(skills) @experimental(FeatureName.SKILL_TOOLSET) @@ -122,6 +140,15 @@ class LoadSkillTool(BaseTool): "error_code": "SKILL_NOT_FOUND", } + # Record skill activation in agent state for tool resolution. + agent_name = tool_context.agent_name + state_key = f"_adk_activated_skill_{agent_name}" + + activated_skills = list(tool_context.state.get(state_key, [])) + if skill_name not in activated_skills: + activated_skills.append(skill_name) + tool_context.state[state_key] = activated_skills + return { "skill_name": skill_name, "instructions": skill.instructions, @@ -131,14 +158,14 @@ class LoadSkillTool(BaseTool): @experimental(FeatureName.SKILL_TOOLSET) class LoadSkillResourceTool(BaseTool): - """Tool to load resources (references or assets) from a skill.""" + """Tool to load resources (references, assets, or scripts) from a skill.""" def __init__(self, toolset: "SkillToolset"): super().__init__( name="load_skill_resource", description=( - "Loads a resource file (from references/ or assets/) from within a" - " skill." + "Loads a resource file (from references/, assets/, or" + " scripts/) from within a skill." ), ) self._toolset = toolset @@ -158,7 +185,8 @@ class LoadSkillResourceTool(BaseTool): "type": "string", "description": ( "The relative path to the resource (e.g.," - " 'references/my_doc.md' or 'assets/template.txt')." + " 'references/my_doc.md', 'assets/template.txt'," + " or 'scripts/setup.sh')." ), }, }, @@ -197,9 +225,16 @@ class LoadSkillResourceTool(BaseTool): elif resource_path.startswith("assets/"): asset_name = resource_path[len("assets/") :] content = skill.resources.get_asset(asset_name) + elif resource_path.startswith("scripts/"): + script_name = resource_path[len("scripts/") :] + script = skill.resources.get_script(script_name) + if script is not None: + content = script.src else: return { - "error": "Path must start with 'references/' or 'assets/'.", + "error": ( + "Path must start with 'references/', 'assets/', or 'scripts/'." + ), "error_code": "INVALID_RESOURCE_PATH", } @@ -218,37 +253,470 @@ class LoadSkillResourceTool(BaseTool): } +class _SkillScriptCodeExecutor: + """A helper that materializes skill files and executes scripts.""" + + _base_executor: BaseCodeExecutor + _script_timeout: int + + def __init__(self, base_executor: BaseCodeExecutor, script_timeout: int): + self._base_executor = base_executor + self._script_timeout = script_timeout + + async def execute_script_async( + self, + invocation_context: Any, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> dict[str, Any]: + """Prepares and executes the script using the base executor.""" + code = self._build_wrapper_code(skill, script_path, script_args) + if code is None: + if "." in script_path: + ext_msg = f"'.{script_path.rsplit('.', 1)[-1]}'" + else: + ext_msg = "(no extension)" + return { + "error": ( + f"Unsupported script type {ext_msg}." + " Supported types: .py, .sh, .bash" + ), + "error_code": "UNSUPPORTED_SCRIPT_TYPE", + } + + try: + # Execute the self-contained script using the underlying executor + result = await asyncio.to_thread( + self._base_executor.execute_code, + invocation_context, + CodeExecutionInput(code=code), + ) + + stdout = result.stdout + stderr = result.stderr + + # Shell scripts serialize both streams as JSON + # through stdout; parse the envelope if present. + rc = 0 + is_shell = "." in script_path and script_path.rsplit(".", 1)[ + -1 + ].lower() in ("sh", "bash") + if is_shell and stdout: + try: + parsed = json.loads(stdout) + if isinstance(parsed, dict) and parsed.get("__shell_result__"): + stdout = parsed.get("stdout", "") + stderr = parsed.get("stderr", "") + rc = parsed.get("returncode", 0) + if rc != 0 and not stderr: + stderr = f"Exit code {rc}" + except (json.JSONDecodeError, ValueError): + pass + + status = "success" + if rc != 0: + status = "error" + elif stderr and not stdout: + status = "error" + elif stderr: + status = "warning" + + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": stdout, + "stderr": stderr, + "status": status, + } + except SystemExit as e: + if e.code in (None, 0): + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": "", + "stderr": "", + "status": "success", + } + return { + "error": ( + f"Failed to execute script '{script_path}':" + f" exited with code {e.code}" + ), + "error_code": "EXECUTION_ERROR", + } + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception( + "Error executing script '%s' from skill '%s'", + script_path, + skill.name, + ) + short_msg = str(e) + if len(short_msg) > 200: + short_msg = short_msg[:200] + "..." + return { + "error": ( + "Failed to execute script" + f" '{script_path}':\n{type(e).__name__}:" + f" {short_msg}" + ), + "error_code": "EXECUTION_ERROR", + } + + def _build_wrapper_code( + self, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> str | None: + """Builds a self-extracting Python script.""" + ext = "" + if "." in script_path: + ext = script_path.rsplit(".", 1)[-1].lower() + + if not script_path.startswith("scripts/"): + script_path = f"scripts/{script_path}" + + files_dict = {} + for ref_name in skill.resources.list_references(): + content = skill.resources.get_reference(ref_name) + if content is not None: + files_dict[f"references/{ref_name}"] = content + + for asset_name in skill.resources.list_assets(): + content = skill.resources.get_asset(asset_name) + if content is not None: + files_dict[f"assets/{asset_name}"] = content + + for scr_name in skill.resources.list_scripts(): + scr = skill.resources.get_script(scr_name) + if scr is not None and scr.src is not None: + files_dict[f"scripts/{scr_name}"] = scr.src + + total_size = sum( + len(v) if isinstance(v, (str, bytes)) else 0 + for v in files_dict.values() + ) + if total_size > _MAX_SKILL_PAYLOAD_BYTES: + logger.warning( + "Skill '%s' resources total %d bytes, exceeding" + " the recommended limit of %d bytes.", + skill.name, + total_size, + _MAX_SKILL_PAYLOAD_BYTES, + ) + + # Build the boilerplate extract string + code_lines = [ + "import os", + "import tempfile", + "import sys", + "import json as _json", + "import subprocess", + "import runpy", + f"_files = {files_dict!r}", + "def _materialize_and_run():", + " _orig_cwd = os.getcwd()", + " with tempfile.TemporaryDirectory() as td:", + " for rel_path, content in _files.items():", + " full_path = os.path.join(td, rel_path)", + " os.makedirs(os.path.dirname(full_path), exist_ok=True)", + " mode = 'wb' if isinstance(content, bytes) else 'w'", + " with open(full_path, mode) as f:", + " f.write(content)", + " os.chdir(td)", + " try:", + ] + + if ext == "py": + argv_list = [script_path] + for k, v in script_args.items(): + argv_list.extend([f"--{k}", str(v)]) + code_lines.extend([ + f" sys.argv = {argv_list!r}", + " try:", + f" runpy.run_path({script_path!r}, run_name='__main__')", + " except SystemExit as e:", + " if e.code is not None and e.code != 0:", + " raise e", + ]) + elif ext in ("sh", "bash"): + arr = ["bash", script_path] + for k, v in script_args.items(): + arr.extend([f"--{k}", str(v)]) + timeout = self._script_timeout + code_lines.extend([ + " try:", + " _r = subprocess.run(", + f" {arr!r},", + " capture_output=True, text=True,", + f" timeout={timeout!r}, cwd=td,", + " )", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _r.stdout,", + " 'stderr': _r.stderr,", + " 'returncode': _r.returncode,", + " }))", + " except subprocess.TimeoutExpired as _e:", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _e.stdout or '',", + f" 'stderr': 'Timed out after {timeout}s',", + " 'returncode': -1,", + " }))", + ]) + else: + return None + + code_lines.extend([ + " finally:", + " os.chdir(_orig_cwd)", + ]) + + code_lines.append("_materialize_and_run()") + return "\n".join(code_lines) + + +@experimental(FeatureName.SKILL_TOOLSET) +class RunSkillScriptTool(BaseTool): + """Tool to execute scripts from a skill's scripts/ directory.""" + + def __init__(self, toolset: "SkillToolset"): + super().__init__( + name="run_skill_script", + description="Executes a script from a skill's scripts/ directory.", + ) + self._toolset = toolset + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "skill_name": { + "type": "string", + "description": "The name of the skill.", + }, + "script_path": { + "type": "string", + "description": ( + "The relative path to the script (e.g.," + " 'scripts/setup.py')." + ), + }, + "args": { + "type": "object", + "description": ( + "Optional arguments to pass to the script as key-value" + " pairs." + ), + }, + }, + "required": ["skill_name", "script_path"], + }, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + skill_name = args.get("skill_name") + script_path = args.get("script_path") + script_args = args.get("args", {}) + if not isinstance(script_args, dict): + return { + "error": ( + "'args' must be a JSON object (key-value pairs)," + f" got {type(script_args).__name__}." + ), + "error_code": "INVALID_ARGS_TYPE", + } + + if not skill_name: + return { + "error": "Skill name is required.", + "error_code": "MISSING_SKILL_NAME", + } + if not script_path: + return { + "error": "Script path is required.", + "error_code": "MISSING_SCRIPT_PATH", + } + + skill = self._toolset._get_skill(skill_name) + if not skill: + return { + "error": f"Skill '{skill_name}' not found.", + "error_code": "SKILL_NOT_FOUND", + } + + script = None + if script_path.startswith("scripts/"): + script = skill.resources.get_script(script_path[len("scripts/") :]) + else: + script = skill.resources.get_script(script_path) + + if script is None: + return { + "error": f"Script '{script_path}' not found in skill '{skill_name}'.", + "error_code": "SCRIPT_NOT_FOUND", + } + + # Resolve code executor: toolset-level first, then agent fallback + code_executor = self._toolset._code_executor + if code_executor is None: + agent = tool_context._invocation_context.agent + if hasattr(agent, "code_executor"): + code_executor = agent.code_executor + if code_executor is None: + return { + "error": ( + "No code executor configured. A code executor is" + " required to run scripts." + ), + "error_code": "NO_CODE_EXECUTOR", + } + + script_executor = _SkillScriptCodeExecutor( + code_executor, self._toolset._script_timeout # pylint: disable=protected-access + ) + return await script_executor.execute_script_async( + tool_context._invocation_context, skill, script_path, script_args # pylint: disable=protected-access + ) + + @experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" - def __init__(self, skills: list[models.Skill]): + def __init__( + self, + skills: list[models.Skill], + *, + code_executor: Optional[BaseCodeExecutor] = None, + script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT, + additional_tools: list[ToolUnion] | None = None, + ): + """Initializes the SkillToolset. + + Args: + skills: List of skills to register. + code_executor: Optional code executor for script execution. + script_timeout: Timeout in seconds for shell script execution via + subprocess.run. Defaults to 300 seconds. Does not apply to Python + scripts executed via exec(). + """ super().__init__() + + # Check for duplicate skill names + seen: set[str] = set() + for skill in skills: + if skill.name in seen: + raise ValueError(f"Duplicate skill name '{skill.name}'.") + seen.add(skill.name) + self._skills = {skill.name: skill for skill in skills} + self._code_executor = code_executor + self._script_timeout = script_timeout + + self._provided_tools_by_name = {} + for tool_union in additional_tools or []: + if isinstance(tool_union, BaseTool): + self._provided_tools_by_name[tool_union.name] = tool_union + elif callable(tool_union): + ft = FunctionTool(tool_union) + self._provided_tools_by_name[ft.name] = ft + + # Initialize core skill tools self._tools = [ ListSkillsTool(self), LoadSkillTool(self), LoadSkillResourceTool(self), + RunSkillScriptTool(self), ] async def get_tools( self, readonly_context: ReadonlyContext | None = None ) -> list[BaseTool]: """Returns the list of tools in this toolset.""" - return self._tools + dynamic_tools = await self._resolve_additional_tools_from_state( + readonly_context + ) + return self._tools + dynamic_tools + + async def _resolve_additional_tools_from_state( + self, readonly_context: ReadonlyContext | None + ) -> list[BaseTool]: + """Resolves tools listed in the "adk_additional_tools" metadata of skills.""" + + if not readonly_context: + return [] + + agent_name = readonly_context.agent_name + state_key = f"_adk_activated_skill_{agent_name}" + activated_skills = readonly_context.state.get(state_key, []) + + if not activated_skills: + return [] + + additional_tool_names = set() + for skill_name in activated_skills: + skill = self._skills.get(skill_name) + if skill: + additional_tools = skill.frontmatter.metadata.get( + "adk_additional_tools" + ) + if additional_tools: + additional_tool_names.update(additional_tools) + + if not additional_tool_names: + return [] + + resolved_tools = [] + existing_tool_names = {t.name for t in self._tools} + for name in additional_tool_names: + if name in self._provided_tools_by_name: + tool = self._provided_tools_by_name[name] + if tool.name in existing_tool_names: + logger.error( + "Tool name collision: tool '%s' already exists.", tool.name + ) + continue + resolved_tools.append(tool) + existing_tool_names.add(tool.name) + + return resolved_tools def _get_skill(self, name: str) -> models.Skill | None: """Retrieves a skill by name.""" return self._skills.get(name) - def _list_skills(self) -> list[models.Frontmatter]: - """Lists the frontmatter of all available skills.""" - return [s.frontmatter for s in self._skills.values()] + def _list_skills(self) -> list[models.Skill]: + """Lists all available skills.""" + return list(self._skills.values()) async def process_llm_request( self, *, tool_context: ToolContext, llm_request: LlmRequest ) -> None: """Processes the outgoing LLM request to include available skills.""" - skill_frontmatters = self._list_skills() - skills_xml = prompt.format_skills_as_xml(skill_frontmatters) - llm_request.append_instructions([skills_xml]) + skills = self._list_skills() + skills_xml = prompt.format_skills_as_xml(skills) + instructions = [] + instructions.append(_DEFAULT_SKILL_SYSTEM_INSTRUCTION) + instructions.append(skills_xml) + llm_request.append_instructions(instructions) + + +def __getattr__(name: str) -> Any: + if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION": + warnings.warn( + "DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental. Its content " + "is internal implementation and will change in minor/patch releases " + "to tune agent performance.", + UserWarning, + stacklevel=2, + ) + return _DEFAULT_SKILL_SYSTEM_INSTRUCTION + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py index 3cdede43..24c1be60 100644 --- a/src/google/adk/tools/spanner/query_tool.py +++ b/src/google/adk/tools/spanner/query_tool.py @@ -14,9 +14,9 @@ from __future__ import annotations +import asyncio import functools import textwrap -import types from typing import Callable from google.auth.credentials import Credentials @@ -27,7 +27,7 @@ from .settings import QueryResultMode from .settings import SpannerToolSettings -def execute_sql( +async def execute_sql( project_id: str, instance_id: str, database_id: str, @@ -82,7 +82,8 @@ def execute_sql( Note: This is running with Read-Only Transaction for query that only read data. """ - return utils.execute_sql( + return await asyncio.to_thread( + utils.execute_sql, project_id, instance_id, database_id, @@ -179,15 +180,10 @@ def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]: if settings and settings.query_result_mode is QueryResultMode.DICT_LIST: - execute_sql_wrapper = types.FunctionType( - execute_sql.__code__, - execute_sql.__globals__, - execute_sql.__name__, - execute_sql.__defaults__, - execute_sql.__closure__, - ) - functools.update_wrapper(execute_sql_wrapper, execute_sql) - # Update with the new docstring + @functools.wraps(execute_sql) + async def execute_sql_wrapper(*args, **kwargs) -> dict: + return await execute_sql(*args, **kwargs) + execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING return execute_sql_wrapper diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index 03f695b8..6fb4a93f 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import json from typing import Any from typing import Dict @@ -230,7 +231,7 @@ def _generate_sql_for_ann( """ -def similarity_search( +async def similarity_search( project_id: str, instance_id: str, database_id: str, @@ -462,13 +463,16 @@ def similarity_search( # Generate embedding for the query according to the embedding options. if vertex_ai_embedding_model_name: - embedding = utils.embed_contents( - vertex_ai_embedding_model_name, - [query], - output_dimensionality, + embedding = ( + await utils.embed_contents_async( + vertex_ai_embedding_model_name, + [query], + output_dimensionality, + ) )[0] else: - embedding = _get_embedding_for_query( + embedding = await asyncio.to_thread( + _get_embedding_for_query, database, database.database_dialect, spanner_gsql_embedding_model_name, @@ -507,22 +511,20 @@ def similarity_search( else: params = {_GOOGLESQL_PARAMETER_QUERY_EMBEDDING: embedding} - with database.snapshot() as snapshot: - result_set = snapshot.execute_sql(sql, params=params) - rows = [] - result = {} - for row in result_set: - try: - # if the json serialization of the row succeeds, use it as is - json.dumps(row) - except (TypeError, ValueError, OverflowError): - row = str(row) + def _execute_sql(): + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(sql, params=params) + rows = [] + for row in result_set: + try: + # If the json serialization of the row succeeds, use it as is + json.dumps(row) + except (TypeError, ValueError, OverflowError): + row = str(row) + rows.append(row) + return {"status": "SUCCESS", "rows": rows} - rows.append(row) - - result["status"] = "SUCCESS" - result["rows"] = rows - return result + return await asyncio.to_thread(_execute_sql) except Exception as ex: return { "status": "ERROR", @@ -530,7 +532,7 @@ def similarity_search( } -def vector_store_similarity_search( +async def vector_store_similarity_search( query: str, credentials: Credentials, settings: SpannerToolSettings, @@ -605,7 +607,7 @@ def vector_store_similarity_search( settings.vector_store_settings.num_leaves_to_search ) - return similarity_search( + return await similarity_search( project_id=settings.vector_store_settings.project_id, instance_id=settings.vector_store_settings.instance_id, database_id=settings.vector_store_settings.database_id, diff --git a/src/google/adk/tools/url_context_tool.py b/src/google/adk/tools/url_context_tool.py index fcdf76da..5e923e74 100644 --- a/src/google/adk/tools/url_context_tool.py +++ b/src/google/adk/tools/url_context_tool.py @@ -21,6 +21,7 @@ from typing_extensions import override from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -46,11 +47,12 @@ class UrlContextTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError('Url context tool cannot be used in Gemini 1.x.') - elif is_gemini_2_or_above(llm_request.model): + elif is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(url_context=types.UrlContext()) ) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 91fe60e5..46104c5e 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -24,6 +24,7 @@ from typing_extensions import override from ..agents.readonly_context import ReadonlyContext from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -141,14 +142,16 @@ class VertexAiSearchTool(BaseTool): tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Vertex AI search tool cannot be used with other tools in Gemini' ' 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] # Build the search config (can be overridden by subclasses) vertex_ai_search_config = self._build_vertex_ai_search_config( diff --git a/src/google/adk/utils/_schema_utils.py b/src/google/adk/utils/_schema_utils.py new file mode 100644 index 00000000..3bb74df9 --- /dev/null +++ b/src/google/adk/utils/_schema_utils.py @@ -0,0 +1,119 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General schema utilities. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from __future__ import annotations + +import json +from typing import Any +from typing import get_args +from typing import get_origin +from typing import Optional + +from google.genai import types +from pydantic import BaseModel +from pydantic import TypeAdapter + +# Use SchemaUnion from google.genai.types to support all schema types +# that the underlying API supports. +SchemaType = types.SchemaUnion +"""Type for schema fields (e.g., output_schema, input_schema). + +Supports all schema types that the underlying Google GenAI API supports: + - type[BaseModel]: A pydantic model class (e.g., MySchema) + - GenericAlias: Generic types like list[str], list[MySchema], dict[str, int] + - dict: Raw dict schemas + - Schema: Google's Schema type +""" + + +def is_basemodel_schema(schema: SchemaType) -> bool: + """Check if the schema is a BaseModel type (not a generic alias). + + Args: + schema: The schema to check. + + Returns: + True if schema is a BaseModel class, False otherwise. + """ + return isinstance(schema, type) and issubclass(schema, BaseModel) + + +def is_list_of_basemodel(schema: SchemaType) -> bool: + """Check if the schema is a list of BaseModel type. + + Args: + schema: The schema to check. + + Returns: + True if schema is list[SomeBaseModel], False otherwise. + """ + origin = get_origin(schema) + if origin is not list: + return False + + args = get_args(schema) + if not args: + return False + + inner_type = args[0] + return isinstance(inner_type, type) and issubclass(inner_type, BaseModel) + + +def get_list_inner_type(schema: SchemaType) -> Optional[type[BaseModel]]: + """Get the inner BaseModel type from a list[BaseModel] schema. + + Args: + schema: The schema (expected to be list[SomeBaseModel]). + + Returns: + The inner BaseModel type, or None if not a list of BaseModel. + """ + if not is_list_of_basemodel(schema): + return None + + args = get_args(schema) + return args[0] + + +def validate_schema(schema: SchemaType, json_text: str) -> Any: + """Validate JSON text against a schema and return the result. + + Args: + schema: The schema to validate against. + json_text: The JSON text to validate. + + Returns: + The validated result. Type depends on the schema: + - dict for BaseModel + - list of dicts for list[BaseModel] + - raw value for other schema types (list[str], dict, etc.) + """ + if is_basemodel_schema(schema): + # For regular BaseModel, use model_validate_json + return schema.model_validate_json(json_text).model_dump(exclude_none=True) + elif is_list_of_basemodel(schema): + # For list[BaseModel], use TypeAdapter to validate + type_adapter = TypeAdapter(schema) + validated = type_adapter.validate_json(json_text) + return [item.model_dump(exclude_none=True) for item in validated] + else: + # For other schema types (list[str], dict, Schema, etc.), + # just parse JSON without pydantic validation + return json.loads(json_text) diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index b47180cd..cb68d800 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -21,6 +21,66 @@ Please do not rely on the implementation details. from __future__ import annotations from contextlib import aclosing +import inspect +from typing import Any +from typing import Callable +from typing import get_args +from typing import get_origin +from typing import Union # Re-export aclosing for backward compatibility Aclosing = aclosing + + +def _is_context_type(annotation: Any) -> bool: + """Check if an annotation is the Context type. + + This checks if the annotation is exactly Context or a type alias of Context + (e.g., ToolContext, CallbackContext). Also handles Optional[Context] types. + + Args: + annotation: The type annotation to check. + + Returns: + True if the annotation is the Context type, False otherwise. + """ + from ..agents.context import Context + + if annotation is inspect.Parameter.empty: + return False + + # Handle Optional[Context] and Union types + origin = get_origin(annotation) + if origin is Union: + args = get_args(annotation) + return any( + _is_context_type(arg) for arg in args if not isinstance(arg, type(None)) + ) + + # Check if it's exactly the Context type (or an alias like ToolContext) + return annotation is Context + + +def find_context_parameter(func: Callable[..., Any]) -> str | None: + """Find the parameter name that has a Context type annotation. + + This function inspects the signature of a callable and returns the name + of the first parameter that is annotated with Context or a type alias of + Context (e.g., ToolContext, CallbackContext). + + Args: + func: The callable to inspect. + + Returns: + The parameter name if found, None otherwise. + """ + if func is None: + return None + try: + signature = inspect.signature(func) + except (ValueError, TypeError): + return None + for name, param in signature.parameters.items(): + if _is_context_type(param.annotation): + return name + return None diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 4960b0b7..57103fb2 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -22,6 +22,19 @@ from typing import Optional from packaging.version import InvalidVersion from packaging.version import Version +from .env_utils import is_env_enabled + +_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK' + + +def is_gemini_model_id_check_disabled() -> bool: + """Returns True when Gemini model-id validation should be bypassed. + + This opt-in environment variable is intended for internal usage where model + ids may not follow the public ``gemini-*`` naming convention. + """ + return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR) + def extract_model_name(model_string: str) -> str: """Extract the actual model name from either simple or path-based format. diff --git a/src/google/adk/utils/output_schema_utils.py b/src/google/adk/utils/output_schema_utils.py index 7c494f92..bb31d098 100644 --- a/src/google/adk/utils/output_schema_utils.py +++ b/src/google/adk/utils/output_schema_utils.py @@ -30,6 +30,17 @@ from .variant_utils import GoogleLLMVariant def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool: """Returns True if output schema with tools is supported.""" + # LiteLLM handles tools + response_format compatibility per-provider: + # - Providers with native support (OpenAI, Azure): both passed directly + # - Providers without (Fireworks): auto-converted to json_tool_call + + # tool_choice enforcement + # This is strictly more reliable than the SetModelResponseTool + # prompt-based workaround. + from ..models.lite_llm import LiteLlm + + if isinstance(model, LiteLlm): + return True + model_string = model if isinstance(model, str) else model.model return ( diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 1ce0bf5e..2e373f50 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.25.1" +__version__ = "1.26.0" diff --git a/tests/unittests/a2a/converters/test_event_round_trip.py b/tests/unittests/a2a/converters/test_event_round_trip.py new file mode 100644 index 00000000..00036f6a --- /dev/null +++ b/tests/unittests/a2a/converters/test_event_round_trip.py @@ -0,0 +1,208 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Round trip tests for ADK and A2A event converters.""" + +from __future__ import annotations + +from typing import Dict +from unittest.mock import Mock + +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskStatusUpdateEvent +from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events +from google.adk.a2a.converters.from_adk_event import create_error_status_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.genai import types as genai_types + + +def test_round_trip_text_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="model", + parts=[genai_types.Part.from_text(text="Hello world!")], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + assert isinstance(a2a_event, TaskArtifactUpdateEvent) + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert restored_event.partial == original_event.partial + assert len(restored_event.content.parts) == len(original_event.content.parts) + assert ( + restored_event.content.parts[0].text + == original_event.content.parts[0].text + ) + + +def test_round_trip_error_status_event(): + original_event = Event( + invocation_id="error_inv", + author="error_agent", + branch="main", + error_message="Test Error", + ) + + a2a_event = create_error_status_event( + event=original_event, + task_id="task2", + context_id="ctx2", + ) + + assert isinstance(a2a_event, TaskStatusUpdateEvent) + + mock_context = Mock( + spec=InvocationContext, invocation_id="error_inv", branch="main" + ) + + restored_event = convert_a2a_status_update_to_event( + a2a_status_update=a2a_event, + author="error_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].text == "Test Error" + + +def test_round_trip_function_call_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="model", + parts=[ + genai_types.Part.from_function_call( + name="my_function", + args={"arg1": "value1"}, + ) + ], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].function_call.name == "my_function" + assert restored_event.content.parts[0].function_call.args == { + "arg1": "value1" + } + + +def test_round_trip_function_response_event(): + original_event = Event( + invocation_id="test_invocation", + author="test_agent", + branch="main", + content=genai_types.Content( + role="user", + parts=[ + genai_types.Part.from_function_response( + name="my_function", + response={"result": "success"}, + ) + ], + ), + partial=False, + ) + agents_artifacts: Dict[str, str] = {} + + a2a_events = convert_event_to_a2a_events( + event=original_event, + agents_artifacts=agents_artifacts, + task_id="task1", + context_id="context1", + ) + + assert len(a2a_events) == 1 + a2a_event = a2a_events[0] + + mock_context = Mock( + spec=InvocationContext, invocation_id="test_invocation", branch="main" + ) + + restored_event = convert_a2a_artifact_update_to_event( + a2a_artifact_update=a2a_event, + author="test_agent", + invocation_context=mock_context, + ) + + assert restored_event is not None + assert restored_event.author == original_event.author + assert restored_event.invocation_id == original_event.invocation_id + assert restored_event.branch == original_event.branch + assert len(restored_event.content.parts) == 1 + assert restored_event.content.parts[0].function_response.name == "my_function" + assert restored_event.content.parts[0].function_response.response == { + "result": "success" + } diff --git a/tests/unittests/a2a/converters/test_from_adk.py b/tests/unittests/a2a/converters/test_from_adk.py new file mode 100644 index 00000000..23546c58 --- /dev/null +++ b/tests/unittests/a2a/converters/test_from_adk.py @@ -0,0 +1,108 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock +from unittest.mock import patch +import uuid + +from a2a.types import Part as A2APart +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events +from google.adk.events.event import Event +from google.genai import types as genai_types +import pytest + + +class TestFromAdk: + """Test suite for from_adk functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_event = Mock(spec=Event) + self.mock_event.id = "test-event-id" + self.mock_event.invocation_id = "test-invocation-id" + self.mock_event.author = "test-author" + self.mock_event.branch = None + self.mock_event.content = None + self.mock_event.error_code = None + self.mock_event.error_message = None + self.mock_event.grounding_metadata = None + self.mock_event.citation_metadata = None + self.mock_event.custom_metadata = None + self.mock_event.usage_metadata = None + self.mock_event.actions = None + self.mock_event.partial = True + self.mock_event.long_running_tool_ids = None + + def test_convert_event_to_a2a_events_artifact_update(self): + """Test conversion of event to TaskArtifactUpdateEvent.""" + # Setup event with content + self.mock_event.content = genai_types.Content( + parts=[genai_types.Part(text="hello")], role="model" + ) + self.mock_event.author = "agent-1" + + agents_artifacts = {} + + # Mock part converter to return a standard text part + mock_a2a_part = A2APart(root=TextPart(text="hello")) + mock_a2a_part.root.metadata = {} + mock_convert_part = Mock(return_value=[mock_a2a_part]) + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + part_converter=mock_convert_part, + ) + + assert len(result) == 1 + assert isinstance(result[0], TaskArtifactUpdateEvent) + assert result[0].task_id == "task-123" + assert result[0].context_id == "context-456" + assert result[0].artifact.parts == [mock_a2a_part] + assert "agent-1" in agents_artifacts # Artifact ID should be stored + + def test_convert_event_to_a2a_events_error(self): + """Test conversion of event with error to TaskStatusUpdateEvent.""" + self.mock_event.error_code = "ERR001" + self.mock_event.error_message = "Something went wrong" + + agents_artifacts = {} + + result = convert_event_to_a2a_events( + self.mock_event, + agents_artifacts, + task_id="task-123", + context_id="context-456", + ) + + # Should not return any artifact events + assert len(result) == 0 + + def test_convert_event_to_a2a_events_none_event(self): + """Test convert_event_to_a2a_events with None event.""" + with pytest.raises(ValueError, match="Event cannot be None"): + convert_event_to_a2a_events(None, {}) + + def test_convert_event_to_a2a_events_none_artifacts(self): + """Test convert_event_to_a2a_events with None agents_artifacts.""" + with pytest.raises(ValueError, match="Agents artifacts cannot be None"): + convert_event_to_a2a_events(self.mock_event, None) diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 647caa5b..446e1185 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json from unittest.mock import Mock from unittest.mock import patch @@ -54,7 +55,9 @@ class TestConvertA2aPartToGenaiPart: a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithUri( - uri="gs://bucket/file.txt", mime_type="text/plain" + uri="gs://bucket/file.txt", + mime_type="text/plain", + name="my_file.txt", ) ) ) @@ -68,19 +71,21 @@ class TestConvertA2aPartToGenaiPart: assert result.file_data is not None assert result.file_data.file_uri == "gs://bucket/file.txt" assert result.file_data.mime_type == "text/plain" + assert result.file_data.display_name == "my_file.txt" def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" # Arrange test_bytes = b"test file content" # A2A FileWithBytes expects base64-encoded string - import base64 base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=base64_encoded, mime_type="text/plain" + bytes=base64_encoded, + mime_type="text/plain", + name="my_bytes.txt", ) ) ) @@ -95,6 +100,7 @@ class TestConvertA2aPartToGenaiPart: # The converter decodes base64 back to original bytes assert result.inline_data.data == test_bytes assert result.inline_data.mime_type == "text/plain" + assert result.inline_data.display_name == "my_bytes.txt" def test_convert_data_part_function_call(self): """Test conversion of A2A DataPart with function call metadata.""" @@ -289,14 +295,16 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root, a2a_types.TextPart) assert result.root.text == "Hello, world!" assert result.root.metadata is not None - assert result.root.metadata[_get_adk_metadata_key("thought")] == True + assert result.root.metadata[_get_adk_metadata_key("thought")] def test_convert_file_data_part(self): """Test conversion of GenAI file_data Part to A2A Part.""" # Arrange genai_part = genai_types.Part( file_data=genai_types.FileData( - file_uri="gs://bucket/file.txt", mime_type="text/plain" + file_uri="gs://bucket/file.txt", + mime_type="text/plain", + display_name="my_file.txt", ) ) @@ -310,13 +318,18 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root.file, a2a_types.FileWithUri) assert result.root.file.uri == "gs://bucket/file.txt" assert result.root.file.mime_type == "text/plain" + assert result.root.file.name == "my_file.txt" def test_convert_inline_data_part(self): """Test conversion of GenAI inline_data Part to A2A Part.""" # Arrange test_bytes = b"test file content" genai_part = genai_types.Part( - inline_data=genai_types.Blob(data=test_bytes, mime_type="text/plain") + inline_data=genai_types.Blob( + data=test_bytes, + mime_type="text/plain", + display_name="my_bytes.txt", + ) ) # Act @@ -328,11 +341,11 @@ class TestConvertGenaiPartToA2aPart: assert isinstance(result.root, a2a_types.FilePart) assert isinstance(result.root.file, a2a_types.FileWithBytes) # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility - import base64 expected_base64 = base64.b64encode(test_bytes).decode("utf-8") assert result.root.file.bytes == expected_base64 assert result.root.file.mime_type == "text/plain" + assert result.root.file.name == "my_bytes.txt" def test_convert_inline_data_part_with_video_metadata(self): """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" @@ -516,6 +529,22 @@ class TestRoundTripConversions: assert isinstance(result_a2a_part.root, a2a_types.TextPart) assert result_a2a_part.root.text == original_text + def test_text_part_with_thought_round_trip(self): + """Test round-trip conversion for text parts with thought.""" + # Arrange + original_text = "Thinking..." + genai_part = genai_types.Part(text=original_text, thought=True) + + # Act + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.text == original_text + assert result_genai_part.thought + def test_file_uri_round_trip(self): """Test round-trip conversion for file parts with URI.""" # Arrange @@ -825,3 +854,204 @@ class TestNewConstants: assert result.executable_code is not None assert result.executable_code.language == genai_types.Language.PYTHON assert result.executable_code.code == "print('Hello, World!')" + + +class TestThoughtSignaturePreservation: + """Tests for thought_signature preservation in function call conversions.""" + + def test_genai_function_call_with_thought_signature_to_a2a(self): + """Test that thought_signature is preserved when converting GenAI to A2A.""" + # Arrange + function_call = genai_types.FunctionCall( + id="fc_gemini3", + name="my_tool", + args={"document": "test content"}, + ) + genai_part = genai_types.Part( + function_call=function_call, + thought_signature=b"gemini3_signature_bytes", + ) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result.root, a2a_types.DataPart) + assert ( + result.root.metadata[ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ] + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + ) + # thought_signature should be base64 encoded in metadata + thought_sig_key = _get_adk_metadata_key("thought_signature") + assert thought_sig_key in result.root.metadata + assert ( + base64.b64decode(result.root.metadata[thought_sig_key]) + == b"gemini3_signature_bytes" + ) + + def test_genai_function_call_without_thought_signature_to_a2a(self): + """Test function call without thought_signature doesn't add metadata key.""" + # Arrange + function_call = genai_types.FunctionCall( + id="fc_regular", + name="regular_tool", + args={}, + ) + genai_part = genai_types.Part(function_call=function_call) + + # Act + result = convert_genai_part_to_a2a_part(genai_part) + + # Assert + assert result is not None + assert isinstance(result.root, a2a_types.DataPart) + # thought_signature key should not be present + thought_sig_key = _get_adk_metadata_key("thought_signature") + assert thought_sig_key not in result.root.metadata + + def test_a2a_function_call_with_thought_signature_to_genai(self): + """Test that thought_signature is restored when converting A2A to GenAI.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_gemini3", + "name": "my_tool", + "args": {"document": "test content"}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key("thought_signature"): ( + base64.b64encode(b"restored_signature").decode("utf-8") + ), + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "my_tool" + # thought_signature should be decoded back to bytes + assert result.thought_signature == b"restored_signature" + + def test_a2a_function_call_without_thought_signature_to_genai(self): + """Test function call without thought_signature returns None for it.""" + # Arrange + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_regular", + "name": "regular_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "regular_tool" + # thought_signature should be None + assert result.thought_signature is None + + def test_function_call_with_thought_signature_round_trip(self): + """Test thought_signature is preserved in GenAI -> A2A -> GenAI round trip.""" + # Arrange + original_signature = b"round_trip_signature_test" + function_call = genai_types.FunctionCall( + id="fc_round_trip", + name="round_trip_tool", + args={"key": "value"}, + ) + original_part = genai_types.Part( + function_call=function_call, + thought_signature=original_signature, + ) + + # Act - Convert GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(original_part) + restored_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert restored_part is not None + assert restored_part.function_call is not None + assert restored_part.function_call.name == "round_trip_tool" + assert restored_part.thought_signature == original_signature + + def test_a2a_function_call_with_bytes_thought_signature_to_genai(self): + """Test that bytes thought_signature is used directly without decoding.""" + # Arrange - metadata contains raw bytes (not base64 encoded) + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_bytes", + "name": "bytes_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key( + "thought_signature" + ): b"raw_bytes_signature", + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + # bytes should be used directly + assert result.thought_signature == b"raw_bytes_signature" + + def test_a2a_function_call_with_invalid_base64_thought_signature(self): + """Test that invalid base64 thought_signature logs warning and returns None.""" + # Arrange - metadata contains invalid base64 string + a2a_part = a2a_types.Part( + root=a2a_types.DataPart( + data={ + "id": "fc_invalid", + "name": "invalid_sig_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key( + "thought_signature" + ): "not_valid_base64!!!", + }, + ) + ) + + # Act + result = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result is not None + assert result.function_call is not None + assert result.function_call.name == "invalid_sig_tool" + # thought_signature should be None due to decode failure + assert result.thought_signature is None diff --git a/tests/unittests/a2a/converters/test_to_adk.py b/tests/unittests/a2a/converters/test_to_adk.py new file mode 100644 index 00000000..90651956 --- /dev/null +++ b/tests/unittests/a2a/converters/test_to_adk.py @@ -0,0 +1,195 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import Mock + +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Part as A2APart +from a2a.types import Task +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY +from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_status_update_to_event +from google.adk.a2a.converters.to_adk_event import convert_a2a_task_to_event +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.agents.invocation_context import InvocationContext +from google.genai import types as genai_types +import pytest + + +class TestToAdk: + """Test suite for to_adk functions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_context = Mock(spec=InvocationContext) + self.mock_context.invocation_id = "test-invocation" + self.mock_context.branch = "test-branch" + + def test_convert_a2a_message_to_event_success(self): + """Test successful conversion of A2A message to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + message = Message(message_id="msg-1", role="user", parts=[a2a_part]) + + mock_genai_part = genai_types.Part.from_text(text="hello") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_message_to_event( + message, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert event.branch == "test-branch" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_message_to_event_none(self): + """Test convert_a2a_message_to_event with None.""" + with pytest.raises(ValueError, match="A2A message cannot be None"): + convert_a2a_message_to_event(None) + + def test_convert_a2a_task_to_event_success(self): + """Test successful conversion of A2A task to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + task = Task( + id="task-1", + status=TaskStatus( + state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + ), + context_id="context-1", + history=[Message(message_id="msg-1", role="agent", parts=[a2a_part])], + artifacts=[ + Artifact( + artifact_id="art-1", artifact_type="message", parts=[a2a_part] + ) + ], + ) + + mock_genai_part = genai_types.Part.from_text(text="task artifact text") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_task_to_event( + task, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_task_to_event_none(self): + """Test convert_a2a_task_to_event with None.""" + with pytest.raises(ValueError, match="A2A task cannot be None"): + convert_a2a_task_to_event(None) + + def test_convert_a2a_status_update_to_event_success(self): + """Test successful conversion of A2A status update to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = { + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True + } + update = TaskStatusUpdateEvent( + task_id="task-1", + status=TaskStatus( + state=TaskState.input_required, + timestamp="now", + message=Message( + message_id="m1", + role="agent", + parts=[a2a_part], + ), + ), + context_id="context-1", + final=False, + ) + + mock_genai_part = genai_types.Part( + function_call=genai_types.FunctionCall( + name="status update text", args={"arg": "value"}, id="call-1" + ) + ) + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_status_update_to_event( + update, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_status_update_to_event_none(self): + """Test convert_a2a_status_update_to_event with None.""" + with pytest.raises(ValueError, match="A2A status update cannot be None"): + convert_a2a_status_update_to_event(None) + + def test_convert_a2a_artifact_update_to_event_success(self): + """Test successful conversion of A2A artifact update to Event.""" + a2a_part = Mock(spec=A2APart) + a2a_part.root = Mock() + a2a_part.root.metadata = {} + update = TaskArtifactUpdateEvent( + task_id="task-1", + artifact=Artifact( + artifact_id="art-1", artifact_type="message", parts=[a2a_part] + ), + append=True, + context_id="context-1", + last_chunk=False, + ) + + mock_genai_part = genai_types.Part.from_text(text="artifact chunk text") + mock_part_converter = Mock(return_value=[mock_genai_part]) + + event = convert_a2a_artifact_update_to_event( + update, + author="test-author", + invocation_context=self.mock_context, + part_converter=mock_part_converter, + ) + + assert event.author == "test-author" + assert event.invocation_id == "test-invocation" + assert event.partial is True + assert len(event.content.parts) == 1 + assert event.content.parts[0] == mock_genai_part + + def test_convert_a2a_artifact_update_to_event_none(self): + """Test convert_a2a_artifact_update_to_event with None.""" + with pytest.raises(ValueError, match="A2A artifact update cannot be None"): + convert_a2a_artifact_update_to_event(None) diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 40736d95..787b260f 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -17,13 +17,17 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.server.agent_execution.context import RequestContext +from a2a.server.events import Event as A2AEvent from a2a.server.events.event_queue import EventQueue from a2a.types import Message +from a2a.types import Part +from a2a.types import Role from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor from google.adk.events.event import Event from google.adk.runners import RunConfig from google.adk.runners import Runner @@ -959,3 +963,111 @@ class TestA2aAgentExecutor: assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" + + @pytest.mark.asyncio + async def test_after_event_interceptors_receive_correct_arguments_and_can_modify_event( + self, + ): + """Test that after_event interceptors receive correct arguments and can modify the event.""" + # Create distinct mock objects for ADK event and A2A event + adk_event = Mock(spec=Event, name="ADK_EVENT") + a2a_event = Mock(spec=A2AEvent, name="A2A_EVENT") + modified_a2a_event = Mock(spec=A2AEvent, name="MODIFIED_A2A_EVENT") + + # Mocks for conversion + self.mock_event_converter.return_value = [a2a_event] + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Setup Interceptor + mock_interceptor = Mock(spec=ExecuteInterceptor) + + # after_event should return the modified event + async def side_effect_after_event(context, event, original_event): + return modified_a2a_event + + mock_interceptor.after_event = AsyncMock( + side_effect=side_effect_after_event + ) + mock_interceptor.before_agent = None + mock_interceptor.after_agent = None + + # Update config with interceptor + self.mock_config.execute_interceptors = [mock_interceptor] + # Re-initialize executor with updated config - but we can just update + # the config in place if it's mutable + # The executor uses self._config which is this mock_config basically. + # self.executor was initialized in setup_method with self.mock_config. + + # However, A2aAgentExecutor constructor does: self._config = config or ... + # So updating self.mock_config properties should work as + # it is the same object reference. + + # Mock context + self.mock_context.task_id = "task-1" + self.mock_context.context_id = "ctx-1" + # Ensure current_task is set so we skip the initial + # submitted event creation logic + # which might complicate this specific test if we don't care about it. + self.mock_context.current_task = Mock() + + # Mock runner.run_async to yield our ADK event + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([adk_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Configure session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + self.mock_runner._new_invocation_context.return_value = Mock() + + # We patch TaskResultAggregator just to avoid other errors and simplfy + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_agg_class: + mock_agg = Mock() + mock_agg.task_status_message = None + mock_agg.task_state = TaskState.working + mock_agg_class.return_value = mock_agg + + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify aggregator processed the MODIFIED event + mock_agg.process_event.assert_called_with(modified_a2a_event) + + # Verification of arguments passed to interceptor + assert mock_interceptor.after_event.called + call_args = mock_interceptor.after_event.call_args + # call_args.args should be (executor_context, a2a_event, adk_event) + + passed_a2a_event = call_args.args[1] + passed_adk_event = call_args.args[2] + + # These assertions verify the bug fix + assert ( + passed_a2a_event is a2a_event + ), f"Expected A2A event to be passed as 2nd arg, but got {passed_a2a_event}" + assert ( + passed_adk_event is adk_event + ), f"Expected ADK event to be passed as 3rd arg, but got {passed_adk_event}" + + # Verify that the modified event was enqueued + # We check if enqueue_event was called with modified_a2a_event + # Note: enqueue_event is called multiple times. + + enqueued_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + ] + assert ( + modified_a2a_event in enqueued_events + ), "The modified event should have been enqueued" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py new file mode 100644 index 00000000..9acae2dc --- /dev/null +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -0,0 +1,808 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.executor.a2a_agent_executor_impl import _A2aAgentExecutor as A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor_impl import A2aAgentExecutorConfig +from google.adk.a2a.executor.config import ExecuteInterceptor +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content +import pytest + + +class TestA2aAgentExecutor: + """Test suite for A2aAgentExecutor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_runner = Mock(spec=Runner) + self.mock_runner.app_name = "test-app" + self.mock_runner.session_service = Mock() + self.mock_runner._new_invocation_context = Mock() + self.mock_runner.run_async = AsyncMock() + + self.mock_a2a_part_converter = Mock() + self.mock_gen_ai_part_converter = Mock() + self.mock_request_converter = Mock() + self.mock_event_converter = Mock() + self.mock_config = A2aAgentExecutorConfig( + a2a_part_converter=self.mock_a2a_part_converter, + gen_ai_part_converter=self.mock_gen_ai_part_converter, + request_converter=self.mock_request_converter, + adk_event_converter=self.mock_event_converter, + ) + self.executor = A2aAgentExecutor( + runner=self.mock_runner, config=self.mock_config + ) + + self.mock_context = Mock(spec=RequestContext) + self.mock_context.message = Mock(spec=Message) + self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.current_task = None + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_event_queue = Mock(spec=EventQueue) + + self.expected_metadata = { + _get_adk_metadata_key("app_name"): "test-app", + _get_adk_metadata_key("user_id"): "test-user", + _get_adk_metadata_key("session_id"): "test-session", + _get_adk_metadata_key("agent_executor_v2"): True, + } + + async def _create_async_generator(self, items): + """Helper to create async generator from items.""" + for item in items: + yield item + + @pytest.mark.asyncio + async def test_execute_success_new_task(self): + """Test successful execution of a new task.""" + # Setup + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a working status update + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify request converter was called with proper arguments + self.mock_request_converter.assert_called_once_with( + self.mock_context, self.mock_a2a_part_converter + ) + + # Verify event converter was called with proper arguments + self.mock_event_converter.assert_called_once_with( + mock_event, + {}, # agents_artifact (initially empty) + self.mock_context.task_id, + self.mock_context.context_id, + self.mock_gen_ai_part_converter, + ) + + # Verify task submitted event was enqueued + # call 0: submitted + # call 1: working (from converter) + # call 2: completed (final) + assert self.mock_event_queue.enqueue_event.call_count >= 3 + + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ + 0 + ] + assert isinstance(submitted_event, Task) + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.metadata == self.expected_metadata + + # Verify working event was enqueued + enqueued_working_event = self.mock_event_queue.enqueue_event.call_args_list[ + 1 + ][0][0] + assert isinstance(enqueued_working_event, TaskStatusUpdateEvent) + assert enqueued_working_event.status.state == TaskState.working + assert enqueued_working_event.metadata == self.expected_metadata + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[2][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event was enqueued + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_execute_no_message_error(self): + """Test execution fails when no message is provided.""" + self.mock_context.message = None + + with pytest.raises(ValueError, match="A2A request must have a message"): + await self.executor.execute(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_existing_task(self): + """Test execution with existing task (no submitted event).""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "existing-task-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with proper async generator + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="existing-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify submitted event was NOT enqueued for existing task + # So we check first event is working state + first_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + assert first_event.metadata == self.expected_metadata + + # Verify manual working event is FIRST + assert isinstance(first_event, TaskStatusUpdateEvent) + assert first_event.status.state == TaskState.working + + # Verify converted event was enqueued + converted_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][ + 0 + ] + assert converted_event == working_event + assert converted_event.metadata == self.expected_metadata + + # Verify final event + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.final == True + assert final_event.status.state == TaskState.completed + assert final_event.metadata == self.expected_metadata + + def test_constructor_with_callable_runner(self): + """Test constructor with callable runner.""" + callable_runner = Mock() + executor = A2aAgentExecutor(runner=callable_runner, config=self.mock_config) + + assert executor._runner == callable_runner + assert executor._config == self.mock_config + + @pytest.mark.asyncio + async def test_resolve_runner_direct_instance(self): + """Test _resolve_runner with direct Runner instance.""" + # Setup - already using direct runner instance in setup_method + runner = await self.executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_sync_callable(self): + """Test _resolve_runner with sync callable that returns Runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_callable(self): + """Test _resolve_runner with async callable that returns Runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_invalid_type(self): + """Test _resolve_runner with invalid runner type.""" + executor = A2aAgentExecutor(runner="invalid", config=self.mock_config) + + with pytest.raises( + TypeError, match="Runner must be a Runner instance or a callable" + ): + await executor._resolve_runner() + + @pytest.mark.asyncio + async def test_handle_request_integration(self): + """Test the complete request handling flow.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Setup detailed mocks + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [ + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ), + ] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return events + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Initialize executor context attributes as they would be in execute() + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify events enqueued + # Should check for working events + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + # Each ADK event generates 1 working event in this mock setup + assert len(working_events) >= len(mock_events) + + # Verify final event is completed + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.completed + + @pytest.mark.asyncio + async def test_cancel_with_task_id(self): + """Test cancellation with a task ID.""" + self.mock_context.task_id = "test-task-id" + + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_with_exception_handling(self): + """Test execution with exception handling.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.current_task = None + + self.mock_request_converter.side_effect = Exception("Test error") + + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True + assert "Test error" in failure_event.status.message.parts[0].root.text + + @pytest.mark.asyncio + async def test_handle_request_with_non_working_state(self): + """Test handle request when a non-working state is encountered.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = "ERROR" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter to return a FAILED event + failed_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.failed, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [failed_event] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Initialize executor context attributes + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + # Execute + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + # Verify final event is FAILED, not COMPLETED + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + # The last event should be the synthesized final event + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + + @pytest.mark.asyncio + async def test_handle_request_with_error_message(self): + """Test handle request when an error message is present without an error code.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Mock agent run event with only error_message + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + mock_event.error_code = None + mock_event.error_message = "Test Error Message" + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + run_request = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + executor_context = Mock() + executor_context.app_name = "test-app" + executor_context.user_id = "test-user" + executor_context.session_id = "test-session" + + await self.executor._handle_request( + self.mock_context, + executor_context, + self.mock_event_queue, + self.mock_runner, + run_request, + ) + + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "final") and call[0][0].final == True + ] + assert len(final_events) >= 1 + final_event = final_events[-1] + assert final_event.status.state == TaskState.failed + assert final_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_interceptors(self): + """Test interceptors execution.""" + # Setup interceptors + before_interceptor = AsyncMock(return_value=self.mock_context) + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: a2a + after_agent_interceptor = AsyncMock() + after_agent_interceptor.side_effect = lambda ctx, event: event + + interceptor = ExecuteInterceptor( + before_agent=before_interceptor, + after_event=after_event_interceptor, + after_agent=after_agent_interceptor, + ) + + self.mock_config.execute_interceptors = [interceptor] + + # Mock run + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Mock event converter + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + # Pre-setup request converter + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Mock session + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify interceptors called + before_interceptor.assert_called_once_with(self.mock_context) + # after_event called for each event + assert after_event_interceptor.call_count >= 1 + after_agent_interceptor.assert_called_once() + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.handle_user_input") + async def test_execute_missing_user_input(self, mock_handle_user_input): + """Test when handle_user_input returns a missing user input event.""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up handle_user_input to return an event + missing_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_handle_user_input.return_value = missing_event + + self.mock_runner.session_service.get_session = AsyncMock( + return_value=Mock(id="test-session") + ) + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify that the missing_event was enqueued + self.mock_event_queue.enqueue_event.assert_called_once_with(missing_event) + + # Verify that metadata was injected + enqueued_event = self.mock_event_queue.enqueue_event.call_args[0][0] + assert enqueued_event.metadata == self.expected_metadata + + @pytest.mark.asyncio + async def test_resolve_session_creates_new_session(self): + """Test that _resolve_session creates a new session if it doesn't exist.""" + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) + + new_session = Mock() + new_session.id = "new-session-id" + self.mock_runner.session_service.create_session = AsyncMock( + return_value=new_session + ) + + run_request = AgentRunRequest( + user_id="test-user", + session_id="old-session-id", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + await self.executor._resolve_session(run_request, self.mock_runner) + + self.mock_runner.session_service.get_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + session_id="old-session-id", + ) + self.mock_runner.session_service.create_session.assert_called_once_with( + app_name=self.mock_runner.app_name, + user_id="test-user", + state={}, + session_id="old-session-id", + ) + assert run_request.session_id == "new-session-id" + + @pytest.mark.asyncio + async def test_execute_enqueue_error_in_exception_handler(self): + """Test failure event publishing handles exception during enqueue.""" + self.mock_context.task_id = "test-task-id" + self.mock_request_converter.side_effect = Exception("Test error") + + # Make enqueue_event raise an exception + self.mock_event_queue.enqueue_event.side_effect = Exception("Enqueue error") + + # This should not raise an exception itself + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify enqueue_event was called to publish the error event + assert self.mock_event_queue.enqueue_event.call_count == 1 + + @pytest.mark.asyncio + @patch("google.adk.a2a.executor.a2a_agent_executor_impl.LongRunningFunctions") + async def test_long_running_functions_final_event(self, mock_lrf_class): + """Test _handle_request when there are long running function calls.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + # Set up mock LongRunningFunctions + mock_lrf = mock_lrf_class.return_value + mock_lrf.process_event.side_effect = lambda e: e + mock_lrf.has_long_running_function_calls.return_value = True + + lrf_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.input_required, timestamp="now"), + context_id="test-context-id", + final=False, + ) + mock_lrf.create_long_running_function_call_event.return_value = lrf_event + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + self.mock_event_converter.return_value = [] + + self.executor._invocation_metadata = {} + self.executor._executor_context = Mock() + + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Verify final event is the long running function call event + final_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if call[0][0] == lrf_event + ] + assert len(final_events) >= 1 + + @pytest.mark.asyncio + async def test_after_event_interceptor_returns_none(self): + """Test after_event_interceptor returning None drops the event.""" + # Setup interceptor returning None + after_event_interceptor = AsyncMock() + after_event_interceptor.side_effect = lambda ctx, a2a, adk: None + + interceptor = ExecuteInterceptor( + after_event=after_event_interceptor, + ) + self.mock_config.execute_interceptors = [interceptor] + + self.mock_context.task_id = "test-task-id" + self.mock_context.context_id = "test-context-id" + + self.mock_request_converter.return_value = AgentRunRequest( + user_id="test-user", + session_id="test-session", + new_message=Mock(spec=Content), + run_config=Mock(spec=RunConfig), + ) + + mock_event = Event( + invocation_id="invocation-id", + author="test-agent", + branch="main", + partial=False, + ) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + # Event converter returns one event + working_event = TaskStatusUpdateEvent( + task_id="test-task-id", + status=TaskStatus(state=TaskState.working, timestamp="now"), + context_id="test-context-id", + final=False, + ) + self.mock_event_converter.return_value = [working_event] + + self.executor._executor_context = Mock() + await self.executor._handle_request( + self.mock_context, + self.executor._executor_context, + self.mock_event_queue, + self.mock_runner, + self.mock_request_converter.return_value, + ) + + # Since the interceptor returns None, working_event should NOT be enqueued + # The only event enqueued by _handle_request should be the final event + assert self.mock_event_queue.enqueue_event.call_count == 1 + final_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] + assert final_event.status.state == TaskState.completed diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9ff6d01..21c96d7e 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCard from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor @@ -77,7 +79,9 @@ class TestToA2A: mock_task_store_class.assert_called_once() mock_agent_executor_class.assert_called_once() mock_request_handler_class.assert_called_once_with( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=mock_task_store, ) mock_card_builder_class.assert_called_once_with( agent=self.mock_agent, rpc_url="http://localhost:8000/" @@ -122,7 +126,9 @@ class TestToA2A: mock_task_store_class.assert_called_once() mock_agent_executor_class.assert_called_once_with(runner=custom_runner) mock_request_handler_class.assert_called_once_with( - agent_executor=mock_agent_executor, task_store=mock_task_store + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=mock_task_store, ) mock_card_builder_class.assert_called_once_with( agent=self.mock_agent, rpc_url="http://localhost:8000/" @@ -131,6 +137,42 @@ class TestToA2A: "startup", mock_app.add_event_handler.call_args[0][1] ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_passes_custom_push_config_store( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a forwards a custom push config store.""" + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + mock_task_store = Mock(spec=InMemoryTaskStore) + mock_task_store_class.return_value = mock_task_store + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + mock_request_handler = Mock(spec=DefaultRequestHandler) + mock_request_handler_class.return_value = mock_request_handler + mock_card_builder = Mock(spec=AgentCardBuilder) + mock_card_builder_class.return_value = mock_card_builder + + custom_push_store = InMemoryPushNotificationConfigStore() + + result = to_a2a(self.mock_agent, push_config_store=custom_push_store) + + assert result == mock_app + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, + push_config_store=custom_push_store, + task_store=mock_task_store, + ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 8a3623cb..df543db9 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -451,6 +451,27 @@ class TestCanonicalTools: assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + async def test_multiple_tools_resolution(self): + """Test that multiple tools are resolved correctly.""" + + def _tool_1(): + pass + + def _tool_2(): + pass + + agent = LlmAgent( + name='test_agent', + model='gemini-pro', + tools=[_tool_1, _tool_2], + ) + ctx = await _create_readonly_context(agent) + tools = await agent.canonical_tools(ctx) + + assert len(tools) == 2 + assert tools[0].name == '_tool_1' + assert tools[1].name == '_tool_2' + # Tests for multi-provider model support via string model names @pytest.mark.parametrize( diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 7643125d..0f1ce896 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -21,7 +21,6 @@ from unittest.mock import Mock from unittest.mock import patch from a2a.client.client import ClientConfig -from a2a.client.client import Consumer from a2a.client.client_factory import ClientFactory from a2a.client.middleware import ClientCallContext from a2a.types import AgentCapabilities @@ -29,13 +28,16 @@ from a2a.types import AgentCard from a2a.types import AgentSkill from a2a.types import Artifact from a2a.types import Message as A2AMessage -from a2a.types import SendMessageSuccessResponse from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus as A2ATaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.adk.a2a.agent import ParametersConfig +from google.adk.a2a.agent import RequestInterceptor +from google.adk.a2a.agent.utils import execute_after_request_interceptors +from google.adk.a2a.agent.utils import execute_before_request_interceptors from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX from google.adk.agents.remote_a2a_agent import AgentCardResolutionError @@ -1683,6 +1685,236 @@ class TestRemoteA2aAgentMessageHandlingFromFactory: assert result is None +class TestRemoteA2aAgentMessageHandlingV2: + """Test _handle_a2a_response_impl functionality.""" + + def setup_method(self): + """Setup test fixtures.""" + from google.adk.a2a.agent.config import A2aRemoteAgentConfig + + self.agent_card = create_test_agent_card() + self.mock_config = Mock(spec=A2aRemoteAgentConfig) + self.mock_config.a2a_part_converter = Mock() + self.mock_config.a2a_task_converter = Mock() + self.mock_config.a2a_status_update_converter = Mock() + self.mock_config.a2a_artifact_update_converter = Mock() + self.mock_config.a2a_message_converter = Mock() + + self.agent = RemoteA2aAgent( + name="test_agent", + agent_card=self.agent_card, + config=self.mock_config, + ) + + # Mock session and context + self.mock_session = Mock(spec=Session) + self.mock_session.id = "session-123" + self.mock_session.events = [] + + self.mock_context = Mock(spec=InvocationContext) + self.mock_context.session = self.mock_session + self.mock_context.invocation_id = "invocation-123" + self.mock_context.branch = "main" + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_message(self): + """Test _handle_a2a_response_impl with A2AMessage.""" + mock_a2a_message = Mock(spec=A2AMessage) + mock_a2a_message.metadata = {} + mock_a2a_message.metadata = {} + mock_a2a_message.context_id = "context-123" + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_message_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + mock_a2a_message, self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_message_converter.assert_called_once_with( + mock_a2a_message, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_and_no_update(self): + """Test _handle_a2a_response_impl with Task and no update.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = "context-123" + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_task_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, None), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_task_converter.assert_called_once_with( + mock_a2a_task, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_status_update(self): + """Test _handle_a2a_response_impl with TaskStatusUpdateEvent.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = None + + mock_update = Mock(spec=TaskStatusUpdateEvent) + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_status_update_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_status_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" not in result.custom_metadata + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_with_task_artifact_update(self): + """Test _handle_a2a_response_impl with TaskArtifactUpdateEvent.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + mock_a2a_task.context_id = "context-123" + + mock_update = Mock(spec=TaskArtifactUpdateEvent) + + mock_event = Event( + author=self.agent.name, + invocation_id=self.mock_context.invocation_id, + branch=self.mock_context.branch, + ) + self.mock_config.a2a_artifact_update_converter.return_value = mock_event + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result == mock_event + self.mock_config.a2a_artifact_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + assert result.custom_metadata is not None + assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata + assert result.custom_metadata[A2A_METADATA_PREFIX + "task_id"] == "task-123" + assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata + assert ( + result.custom_metadata[A2A_METADATA_PREFIX + "context_id"] + == "context-123" + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_update_converter_returns_none(self): + """Test _handle_a2a_response_impl when converter returns None.""" + mock_a2a_task = Mock(spec=A2ATask) + mock_a2a_task.id = "task-123" + + mock_update = Mock(spec=TaskArtifactUpdateEvent) + + self.mock_config.a2a_artifact_update_converter.return_value = None + + result = await self.agent._handle_a2a_response_v2( + (mock_a2a_task, mock_update), self.mock_context + ) + + assert result is None + self.mock_config.a2a_artifact_update_converter.assert_called_once_with( + mock_update, + self.agent.name, + self.mock_context, + self.mock_config.a2a_part_converter, + ) + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_unknown_response_type(self): + """Test _handle_a2a_response_impl with unknown response type.""" + unknown_response = object() + + result = await self.agent._handle_a2a_response_v2( + unknown_response, self.mock_context + ) + + assert result is not None + assert result.author == self.agent.name + assert result.error_message == "Unknown A2A response type" + assert result.invocation_id == self.mock_context.invocation_id + assert result.branch == self.mock_context.branch + + @pytest.mark.asyncio + async def test_handle_a2a_response_impl_handles_client_error(self): + """Test _handle_a2a_response_impl catches A2AClientError.""" + mock_a2a_message = Mock(spec=A2AMessage) + mock_a2a_message.metadata = {} + mock_a2a_message.metadata = {} + + from google.adk.agents.remote_a2a_agent import A2AClientError + + self.mock_config.a2a_message_converter.side_effect = A2AClientError( + "Test client error" + ) + + result = await self.agent._handle_a2a_response_v2( + mock_a2a_message, self.mock_context + ) + + assert result is not None + assert result.author == self.agent.name + assert ( + "Failed to process A2A response: Test client error" + in result.error_message + ) + assert result.invocation_id == self.mock_context.invocation_id + assert result.branch == self.mock_context.branch + + class TestRemoteA2aAgentExecution: """Test agent execution functionality.""" @@ -1771,7 +2003,7 @@ class TestRemoteA2aAgentExecution: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -1910,7 +2142,7 @@ class TestRemoteA2aAgentExecution: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2047,7 +2279,7 @@ class TestRemoteA2aAgentExecutionFromFactory: # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock() + mock_response = Mock(metadata={}) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2293,6 +2525,7 @@ class TestRemoteA2aAgentIntegration: with patch.object(agent, "_a2a_client") as mock_a2a_client: mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) mock_a2a_message.context_id = "context-123" + mock_a2a_message.metadata = {} mock_response = mock_a2a_message mock_send_message = AsyncMock() @@ -2389,6 +2622,7 @@ class TestRemoteA2aAgentIntegration: with patch.object(agent, "_a2a_client") as mock_a2a_client: mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) mock_a2a_message.context_id = "context-123" + mock_a2a_message.metadata = {} mock_response = mock_a2a_message mock_send_message = AsyncMock() @@ -2432,3 +2666,203 @@ class TestRemoteA2aAgentIntegration: # Verify A2A client was called mock_a2a_client.send_message.assert_called_once() + + +class TestRemoteA2aAgentInterceptors: + + @pytest.fixture + def mock_context(self): + ctx = Mock(spec=InvocationContext) + ctx.session = Mock() + ctx.session.state = {"key": "value"} + return ctx + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_none(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + None, mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_empty(self, mock_context): + request = Mock(spec=A2AMessage) + result_req, params = await execute_before_request_interceptors( + [], mock_context, request + ) + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_success( + self, mock_context + ): + request = Mock(spec=A2AMessage) + new_request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + new_request, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is new_request + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_returns_event( + self, mock_context + ): + request = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = AsyncMock( + return_value=( + event, + ParametersConfig( + client_call_context=ClientCallContext(state={"updated": "true"}) + ), + ) + ) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.before_request = AsyncMock() + + result, params = await execute_before_request_interceptors( + [interceptor1, interceptor2], mock_context, request + ) + + assert result is event + assert params.client_call_context.state == {"updated": "true"} + interceptor1.before_request.assert_called_once() + interceptor2.before_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_before_request_interceptors_no_before_request( + self, mock_context + ): + request = Mock(spec=A2AMessage) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.before_request = None + + result_req, params = await execute_before_request_interceptors( + [interceptor1], mock_context, request + ) + + assert result_req is request + assert params.client_call_context.state == {"key": "value"} + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_none(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + None, mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_empty(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + result = await execute_after_request_interceptors( + [], mock_context, response, event + ) + assert result is event + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_success(self, mock_context): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + new_event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=new_event) + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is new_event + interceptor1.after_request.assert_called_once_with( + mock_context, response, event + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_reverse_order( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + event1 = Mock(spec=Event) + event2 = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock(return_value=event1) + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=event2) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is event1 + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_called_once_with( + mock_context, response, event2 + ) + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_returns_none( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = AsyncMock() + + interceptor2 = Mock(spec=RequestInterceptor) + interceptor2.after_request = AsyncMock(return_value=None) + + result = await execute_after_request_interceptors( + [interceptor1, interceptor2], mock_context, response, event + ) + + assert result is None + interceptor2.after_request.assert_called_once_with( + mock_context, response, event + ) + interceptor1.after_request.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_after_request_interceptors_no_after_request( + self, mock_context + ): + response = Mock(spec=A2AMessage) + event = Mock(spec=Event) + + interceptor1 = Mock(spec=RequestInterceptor) + interceptor1.after_request = None + + result = await execute_after_request_interceptors( + [interceptor1], mock_context, response, event + ) + + assert result is event diff --git a/tests/unittests/apps/test_compaction.py b/tests/unittests/apps/test_compaction.py index fadcd39d..6960c8d4 100644 --- a/tests/unittests/apps/test_compaction.py +++ b/tests/unittests/apps/test_compaction.py @@ -50,6 +50,7 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): invocation_id: str, text: str, prompt_token_count: int | None = None, + thought: bool = False, ) -> Event: usage_metadata = None if prompt_token_count is not None: @@ -60,7 +61,60 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): timestamp=timestamp, invocation_id=invocation_id, author='user', - content=Content(role='user', parts=[Part(text=text)]), + content=Content(role='user', parts=[Part(text=text, thought=thought)]), + usage_metadata=usage_metadata, + ) + + def _create_function_call_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + ) -> Event: + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id=function_call_id, name='tool', args={} + ) + ) + ], + ), + ) + + def _create_function_response_event( + self, + timestamp: float, + invocation_id: str, + function_call_id: str, + prompt_token_count: int | None = None, + ) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id=function_call_id, + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), usage_metadata=usage_metadata, ) @@ -249,9 +303,21 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, event_retention_size=5, ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) self.assertEqual(config.token_threshold, 50_000) self.assertEqual(config.event_retention_size, 5) + def test_events_compaction_config_accepts_sliding_window_fields(self): + config = EventsCompactionConfig( + compaction_interval=2, + overlap_size=1, + ) + self.assertEqual(config.compaction_interval, 2) + self.assertEqual(config.overlap_size, 1) + self.assertIsNone(config.token_threshold) + self.assertIsNone(config.event_retention_size) + def test_events_compaction_config_rejects_partial_token_fields( self, ): @@ -262,6 +328,23 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): token_threshold=50_000, ) + def test_events_compaction_config_rejects_partial_sliding_fields( + self, + ): + with pytest.raises(ValidationError): + EventsCompactionConfig( + compaction_interval=2, + ) + + with pytest.raises(ValidationError): + EventsCompactionConfig( + overlap_size=0, + ) + + def test_events_compaction_config_rejects_missing_modes(self): + with pytest.raises(ValidationError): + EventsCompactionConfig() + def test_latest_prompt_token_count_fallback_applies_compaction(self): events = [ self._create_event(1.0, 'inv1', 'a' * 40), @@ -275,6 +358,25 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): # Visible text after compaction is: 'S' + ('c' * 20) = 21 chars. self.assertEqual(estimated_token_count, 21 // 4) + def test_latest_prompt_token_count_fallback_uses_effective_contents(self): + events = [ + self._create_event(1.0, 'inv1', 'visible'), + Event( + timestamp=2.0, + invocation_id='inv2', + author='model', + content=Content( + role='model', + parts=[Part(text='hidden-thought', thought=True)], + ), + ), + ] + + estimated_token_count = compaction_module._latest_prompt_token_count(events) + + # Thought-only events are filtered by contents processing. + self.assertEqual(estimated_token_count, len('visible') // 4) + async def test_run_compaction_for_token_threshold_keeps_retention_events( self, ): @@ -324,6 +426,136 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_keeps_tool_call_pair( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_function_call_event(2.0, 'inv2', 'tool-call-1'), + self._create_function_response_event( + 3.0, + 'inv2', + 'tool-call-1', + prompt_token_count=100, + ), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_for_token_threshold_equal_threshold_compacts( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=100, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 1.0, 'Summary inv1' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + [e.invocation_id for e in compacted_events_arg], + ['inv1'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + + async def test_run_compaction_skip_token_compaction(self): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2', prompt_token_count=100), + ], + ) + + await _run_compaction_for_sliding_window( + app, + session, + self.mock_session_service, + skip_token_compaction=True, + ) + + self.mock_compactor.maybe_summarize_events.assert_not_called() + self.mock_session_service.append_event.assert_not_called() + async def test_run_compaction_for_token_threshold_seeds_previous_compaction( self, ): @@ -482,6 +714,68 @@ class TestCompaction(unittest.IsolatedAsyncioTestCase): session=session, event=mock_compacted_event ) + async def test_run_compaction_for_token_threshold_uses_latest_ordered_seed( + self, + ): + app = App( + name='test', + root_agent=Mock(spec=BaseAgent), + events_compaction_config=EventsCompactionConfig( + summarizer=self.mock_compactor, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + session = Session( + app_name='test', + user_id='u1', + id='s1', + events=[ + self._create_event(1.0, 'inv1', 'e1'), + self._create_event(2.0, 'inv2', 'e2'), + self._create_event(3.0, 'inv3', 'e3'), + self._create_event(4.0, 'inv4', 'e4'), + self._create_event(5.0, 'inv5', 'e5'), + self._create_event(15.0, 'inv6', 'e6'), + self._create_event(20.0, 'inv7', 'e7'), + self._create_compacted_event( + 15.0, 20.0, 'Summary 15-20', appended_ts=21.0 + ), + self._create_compacted_event( + 1.0, 5.0, 'Summary 1-5', appended_ts=22.0 + ), + self._create_event(23.0, 'inv8', 'e8'), + self._create_event(24.0, 'inv9', 'e9', prompt_token_count=120), + ], + ) + + mock_compacted_event = self._create_compacted_event( + 1.0, 23.0, 'Summary 1-23' + ) + self.mock_compactor.maybe_summarize_events.return_value = ( + mock_compacted_event + ) + + await _run_compaction_for_sliding_window( + app, session, self.mock_session_service + ) + + compacted_events_arg = self.mock_compactor.maybe_summarize_events.call_args[ + 1 + ]['events'] + self.assertEqual( + compacted_events_arg[0].content.parts[0].text, 'Summary 1-5' + ) + self.assertEqual( + [e.invocation_id for e in compacted_events_arg[1:]], + ['inv6', 'inv7', 'inv8'], + ) + self.mock_session_service.append_event.assert_called_once_with( + session=session, event=mock_compacted_event + ) + def test_get_contents_with_multiple_compactions(self): # Event timestamps: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index ec74f8ab..f3e7380b 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -29,6 +29,7 @@ from urllib.parse import unquote from urllib.parse import urlparse from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.artifacts.base_artifact_service import ensure_part from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -766,3 +767,132 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): filename=str(absolute_in_scope), artifact=part, ) + + +class TestEnsurePart: + """Tests for the ensure_part normalization helper.""" + + def test_returns_part_unchanged(self): + """A types.Part instance passes through without modification.""" + part = types.Part.from_bytes(data=b"hello", mime_type="text/plain") + result = ensure_part(part) + assert result is part + + def test_converts_camel_case_dict(self): + """A camelCase dict (Agentspace format) is converted to types.Part.""" + raw = {"inlineData": {"mimeType": "image/png", "data": "dGVzdA=="}} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.inline_data is not None + assert result.inline_data.mime_type == "image/png" + + def test_converts_snake_case_dict(self): + """A snake_case dict is converted to types.Part.""" + raw = {"inline_data": {"mime_type": "text/plain", "data": "aGVsbG8="}} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.inline_data is not None + assert result.inline_data.mime_type == "text/plain" + + def test_converts_text_dict(self): + """A dict with 'text' key is converted to types.Part.""" + raw = {"text": "hello world"} + result = ensure_part(raw) + assert isinstance(result, types.Part) + assert result.text == "hello world" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_artifact_with_camel_case_dict( + service_type, artifact_service_factory +): + """Artifact services accept camelCase dicts (Agentspace format). + + Regression test for https://github.com/google/adk-python/issues/2886 + """ + artifact_service = artifact_service_factory(service_type) + app_name = "app0" + user_id = "user0" + session_id = "sess0" + filename = "uploaded.png" + + # Simulate what Agentspace sends: a plain dict with camelCase keys. + raw_artifact = { + "inlineData": { + "mimeType": "image/png", + "data": "dGVzdF9pbWFnZV9kYXRh", + } + } + + version = await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=raw_artifact, + ) + assert version == 0 + + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.mime_type == "image/png" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_artifact_with_snake_case_dict( + service_type, artifact_service_factory +): + """Artifact services accept snake_case dicts.""" + artifact_service = artifact_service_factory(service_type) + app_name = "app0" + user_id = "user0" + session_id = "sess0" + filename = "uploaded.txt" + + raw_artifact = { + "inline_data": { + "mime_type": "text/plain", + "data": "aGVsbG8=", + } + } + + version = await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=raw_artifact, + ) + assert version == 0 + + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.mime_type == "text/plain" diff --git a/tests/unittests/auth/test_auth_preprocessor.py b/tests/unittests/auth/test_auth_preprocessor.py index 04a64fc5..fb45cc34 100644 --- a/tests/unittests/auth/test_auth_preprocessor.py +++ b/tests/unittests/auth/test_auth_preprocessor.py @@ -79,7 +79,9 @@ class TestAuthLlmRequestProcessor: @pytest.fixture def mock_auth_config(self): """Create a mock AuthConfig.""" - return Mock(spec=AuthConfig) + config = Mock(spec=AuthConfig) + config.credential_key = None + return config @pytest.fixture def mock_function_response_with_auth(self, mock_auth_config): @@ -347,10 +349,12 @@ class TestAuthLlmRequestProcessor: auth_response_1, auth_response_2, ] + user_event_with_multiple_responses.get_function_calls.return_value = [] # Create system function call events system_function_call_1 = Mock() system_function_call_1.id = 'auth_id_1' + system_function_call_1.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_1.args = { 'function_call_id': 'tool_id_1', 'auth_config': mock_auth_config, @@ -358,6 +362,7 @@ class TestAuthLlmRequestProcessor: system_function_call_2 = Mock() system_function_call_2.id = 'auth_id_2' + system_function_call_2.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_2.args = { 'function_call_id': 'tool_id_2', 'auth_config': mock_auth_config, diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 913e11ae..0ea28e66 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -15,7 +15,6 @@ import asyncio import json import logging -import os from pathlib import Path import signal import tempfile @@ -33,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.cli import fast_api as fast_api_module from google.adk.cli.fast_api import get_fast_api_app from google.adk.errors.input_validation_error import InputValidationError +from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult @@ -452,18 +452,28 @@ def mock_eval_set_results_manager(): return MockEvalSetResultsManager() -@pytest.fixture -def test_app( +def _create_test_client( mock_session_service, mock_artifact_service, mock_memory_service, mock_agent_loader, mock_eval_sets_manager, mock_eval_set_results_manager, + **app_kwargs, ): - """Create a TestClient for the FastAPI app without starting a server.""" - - # Patch multiple services and signal handlers + """Helper to create a TestClient with the given get_fast_api_app overrides.""" + defaults = dict( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=False, + host="127.0.0.1", + port=8000, + ) + defaults.update(app_kwargs) with ( patch.object(signal, "signal", autospec=True, return_value=None), patch.object( @@ -503,23 +513,28 @@ def test_app( return_value=mock_eval_set_results_manager, ), ): - # Get the FastAPI app, but don't actually run it - app = get_fast_api_app( - agents_dir=".", - web=True, - session_service_uri="", - artifact_service_uri="", - memory_service_uri="", - allow_origins=["*"], - a2a=False, # Disable A2A for most tests - host="127.0.0.1", - port=8000, - ) + app = get_fast_api_app(**defaults) + return TestClient(app) - # Create a TestClient that doesn't start a real server - client = TestClient(app) - return client +@pytest.fixture +def test_app( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Create a TestClient for the FastAPI app without starting a server.""" + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + ) @pytest.fixture @@ -677,6 +692,7 @@ def test_app_with_a2a( mock_eval_sets_manager, mock_eval_set_results_manager, temp_agents_dir_with_a2a, + monkeypatch, ): """Create a TestClient for the FastAPI app with A2A enabled.""" # Mock A2A related classes @@ -728,26 +744,22 @@ def test_app_with_a2a( mock_a2a_app.return_value = mock_app_instance # Change to temp directory - original_cwd = os.getcwd() - os.chdir(temp_agents_dir_with_a2a) + monkeypatch.chdir(temp_agents_dir_with_a2a) - try: - app = get_fast_api_app( - agents_dir=".", - web=True, - session_service_uri="", - artifact_service_uri="", - memory_service_uri="", - allow_origins=["*"], - a2a=True, - host="127.0.0.1", - port=8000, - ) + app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) - client = TestClient(app) - yield client - finally: - os.chdir(original_cwd) + client = TestClient(app) + yield client ################################################# @@ -1104,13 +1116,13 @@ def test_agent_run_sse_splits_artifact_delta( assert sse_events[1]["actions"]["artifactDelta"] == {"artifact.txt": 0} -def test_agent_run_sse_yields_error_object_on_exception( +def test_agent_run_sse_does_not_split_artifact_delta_for_function_resume( test_app, create_test_session, monkeypatch ): - """Test /run_sse streams an error object if streaming raises.""" + """Test /run_sse keeps artifactDelta with content for function resume flow.""" info = create_test_session - async def run_async_raises( + async def run_async_with_artifact_delta( self, *, user_id: str, @@ -1121,9 +1133,49 @@ def test_agent_run_sse_yields_error_object_on_exception( run_config: Optional[RunConfig] = None, ): del user_id, session_id, invocation_id, new_message, state_delta, run_config + yield Event( + author="dummy agent", + invocation_id="invocation_id", + content=types.Content( + role="model", parts=[types.Part(text="LLM reply")] + ), + actions=EventActions(artifact_delta={"artifact.txt": 0}), + ) + + monkeypatch.setattr(Runner, "run_async", run_async_with_artifact_delta) + + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello agent"}]}, + "streaming": True, + "functionCallEventId": "function-call-event-id", + } + + response = test_app.post("/run_sse", json=payload) + assert response.status_code == 200 + + sse_events = [ + json.loads(line.removeprefix("data: ")) + for line in response.text.splitlines() + if line.startswith("data: ") + ] + + assert len(sse_events) == 1 + assert sse_events[0]["content"]["parts"][0]["text"] == "LLM reply" + assert sse_events[0]["actions"]["artifactDelta"] == {"artifact.txt": 0} + + +def test_agent_run_sse_yields_error_object_on_exception( + test_app, create_test_session, monkeypatch +): + """Test /run_sse streams an error object if streaming raises.""" + info = create_test_session + + async def run_async_raises(self, **kwargs): raise ValueError("boom") - if False: # pylint: disable=using-constant-test - yield _event_1() + yield # make it an async generator # pylint: disable=unreachable monkeypatch.setattr(Runner, "run_async", run_async_raises) @@ -1406,6 +1458,86 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") +def test_a2a_request_handler_uses_push_config_store( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test A2A request handler gets push config store when supported.""" + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "a2a.server.tasks.InMemoryPushNotificationConfigStore" + ) as mock_push_config_store_class, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_task_store_instance = MagicMock() + mock_task_store.return_value = mock_task_store_instance + mock_push_config_store = MagicMock() + mock_push_config_store_class.return_value = mock_push_config_store + mock_executor_instance = MagicMock() + mock_executor.return_value = mock_executor_instance + mock_handler.return_value = MagicMock() + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance + + monkeypatch.chdir(temp_agents_dir_with_a2a) + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) + + mock_handler.assert_called_once_with( + agent_executor=mock_executor_instance, + push_config_store=mock_push_config_store, + task_store=mock_task_store_instance, + ) + + def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False @@ -1561,5 +1693,80 @@ def test_version_endpoint(test_app): assert "language_version" in data +@pytest.fixture +def test_app_auto_session( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Create a TestClient with auto_create_session=True.""" + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + web=False, + auto_create_session=True, + ) + + +@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"]) +def test_auto_creates_session( + test_app_auto_session, test_session_info, endpoint +): + """Test /run and /run_sse auto-create sessions when auto_create_session=True.""" + payload = { + "app_name": test_session_info["app_name"], + "user_id": test_session_info["user_id"], + "session_id": "nonexistent_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + + response = test_app_auto_session.post(endpoint, json=payload) + assert response.status_code == 200 + + if endpoint == "/run": + data = response.json() + assert isinstance(data, list) + assert len(data) > 0 + else: + sse_events = [ + json.loads(line.removeprefix("data: ")) + for line in response.text.splitlines() + if line.startswith("data: ") + ] + assert len(sse_events) > 0 + assert not any("error" in e for e in sse_events) + + +@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"]) +def test_returns_404_without_auto_create( + test_app, test_session_info, monkeypatch, endpoint +): + """Test /run and /run_sse return 404 for missing sessions without auto_create.""" + + async def run_async_session_not_found(self, **kwargs): + raise SessionNotFoundError(f"Session not found: {kwargs['session_id']}") + yield # make it an async generator # pylint: disable=unreachable + + monkeypatch.setattr(Runner, "run_async", run_async_session_not_found) + + payload = { + "app_name": test_session_info["app_name"], + "user_id": test_session_info["user_id"], + "session_id": "nonexistent_session", + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + } + + response = test_app.post(endpoint, json=payload) + assert response.status_code == 404 + assert "Session not found" in response.json()["detail"] + + if __name__ == "__main__": pytest.main(["-xvs", __file__]) diff --git a/tests/unittests/cli/test_service_registry.py b/tests/unittests/cli/test_service_registry.py index 37c6e7c2..dd33e006 100644 --- a/tests/unittests/cli/test_service_registry.py +++ b/tests/unittests/cli/test_service_registry.py @@ -165,6 +165,13 @@ def test_create_memory_service_agentengine_full(registry, mock_services): ) +def test_create_memory_service_memory(registry): + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + memory_service = registry.create_memory_service("memory://") + assert isinstance(memory_service, InMemoryMemoryService) + + # General Tests def test_unsupported_scheme(registry, mock_services): session_service = registry.create_session_service("unsupported://foo") diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py index f3eb3396..0a7f9fc0 100644 --- a/tests/unittests/cli/utils/test_agent_loader.py +++ b/tests/unittests/cli/utils/test_agent_loader.py @@ -993,3 +993,48 @@ class TestAgentLoader: assert len(detailed_list) == 1 assert detailed_list[0]["name"] == agent_name assert not detailed_list[0]["is_computer_use"] + + def test_list_agents_excludes_non_agent_directories(self): + """Test that list_agents filters out directories without agent definitions.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + valid_package = temp_path / "valid_agent" + valid_package.mkdir() + (valid_package / "__init__.py").write_text(dedent(""" + from google.adk.agents.base_agent import BaseAgent + + class ValidAgent(BaseAgent): + def __init__(self): + super().__init__(name="valid_agent") + + root_agent = ValidAgent() + """)) + + valid_module = temp_path / "module_agent" + valid_module.mkdir() + (valid_module / "agent.py").write_text(dedent(""" + from google.adk.agents.base_agent import BaseAgent + + class ModuleAgent(BaseAgent): + def __init__(self): + super().__init__(name="module_agent") + + root_agent = ModuleAgent() + """)) + + valid_yaml = temp_path / "yaml_agent" + valid_yaml.mkdir() + (valid_yaml / "root_agent.yaml").write_text("name: yaml_agent\n") + + (temp_path / "random_folder").mkdir() + (temp_path / "data").mkdir() + (temp_path / "tmp").mkdir() + + loader = AgentLoader(str(temp_path)) + agents = loader.list_agents() + + assert agents == ["module_agent", "valid_agent", "yaml_agent"] + assert "random_folder" not in agents + assert "data" not in agents + assert "tmp" not in agents diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 6814ef97..f7df1bf1 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -354,9 +354,145 @@ async def test_run_cli_accepts_memory_scheme( save_session=False, session_service_uri="memory://", artifact_service_uri="memory://", + memory_service_uri="memory://", ) +@pytest.mark.asyncio +async def test_run_cli_invalid_memory_uri_surfaces_value_error( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should let ValueError propagate for invalid memory service URIs.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "invalid_memory_uri.json" + input_path.write_text(json.dumps(input_json)) + + def _raise_invalid_memory_uri( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + del base_dir, memory_service_uri + raise ValueError("Unsupported memory service URI: unknown://x") + + monkeypatch.setattr( + cli, "create_memory_service_from_options", _raise_invalid_memory_uri + ) + + with pytest.raises(ValueError, match="Unsupported memory service URI"): + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="unknown://x", + ) + + +@pytest.mark.asyncio +async def test_run_cli_passes_memory_service_to_input_file( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should construct and pass the configured memory service.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "memory_input.json" + input_path.write_text(json.dumps(input_json)) + + memory_service_sentinel = object() + captured_factory_args: dict[str, Any] = {} + captured_memory_service: dict[str, Any] = {} + + def _memory_factory( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + captured_factory_args["base_dir"] = base_dir + captured_factory_args["memory_service_uri"] = memory_service_uri + return memory_service_sentinel + + async def _run_input_file( + app_name: str, + user_id: str, + agent_or_app: BaseAgent | App, + artifact_service: Any, + session_service: Any, + credential_service: InMemoryCredentialService, + input_path: str, + memory_service: Any = None, + ) -> object: + del app_name, user_id, agent_or_app, artifact_service + del session_service, credential_service, input_path + captured_memory_service["value"] = memory_service + return object() + + monkeypatch.setattr( + cli, "create_memory_service_from_options", _memory_factory + ) + monkeypatch.setattr(cli, "run_input_file", _run_input_file) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="memory://", + ) + + assert Path(captured_factory_args["base_dir"]) == parent_dir.resolve() + assert captured_factory_args["memory_service_uri"] == "memory://" + assert captured_memory_service["value"] is memory_service_sentinel + + +@pytest.mark.asyncio +async def test_run_cli_loads_dotenv_before_memory_service_creation( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should load agent .env values before creating memory service.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "dotenv_order_input.json" + input_path.write_text(json.dumps(input_json)) + + call_order: list[str] = [] + + def _load_dotenv_for_agent(agent_name: str, agents_dir: str) -> None: + del agent_name, agents_dir + call_order.append("load_dotenv") + + def _memory_factory( + *, + base_dir: Path | str, + memory_service_uri: str | None = None, + ) -> object: + del base_dir, memory_service_uri + call_order.append("create_memory") + return object() + + monkeypatch.setenv("ADK_DISABLE_LOAD_DOTENV", "0") + monkeypatch.setattr(cli.envs, "load_dotenv_for_agent", _load_dotenv_for_agent) + monkeypatch.setattr( + cli, "create_memory_service_from_options", _memory_factory + ) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + memory_service_uri="memory://", + ) + + assert "create_memory" in call_order + assert "load_dotenv" in call_order + assert call_order.index("load_dotenv") < call_order.index("create_memory") + + @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index 61b1468c..7c642dbb 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -23,6 +23,7 @@ from types import SimpleNamespace from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Tuple from unittest import mock @@ -129,7 +130,7 @@ def test_cli_create_cmd_invokes_run_cmd( # cli run @pytest.mark.parametrize( - "cli_args,expected_session_uri,expected_artifact_uri", + "cli_args,expected_session_uri,expected_artifact_uri,expected_memory_uri", [ pytest.param( [ @@ -137,15 +138,19 @@ def test_cli_create_cmd_invokes_run_cmd( "memory://", "--artifact_service_uri", "memory://", + "--memory_service_uri", + "memory://", ], "memory://", "memory://", + "memory://", id="memory_scheme_uris", ), pytest.param( [], None, None, + None, id="default_uris_none", ), ], @@ -154,8 +159,9 @@ def test_cli_run_service_uris( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cli_args: list, - expected_session_uri: str, - expected_artifact_uri: str, + expected_session_uri: Optional[str], + expected_artifact_uri: Optional[str], + expected_memory_uri: Optional[str], ) -> None: """`adk run` should forward service URIs correctly to run_cli.""" agent_dir = tmp_path / "agent" @@ -186,6 +192,7 @@ def test_cli_run_service_uris( coro_locals = captured_locals[0] assert coro_locals.get("session_service_uri") == expected_session_uri assert coro_locals.get("artifact_service_uri") == expected_artifact_uri + assert coro_locals.get("memory_service_uri") == expected_memory_uri assert coro_locals["agent_folder_name"] == "agent" diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 910bf906..d6f1426a 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -252,6 +252,15 @@ def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): assert isinstance(service, InMemoryMemoryService) +def test_create_memory_service_supports_memory_uri(tmp_path: Path): + service = service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="memory://", + ) + + assert isinstance(service, InMemoryMemoryService) + + def test_create_memory_service_raises_on_unknown_scheme( tmp_path: Path, monkeypatch ): diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index c9480601..9b27b82c 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -19,6 +19,7 @@ from unittest.mock import patch from google.adk.agents.invocation_context import InvocationContext from google.adk.code_executors.agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionInput +from google.adk.sessions.session import Session import pytest @@ -27,6 +28,10 @@ def mock_invocation_context() -> InvocationContext: """Fixture for a mock InvocationContext.""" mock = MagicMock(spec=InvocationContext) mock.invocation_id = "test-invocation-123" + session = MagicMock(spec=Session) + mock.session = session + session.state = {} + return mock @@ -71,7 +76,7 @@ class TestAgentEngineSandboxCodeExecutor: mock_json_output = MagicMock() mock_json_output.mime_type = "application/json" mock_json_output.data = json.dumps( - {"stdout": "hello world", "stderr": ""} + {"msg_out": "hello world", "msg_err": ""} ).encode("utf-8") mock_json_output.metadata = None @@ -118,3 +123,129 @@ class TestAgentEngineSandboxCodeExecutor: name="projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789", input_data={"code": 'print("hello world")'}, ) + + @patch("vertexai.Client") + def test_execute_code_recreates_sandbox_when_get_returns_none( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Existing sandbox name stored in session, but get() will return None + existing_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/old" + mock_invocation_context.session.state = { + "sandbox_name": existing_sandbox_name + } + + # Mock get to return None (simulating missing/expired sandbox) + mock_api_client.agent_engines.sandboxes.get.return_value = None + + # Mock create operation to return a new sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "recreated sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Execute using agent_engine_resource_name so a sandbox can be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert get was called for the existing sandbox + mock_api_client.agent_engines.sandboxes.get.assert_called_once_with( + name=existing_sandbox_name + ) + + # Assert create was called and session updated with new sandbox + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) + + @patch("vertexai.Client") + def test_execute_code_creates_sandbox_if_missing( + self, + mock_vertexai_client, + mock_invocation_context, + ): + # Setup Mocks + mock_api_client = MagicMock() + mock_vertexai_client.return_value = mock_api_client + + # Mock create operation to return a sandbox resource name + operation_mock = MagicMock() + created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + operation_mock.response.name = created_sandbox_name + mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock + + # Mock execute_code response + mock_response = MagicMock() + mock_json_output = MagicMock() + mock_json_output.mime_type = "application/json" + mock_json_output.data = json.dumps( + {"stdout": "created sandbox run", "stderr": ""} + ).encode("utf-8") + mock_json_output.metadata = None + mock_response.outputs = [mock_json_output] + mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( + mock_response + ) + + # Ensure session.state behaves like a dict for storing sandbox_name + mock_invocation_context.session.state = {} + + # Execute using agent_engine_resource_name so a sandbox will be created + executor = AgentEngineSandboxCodeExecutor( + agent_engine_resource_name=( + "projects/123/locations/us-central1/reasoningEngines/456" + ), + sandbox_resource_name=None, + ) + code_input = CodeExecutionInput(code='print("hello world")') + result = executor.execute_code(mock_invocation_context, code_input) + + # Assert sandbox creation was called and session state updated + mock_api_client.agent_engines.sandboxes.create.assert_called_once() + create_call_kwargs = ( + mock_api_client.agent_engines.sandboxes.create.call_args.kwargs + ) + assert create_call_kwargs["name"] == ( + "projects/123/locations/us-central1/reasoningEngines/456" + ) + assert ( + mock_invocation_context.session.state["sandbox_name"] + == created_sandbox_name + ) + + # Assert execute_code used the created sandbox name + mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( + name=created_sandbox_name, + input_data={"code": 'print("hello world")'}, + ) diff --git a/tests/unittests/code_executors/test_built_in_code_executor.py b/tests/unittests/code_executors/test_built_in_code_executor.py index 58f54c7c..cbf128fb 100644 --- a/tests/unittests/code_executors/test_built_in_code_executor.py +++ b/tests/unittests/code_executors/test_built_in_code_executor.py @@ -97,6 +97,22 @@ def test_process_llm_request_non_gemini_2_model( ) +def test_process_llm_request_non_gemini_2_model_with_disabled_check( + built_in_executor: BuiltInCodeExecutor, + monkeypatch, +): + """Tests non-Gemini models pass when model-id check is disabled.""" + monkeypatch.setenv("ADK_DISABLE_GEMINI_MODEL_ID_CHECK", "true") + llm_request = LlmRequest(model="internal-model-v1") + + built_in_executor.process_llm_request(llm_request) + + assert llm_request.config is not None + assert llm_request.config.tools == [ + types.Tool(code_execution=types.ToolCodeExecution()) + ] + + def test_process_llm_request_no_model_name( built_in_executor: BuiltInCodeExecutor, ): diff --git a/tests/unittests/code_executors/test_gke_code_executor.py b/tests/unittests/code_executors/test_gke_code_executor.py index 3d62fd8d..300780ca 100644 --- a/tests/unittests/code_executors/test_gke_code_executor.py +++ b/tests/unittests/code_executors/test_gke_code_executor.py @@ -71,19 +71,74 @@ class TestGkeCodeExecutor: assert executor.timeout_seconds == 300 assert executor.cpu_requested == "200m" assert executor.mem_limit == "512Mi" + assert executor.executor_type == "job" - def test_init_with_overrides(self): + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_init_with_overrides(self, mock_sandbox_client): """Tests that class attributes can be overridden at instantiation.""" executor = GkeCodeExecutor( namespace="test-ns", image="custom-python:latest", timeout_seconds=60, cpu_limit="1000m", + executor_type="sandbox", ) assert executor.namespace == "test-ns" assert executor.image == "custom-python:latest" assert executor.timeout_seconds == 60 assert executor.cpu_limit == "1000m" + assert executor.executor_type == "sandbox" + assert executor.sandbox_template == "python-sandbox-template" + + def test_init_backward_compatibility(self): + """Tests that the executor can be initialized with positional arguments.""" + executor = GkeCodeExecutor( + "/path/to/kubeconfig", + "test-context", + namespace="test-ns", + image="test-image", + timeout_seconds=100, + executor_type="job", + cpu_requested="100m", + mem_requested="128Mi", + cpu_limit="200m", + mem_limit="256Mi", + ) + assert executor.namespace == "test-ns" + assert executor.image == "test-image" + assert executor.timeout_seconds == 100 + assert executor.executor_type == "job" + assert executor.cpu_requested == "100m" + assert executor.mem_requested == "128Mi" + assert executor.cpu_limit == "200m" + assert executor.mem_limit == "256Mi" + assert executor.kubeconfig_path == "/path/to/kubeconfig" + assert executor.kubeconfig_context == "test-context" + + def test_init_partial_positional_args(self): + """Tests initialization with partial positional arguments.""" + executor = GkeCodeExecutor("/path/to/kubeconfig") + assert executor.kubeconfig_path == "/path/to/kubeconfig" + assert executor.kubeconfig_context is None + + def test_init_mixed_args(self): + """Tests initialization with mixed positional and keyword arguments.""" + executor = GkeCodeExecutor( + "/path/to/kubeconfig", + kubeconfig_context="test-context", + namespace="test-ns", + ) + assert executor.kubeconfig_path == "/path/to/kubeconfig" + + def test_init_sandbox_missing_dependency(self): + """Tests that init raises ImportError if k8s-agent-sandbox is missing.""" + with patch( + "google.adk.code_executors.gke_code_executor.SandboxClient", None + ): + with pytest.raises(ImportError, match="k8s-agent-sandbox not found"): + GkeCodeExecutor(executor_type="sandbox") + + GkeCodeExecutor(executor_type="sandbox") @patch("google.adk.code_executors.gke_code_executor.Watch") def test_execute_code_success( @@ -225,3 +280,170 @@ class TestGkeCodeExecutor: assert sec_context.allow_privilege_escalation is False assert sec_context.read_only_root_filesystem is True assert sec_context.capabilities.drop == ["ALL"] + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_forks_to_sandbox( + self, + mock_sandbox_client, + mock_invocation_context, + mock_k8s_clients, + ): + """Tests execute_code with executor_type='sandbox'. + + Verifies that execute_code uses SandboxClient when executor_type is set to + 'sandbox'. + """ + # Setup Sandbox mock + mock_sandbox_instance = ( + mock_sandbox_client.return_value.__enter__.return_value + ) + mock_run_result = MagicMock() + mock_run_result.stdout = "sandbox stdout" + mock_run_result.stderr = None + mock_sandbox_instance.run.return_value = mock_run_result + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "sandbox stdout" + + # Verify SandboxClient was used + mock_sandbox_client.assert_called_once() + mock_sandbox_instance.run.assert_called_once() + + # Verify Job path was NOT taken + mock_k8s_clients["batch_v1"].create_namespaced_job.assert_not_called() + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_connection_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of exceptions from SandboxClient.""" + # Setup Sandbox mock to raise exception + mock_sandbox_client.return_value.__enter__.side_effect = Exception( + "Connection failed" + ) + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + # Execute & Assert + with pytest.raises(Exception, match="Connection failed"): + executor.execute_code(mock_invocation_context, code_input) + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_runtime_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of RuntimeError from SandboxClient.""" + mock_sandbox_client.return_value.__enter__.side_effect = RuntimeError( + "Gateway not found" + ) + + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + with pytest.raises( + RuntimeError, match="Sandbox infrastructure error: Gateway not found" + ): + executor.execute_code(mock_invocation_context, code_input) + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_code_sandbox_timeout_error( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests handling of TimeoutError from SandboxClient.""" + mock_sandbox_client.return_value.__enter__.side_effect = TimeoutError( + "Execution timed out" + ) + + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput(code='print("sandbox")') + + result = executor.execute_code(mock_invocation_context, code_input) + + assert result.stdout == "" + assert "Sandbox timed out: Execution timed out" in result.stderr + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + @patch("google.adk.code_executors.gke_code_executor.Watch") + def test_execute_code_forks_to_job( + self, + mock_watch, + mock_sandbox_client, + mock_invocation_context, + mock_k8s_clients, + ): + """Tests that execute_code uses K8s Job when executor_type='job'.""" + # Setup K8s Job mocks (success path) + mock_job = MagicMock() + mock_job.status.succeeded = True + mock_watch.return_value.stream.return_value = [{"object": mock_job}] + + mock_pod = MagicMock() + mock_pod.metadata.name = "pod-1" + mock_k8s_clients["core_v1"].list_namespaced_pod.return_value.items = [ + mock_pod + ] + mock_k8s_clients["core_v1"].read_namespaced_pod_log.return_value = ( + "job stdout" + ) + + # Instantiate with job type + executor = GkeCodeExecutor(executor_type="job") + code_input = CodeExecutionInput(code='print("job")') + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "job stdout" + + # Verify Job path WAS taken + mock_k8s_clients["batch_v1"].create_namespaced_job.assert_called_once() + + # Verify SandboxClient was NOT used + mock_sandbox_client.assert_not_called() + + @patch("google.adk.code_executors.gke_code_executor.SandboxClient") + def test_execute_in_sandbox_returns_stderr( + self, + mock_sandbox_client, + mock_invocation_context, + ): + """Tests that stderr from the sandbox run is propagated to the result.""" + # Setup Sandbox mock + mock_sandbox_instance = ( + mock_sandbox_client.return_value.__enter__.return_value + ) + mock_run_result = MagicMock() + mock_run_result.stdout = "" + mock_run_result.stderr = "oops\n" + mock_sandbox_instance.run.return_value = mock_run_result + + # Instantiate with sandbox type + executor = GkeCodeExecutor(executor_type="sandbox") + code_input = CodeExecutionInput( + code="import sys; print('oops', file=sys.stderr)" + ) + + # Execute + result = executor.execute_code(mock_invocation_context, code_input) + + # Assertions + assert result.stdout == "" + assert result.stderr == "oops\n" + mock_sandbox_instance.write.assert_called_with("script.py", code_input.code) + mock_sandbox_instance.run.assert_called_with("python3 script.py") diff --git a/tests/unittests/flows/llm_flows/test_compaction_processor.py b/tests/unittests/flows/llm_flows/test_compaction_processor.py new file mode 100644 index 00000000..9f747c4b --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_compaction_processor.py @@ -0,0 +1,346 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for request-phase token compaction processor.""" + +from unittest.mock import AsyncMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.apps.app import EventsCompactionConfig +from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.events.event import Event +from google.adk.flows.llm_flows import compaction +from google.adk.flows.llm_flows import contents +from google.adk.flows.llm_flows.single_flow import SingleFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.session import Session +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +def _create_event( + *, + timestamp: float, + invocation_id: str, + text: str, + prompt_token_count: int | None = None, +) -> Event: + usage_metadata = None + if prompt_token_count is not None: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_token_count + ) + return Event( + timestamp=timestamp, + invocation_id=invocation_id, + author='user', + content=Content(role='user', parts=[Part(text=text)]), + usage_metadata=usage_metadata, + ) + + +def test_single_flow_includes_compaction_before_contents(): + flow = SingleFlow() + + compaction_index = flow.request_processors.index(compaction.request_processor) + contents_index = flow.request_processors.index(contents.request_processor) + + assert compaction_index < contents_index + + +@pytest.mark.asyncio +async def test_compaction_request_processor_no_token_config(): + session = Session(app_name='app', user_id='user', id='session', events=[]) + session_service = AsyncMock(spec=BaseSessionService) + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + compaction_interval=2, + overlap_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + session_service.append_event.assert_not_called() + + +@pytest.mark.asyncio +async def test_compaction_request_processor_runs_token_compaction(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + _create_event( + timestamp=3.0, + invocation_id='inv3', + text='e3', + prompt_token_count=100, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_compacts_with_latest_tool_response(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event(timestamp=2.0, invocation_id='inv2', text='e2'), + Event( + timestamp=3.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='model', + parts=[ + Part( + function_call=types.FunctionCall( + id='call-1', name='tool', args={} + ) + ) + ], + ), + ), + Event( + timestamp=4.0, + invocation_id='current-inv', + author='agent', + content=Content( + role='user', + parts=[ + Part( + function_response=types.FunctionResponse( + id='call-1', + name='tool', + response={'result': 'ok'}, + ) + ) + ], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'inv2', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_can_compact_current_user_event(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + Event( + timestamp=2.0, + invocation_id='current-inv', + author='user', + content=Content( + role='user', + parts=[Part(text='latest user message')], + ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100 + ), + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + compacted_event = Event(author='compactor', invocation_id=Event.new_id()) + mock_summarizer.maybe_summarize_events.return_value = compacted_event + + invocation_context = InvocationContext( + invocation_id='current-inv', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=0, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert invocation_context.token_compaction_checked + compacted_events_arg = mock_summarizer.maybe_summarize_events.call_args[1][ + 'events' + ] + assert [event.invocation_id for event in compacted_events_arg] == [ + 'inv1', + 'current-inv', + ] + session_service.append_event.assert_called_once_with( + session=session, event=compacted_event + ) + + +@pytest.mark.asyncio +async def test_compaction_request_processor_not_marked_when_not_compacted(): + session = Session( + app_name='app', + user_id='user', + id='session', + events=[ + _create_event(timestamp=1.0, invocation_id='inv1', text='e1'), + _create_event( + timestamp=2.0, + invocation_id='inv2', + text='e2', + prompt_token_count=40, + ), + ], + ) + session_service = AsyncMock(spec=BaseSessionService) + mock_summarizer = AsyncMock(spec=LlmEventSummarizer) + mock_summarizer.maybe_summarize_events.return_value = Event( + author='compactor', + invocation_id=Event.new_id(), + ) + + invocation_context = InvocationContext( + invocation_id='invocation', + agent=LlmAgent(name='agent'), + session=session, + session_service=session_service, + events_compaction_config=EventsCompactionConfig( + summarizer=mock_summarizer, + compaction_interval=999, + overlap_size=0, + token_threshold=50, + event_retention_size=1, + ), + ) + + llm_request = LlmRequest() + processor = compaction.CompactionRequestProcessor() + + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + assert not events + assert not invocation_context.token_compaction_checked + mock_summarizer.maybe_summarize_events.assert_not_called() + session_service.append_event.assert_not_called() diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 93f8c151..7aacb237 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -19,7 +19,10 @@ from typing import Callable from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types @@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses(): @pytest.mark.asyncio async def test_function_call_args_not_modified(): """Test that function_call.args is not modified when making a copy.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -455,8 +456,6 @@ async def test_function_call_args_not_modified(): @pytest.mark.asyncio async def test_function_call_args_none_handling(): """Test that function_call.args=None is handled correctly.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(**kwargs) -> dict: return {'result': 'test'} @@ -504,8 +503,6 @@ async def test_function_call_args_none_handling(): @pytest.mark.asyncio async def test_function_call_args_copy_behavior(): """Test that modifying the copied args doesn't affect the original.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(test_param: str, other_param: int) -> dict: # Modify the args to test that the copy prevents affecting the original @@ -565,8 +562,6 @@ async def test_function_call_args_copy_behavior(): @pytest.mark.asyncio async def test_function_call_args_deep_copy_behavior(): """Test that deep copy behavior works correctly with nested structures.""" - from google.adk.flows.llm_flows.functions import handle_function_calls_async - from google.adk.flows.llm_flows.functions import handle_function_calls_live def simple_fn(nested_dict: dict, list_param: list) -> dict: # Modify the nested structures to test deep copy @@ -1141,3 +1136,62 @@ async def test_mixed_function_types_execution_order(): 'yield_E', 'yield_F', ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'handle_function_calls', + [ + (handle_function_calls_async), + (handle_function_calls_live), + ], +) +async def test_computer_use_tool_decoding_behavior(handle_function_calls): + """Tests that computer use tools automatically decode base64 images.""" + valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7' + + # make the tool return a dictionary with the image + async def mock_run(*args, **kwargs): + return { + 'image': {'data': valid_b64, 'mimetype': 'image/png'}, + 'url': 'https://example.com', + } + + # create a ComputerUseTool + tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768)) + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create function call + function_call = types.FunctionCall(name=tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result = await handle_function_calls( + invocation_context, + event, + tools_dict, + ) + + assert result is not None + response_part = result.content.parts[0].function_response + + # Verify original image data is removed from the dict response + assert 'image' not in response_part.response + assert 'url' in response_part.response + # Verify the image was converted to a blob + assert len(response_part.parts) == 1 + assert response_part.parts[0].inline_data is not None diff --git a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py index caab8f3f..016e9b49 100644 --- a/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_live_tool_callbacks.py @@ -386,3 +386,80 @@ async def test_live_callback_compatibility_with_async(): async_response = async_result.content.parts[0].function_response.response live_response = live_result.content.parts[0].function_response.response assert async_response == live_response == {"bypassed": "by_before_callback"} + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_noop(): + """Test that on_tool_error_callback is a no-op when the tool is not found.""" + + def noop_on_tool_error_callback(tool, args, tool_context, error): + return None + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=noop_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + with pytest.raises(ValueError): + await handle_function_calls_live(invocation_context, event, tools_dict) + + +@pytest.mark.asyncio +async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response(): + """Test that on_tool_error_callback modifies tool response when tool is not found.""" + + def mock_on_tool_error_callback(tool, args, tool_context, error): + return {"result": "on_tool_error_callback_response"} + + def simple_fn(**kwargs) -> Dict[str, Any]: + return {"initial": "response"} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[tool], + on_tool_error_callback=mock_on_tool_error_callback, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="" + ) + function_call = types.FunctionCall(name="nonexistent_function", args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + result_event = await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == { + "result": "on_tool_error_callback_response" + } diff --git a/tests/unittests/flows/llm_flows/test_output_schema_processor.py b/tests/unittests/flows/llm_flows/test_output_schema_processor.py index 3e95dea1..23c741bc 100644 --- a/tests/unittests/flows/llm_flows/test_output_schema_processor.py +++ b/tests/unittests/flows/llm_flows/test_output_schema_processor.py @@ -199,7 +199,6 @@ async def test_output_schema_request_processor( @pytest.mark.asyncio async def test_set_model_response_tool(): """Test the set_model_response tool functionality.""" - from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY from google.adk.tools.set_model_response_tool import SetModelResponseTool from google.adk.tools.tool_context import ToolContext @@ -215,18 +214,12 @@ async def test_set_model_response_tool(): tool_context=tool_context, ) - # Verify the tool now returns dict directly + # Verify the tool returns dict directly assert result is not None assert result['name'] == 'John Doe' assert result['age'] == 30 assert result['city'] == 'New York' - # Check that the response is no longer stored in session state - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_output_schema_helper_functions(): @@ -328,6 +321,48 @@ async def test_get_structured_model_response_with_non_ascii(): assert extracted_json == expected_json +@pytest.mark.asyncio +async def test_get_structured_model_response_with_wrapped_result(): + """Test get_structured_model_response with wrapped list result. + + When a tool returns a non-dict (e.g., list), it gets wrapped as + {'result': [...]}. This test ensures we correctly unwrap the result. + """ + from google.adk.events.event import Event + from google.adk.flows.llm_flows._output_schema_processor import get_structured_model_response + from google.genai import types + + # Simulate a list result wrapped by ADK's functions.py + wrapped_response = { + 'result': [ + {'name': 'Alice', 'age': 30}, + {'name': 'Bob', 'age': 25}, + ] + } + expected_json = '[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]' + + # Create a function response event with wrapped result + function_response_event = Event( + author='test_agent', + content=types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='set_model_response', response=wrapped_response + ) + ) + ], + ), + ) + + # Get the structured response + extracted_json = get_structured_model_response(function_response_event) + + # Should extract the unwrapped list, not the wrapped dict + assert extracted_json == expected_json + + @pytest.mark.asyncio async def test_end_to_end_integration(): """Test the complete output schema with tools integration.""" diff --git a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py index cc375ad0..3c39e284 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py @@ -19,6 +19,7 @@ from typing import Optional from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.flows.llm_flows.functions import handle_function_calls_live from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_tool import BaseTool from google.adk.tools.function_tool import FunctionTool @@ -185,5 +186,159 @@ async def test_async_on_tool_error_fallback_to_runner( assert e == mock_error +async def invoke_tool_with_plugin_live( + mock_tool, mock_plugin +) -> Optional[Event]: + """Invokes a tool with a plugin using the live path.""" + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + # Build function call event + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + return await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + +@pytest.mark.asyncio +async def test_live_before_tool_callback(mock_tool, mock_plugin): + mock_plugin.enable_before_tool_callback = True + + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.before_tool_response + + +@pytest.mark.asyncio +async def test_live_after_tool_callback(mock_tool, mock_plugin): + mock_plugin.enable_after_tool_callback = True + + result_event = await invoke_tool_with_plugin_live(mock_tool, mock_plugin) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.after_tool_response + + +@pytest.mark.asyncio +async def test_live_on_tool_error_use_plugin_response( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = True + + result_event = await invoke_tool_with_plugin_live( + mock_error_tool, mock_plugin + ) + + assert result_event is not None + part = result_event.content.parts[0] + assert part.function_response.response == mock_plugin.on_tool_error_response + + +@pytest.mark.asyncio +async def test_live_on_tool_error_fallback_to_runner( + mock_error_tool, mock_plugin +): + mock_plugin.enable_on_tool_error_callback = False + + try: + await invoke_tool_with_plugin_live(mock_error_tool, mock_plugin) + except Exception as e: + assert e == mock_error + + +@pytest.mark.asyncio +async def test_live_plugin_before_tool_callback_takes_priority( + mock_tool, mock_plugin +): + """Plugin before_tool_callback should run before agent canonical callbacks.""" + mock_plugin.enable_before_tool_callback = True + + def agent_before_cb(tool, args, tool_context): + return {"agent": "should_not_be_called"} + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + before_tool_callback=agent_before_cb, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + result_event = await handle_function_calls_live( + invocation_context, event, tools_dict + ) + + assert result_event is not None + part = result_event.content.parts[0] + # Plugin response should win, not the agent callback + assert part.function_response.response == mock_plugin.before_tool_response + + +@pytest.mark.asyncio +async def test_live_plugin_after_tool_callback_takes_priority( + mock_tool, mock_plugin +): + """Plugin after_tool_callback should run before agent canonical callbacks.""" + mock_plugin.enable_after_tool_callback = True + + def agent_after_cb(tool, args, tool_context, tool_response): + return {"agent": "should_not_be_called"} + + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name="agent", + model=model, + tools=[mock_tool], + after_tool_callback=agent_after_cb, + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content="", plugins=[mock_plugin] + ) + function_call = types.FunctionCall(name=mock_tool.name, args={}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {mock_tool.name: mock_tool} + result_event = await handle_function_calls_live( + invocation_context, event, tools_dict + ) + + assert result_event is not None + part = result_event.content.parts[0] + # Plugin response should win, not the agent callback + assert part.function_response.response == mock_plugin.after_tool_response + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/unittests/integrations/agent_registry/__init__.py b/tests/unittests/integrations/agent_registry/__init__.py new file mode 100644 index 00000000..58d482ea --- /dev/null +++ b/tests/unittests/integrations/agent_registry/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py new file mode 100644 index 00000000..fc680869 --- /dev/null +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -0,0 +1,250 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +from a2a.types import TransportProtocol as A2ATransport +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent +from google.adk.integrations.agent_registry import _ProtocolType +from google.adk.integrations.agent_registry import AgentRegistry +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import httpx +import pytest + + +class TestAgentRegistry: + + @pytest.fixture + def registry(self): + with patch("google.auth.default", return_value=(MagicMock(), "project-id")): + return AgentRegistry(project_id="test-project", location="global") + + def test_init_raises_value_error_if_params_missing(self): + with pytest.raises( + ValueError, match="project_id and location must be provided" + ): + AgentRegistry(project_id=None, location=None) + + def test_get_connection_uri_mcp_interfaces_top_level(self, registry): + resource_details = { + "interfaces": [ + {"url": "https://mcp-v1main.com", "protocolBinding": "JSONRPC"} + ] + } + uri = registry._get_connection_uri( + resource_details, protocol_binding=A2ATransport.jsonrpc + ) + assert uri == "https://mcp-v1main.com" + + def test_get_connection_uri_agent_nested_protocols(self, registry): + resource_details = { + "protocols": [{ + "type": _ProtocolType.A2A_AGENT, + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": A2ATransport.jsonrpc, + }], + }] + } + uri = registry._get_connection_uri( + resource_details, protocol_type=_ProtocolType.A2A_AGENT + ) + assert uri == "https://my-agent.com" + + def test_get_connection_uri_filtering(self, registry): + resource_details = { + "protocols": [ + { + "type": "CUSTOM", + "interfaces": [{"url": "https://custom.com"}], + }, + { + "type": _ProtocolType.A2A_AGENT, + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": A2ATransport.http_json, + }], + }, + ] + } + # Filter by type + uri = registry._get_connection_uri( + resource_details, protocol_type=_ProtocolType.A2A_AGENT + ) + assert uri == "https://my-agent.com" + + # Filter by binding + uri = registry._get_connection_uri( + resource_details, protocol_binding=A2ATransport.http_json + ) + assert uri == "https://my-agent.com" + + # No match + uri = registry._get_connection_uri( + resource_details, + protocol_type=_ProtocolType.A2A_AGENT, + protocol_binding=A2ATransport.jsonrpc, + ) + assert uri is None + + def test_get_connection_uri_returns_none_if_no_interfaces(self, registry): + resource_details = {} + uri = registry._get_connection_uri(resource_details) + assert uri is None + + def test_get_connection_uri_returns_none_if_no_url_in_interfaces( + self, registry + ): + resource_details = {"interfaces": [{"protocolBinding": "HTTP"}]} + uri = registry._get_connection_uri(resource_details) + assert uri is None + + @patch("httpx.Client") + def test_list_agents(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = {"agents": []} + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + # Mock auth refresh + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + agents = registry.list_agents() + assert agents == {"agents": []} + + @patch("httpx.Client") + def test_get_mcp_server(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = {"name": "test-mcp"} + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + server = registry.get_mcp_server("test-mcp") + assert server == {"name": "test-mcp"} + + @patch("httpx.Client") + def test_get_mcp_toolset(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = { + "displayName": "TestPrefix", + "interfaces": [{ + "url": "https://mcp.com", + "protocolBinding": A2ATransport.jsonrpc, + }], + } + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + toolset = registry.get_mcp_toolset("test-mcp") + assert isinstance(toolset, McpToolset) + assert toolset.tool_name_prefix == "TestPrefix" + + @patch("httpx.Client") + def test_get_remote_a2a_agent(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.json.return_value = { + "displayName": "TestAgent", + "description": "Test Desc", + "version": "1.0", + "protocols": [{ + "type": _ProtocolType.A2A_AGENT, + "interfaces": [{ + "url": "https://my-agent.com", + "protocolBinding": A2ATransport.jsonrpc, + }], + }], + "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], + } + mock_response.raise_for_status = MagicMock() + mock_httpx.return_value.__enter__.return_value.get.return_value = ( + mock_response + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + agent = registry.get_remote_a2a_agent("test-agent") + assert isinstance(agent, RemoteA2aAgent) + assert agent.name == "TestAgent" + assert agent.description == "Test Desc" + assert agent._agent_card.url == "https://my-agent.com" + assert agent._agent_card.version == "1.0" + assert len(agent._agent_card.skills) == 1 + assert agent._agent_card.skills[0].name == "Skill 1" + + def test_get_auth_headers(self, registry): + registry._credentials.token = "fake-token" + registry._credentials.refresh = MagicMock() + registry._credentials.quota_project_id = "quota-project" + + headers = registry._get_auth_headers() + assert headers["Authorization"] == "Bearer fake-token" + assert headers["x-goog-user-project"] == "quota-project" + + @patch("httpx.Client") + def test_make_request_raises_http_status_error(self, mock_httpx, registry): + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + error = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=mock_response + ) + mock_httpx.return_value.__enter__.return_value.get.side_effect = error + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises( + RuntimeError, match="API request failed with status 404" + ): + registry._make_request("test-path") + + @patch("httpx.Client") + def test_make_request_raises_request_error(self, mock_httpx, registry): + error = httpx.RequestError("Connection failed", request=MagicMock()) + mock_httpx.return_value.__enter__.return_value.get.side_effect = error + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises( + RuntimeError, match="API request failed \(network error\)" + ): + registry._make_request("test-path") + + @patch("httpx.Client") + def test_make_request_raises_generic_exception(self, mock_httpx, registry): + mock_httpx.return_value.__enter__.return_value.get.side_effect = Exception( + "Generic error" + ) + + registry._credentials.token = "token" + registry._credentials.refresh = MagicMock() + + with pytest.raises(RuntimeError, match="API request failed: Generic error"): + registry._make_request("test-path") diff --git a/tests/unittests/integrations/api_registry/__init__.py b/tests/unittests/integrations/api_registry/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/tests/unittests/integrations/api_registry/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/integrations/api_registry/test_api_registry.py similarity index 96% rename from tests/unittests/tools/test_api_registry.py rename to tests/unittests/integrations/api_registry/test_api_registry.py index 59612434..7edaee9f 100644 --- a/tests/unittests/tools/test_api_registry.py +++ b/tests/unittests/integrations/api_registry/test_api_registry.py @@ -18,8 +18,8 @@ from unittest.mock import create_autospec from unittest.mock import MagicMock from unittest.mock import patch -from google.adk.tools import api_registry -from google.adk.tools.api_registry import ApiRegistry +from google.adk.integrations import api_registry +from google.adk.integrations.api_registry import ApiRegistry from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams import httpx @@ -218,7 +218,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) mock_response.raise_for_status.assert_called_once() - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): mock_response = MagicMock() @@ -245,7 +248,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(toolset, MockMcpToolset.return_value) - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_with_quota_project_id_success( self, MockHttpClient, MockMcpToolset @@ -277,7 +283,10 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(toolset, MockMcpToolset.return_value) - @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch( + "google.adk.integrations.api_registry.api_registry.McpToolset", + autospec=True, + ) @patch("httpx.Client", autospec=True) async def test_get_toolset_with_filter_and_prefix( self, MockHttpClient, MockMcpToolset @@ -321,7 +330,7 @@ class TestApiRegistry(unittest.IsolatedAsyncioTestCase): with ( patch.object(httpx, "Client", autospec=True) as MockHttpClient, patch.object( - api_registry, "McpToolset", autospec=True + api_registry.api_registry, "McpToolset", autospec=True ) as MockMcpToolset, ): mock_response = create_autospec(httpx.Response, instance=True) diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 6f342a08..c498b833 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -230,6 +230,14 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +def test_initialize_without_agent_engine_id_error(): + with pytest.raises( + ValueError, + match='agent_engine_id is required for VertexAiMemoryBankService', + ): + mock_vertex_ai_memory_bank_service(agent_engine_id=None) + + @pytest.mark.asyncio async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() @@ -481,6 +489,7 @@ async def test_add_memory_calls_create( ), ], custom_metadata={ + 'enable_consolidation': False, 'ttl': '6000s', 'source': 'agent', }, @@ -518,6 +527,139 @@ async def test_add_memory_calls_create( vertex_common_types.AgentEngineMemoryConfig(**create_config) +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_calls_generate_direct_source( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + 'source': 'agent', + }, + ) + + expected_config = {'wait_for_completion': False} + if _supports_generate_memories_metadata(): + expected_config['metadata'] = {'source': {'string_value': 'agent'}} + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_batches_generate_calls( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact three')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact four')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact five')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact six')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + }, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_has_awaits([ + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + {'fact': 'fact three'}, + {'fact': 'fact four'}, + {'fact': 'fact five'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact six'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + ]) + assert mock_vertexai_client.agent_engines.memories.generate.await_count == 2 + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_invalid_enable_consolidation_type_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + TypeError, + match=r'custom_metadata\["enable_consolidation"\] must be a bool', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ) + ], + custom_metadata={'enable_consolidation': 'yes'}, + ) + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_memory_calls_create_with_memory_entry_metadata( mock_vertexai_client, diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index b0a282d1..e28efdb1 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys from unittest import mock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from anthropic import types as anthropic_types from google.adk import version as adk_version @@ -23,6 +26,7 @@ from google.adk.models.anthropic_llm import AnthropicLlm from google.adk.models.anthropic_llm import Claude from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param +from google.adk.models.anthropic_llm import part_to_message_block from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types @@ -661,3 +665,354 @@ def test_content_to_message_param_with_images( ) else: mock_logger.warning.assert_not_called() + + +# --- Tests for Bug #2: json.dumps for dict/list function results --- + + +def test_part_to_message_block_dict_result_serialized_as_json(): + """Dict results should be serialized with json.dumps, not str().""" + response_part = types.Part.from_function_response( + name="get_topic", + response={"result": {"topic": "travel", "active": True, "count": None}}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + content = result["content"] + + # Must be valid JSON (json.dumps produces "true"/"null", not "True"/"None") + parsed = json.loads(content) + assert parsed["topic"] == "travel" + assert parsed["active"] is True + assert parsed["count"] is None + + +def test_part_to_message_block_list_result_serialized_as_json(): + """List results should be serialized with json.dumps.""" + response_part = types.Part.from_function_response( + name="get_items", + response={"result": ["item1", "item2", "item3"]}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + content = result["content"] + + parsed = json.loads(content) + assert parsed == ["item1", "item2", "item3"] + + +def test_part_to_message_block_empty_dict_result_not_dropped(): + """Empty dict results should produce '{}', not empty string.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"result": {}}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "{}" + + +def test_part_to_message_block_empty_list_result_not_dropped(): + """Empty list results should produce '[]', not empty string.""" + response_part = types.Part.from_function_response( + name="some_tool", + response={"result": []}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "[]" + + +def test_part_to_message_block_string_result_unchanged(): + """String results should still work as before (backward compat).""" + response_part = types.Part.from_function_response( + name="simple_tool", + response={"result": "plain text result"}, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + assert result["content"] == "plain text result" + + +def test_part_to_message_block_nested_dict_result(): + """Nested dict with arrays should produce valid JSON.""" + response_part = types.Part.from_function_response( + name="search", + response={ + "result": { + "results": [ + {"id": 1, "tags": ["a", "b"]}, + {"id": 2, "meta": {"key": "val"}}, + ], + "has_more": False, + } + }, + ) + response_part.function_response.id = "test_id" + + result = part_to_message_block(response_part) + parsed = json.loads(result["content"]) + assert parsed["has_more"] is False + assert parsed["results"][0]["tags"] == ["a", "b"] + + +# --- Tests for Bug #1: Streaming support --- + + +def _make_mock_stream_events(events): + """Helper to create an async iterable from a list of events.""" + + async def _stream(): + for event in events: + yield event + + return _stream() + + +@pytest.mark.asyncio +async def test_streaming_text_yields_partial_and_final(): + """Streaming text should yield partial chunks then a final response.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=10, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hello ", type="text_delta"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="world!", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=5), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="You are helpful", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + # 2 partial text chunks + 1 final aggregated + assert len(responses) == 3 + assert responses[0].partial is True + assert responses[0].content.parts[0].text == "Hello " + assert responses[1].partial is True + assert responses[1].content.parts[0].text == "world!" + assert responses[2].partial is False + assert responses[2].content.parts[0].text == "Hello world!" + assert responses[2].usage_metadata.prompt_token_count == 10 + assert responses[2].usage_metadata.candidates_token_count == 5 + + +@pytest.mark.asyncio +async def test_streaming_tool_use_yields_function_call(): + """Streaming tool_use should accumulate args and yield in final.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=20, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Checking.", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="content_block_start", + index=1, + content_block=anthropic_types.ToolUseBlock( + id="toolu_abc", + name="get_weather", + input={}, + type="tool_use", + ), + ), + MagicMock( + type="content_block_delta", + index=1, + delta=anthropic_types.InputJSONDelta( + partial_json='{"city": "Paris"}', + type="input_json_delta", + ), + ), + MagicMock(type="content_block_stop", index=1), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="tool_use"), + usage=MagicMock(output_tokens=12), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[ + Content( + role="user", + parts=[Part.from_text(text="Weather?")], + ) + ], + config=types.GenerateContentConfig( + system_instruction="You are helpful", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + # 1 text partial + 1 final + assert len(responses) == 2 + + final = responses[-1] + assert final.partial is False + assert len(final.content.parts) == 2 + assert final.content.parts[0].text == "Checking." + assert final.content.parts[1].function_call.name == "get_weather" + assert final.content.parts[1].function_call.args == {"city": "Paris"} + assert final.content.parts[1].function_call.id == "toolu_abc" + + +@pytest.mark.asyncio +async def test_streaming_passes_stream_true_to_create(): + """When stream=True, messages.create should be called with stream=True.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hi", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="end_turn"), + usage=MagicMock(output_tokens=1), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + _ = [r async for r in llm.generate_content_async(llm_request, stream=True)] + + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert kwargs["stream"] is True + + +@pytest.mark.asyncio +async def test_non_streaming_does_not_pass_stream_param(): + """When stream=False, messages.create should NOT get stream param.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + mock_message = anthropic_types.Message( + id="msg_test", + content=[ + anthropic_types.TextBlock(text="Hello!", type="text", citations=None) + ], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Test", + ), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + mock_client.messages.create.assert_called_once() + _, kwargs = mock_client.messages.create.call_args + assert "stream" not in kwargs diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index 67894ea8..c57bc9fc 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -19,6 +19,7 @@ from unittest import mock from unittest.mock import AsyncMock from google.adk.models.apigee_llm import ApigeeLlm +from google.adk.models.apigee_llm import CompletionsHTTPClient from google.adk.models.llm_request import LlmRequest from google.genai import types from google.genai.types import Content @@ -441,7 +442,6 @@ async def test_model_string_parsing_and_client_initialization( @pytest.mark.parametrize( 'invalid_model_string', [ - 'apigee/openai/v1/gpt', 'apigee/', # Missing model_id 'apigee', # Invalid format 'gemini-pro', # Invalid format @@ -455,3 +455,175 @@ async def test_invalid_model_strings_raise_value_error(invalid_model_string): ValueError, match=f'Invalid model string: {invalid_model_string}' ): ApigeeLlm(model=invalid_model_string, proxy_url=PROXY_URL) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'model', + [ + 'apigee/openai/gpt-4o', + 'apigee/openai/v1/gpt-4o', + 'apigee/openai/v1/gpt-3.5-turbo', + ], +) +async def test_validate_model_for_chat_completion_providers(model): + """Tests that new providers like OpenAI are accepted.""" + # Should not raise ValueError + ApigeeLlm(model=model, proxy_url=PROXY_URL) + + +@pytest.mark.parametrize( + ('model', 'api_type', 'expected_api_type'), + [ + # Default case (input defaults to UNKNOWN) + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/v1/gpt-3.5-turbo', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/v1/gemini-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/vertex_ai/gemini-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/vertex_ai/v1beta/gemini-1.5-pro', + ApigeeLlm.ApiType.UNKNOWN, + ApigeeLlm.ApiType.GENAI, + ), + # Override by setting the ApiType + ( + 'apigee/gemini/pro', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/pro', + ApigeeLlm.ApiType.GENAI, + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/gpt-4o', + ApigeeLlm.ApiType.GENAI, + ApigeeLlm.ApiType.GENAI, + ), + # Override by setting the ApiType as a string + ( + 'apigee/gemini/pro', + 'chat_completions', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/gemini/pro', + 'genai', + ApigeeLlm.ApiType.GENAI, + ), + ( + 'apigee/openai/gpt-4o', + 'chat_completions', + ApigeeLlm.ApiType.CHAT_COMPLETIONS, + ), + ( + 'apigee/openai/gpt-4o', + 'genai', + ApigeeLlm.ApiType.GENAI, + ), + ], +) +def test_api_type_resolution(model, api_type, expected_api_type): + """Tests that api_type is resolved correctly.""" + llm = ApigeeLlm( + model=model, + proxy_url=PROXY_URL, + api_type=api_type, + ) + assert llm._api_type == expected_api_type + + +@pytest.mark.parametrize( + ('input_value', 'expected_type'), + [ + ('chat_completions', ApigeeLlm.ApiType.CHAT_COMPLETIONS), + ('genai', ApigeeLlm.ApiType.GENAI), + ('unknown', ApigeeLlm.ApiType.UNKNOWN), + ('', ApigeeLlm.ApiType.UNKNOWN), + (None, ApigeeLlm.ApiType.UNKNOWN), + ], +) +def test_apitype_creation(input_value, expected_type): + """Tests the creation of ApiType enum members.""" + assert ApigeeLlm.ApiType(input_value) == expected_type + + +def test_apitype_creation_invalid(): + """Tests that invalid ApiType raises ValueError.""" + with pytest.raises(ValueError): + ApigeeLlm.ApiType('invalid') + + +def test_invalid_api_type_raises_error(): + """Tests that invalid string for api_type raises ValueError.""" + with pytest.raises(ValueError): + ApigeeLlm( + model='apigee/gemini-pro', + proxy_url=PROXY_URL, + api_type='invalid_type', + ) + + +@pytest.mark.asyncio +async def test_generate_content_async_dispatch_to_completions_client( + llm_request, +): + """Tests that generate_content_async uses CompletionsHTTPClient for OpenAI models.""" + llm_request.model = 'apigee/openai/gpt-4o' + with ( + mock.patch.object( + CompletionsHTTPClient, + 'generate_content_async', + ) as mock_completions_generate_content, + mock.patch('google.genai.Client') as mock_genai_client, + ): + apigee_llm = ApigeeLlm(model='apigee/openai/gpt-4o', proxy_url=PROXY_URL) + _ = [ + r + async for r in apigee_llm.generate_content_async( + llm_request, stream=False + ) + ] + mock_completions_generate_content.assert_called_once() + mock_genai_client.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'model', + [ + 'apigee/openai/gpt-4o', + 'apigee/openai/v1/gpt-3.5-turbo', + ], +) +async def test_api_key_injection_openai(model): + """Tests that api_key is injected for OpenAI models.""" + apigee_llm = ApigeeLlm( + model=model, + proxy_url=PROXY_URL, + custom_headers={'Authorization': 'Bearer sk-test-key'}, + ) + client = apigee_llm._completions_http_client + assert client._headers['Authorization'] == 'Bearer sk-test-key' diff --git a/tests/unittests/models/test_completions_http_client.py b/tests/unittests/models/test_completions_http_client.py new file mode 100644 index 00000000..615871eb --- /dev/null +++ b/tests/unittests/models/test_completions_http_client.py @@ -0,0 +1,773 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock +from unittest.mock import AsyncMock + +from google.adk.models.apigee_llm import CompletionsHTTPClient +from google.adk.models.llm_request import LlmRequest +from google.genai import types +import httpx +import pytest + + +@pytest.fixture +def client(): + return CompletionsHTTPClient(base_url='https://localhost') + + +@pytest.fixture(name='llm_request') +def fixture_llm_request(): + return LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='Hello')]) + ], + ) + + +@pytest.mark.asyncio +async def test_construct_payload_basic_payload(client, llm_request): + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + call_args = mock_post.call_args + url = call_args[0][0] + kwargs = call_args[1] + + assert url == 'https://localhost/chat/completions' + payload = kwargs['json'] + assert payload['model'] == 'open_llama' + assert payload['stream'] is False + assert len(payload['messages']) == 1 + assert payload['messages'][0]['role'] == 'user' + assert payload['messages'][0]['content'] == 'Hello' + + +@pytest.mark.asyncio +async def test_construct_payload_with_config(client, llm_request): + llm_request.config = types.GenerateContentConfig( + temperature=0.7, + top_p=0.9, + max_output_tokens=100, + stop_sequences=['STOP'], + frequency_penalty=0.5, + presence_penalty=0.5, + seed=42, + candidate_count=2, + response_mime_type='application/json', + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + + assert payload['temperature'] == 0.7 + assert payload['top_p'] == 0.9 + assert payload['max_tokens'] == 100 + assert payload['stop'] == ['STOP'] + assert payload['frequency_penalty'] == 0.5 + assert payload['presence_penalty'] == 0.5 + assert payload['seed'] == 42 + assert payload['n'] == 2 + assert payload['response_format'] == {'type': 'json_object'} + + +@pytest.mark.asyncio +async def test_construct_payload_with_tools(client, llm_request): + tool = types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name='get_weather', + description='Get weather', + parameters=types.Schema( + type=types.Type.OBJECT, + properties={'location': types.Schema(type=types.Type.STRING)}, + ), + ) + ] + ) + llm_request.config = types.GenerateContentConfig(tools=[tool]) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert 'tools' in payload + assert payload['tools'][0]['function']['name'] == 'get_weather' + + +@pytest.mark.asyncio +async def test_construct_payload_system_instruction(client, llm_request): + llm_request.config = types.GenerateContentConfig( + system_instruction='You are a helpful assistant.' + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': 'Hi'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + payload = mock_post.call_args[1]['json'] + assert payload['messages'][0]['role'] == 'system' + assert payload['messages'][0]['content'] == 'You are a helpful assistant.' + # Ensure user message follows system + assert payload['messages'][1]['role'] == 'user' + + +@pytest.mark.asyncio +async def test_construct_payload_multimodal_content(client): + # Mock inline_data for image + image_data = b'fake_image_bytes' + llm_request = LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_text(text='What is this?'), + types.Part.from_bytes( + data=image_data, mime_type='image/jpeg' + ), + ], + ) + ], + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [ + {'message': {'role': 'assistant', 'content': 'It is an image'}} + ] + } + + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert len(payload['messages']) == 1 + message = payload['messages'][0] + assert message['role'] == 'user' + assert isinstance(message['content'], list) + assert len(message['content']) == 2 + assert message['content'][0] == {'type': 'text', 'text': 'What is this?'} + assert message['content'][1]['type'] == 'image_url' + # Base64 encoding of b'fake_image_bytes' is 'ZmFrZV9pbWFnZV9ieXRlcw==' + assert message['content'][1]['image_url']['url'] == ( + 'data:image/jpeg;base64,ZmFrZV9pbWFnZV9ieXRlcw==' + ) + + +@pytest.mark.asyncio +async def test_construct_payload_image_file_uri(client): + llm_request = LlmRequest( + model='apigee/open_llama', + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_uri( + file_uri='https://localhost/image.jpg', + mime_type='image/jpeg', + ) + ], + ) + ], + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [ + {'message': {'role': 'assistant', 'content': 'It is an image'}} + ] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + assert len(payload['messages']) == 1 + message = payload['messages'][0] + assert message['role'] == 'user' + assert isinstance(message['content'], list) + assert message['content'][0] == { + 'type': 'image_url', + 'image_url': {'url': 'https://localhost/image.jpg'}, + } + + +@pytest.mark.asyncio +async def test_generate_content_async_function_call_response( + client, llm_request +): + # Mock response with tool call + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'tool_calls': [{ + 'id': 'call_123', + 'type': 'function', + 'function': { + 'name': 'get_weather', + 'arguments': '{"location": "London"}', + }, + }], + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + responses = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + part = responses[0].content.parts[0] + assert part.function_call + assert part.function_call.name == 'get_weather' + assert part.function_call.args == {'location': 'London'} + assert part.function_call.id == 'call_123' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('response_json_schema', 'response_mime_type', 'expected_response_format'), + [ + # Case 1: Only response_json_schema is provided + ( + {'type': 'object', 'properties': {'name': {'type': 'string'}}}, + None, + { + 'type': 'json_schema', + 'json_schema': { + 'type': 'object', + 'properties': {'name': {'type': 'string'}}, + }, + }, + ), + # Case 2: Both provided, schema takes precedence + ( + {'type': 'object', 'properties': {'name': {'type': 'string'}}}, + 'application/json', + { + 'type': 'json_schema', + 'json_schema': { + 'type': 'object', + 'properties': {'name': {'type': 'string'}}, + }, + }, + ), + # Case 3: Only response_mime_type is provided + ( + None, + 'application/json', + {'type': 'json_object'}, + ), + ], +) +async def test_construct_payload_response_format( + client, + llm_request, + response_json_schema, + response_mime_type, + expected_response_format, +): + llm_request.config = types.GenerateContentConfig( + response_json_schema=response_json_schema, + response_mime_type=response_mime_type, + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{'message': {'role': 'assistant', 'content': '{}'}}] + } + mock_response.status_code = 200 + + with mock.patch.object( + httpx.AsyncClient, 'post', return_value=mock_response + ) as mock_post: + _ = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + mock_post.assert_called_once() + payload = mock_post.call_args[1]['json'] + + assert payload['response_format'] == expected_response_format + + +@pytest.mark.asyncio +async def test_generate_content_async_invalid_tool_call_type_raises_error( + client, llm_request +): + # Mock response with invalid tool call type + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'tool_calls': [{ + 'id': 'call_123', + # Invalid type + 'type': 'custom', + 'custom': { + 'name': 'read_string', + 'input': 'Hi! The this is a custom tool call!', + }, + }], + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + with pytest.raises(ValueError, match='Unsupported tool_call type: custom'): + _ = [ + r + async for r in client.generate_content_async( + llm_request, stream=False + ) + ] + + +@pytest.mark.asyncio +async def test_generate_content_async_function_call_response( + client, llm_request +): + # Mock response with deprecated function call + mock_response = AsyncMock(spec=httpx.Response) + mock_response.json.return_value = { + 'choices': [{ + 'message': { + 'role': 'assistant', + 'content': None, + 'function_call': { + 'name': 'get_weather', + 'arguments': '{"location": "London"}', + }, + } + }] + } + mock_response.status_code = 200 + + with mock.patch.object(httpx.AsyncClient, 'post', return_value=mock_response): + responses = [ + r + async for r in client.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + part = responses[0].content.parts[0] + assert part.function_call + assert part.function_call.name == 'get_weather' + assert part.function_call.args == {'location': 'London'} + assert part.function_call.id is None + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_function_call(): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + # Mock chunks simulating split arguments + chunk_data_0 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'id': 'call_123', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': ''}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_1 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"location": "London"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_2 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"country": "UK"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_3 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}], + 'usage': { + 'prompt_tokens': 10, + 'completion_tokens': 20, + 'total_tokens': 30, + }, + } + + chunks = [ + f'{json.dumps(chunk_data_0)}\n', + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + # Check that we get 5 responses (one per chunk + extra final accumulated) + assert len(responses) == 5 + + # Check 1st response: partial tool call, empty args + assert responses[0].partial is True + assert responses[0].content.parts[0].function_call.name == 'get_weather' + assert responses[0].content.parts[0].function_call.id == 'call_123' + + # Check 2nd response: full args for first update + assert responses[1].partial is True + assert responses[1].content.parts[0].function_call.args == { + 'location': 'London' + } + + # Check 3rd response: full args for second update (merged) + assert responses[2].partial is True + assert responses[2].content.parts[0].function_call.args == {'country': 'UK'} + + # Check 4th response: last delta (empty) + assert responses[3].partial is True + assert responses[3].content.parts == [] + + # Check 5th response: final accumulated + assert responses[4].finish_reason == types.FinishReason.STOP + # Full accumulated args + assert responses[4].content.parts[0].function_call.args == { + 'location': 'London', + 'country': 'UK', + } + + # Check metadata and usage + assert responses[4].model_version == 'gpt-3.5-turbo' + assert responses[4].custom_metadata['id'] == 'chatcmpl-123' + assert responses[4].custom_metadata['created'], 1234567890 + assert responses[4].custom_metadata['object'], 'chat.completion.chunk' + assert responses[4].custom_metadata['service_tier'], 'default' + assert responses[4].usage_metadata is not None + assert responses[4].usage_metadata.prompt_token_count == 10 + assert responses[4].usage_metadata.candidates_token_count == 20 + assert responses[4].usage_metadata.total_token_count == 30 + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_multiple_function_calls(): + # Mock streaming response with multiple tool calls + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + chunk_data_1 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + { + 'index': 0, + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'func_1', 'arguments': ''}, + }, + { + 'index': 1, + 'id': 'call_2', + 'type': 'function', + 'function': {'name': 'func_2', 'arguments': ''}, + }, + ] + }, + 'finish_reason': None, + }] + } + # the tool_call type is optional in chunk responses. + chunk_data_2 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + {'index': 0, 'function': {'arguments': '{"arg": 1}'}}, + {'index': 1, 'function': {'arguments': '{"arg": 2}'}}, + ] + }, + 'finish_reason': None, + }] + } + chunk_data_3 = { + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}] + } + + chunks = [ + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + + assert len(responses) == 4 + parts = responses[-1].content.parts + assert len(parts) == 2 + + assert parts[0].function_call.name == 'func_1' + assert parts[0].function_call.args == {'arg': 1} + assert parts[0].function_call.id == 'call_1' + + assert parts[1].function_call.name == 'func_2' + assert parts[1].function_call.args == {'arg': 2} + + assert parts[1].function_call.id == 'call_2' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('chunks', 'expected_response_count'), + [ + ( + [ + '\n', + ' \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ], + 1, + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + '[DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ' [DONE] \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + 'data: [DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ], +) +async def test_generate_content_async_streaming_parse_lines( + chunks, expected_response_count +): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + assert len(responses) == expected_response_count + assert responses[0].content.parts[0].text == 'Hello' diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 3e3ecce0..2bd5f7d2 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -26,6 +26,8 @@ import warnings from google.adk.models.lite_llm import _append_fallback_user_content_if_missing from google.adk.models.lite_llm import _content_to_message_param +from google.adk.models.lite_llm import _enforce_strict_openai_schema +from google.adk.models.lite_llm import _extract_reasoning_value from google.adk.models.lite_llm import _FILE_ID_REQUIRED_PROVIDERS from google.adk.models.lite_llm import _FINISH_REASON_MAPPING from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -45,6 +47,7 @@ from google.adk.models.lite_llm import _to_litellm_role from google.adk.models.lite_llm import FunctionChunk from google.adk.models.lite_llm import LiteLlm from google.adk.models.lite_llm import LiteLLMClient +from google.adk.models.lite_llm import ReasoningChunk from google.adk.models.lite_llm import TextChunk from google.adk.models.lite_llm import UsageMetadataChunk from google.adk.models.llm_request import LlmRequest @@ -57,6 +60,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall from litellm.types.utils import Choices from litellm.types.utils import Delta from litellm.types.utils import ModelResponse +from litellm.types.utils import ModelResponseStream from litellm.types.utils import StreamingChoices from pydantic import BaseModel from pydantic import Field @@ -129,7 +133,7 @@ FILE_BYTES_TEST_CASES = [ ] STREAMING_MODEL_RESPONSE = [ - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -141,7 +145,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -153,7 +157,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -165,7 +169,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -187,7 +191,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -209,7 +213,7 @@ STREAMING_MODEL_RESPONSE = [ ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -392,6 +396,145 @@ def test_to_litellm_response_format_with_dict_schema_for_openai(): assert formatted["json_schema"]["schema"]["additionalProperties"] is False +class _InnerModel(BaseModel): + value: str = Field(description="A value") + optional_field: str | None = Field(default=None, description="Optional") + + +class _OuterModel(BaseModel): + inner: _InnerModel = Field(description="Nested model") + name: str + + +class _WithList(BaseModel): + items: list[_InnerModel] = Field(description="List of items") + label: str + + +def test_enforce_strict_openai_schema_adds_additional_properties_recursively(): + """additionalProperties: false must appear on all object schemas.""" + schema = _OuterModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + # Root level + assert schema["additionalProperties"] is False + # Nested model in $defs + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + + +def test_enforce_strict_openai_schema_marks_all_properties_required(): + """All properties must appear in 'required', including optional fields.""" + schema = _InnerModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + assert sorted(schema["required"]) == ["optional_field", "value"] + + +def test_enforce_strict_openai_schema_strips_ref_sibling_keywords(): + """$ref nodes must have no sibling keywords like 'description'.""" + schema = _OuterModel.model_json_schema() + # Pydantic v2 generates {"$ref": "...", "description": "..."} for nested models + inner_prop = schema["properties"]["inner"] + assert "$ref" in inner_prop, "Expected Pydantic to generate a $ref property" + assert len(inner_prop) > 1, "Expected sibling keywords alongside $ref" + + _enforce_strict_openai_schema(schema) + + inner_prop = schema["properties"]["inner"] + assert list(inner_prop.keys()) == ["$ref"] + + +def test_enforce_strict_openai_schema_handles_array_items(): + """Array item schemas should also be recursively transformed.""" + schema = _WithList.model_json_schema() + + _enforce_strict_openai_schema(schema) + + assert schema["additionalProperties"] is False + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_enforce_strict_openai_schema_preserves_anyof_and_default(): + """anyOf structure and default value for Optional fields must be preserved.""" + schema = _InnerModel.model_json_schema() + + _enforce_strict_openai_schema(schema) + + opt_prop = schema["properties"]["optional_field"] + assert opt_prop["anyOf"] == [{"type": "string"}, {"type": "null"}] + assert opt_prop["default"] is None + + +def test_to_litellm_response_format_dict_input_not_mutated(): + """Passing a raw dict should not mutate the caller's original dict.""" + schema = { + "type": "object", + "properties": { + "nested": { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + }, + } + import copy + + original = copy.deepcopy(schema) + + _to_litellm_response_format(schema, model="gpt-4o") + + assert schema == original, "Caller's input dict was mutated" + + +def test_to_litellm_response_format_instance_input_for_openai(): + """Passing a BaseModel instance should produce a valid strict schema.""" + instance = _OuterModel( + inner=_InnerModel(value="test", optional_field=None), name="foo" + ) + + formatted = _to_litellm_response_format(instance, model="gpt-4o") + + assert formatted["type"] == "json_schema" + schema = formatted["json_schema"]["schema"] + assert schema["additionalProperties"] is False + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_to_litellm_response_format_nested_pydantic_for_openai(): + """Nested Pydantic model should produce a valid OpenAI strict schema.""" + formatted = _to_litellm_response_format(_OuterModel, model="gpt-4o") + + assert formatted["type"] == "json_schema" + assert formatted["json_schema"]["strict"] is True + + schema = formatted["json_schema"]["schema"] + assert schema["additionalProperties"] is False + assert sorted(schema["required"]) == ["inner", "name"] + + # $defs inner model must also be strict + inner_def = schema["$defs"]["_InnerModel"] + assert inner_def["additionalProperties"] is False + assert sorted(inner_def["required"]) == ["optional_field", "value"] + + +def test_to_litellm_response_format_nested_pydantic_for_gemini_unchanged(): + """Gemini models should NOT get the strict OpenAI transformations.""" + formatted = _to_litellm_response_format( + _OuterModel, model="gemini/gemini-2.0-flash" + ) + + assert formatted["type"] == "json_object" + schema = formatted["response_schema"] + # Gemini path should pass through the raw Pydantic schema untouched + assert schema == _OuterModel.model_json_schema() + + async def test_get_completion_inputs_uses_openai_format_for_openai_model(): """Test that _get_completion_inputs produces OpenAI-compatible format.""" llm_request = LlmRequest( @@ -532,7 +675,7 @@ def test_schema_to_dict_filters_none_enum_values(): MULTIPLE_FUNCTION_CALLS_STREAM = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -553,7 +696,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -574,7 +717,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -595,7 +738,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -616,7 +759,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason="tool_calls", @@ -627,7 +770,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [ STREAM_WITH_EMPTY_CHUNK = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -648,7 +791,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -670,7 +813,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ] ), # This is the problematic empty chunk that should be ignored. - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -691,7 +834,7 @@ STREAM_WITH_EMPTY_CHUNK = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[StreamingChoices(finish_reason="tool_calls", delta=Delta())] ), ] @@ -727,7 +870,7 @@ def mock_response(): # indices all 0 # finish_reason stop instead of tool_calls NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -748,7 +891,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -769,7 +912,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -790,7 +933,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -811,7 +954,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ ) ] ), - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason="stop", @@ -2143,6 +2286,139 @@ def test_model_response_to_generate_content_response_reasoning_content(): assert response.content.parts[1].text == "Answer" +def test_message_to_generate_content_response_reasoning_field(): + """Test that the 'reasoning' field is supported (LM Studio, vLLM).""" + message = { + "role": "assistant", + "content": "Final answer", + "reasoning": "Thinking process", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + text_part = response.content.parts[1] + assert thought_part.text == "Thinking process" + assert thought_part.thought is True + assert text_part.text == "Final answer" + + +def test_model_response_to_generate_content_response_reasoning_field(): + """Test that 'reasoning' field is supported in ModelResponse.""" + model_response = ModelResponse( + model="test-model", + choices=[{ + "message": { + "role": "assistant", + "content": "Result", + "reasoning": "Chain of thought", + }, + "finish_reason": "stop", + }], + ) + + response = _model_response_to_generate_content_response(model_response) + + assert response.content.parts[0].text == "Chain of thought" + assert response.content.parts[0].thought is True + assert response.content.parts[1].text == "Result" + + +def test_reasoning_content_takes_precedence_over_reasoning(): + """Test that 'reasoning_content' is prioritized over 'reasoning'.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "LiteLLM standard reasoning", + "reasoning": "Alternative reasoning", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + assert thought_part.text == "LiteLLM standard reasoning" + assert thought_part.thought is True + + +def test_extract_reasoning_value_from_reasoning_content(): + """Test extraction from reasoning_content (LiteLLM standard).""" + message = ChatCompletionAssistantMessage( + role="assistant", + content="Answer", + reasoning_content="LiteLLM reasoning", + ) + result = _extract_reasoning_value(message) + assert result == "LiteLLM reasoning" + + +def test_extract_reasoning_value_from_reasoning(): + """Test extraction from reasoning (LM Studio, vLLM).""" + + class MockMessage: + + def __init__(self): + self.role = "assistant" + self.content = "Answer" + self.reasoning = "Alternative reasoning" + + def get(self, key, default=None): + return getattr(self, key, default) + + message = MockMessage() + result = _extract_reasoning_value(message) + assert result == "Alternative reasoning" + + +def test_extract_reasoning_value_dict_reasoning_content(): + """Test extraction from dict with reasoning_content field.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Dict reasoning content", + } + result = _extract_reasoning_value(message) + assert result == "Dict reasoning content" + + +def test_extract_reasoning_value_dict_reasoning(): + """Test extraction from dict with reasoning field.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning": "Dict reasoning", + } + result = _extract_reasoning_value(message) + assert result == "Dict reasoning" + + +def test_extract_reasoning_value_dict_prefers_reasoning_content(): + """Test that reasoning_content takes precedence over reasoning in dicts.""" + message = { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Primary", + "reasoning": "Secondary", + } + result = _extract_reasoning_value(message) + assert result == "Primary" + + +def test_extract_reasoning_value_none_message(): + """Test that None message returns None.""" + result = _extract_reasoning_value(None) + assert result is None + + +def test_extract_reasoning_value_no_reasoning_fields(): + """Test that None is returned when no reasoning fields exist.""" + message = { + "role": "assistant", + "content": "Answer only", + } + result = _extract_reasoning_value(message) + assert result is None + + def test_parse_tool_calls_from_text_multiple_calls(): text = ( '{"name":"alpha","arguments":{"value":1}}\n' @@ -2677,7 +2953,12 @@ def test_to_litellm_role(): "content": "this is a test", } } - ] + ], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, ), [TextChunk(text="this is a test")], UsageMetadataChunk( @@ -2707,7 +2988,7 @@ def test_to_litellm_role(): "stop", ), ( - ModelResponse( + ModelResponseStream( choices=[ StreamingChoices( finish_reason=None, @@ -2729,13 +3010,20 @@ def test_to_litellm_role(): ] ), [FunctionChunk(id="1", name="test_function", args='{"key": "va')], - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 - ), None, + # LiteLLM 1.81+ defaults finish_reason to "stop" for partial chunks, + # older versions return None. Both are valid for streaming chunks. + (None, "stop"), ), ( - ModelResponse(choices=[{"finish_reason": "tool_calls"}]), + ModelResponse( + choices=[{"finish_reason": "tool_calls"}], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + ), [None], UsageMetadataChunk( prompt_tokens=0, completion_tokens=0, total_tokens=0 @@ -2743,7 +3031,14 @@ def test_to_litellm_role(): "tool_calls", ), ( - ModelResponse(choices=[{}]), + ModelResponse( + choices=[{}], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + ), [None], UsageMetadataChunk( prompt_tokens=0, completion_tokens=0, total_tokens=0 @@ -2813,6 +3108,40 @@ def test_to_litellm_role(): ), "tool_calls", ), + ( + ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta(role="assistant", content="Hello"), + ) + ], + usage=None, + ), + [TextChunk(text="Hello")], + None, + (None, "stop"), + ), + ( + ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason="stop", + delta=Delta( + role="assistant", reasoning_content="thinking..." + ), + ) + ], + usage=None, + ), + [ + ReasoningChunk( + parts=[types.Part(text="thinking...", thought=True)] + ) + ], + None, + "stop", + ), ], ) def test_model_response_to_chunk( @@ -2836,7 +3165,10 @@ def test_model_response_to_chunk( else: assert isinstance(chunk, type(expected_chunk)) assert chunk == expected_chunk - assert finished == expected_finished + if isinstance(expected_finished, tuple): + assert finished in expected_finished + else: + assert finished == expected_finished if expected_usage_chunk is None: assert usage_chunk is None @@ -2845,6 +3177,38 @@ def test_model_response_to_chunk( assert usage_chunk == expected_usage_chunk +def test_model_response_to_chunk_does_not_mutate_delta_object(): + """Verify that _model_response_to_chunk doesn't mutate the Delta object. + + In real streaming responses, LiteLLM's StreamingChoices only has 'delta' + (message is explicitly popped in StreamingChoices constructor). The delta + object itself carries reasoning_content when present. + """ + delta = Delta( + role="assistant", content="Hello", reasoning_content="thinking..." + ) + response = ModelResponseStream( + choices=[StreamingChoices(delta=delta, finish_reason=None)] + ) + + chunks = [chunk for chunk, _ in _model_response_to_chunk(response) if chunk] + + assert ( + ReasoningChunk(parts=[types.Part(text="thinking...", thought=True)]) + in chunks + ) + assert TextChunk(text="Hello") in chunks + + # Verify we don't accidentally mutate the original delta object. + assert delta.content == "Hello" + assert delta.reasoning_content == "thinking..." + + +def test_model_response_to_chunk_rejects_dict_response(): + with pytest.raises(TypeError): + list(_model_response_to_chunk({"choices": []})) + + @pytest.mark.asyncio async def test_acompletion_additional_args(mock_acompletion, mock_client): lite_llm_instance = LiteLlm( @@ -3056,7 +3420,7 @@ async def test_generate_content_async_stream_sets_finish_reason( mock_completion, lite_llm_instance ): mock_completion.return_value = iter([ - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -3065,7 +3429,7 @@ async def test_generate_content_async_stream_sets_finish_reason( ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[ StreamingChoices( @@ -3074,7 +3438,7 @@ async def test_generate_content_async_stream_sets_finish_reason( ) ], ), - ModelResponse( + ModelResponseStream( model="test_model", choices=[StreamingChoices(finish_reason="stop", delta=Delta())], ), @@ -3107,7 +3471,7 @@ async def test_generate_content_async_stream_with_usage_metadata( streaming_model_response_with_usage_metadata = [ *STREAMING_MODEL_RESPONSE, - ModelResponse( + ModelResponseStream( usage={ "prompt_tokens": 10, "completion_tokens": 5, @@ -3176,7 +3540,7 @@ async def test_generate_content_async_stream_with_usage_metadata( """Tests that cached prompt tokens are propagated in streaming mode.""" streaming_model_response_with_usage_metadata = [ *STREAMING_MODEL_RESPONSE, - ModelResponse( + ModelResponseStream( usage={ "prompt_tokens": 10, "completion_tokens": 5, @@ -3657,7 +4021,7 @@ async def test_finish_reason_propagation( async def test_finish_reason_unknown_maps_to_other( mock_acompletion, lite_llm_instance ): - """Test that unknown finish_reason values map to FinishReason.OTHER.""" + """Test that unmapped finish_reason values map to FinishReason.OTHER.""" mock_response = ModelResponse( choices=[ Choices( @@ -3665,7 +4029,9 @@ async def test_finish_reason_unknown_maps_to_other( role="assistant", content="Test response", ), - finish_reason="unknown_reason_type", + # LiteLLM validates finish_reason to a known set. Use a value that + # LiteLLM accepts but ADK does not explicitly map. + finish_reason="eos", ) ] ) diff --git a/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py b/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py new file mode 100644 index 00000000..bd5da524 --- /dev/null +++ b/tests/unittests/optimization/gepa_root_agent_prompt_optimizer_test.py @@ -0,0 +1,264 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import sys + +from google.adk.agents.llm_agent import Agent +from google.adk.optimization.data_types import UnstructuredSamplingResult +from google.adk.optimization.gepa_root_agent_prompt_optimizer import _create_agent_gepa_adapter_class +from google.adk.optimization.gepa_root_agent_prompt_optimizer import GEPARootAgentPromptOptimizer +from google.adk.optimization.gepa_root_agent_prompt_optimizer import GEPARootAgentPromptOptimizerConfig +from google.adk.optimization.sampler import Sampler +import pytest + + +class MockEvaluationBatch: + + def __init__(self, outputs, scores, trajectories): + self.outputs = outputs + self.scores = scores + self.trajectories = trajectories + + +class MockGEPAAdapter: + """Mock that supports generic type hints.""" + + def __class_getitem__(cls, item): + return cls + + +@pytest.fixture(name="mock_gepa") +def fixture_mock_gepa(mocker): + # mock gepa before it gets imported by the optimizer module + mock_gepa_module = mocker.MagicMock() + mock_gepa_adapter = mocker.MagicMock() + + mock_gepa_adapter.EvaluationBatch = MockEvaluationBatch + mock_gepa_adapter.GEPAAdapter = MockGEPAAdapter + + mock_gepa_module.core = mocker.MagicMock() + mock_gepa_module.core.adapter = mock_gepa_adapter + + mocker.patch.dict( + sys.modules, + { + "gepa": mock_gepa_module, + "gepa.core": mock_gepa_module.core, + "gepa.core.adapter": mock_gepa_adapter, + }, + ) + return mock_gepa_module + + +@pytest.fixture +def mock_sampler(mocker): + sampler = mocker.MagicMock(spec=Sampler) + sampler.get_train_example_ids.return_value = ["train1", "train2"] + sampler.get_validation_example_ids.return_value = ["val1", "val2"] + return sampler + + +@pytest.fixture +def mock_agent(mocker): + agent = mocker.MagicMock(spec=Agent) + agent.instruction = "Initial instruction" + agent.sub_agents = {} + agent.clone.return_value = agent + return agent + + +def test_adapter_init(mock_gepa, mock_sampler, mock_agent): + del mock_gepa # only needed to mock gepa in background + loop = asyncio.new_event_loop() + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + assert adapter._initial_agent == mock_agent + assert adapter._sampler == mock_sampler + assert adapter._main_loop == loop + assert adapter._train_example_ids == {"train1", "train2"} + assert adapter._validation_example_ids == {"val1", "val2"} + loop.close() + + +def test_adapter_evaluate_train(mocker, mock_gepa, mock_sampler, mock_agent): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "New prompt"} + batch = ["train1"] + + # mock the future returned by run_coroutine_threadsafe + mock_future = mocker.MagicMock() + expected_result = UnstructuredSamplingResult( + scores={"train1": 0.8}, + data={"train1": {"output": "result"}}, + ) + mock_future.result.return_value = expected_result + + mock_rct = mocker.patch( + "asyncio.run_coroutine_threadsafe", return_value=mock_future + ) + eval_batch = adapter.evaluate(batch, candidate, capture_traces=True) + + mock_rct.assert_called_once() + mock_sampler.sample_and_score.assert_called_once_with( + mocker.ANY, + example_set="train", + batch=batch, + capture_full_eval_data=True, + ) + + mock_agent.clone.assert_called_once_with(update={"instruction": "New prompt"}) + + assert isinstance(eval_batch, MockEvaluationBatch) + assert eval_batch.scores == [0.8] + assert eval_batch.outputs == [{"output": "result"}] + assert eval_batch.trajectories == [{"output": "result"}] + + +def test_adapter_evaluate_validation( + mocker, mock_gepa, mock_sampler, mock_agent +): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "New prompt"} + batch = ["val1"] + + mock_future = mocker.MagicMock() + expected_result = UnstructuredSamplingResult(scores={"val1": 0.5}, data={}) + mock_future.result.return_value = expected_result + + mocker.patch("asyncio.run_coroutine_threadsafe", return_value=mock_future) + adapter.evaluate(batch, candidate) + + mock_sampler.sample_and_score.assert_called_once_with( + mocker.ANY, + example_set="validation", + batch=batch, + capture_full_eval_data=False, + ) + + +def test_adapter_make_reflective_dataset( + mocker, mock_gepa, mock_sampler, mock_agent +): + del mock_gepa # only needed to mock gepa in background + loop = mocker.MagicMock(spec=asyncio.AbstractEventLoop) + _AdapterClass = _create_agent_gepa_adapter_class() + adapter = _AdapterClass(mock_agent, mock_sampler, loop) + + candidate = {"agent_prompt": "Prompt"} + eval_batch = MockEvaluationBatch( + outputs=[{"o": 1}, {"o": 2}], + scores=[0.9, 0.1], + trajectories=[{"t": 1}, {"t": 2}], + ) + components = ["component1"] + + dataset = adapter.make_reflective_dataset(candidate, eval_batch, components) + + assert "component1" in dataset + assert len(dataset["component1"]) == 2 + assert dataset["component1"][0] == { + "agent_prompt": "Prompt", + "score": 0.9, + "eval_data": {"t": 1}, + } + assert dataset["component1"][1] == { + "agent_prompt": "Prompt", + "score": 0.1, + "eval_data": {"t": 2}, + } + + +@pytest.mark.asyncio +async def test_optimize(mocker, mock_gepa, mock_sampler, mock_agent): + config = GEPARootAgentPromptOptimizerConfig() + optimizer = GEPARootAgentPromptOptimizer(config) + + # mock LLM + mock_llm_class = mocker.MagicMock() + mock_llm = mocker.MagicMock() + mock_llm_class.return_value = mock_llm + optimizer._llm_class = mock_llm_class + + # mock gepa.optimize return value + mock_gepa_result = mocker.MagicMock() + mock_gepa_result.candidates = [{"agent_prompt": "Optimized instruction"}] + mock_gepa_result.val_aggregate_scores = [0.95] + mock_gepa_result.to_dict.return_value = {"full": "result"} + mock_gepa.optimize.return_value = mock_gepa_result + + result = await optimizer.optimize(mock_agent, mock_sampler) + + mock_gepa.optimize.assert_called_once() + call_kwargs = mock_gepa.optimize.call_args[1] + + assert call_kwargs["seed_candidate"] == { + "agent_prompt": "Initial instruction" + } + assert call_kwargs["trainset"] == ["train1", "train2"] + assert call_kwargs["valset"] == ["val1", "val2"] + + assert len(result.optimized_agents) == 1 + assert result.optimized_agents[0].overall_score == 0.95 + mock_agent.clone.assert_called_with( + update={"instruction": "Optimized instruction"} + ) + assert result.gepa_result == {"full": "result"} + + +@pytest.mark.asyncio +async def test_optimize_logs_warning_on_overlapping_ids( + mocker, mock_gepa, mock_sampler, mock_agent +): + # Setup overlapping IDs + mock_sampler.get_train_example_ids.return_value = ["id1", "id2"] + mock_sampler.get_validation_example_ids.return_value = ["id2", "id3"] + + config = GEPARootAgentPromptOptimizerConfig() + optimizer = GEPARootAgentPromptOptimizer(config) + + # Mock LLM class + mock_llm_class = mocker.MagicMock() + optimizer._llm_class = mock_llm_class + + # Mock gepa.optimize return value + mock_gepa_result = mocker.MagicMock() + mock_gepa_result.candidates = [] + mock_gepa_result.val_aggregate_scores = [] + mock_gepa_result.to_dict.return_value = {} + mock_gepa.optimize.return_value = mock_gepa_result + + mock_logger = mocker.patch( + "google.adk.optimization.gepa_root_agent_prompt_optimizer._logger" + ) + + # Run optimization + await optimizer.optimize(mock_agent, mock_sampler) + + # Verify warning + mock_logger.warning.assert_called_with( + "The training and validation example UIDs overlap. This WILL cause" + " aliasing issues unless each common UID refers to the same example" + " in both sets." + ) diff --git a/tests/unittests/optimization/local_eval_sampler_test.py b/tests/unittests/optimization/local_eval_sampler_test.py new file mode 100644 index 00000000..6ebd99cb --- /dev/null +++ b/tests/unittests/optimization/local_eval_sampler_test.py @@ -0,0 +1,383 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.agents.llm_agent import Agent +from google.adk.evaluation.base_eval_service import EvaluateConfig +from google.adk.evaluation.base_eval_service import EvaluateRequest +from google.adk.evaluation.base_eval_service import InferenceConfig +from google.adk.evaluation.base_eval_service import InferenceRequest +from google.adk.evaluation.base_eval_service import InferenceResult +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_case import InvocationEvent +from google.adk.evaluation.eval_case import InvocationEvents +from google.adk.evaluation.eval_config import EvalConfig +from google.adk.evaluation.eval_config import EvalMetric +from google.adk.evaluation.eval_metrics import EvalMetricResult +from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation +from google.adk.evaluation.eval_metrics import EvalStatus +from google.adk.evaluation.eval_result import EvalCaseResult +from google.adk.evaluation.eval_sets_manager import EvalSetsManager +from google.adk.optimization.local_eval_sampler import _log_eval_summary +from google.adk.optimization.local_eval_sampler import extract_single_invocation_info +from google.adk.optimization.local_eval_sampler import extract_tool_call_data +from google.adk.optimization.local_eval_sampler import LocalEvalSampler +from google.adk.optimization.local_eval_sampler import LocalEvalSamplerConfig +from google.genai import types +import pytest + + +def test_log_eval_summary(mocker): + statuses = ( + [EvalStatus.PASSED] * 3 + + [EvalStatus.FAILED] * 2 + + [EvalStatus.NOT_EVALUATED] + ) + expected_log = "Evaluation summary: 3 PASSED, 2 FAILED, 1 OTHER" + + eval_results = [ + mocker.MagicMock(spec=EvalCaseResult, final_eval_status=status) + for status in statuses + ] + mock_logger = mocker.patch( + "google.adk.optimization.local_eval_sampler.logger" + ) + + _log_eval_summary(eval_results) + + mock_logger.info.assert_called_once_with(expected_log) + + +def test_extract_tool_call_data(): + # omitting IntermediateData tests as it is no longer used + # case 1: empty invocation events + assert not extract_tool_call_data(InvocationEvents()) + # case 2: multi call invocation events + multi_call_invocation_events = InvocationEvents( + invocation_events=[ + InvocationEvent( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + id="call_1", + name="tool_1", + args={"a": 1}, + ) + ), + types.Part( + function_call=types.FunctionCall( + id="call_2", + name="tool_2", + args={"b": 2}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_1", + name="tool_1", + response={"result_1": "done"}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_2", + name="tool_2", + response={"result_2": "done"}, + ) + ), + ] + ), + ) + ] + ) + expected_entries = [ + { + "name": "tool_1", + "args": {"a": 1}, + "response": {"result_1": "done"}, + }, + { + "name": "tool_2", + "args": {"b": 2}, + "response": {"result_2": "done"}, + }, + ] + result = extract_tool_call_data(multi_call_invocation_events) + # order is not guaranteed + for expected_entry in expected_entries: + assert expected_entry in result + assert len(result) == len(expected_entries) + + +def test_extract_single_invocation_info(): + invocation = Invocation( + user_content=types.Content( + parts=[ + types.Part(text="user thought", thought=True), + types.Part(text="Hello agent!"), + ] + ), + final_response=types.Content( + parts=[ + types.Part(text="agent thought", thought=True), + types.Part(text="Hello user!"), + ] + ), + ) + + result = extract_single_invocation_info(invocation) + + assert result == { + "user_prompt": "Hello agent!", + "agent_response": "Hello user!", + } + + +@pytest.mark.parametrize( + "config_kwargs, expected_attrs", + [ + ( + {"train_eval_set": "train_set"}, + { + "_train_eval_set": "train_set", + "_train_eval_case_ids": ["train_set_1", "train_set_2"], + "_validation_eval_set": "train_set", + "_validation_eval_case_ids": ["train_set_1", "train_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "train_eval_case_ids": ["t1"]}, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_case_ids": ["t1"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_set": "val_set"}, + { + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["val_set_1", "val_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_case_ids": ["v1"]}, + { + "_validation_eval_case_ids": ["v1"], + }, + ), + ( + { + "train_eval_set": "train_set", + "train_eval_case_ids": ["t1"], + "validation_eval_set": "val_set", + "validation_eval_case_ids": ["v1"], + }, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["v1"], + }, + ), + ], +) +def test_local_eval_service_interface_init( + mocker, config_kwargs, expected_attrs +): + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + + def mock_get_eval_case_ids(self, eval_set_id): + return [f"{eval_set_id}_1", f"{eval_set_id}_2"] + + mocker.patch.object( + LocalEvalSampler, + "_get_eval_case_ids", + autospec=True, + side_effect=mock_get_eval_case_ids, + ) + + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), app_name="test_app", **config_kwargs + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + for attr, expected_value in expected_attrs.items(): + assert getattr(interface, attr) == expected_value + + +@pytest.mark.asyncio +async def test_evaluate_agent(mocker): + # Mocking LocalEvalService and its methods + mock_eval_service_cls = mocker.patch( + "google.adk.optimization.local_eval_sampler.LocalEvalService" + ) + mock_eval_service = mock_eval_service_cls.return_value + + # mocking inference + mock_inference_result = mocker.MagicMock(spec=InferenceResult) + + async def mock_perform_inference(*args, **kwargs): + yield mock_inference_result + + mock_eval_service.perform_inference.side_effect = mock_perform_inference + + # mocking evaluate + mock_eval_case_result = mocker.MagicMock(spec=EvalCaseResult) + + async def mock_evaluate(*args, **kwargs): + yield mock_eval_case_result + + mock_eval_service.evaluate.side_effect = mock_evaluate + + # mocking get_eval_metrics_from_config + mock_metrics = [EvalMetric(metric_name="test_metric")] + mocker.patch( + "google.adk.optimization.local_eval_sampler.get_eval_metrics_from_config", + return_value=mock_metrics, + ) + + mocker.patch("google.adk.evaluation.base_eval_service.EvaluateConfig") + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Call _evaluate_agent + results = await interface._evaluate_agent( + mocker.MagicMock(spec=Agent), "train_set", ["t1"] + ) + + # Assertions + mock_eval_service.perform_inference.assert_called_once_with( + inference_request=InferenceRequest( + app_name="test_app", + eval_set_id="train_set", + eval_case_ids=["t1"], + inference_config=InferenceConfig(), + ) + ) + mock_eval_service.evaluate.assert_called_once_with( + evaluate_request=EvaluateRequest( + inference_results=[mock_inference_result], + evaluate_config=EvaluateConfig(eval_metrics=mock_metrics), + ) + ) + assert results == [mock_eval_case_result] + + +@pytest.mark.asyncio +async def test_extract_eval_data(mocker): + # Mock components + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + mock_eval_case = mocker.MagicMock() + mock_eval_case.conversation_scenario = "test_scenario" + mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case + + # Mock per invocation result + mock_actual_invocation = mocker.MagicMock(spec=Invocation) + mock_expected_invocation = mocker.MagicMock(spec=Invocation) + mock_metric_result = mocker.MagicMock(spec=EvalMetricResult) + mock_metric_result.metric_name = "test_metric" + mock_metric_result.score = 0.854 # should be rounded to 0.85 + mock_metric_result.eval_status = EvalStatus.PASSED + + mock_per_inv_result = mocker.MagicMock(spec=EvalMetricResultPerInvocation) + mock_per_inv_result.actual_invocation = mock_actual_invocation + mock_per_inv_result.expected_invocation = mock_expected_invocation + mock_per_inv_result.eval_metric_results = [mock_metric_result] + + mock_eval_result = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result.eval_id = "t1" + mock_eval_result.eval_metric_result_per_invocation = [mock_per_inv_result] + + # Mock extract_single_invocation_info + mocker.patch( + "google.adk.optimization.local_eval_sampler.extract_single_invocation_info", + side_effect=[{"info": "actual"}, {"info": "expected"}], + ) + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + # Call _extract_eval_data + eval_data = interface._extract_eval_data("train_set", [mock_eval_result]) + + # Assertions + assert "t1" in eval_data + assert eval_data["t1"]["conversation_scenario"] == "test_scenario" + assert len(eval_data["t1"]["invocations"]) == 1 + inv = eval_data["t1"]["invocations"][0] + assert inv["actual_invocation"] == {"info": "actual"} + assert inv["expected_invocation"] == {"info": "expected"} + assert inv["eval_metric_results"] == [ + {"metric_name": "test_metric", "score": 0.85, "eval_status": "PASSED"} + ] + + +@pytest.mark.asyncio +async def test_sample_and_score(mocker): + # Mock results + mock_eval_result_1 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_1.eval_id = "t1" + mock_eval_result_1.final_eval_status = EvalStatus.PASSED + + mock_eval_result_2 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_2.eval_id = "t2" + mock_eval_result_2.final_eval_status = EvalStatus.FAILED + + eval_results = [mock_eval_result_1, mock_eval_result_2] + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1", "t2"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Patch internal methods + mocker.patch.object(interface, "_evaluate_agent", return_value=eval_results) + mock_log_summary = mocker.patch( + "google.adk.optimization.local_eval_sampler._log_eval_summary" + ) + mock_extract_data = mocker.patch.object( + interface, "_extract_eval_data", return_value={"t1": {}, "t2": {}} + ) + + # Call sample_and_score + result = await interface.sample_and_score( + mocker.MagicMock(spec=Agent), + example_set="train", + capture_full_eval_data=True, + ) + + # Assertions + assert result.scores == {"t1": 1.0, "t2": 0.0} + assert result.data == {"t1": {}, "t2": {}} + mock_log_summary.assert_called_once_with(eval_results) + mock_extract_data.assert_called_once_with("train_set", eval_results) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index e9f617c4..a39eb932 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -17,11 +17,12 @@ import asyncio import contextlib import dataclasses import json +import os from unittest import mock from google.adk.agents import base_agent -from google.adk.agents import callback_context as callback_context_lib -from google.adk.agents import invocation_context as invocation_context_lib +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.events import event as event_lib from google.adk.events import event_actions as event_actions_lib from google.adk.models import llm_request as llm_request_lib @@ -83,7 +84,7 @@ def invocation_context(mock_agent, mock_session): mock_plugin_manager = mock.create_autospec( plugin_manager_lib.PluginManager, instance=True, spec_set=True ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_agent, session=mock_session, invocation_id="inv-789", @@ -94,9 +95,7 @@ def invocation_context(mock_agent, mock_session): @pytest.fixture def callback_context(invocation_context): - return callback_context_lib.CallbackContext( - invocation_context=invocation_context - ) + return CallbackContext(invocation_context=invocation_context) @pytest.fixture @@ -1736,6 +1735,7 @@ class TestBigQueryAgentAnalyticsPlugin: _assert_common_fields(log_entry, "LLM_ERROR") assert log_entry["content"] is None assert log_entry["error_message"] == "LLM failed" + assert log_entry["status"] == "ERROR" @pytest.mark.asyncio async def test_on_tool_error_callback_logs_correctly( @@ -1763,6 +1763,7 @@ class TestBigQueryAgentAnalyticsPlugin: assert content_dict["tool"] == "MyTool" assert content_dict["args"] == {"param": "value"} assert log_entry["error_message"] == "Tool timed out" + assert log_entry["status"] == "ERROR" @pytest.mark.asyncio async def test_table_creation_options( @@ -2152,7 +2153,7 @@ class TestBigQueryAgentAnalyticsPlugin: span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - mock_tracer.start_span.assert_called_with("test_span") + mock_tracer.start_span.assert_called_with("test_span", context=None) assert span_id == format(span_id_int, "016x") # Test get_trace_id # We need to mock trace.get_current_span() to return our mock span @@ -3018,81 +3019,221 @@ class TestDuplicateLabels: assert "labels" not in attributes -class TestResolveSpanIds: - """Tests for the _resolve_span_ids static helper.""" +class TestResolveIds: + """Tests for the _resolve_ids static helper.""" - def test_uses_trace_manager_defaults(self): - """Should use TraceManager values when no overrides provided.""" + def _resolve(self, ed, callback_context): + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_ids( + ed, callback_context + ) + + def test_uses_trace_manager_defaults(self, callback_context): + """Should use TraceManager values when no overrides and no ambient.""" ed = bigquery_agent_analytics_plugin.EventData( extra_attributes={"some_key": "value"} ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + assert trace_id == "trace-1" assert span_id == "span-1" assert parent_id == "parent-1" - def test_span_id_override(self): + def test_span_id_override(self, callback_context): """Should use span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override="custom-span" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "custom-span" assert parent_id == "parent-1" - def test_parent_span_id_override(self): + def test_parent_span_id_override(self, callback_context): """Should use parent_span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( parent_span_id_override="custom-parent" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "custom-parent" - def test_none_override_keeps_default(self): + def test_none_override_keeps_default(self, callback_context): """None overrides should keep the TraceManager defaults.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override=None, parent_span_id_override=None ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "parent-1" + def test_ambient_otel_span_takes_priority(self, callback_context): + """When an ambient OTel span is valid, its IDs take priority.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData() + + with real_tracer.start_as_current_span("invocation") as parent_span: + with real_tracer.start_as_current_span("agent") as agent_span: + ambient_ctx = agent_span.get_span_context() + expected_trace = format(ambient_ctx.trace_id, "032x") + expected_span = format(ambient_ctx.span_id, "016x") + expected_parent = format(parent_span.get_span_context().span_id, "016x") + + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == expected_trace + assert span_id == expected_span + assert parent_id == expected_parent + provider.shutdown() + + def test_override_beats_ambient(self, callback_context): + """EventData overrides take priority over ambient OTel span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData( + trace_id_override="forced-trace", + span_id_override="forced-span", + parent_span_id_override="forced-parent", + ) + + with real_tracer.start_as_current_span("invocation"): + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == "forced-trace" + assert span_id == "forced-span" + assert parent_id == "forced-parent" + provider.shutdown() + + def test_ambient_root_span_no_self_parent(self, callback_context): + """Ambient root span (no parent) must not produce self-parent.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + # Seed the plugin stack with a span so there's a stale parent. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "plugin-child" + ) + + ed = bigquery_agent_analytics_plugin.EventData() + + # Single root ambient span — no parent. + with real_tracer.start_as_current_span("root_invocation") as root: + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + root_span_id = format(root.get_span_context().span_id, "016x") + + # span_id should be the ambient root's span_id + assert span_id == root_span_id + # parent must be None — not the stale plugin parent, not self + assert parent_id is None + assert span_id != parent_id + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + provider.shutdown() + + def test_ambient_span_used_for_completed_event(self, callback_context): + """Completed event with overrides should use ambient when present. + + When an ambient OTel span is valid, passing None overrides lets + _resolve_ids Layer 2 pick the ambient span — matching the + STARTING event's span_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with real_tracer.start_as_current_span("invoke_agent") as agent_span: + expected_span = format(agent_span.get_span_context().span_id, "016x") + + # Simulate STARTING: no overrides → ambient Layer 2 wins. + ed_starting = bigquery_agent_analytics_plugin.EventData() + _, span_starting, _ = self._resolve(ed_starting, callback_context) + + # Simulate COMPLETED: None overrides (ambient check passed). + ed_completed = bigquery_agent_analytics_plugin.EventData( + span_id_override=None, + parent_span_id_override=None, + latency_ms=42, + ) + _, span_completed, _ = self._resolve(ed_completed, callback_context) + + assert span_starting == expected_span + assert span_completed == expected_span + assert span_starting == span_completed + + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" @@ -3282,7 +3423,7 @@ class TestMultiSubagentToolLogging: instance=True, spec_set=True, ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_a, session=session, invocation_id=invocation_id, @@ -3488,7 +3629,7 @@ class TestMultiSubagentToolLogging: """ session = self._make_session() inv_ctx = self._make_invocation_context("schema_explorer", session) - cb_ctx = callback_context_lib.CallbackContext(invocation_context=inv_ctx) + cb_ctx = CallbackContext(invocation_context=inv_ctx) tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) mock_agent = inv_ctx.agent tool = self._make_tool("get_table_info") @@ -3766,9 +3907,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t1_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t1" ) - cb_ctx_t1_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_orch - ) + cb_ctx_t1_orch = CallbackContext(invocation_context=inv_ctx_t1_orch) # Orchestrator agent_starting await plugin.before_agent_callback( @@ -3781,9 +3920,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t1_sub = self._make_invocation_context( "schema_explorer", session, invocation_id="inv-t1" ) - cb_ctx_t1_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_sub - ) + cb_ctx_t1_sub = CallbackContext(invocation_context=inv_ctx_t1_sub) tool_ctx_t1 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t1_sub ) @@ -3831,9 +3968,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t2_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t2" ) - cb_ctx_t2_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_orch - ) + cb_ctx_t2_orch = CallbackContext(invocation_context=inv_ctx_t2_orch) await plugin.before_agent_callback( agent=inv_ctx_t2_orch.agent, @@ -3845,9 +3980,7 @@ class TestMultiSubagentToolLogging: inv_ctx_t2_sub = self._make_invocation_context( "image_describer", session, invocation_id="inv-t2" ) - cb_ctx_t2_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_sub - ) + cb_ctx_t2_sub = CallbackContext(invocation_context=inv_ctx_t2_sub) tool_ctx_t2 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t2_sub ) @@ -3949,3 +4082,2373 @@ class TestMultiSubagentToolLogging: # All rows share the same session for row in rows: assert row["session_id"] == "session-multi" + + +class TestSchemaAutoUpgrade: + """Tests for _ensure_schema_exists with auto_schema_upgrade.""" + + def _make_plugin(self, auto_schema_upgrade=False): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + auto_schema_upgrade=auto_schema_upgrade, + ) + with mock.patch("google.cloud.bigquery.Client"): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_create_table_sets_version_label(self): + """New tables get the schema version label.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin._ensure_schema_exists() + plugin.client.create_table.assert_called_once() + tbl = plugin.client.create_table.call_args[0][0] + assert ( + tbl.labels[bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_no_upgrade_when_disabled(self): + """Auto-upgrade disabled: existing table is not modified.""" + plugin = self._make_plugin(auto_schema_upgrade=False) + existing = mock.MagicMock() + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_adds_missing_columns(self): + """Auto-upgrade adds columns missing from existing table.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {"other": "label"} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = {f.name for f in updated_table.schema} + assert "event_type" in updated_names + assert "agent" in updated_names + assert "content" in updated_names + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_skip_upgrade_when_version_matches(self): + """No update when stored version matches current.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_error_is_logged_not_raised(self): + """Schema upgrade errors are logged, not propagated.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin.client.update_table.side_effect = Exception("boom") + # Should not raise + plugin._ensure_schema_exists() + + def test_upgrade_preserves_existing_columns(self): + """Existing columns are never dropped or altered during upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + # Simulate a table with a subset of canonical columns plus a + # user-added custom column that is NOT in the canonical schema. + custom_field = bigquery.SchemaField("my_custom_col", "STRING") + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + custom_field, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = [f.name for f in updated_table.schema] + # Original columns are still present and in original order. + assert updated_names[0] == "timestamp" + assert updated_names[1] == "event_type" + assert updated_names[2] == "my_custom_col" + # New canonical columns were appended after existing ones. + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_from_no_label_treats_as_outdated(self): + """A table with no version label is treated as needing upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = list(plugin._schema) # All columns present + existing.labels = {} # No version label + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + # update_table should be called to stamp the version label even + # though no new columns were needed. + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_upgrade_from_older_version_label(self): + """A table with an older version label triggers upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + ] + # Simulate a table stamped with an older version. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: "0", + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + # Version label should be updated to current. + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + # Missing columns should have been added. + updated_names = {f.name for f in updated_table.schema} + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_is_idempotent(self): + """Calling _ensure_schema_exists twice doesn't double-update.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + + # First call: table exists with old schema. + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + assert plugin.client.update_table.call_count == 1 + + # Second call: table now has current version label. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.update_table.reset_mock() + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_update_table_receives_schema_and_labels_fields(self): + """update_table is called with update_fields=['schema', 'labels'].""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + call_args = plugin.client.update_table.call_args + update_fields = call_args[0][1] + assert "schema" in update_fields + assert "labels" in update_fields + + def test_auto_schema_upgrade_defaults_to_true(self): + """Default config has auto_schema_upgrade enabled.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.auto_schema_upgrade is True + + def test_create_table_conflict_is_ignored(self): + """Race condition (Conflict) during create_table is silently handled.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = cloud_exceptions.Conflict( + "already exists" + ) + # Should not raise. + plugin._ensure_schema_exists() + + +class TestToolProvenance: + """Tests for _get_tool_origin helper.""" + + def test_function_tool_returns_local(self): + from google.adk.tools.function_tool import FunctionTool + + def dummy(): + pass + + tool = FunctionTool(dummy) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "LOCAL" + + def test_agent_tool_returns_sub_agent(self): + from google.adk.tools.agent_tool import AgentTool + + agent = mock.MagicMock() + agent.name = "sub" + tool = AgentTool.__new__(AgentTool) + tool.agent = agent + tool._name = "sub" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "SUB_AGENT" + + def test_transfer_tool_returns_transfer_agent(self): + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + tool = TransferToAgentTool(agent_names=["other"]) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "TRANSFER_AGENT" + + def test_mcp_tool_returns_mcp(self): + try: + from google.adk.tools.mcp_tool.mcp_tool import McpTool + except ImportError: + pytest.skip("MCP not installed") + tool = McpTool.__new__(McpTool) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "MCP" + + def test_a2a_agent_tool_returns_a2a(self): + from google.adk.tools.agent_tool import AgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote" + remote_agent.description = "remote a2a agent" + tool = AgentTool.__new__(AgentTool) + tool.agent = remote_agent + tool._name = "remote" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "A2A" + + def test_unknown_tool_returns_unknown(self): + tool = mock.MagicMock(spec=base_tool_lib.BaseTool) + tool.name = "mystery" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "UNKNOWN" + + +class TestHITLTracing: + """Tests for HITL-specific event emission via on_event_callback. + + HITL events (``adk_request_credential``, ``adk_request_confirmation``, + ``adk_request_input``) are synthetic function calls injected by the + framework — they never pass through ``before_tool_callback`` / + ``after_tool_callback``. Detection therefore lives in + ``on_event_callback``, which inspects the event stream for these + function calls and their corresponding function responses. + """ + + def _make_fc_event(self, fc_name, args=None): + """Build a mock Event containing a function call.""" + event = mock.MagicMock(spec=event_lib.Event) + fc = types.FunctionCall(name=fc_name, args=args or {}) + part = types.Part(function_call=fc) + event.content = types.Content(role="model", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + def _make_fr_event(self, fr_name, response=None): + """Build a mock Event containing a function response.""" + event = mock.MagicMock(spec=event_lib.Event) + fr = types.FunctionResponse(name=fr_name, response=response or {}) + part = types.Part(function_response=fr) + event.content = types.Content(role="user", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + @pytest.mark.asyncio + async def test_hitl_confirmation_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_confirmation", {"confirm": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_credential_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_credential", {"auth": "oauth2"}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CREDENTIAL_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_completion_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fr_event("adk_request_confirmation", {"confirmed": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + + @pytest.mark.asyncio + async def test_regular_tool_no_hitl_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("regular_tool", {"x": 1}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + # No HITL events should be emitted for non-HITL function calls. + # on_event_callback only logs STATE_DELTA and HITL events; a regular + # function call produces neither. + assert mock_write_client.append_rows.call_count == 0 + + +# ============================================================================== +# TEST CLASS: Span Hierarchy Isolation (Issue #4561) +# ============================================================================== + + +class TestSpanHierarchyIsolation: + """Regression tests for https://github.com/google/adk-python/issues/4561. + + ``push_span()`` must NOT attach its span to the ambient OTel context. + If it does, any subsequent ``tracer.start_as_current_span()`` in the + framework (e.g. ``call_llm``, ``execute_tool``) will be incorrectly + re-parented under the plugin's span. + """ + + def test_push_span_does_not_change_ambient_context(self, callback_context): + """push_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_attach_current_span_does_not_change_ambient_context( + self, callback_context + ): + """attach_current_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.attach_current_span( + callback_context + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_pop_span_does_not_change_ambient_context(self, callback_context): + """pop_span must not mutate the current OTel span.""" + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + span_after = trace.get_current_span() + assert span_after is span_before + + def test_push_span_with_real_tracer_does_not_reparent(self, callback_context): + """With a real OTel tracer, plugin spans must not become parents + + of subsequently created framework spans. + """ + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + provider.add_span_processor(SimpleSpanProcessor(exporter)) + framework_tracer = provider.get_tracer("test-framework") + + # Simulate: plugin pushes a span BEFORE the framework span + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "llm_request" + ) + + # Framework creates its own span via start_as_current_span + with framework_tracer.start_as_current_span("call_llm") as fw_span: + fw_context = fw_span.get_span_context() + + # Pop the plugin span + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + provider.shutdown() + + # Verify the framework span was NOT re-parented under the + # plugin's llm_request span + finished = exporter.get_finished_spans() + call_llm_spans = [s for s in finished if s.name == "call_llm"] + assert len(call_llm_spans) == 1 + fw_finished = call_llm_spans[0] + + # The framework span's parent should NOT be the plugin's + # llm_request span. With the fix, the plugin never + # attaches to the ambient context, so ``call_llm`` will + # have whatever parent existed before (None in this test). + assert fw_finished.parent is None + + def test_multiple_push_pop_cycles_leave_context_clean(self, callback_context): + """Multiple push/pop cycles must not leak context changes.""" + original_span = trace.get_current_span() + + for _ in range(5): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "cycle_span" + ) + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + assert trace.get_current_span() is original_span + + +# ============================================================================== +# TEST CLASS: End-to-End HITL Tracing via Runner +# ============================================================================== + + +def _hitl_my_action( + tool_context: tool_context_lib.ToolContext, +) -> dict[str, str]: + """Tool function used by HITL end-to-end tests.""" + return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"} + + +class TestHITLTracingEndToEnd: + """End-to-end tests that run the full Runner + Plugin pipeline with + + ``FunctionTool(require_confirmation=True)`` and verify that HITL events + are logged alongside normal TOOL_* events in the BQ analytics plugin. + """ + + @pytest.fixture + def _mock_bq_infra( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Bundle all BQ mocking fixtures.""" + yield mock_write_client + + @pytest.mark.asyncio + async def test_confirmation_flow_emits_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """Full Runner pipeline: tool with require_confirmation emits + + HITL_CONFIRMATION_REQUEST and HITL_CONFIRMATION_REQUEST_COMPLETED. + """ + from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import FunctionResponse + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + tool = FunctionTool(func=_hitl_my_action, require_confirmation=True) + + # -- Mock LLM: first response calls the tool, second is final text -- + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[Part(text="Done, action confirmed.")] + ) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + # -- Build the plugin -- + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # -- Build agent + runner WITH the plugin -- + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="hitl_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + # -- Turn 1: user query → LLM calls tool → HITL pause -- + events_turn1 = await runner.run_async( + testing_utils.UserContent("run my_action") + ) + + # Find the adk_request_confirmation function call + confirmation_fc_id = None + for ev in events_turn1: + if ev.content and ev.content.parts: + for part in ev.content.parts: + if ( + hasattr(part, "function_call") + and part.function_call + and part.function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + confirmation_fc_id = part.function_call.id + break + if confirmation_fc_id: + break + + assert ( + confirmation_fc_id is not None + ), "Expected adk_request_confirmation function call in turn 1" + + # -- Turn 2: user sends confirmation → tool re-executes -- + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=confirmation_fc_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events_turn2 = await runner.run_async(user_confirmation) + + # -- Give the async BQ writer a moment to flush -- + await asyncio.sleep(0.2) + + # -- Collect all BQ rows -- + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # -- Verify standard events are present -- + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # -- Verify HITL-specific events are present -- + assert ( + "HITL_CONFIRMATION_REQUEST" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST in {event_types}" + assert ( + "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST_COMPLETED in {event_types}" + + # -- Verify HITL events have correct tool name in content -- + hitl_rows = [r for r in rows if r["event_type"].startswith("HITL_")] + for row in hitl_rows: + content = json.loads(row["content"]) if row["content"] else {} + assert content.get("tool") == "adk_request_confirmation", ( + "HITL event should reference 'adk_request_confirmation'," + f" got {content.get('tool')}" + ) + + await bq_plugin.shutdown() + + @pytest.mark.asyncio + async def test_regular_tool_does_not_emit_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """A tool WITHOUT require_confirmation should not produce HITL events.""" + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + def regular_tool() -> str: + return "done" + + tool = FunctionTool(func=regular_tool) + + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent(parts=[Part(text="All done.")]) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="regular_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + await runner.run_async(testing_utils.UserContent("run regular_tool")) + await asyncio.sleep(0.2) + + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # Standard tool events should be present + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # No HITL events + hitl_events = [et for et in event_types if et.startswith("HITL_")] + assert ( + hitl_events == [] + ), f"Expected no HITL events for regular tool, got {hitl_events}" + + await bq_plugin.shutdown() + + +# ============================================================================== +# Fork-Safety Tests +# ============================================================================== +class TestForkSafety: + """Tests for fork-safety via PID tracking.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + return plugin + + @pytest.mark.asyncio + async def test_pid_change_triggers_reinit( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Simulating a fork by changing _init_pid forces re-init.""" + plugin = self._make_plugin() + await plugin._ensure_started() + assert plugin._started is True + + # Simulate a fork: set _init_pid to a stale value + plugin._init_pid = -1 + assert plugin._started is True # still True before check + + # _ensure_started should detect PID mismatch and reset + await plugin._ensure_started() + # After reset + re-init, _init_pid should match current + + assert plugin._init_pid == os.getpid() + assert plugin._started is True + await plugin.shutdown() + + @pytest.mark.asyncio + async def test_pid_unchanged_skips_reset( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Same PID should not trigger a reset.""" + plugin = self._make_plugin() + await plugin._ensure_started() + + # Save references to verify they are not recreated + original_client = plugin.client + original_parser = plugin.parser + + await plugin._ensure_started() + assert plugin.client is original_client + assert plugin.parser is original_parser + await plugin.shutdown() + + def test_reset_runtime_state_clears_fields(self): + """_reset_runtime_state clears all runtime fields.""" + plugin = self._make_plugin() + # Fake some runtime state + plugin._started = True + plugin._is_shutting_down = True + plugin.client = mock.MagicMock() + plugin._loop_state_by_loop = {"fake": "state"} + plugin._write_stream_name = "some/stream" + plugin._executor = mock.MagicMock() + plugin.offloader = mock.MagicMock() + plugin.parser = mock.MagicMock() + plugin._setup_lock = mock.MagicMock() + # Keep pure-data fields + plugin._schema = ["kept"] + plugin.arrow_schema = "kept_arrow" + + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._is_shutting_down is False + assert plugin.client is None + assert plugin._loop_state_by_loop == {} + assert plugin._write_stream_name is None + assert plugin._executor is None + assert plugin.offloader is None + assert plugin.parser is None + assert plugin._setup_lock is None + # Pure-data fields are preserved + assert plugin._schema == ["kept"] + assert plugin.arrow_schema == "kept_arrow" + + assert plugin._init_pid == os.getpid() + + def test_getstate_resets_pid(self): + """Pickle state should have _init_pid = 0 to force re-init.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + assert state["_init_pid"] == 0 + assert state["_started"] is False + + @pytest.mark.asyncio + async def test_unpickle_legacy_state_missing_init_pid( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Unpickling state from older code without _init_pid should not crash.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + # Simulate legacy pickle state that lacks _init_pid entirely + del state["_init_pid"] + + new_plugin = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin.__new__( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin + ) + ) + new_plugin.__setstate__(state) + + # _init_pid should be backfilled to 0, triggering re-init + assert new_plugin._init_pid == 0 + # _ensure_started should not raise AttributeError + await new_plugin._ensure_started() + assert new_plugin._started is True + await new_plugin.shutdown() + + +class TestForkGrpcSafety: + """Tests for gRPC fork safety enhancements.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + + def test_grpc_fork_env_var_set(self): + """GRPC_ENABLE_FORK_SUPPORT should be '1' after import.""" + + assert os.environ.get("GRPC_ENABLE_FORK_SUPPORT") == "1" + + def test_register_at_fork_resets_all_instances(self): + """_after_fork_in_child resets all living plugin instances.""" + p1 = self._make_plugin() + p2 = self._make_plugin() + p1._started = True + p2._started = True + p1._init_pid = -1 + p2._init_pid = -1 + + bigquery_agent_analytics_plugin._after_fork_in_child() + + assert p1._started is False + assert p2._started is False + assert p1._init_pid == os.getpid() + assert p2._init_pid == os.getpid() + + def test_dead_plugin_removed_from_live_set(self): + """WeakSet should not hold dead plugin references.""" + p = self._make_plugin() + assert p in bigquery_agent_analytics_plugin._LIVE_PLUGINS + pid = id(p) + del p + # After deletion, the WeakSet should no longer contain it. + for alive in bigquery_agent_analytics_plugin._LIVE_PLUGINS: + assert id(alive) != pid + + def test_reset_closes_inherited_sync_transports(self): + """_reset_runtime_state closes inherited sync gRPC channels.""" + plugin = self._make_plugin() + mock_channel = mock.MagicMock() + mock_channel.close.return_value = None # sync close + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + plugin._reset_runtime_state() + + mock_channel.close.assert_called_once() + + def test_reset_discards_async_channel_close_coroutine(self): + """Async channel close() returns a coroutine; must not warn.""" + import warnings + + plugin = self._make_plugin() + + async def _async_close(): + pass + + mock_channel = mock.MagicMock() + mock_channel.close.return_value = _async_close() + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + # Must not raise RuntimeWarning for unawaited coroutine + plugin._reset_runtime_state() + + mock_channel.close.assert_called_once() + + def test_transport_close_exception_swallowed(self): + """close() raising should not prevent reset from completing.""" + plugin = self._make_plugin() + mock_channel = mock.MagicMock() + mock_channel.close.side_effect = RuntimeError("broken channel") + mock_transport = mock.MagicMock() + mock_transport._grpc_channel = mock_channel + mock_wc = mock.MagicMock() + mock_wc.transport = mock_transport + + mock_loop_state = mock.MagicMock() + mock_loop_state.write_client = mock_wc + + plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state} + plugin._init_pid = -1 + + # Should not raise + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._loop_state_by_loop == {} + + def test_reset_logs_fork_warning(self): + """_reset_runtime_state logs a warning with 'Fork detected'.""" + plugin = self._make_plugin() + plugin._init_pid = -1 + + with mock.patch.object( + bigquery_agent_analytics_plugin.logger, "warning" + ) as mock_warn: + plugin._reset_runtime_state() + + mock_warn.assert_called_once() + assert "Fork detected" in mock_warn.call_args[0][0] + + +# ============================================================================== +# Analytics Views Tests +# ============================================================================== +class TestAnalyticsViews: + """Tests for auto-created per-event-type BigQuery views.""" + + def _make_plugin(self, create_views=True): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=create_views, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_views_created_on_new_table(self): + """NotFound path creates all views.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_created_for_existing_table(self): + """Existing table path also creates views.""" + plugin = self._make_plugin(create_views=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_not_created_when_disabled(self): + """create_views=False skips view creation.""" + plugin = self._make_plugin(create_views=False) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + + plugin._ensure_schema_exists() + + plugin.client.query.assert_not_called() + + def test_view_creation_error_logged_not_raised(self): + """Errors during view creation don't crash the plugin.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.query.side_effect = Exception("BQ error") + + # Should not raise + plugin._ensure_schema_exists() + + # Verify it tried to create views (and failed gracefully) + assert plugin.client.query.call_count > 0 + + def test_view_sql_contains_correct_event_filter(self): + """Each SQL has correct WHERE clause and view name.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + calls = plugin.client.query.call_args_list + for call in calls: + sql = call[0][0] + # Each SQL should have CREATE OR REPLACE VIEW + assert "CREATE OR REPLACE VIEW" in sql + # Each SQL should filter by event_type + assert "WHERE" in sql + assert "event_type = " in sql + # View name should start with v_ + assert ".v_" in sql + + # Verify specific views exist + all_sql = " ".join(c[0][0] for c in calls) + for event_type in bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS: + view_name = "v_" + event_type.lower() + assert view_name in all_sql, f"View {view_name} not found in SQL" + + def test_config_create_views_default_true(self): + """Config create_views defaults to True.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.create_views is True + + @pytest.mark.asyncio + async def test_create_analytics_views_ensures_started( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Public create_analytics_views() initializes plugin first.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + assert plugin._started is False + + await plugin.create_analytics_views() + + # Plugin should be started after the call + assert plugin._started is True + # Views should have been created (query called) + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + # _ensure_schema_exists also creates views, so total calls + # = schema-creation views + explicit views + assert mock_bq_client.query.call_count >= expected_count + await plugin.shutdown() + + def test_views_not_created_after_table_creation_failure(self): + """View creation is skipped when create_table raises a non-Conflict error.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = RuntimeError("BQ down") + + plugin._ensure_schema_exists() + + # Views should NOT be attempted since table creation failed + plugin.client.query.assert_not_called() + + @pytest.mark.asyncio + async def test_create_analytics_views_raises_on_startup_failure( + self, mock_auth_default, mock_write_client + ): + """create_analytics_views() raises if plugin init fails.""" + # Make the BQ Client constructor raise so _lazy_setup fails + # before _started is set to True. + with mock.patch.object( + bigquery, "Client", side_effect=Exception("client boom") + ): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + with pytest.raises( + RuntimeError, match="Plugin initialization failed" + ) as exc_info: + await plugin.create_analytics_views() + # Root cause should be chained for debuggability + assert exc_info.value.__cause__ is not None + assert "client boom" in str(exc_info.value.__cause__) + + +# ============================================================================== +# Trace-ID Continuity Tests (Issue #4645) +# ============================================================================== +class TestTraceIdContinuity: + """Tests for trace_id continuity across all events in an invocation. + + Regression tests for https://github.com/google/adk-python/issues/4645. + + When there is no ambient OTel span (e.g. Agent Engine, custom runners), + early events (USER_MESSAGE_RECEIVED, INVOCATION_STARTING) used to fall + back to ``invocation_id`` while AGENT_STARTING got a new OTel hex + trace_id from ``push_span()``. The ``ensure_invocation_span()`` fix + guarantees a root span is always on the stack before any events fire. + """ + + @pytest.mark.asyncio + async def test_trace_id_continuity_no_ambient_span(self, callback_context): + """All events share one trace_id when no ambient OTel span exists. + + Simulates the #4645 scenario: OTel IS configured (real TracerProvider) + but the Runner's ambient span is NOT present (e.g. Agent Engine, + custom runners). + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Create a real TracerProvider and patch the plugin's module-level + # tracer so push_span creates valid spans with proper trace_ids. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span — we do NOT start_as_current_span. + ambient = trace.get_current_span() + assert not ambient.get_span_context().is_valid + + # ensure_invocation_span should push a new span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early is not None + # Should NOT fall back to invocation_id — it should be + # a 32-char hex OTel trace_id. + assert trace_id_early != callback_context.invocation_id + assert len(trace_id_early) == 32 + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + + # Both trace_ids must be identical. + assert trace_id_early == trace_id_agent + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_completed_trace_continuity_no_ambient( + self, callback_context + ): + """INVOCATION_COMPLETED must share trace_id with earlier events. + + Reproduces the completion-event fracture: after_run_callback pops + the invocation span, then _log_event would resolve trace_id via + the fallback to invocation_id. The trace_id_override ensures the + completion event keeps the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset for a clean invocation; no ambient span. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + assert not trace.get_current_span().get_span_context().is_valid + + # --- Simulate the full callback lifecycle --- + # 1. before_run / on_user_message: ensure invocation span + TM.ensure_invocation_span(callback_context) + trace_id_start = TM.get_trace_id(callback_context) + + # 2. before_agent: push agent span + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_start + + # 3. after_agent: pop agent span + TM.pop_span() + + # 4. after_run: capture trace_id THEN pop invocation span + trace_id_before_pop = TM.get_trace_id(callback_context) + assert trace_id_before_pop == trace_id_start + + TM.pop_span() + + # After popping, get_trace_id falls back to invocation_id + trace_id_after_pop = TM.get_trace_id(callback_context) + assert trace_id_after_pop == callback_context.invocation_id + + # The trace_id_override preserves continuity + assert trace_id_before_pop == trace_id_start + assert trace_id_before_pop != trace_id_after_pop + + provider.shutdown() + + @pytest.mark.asyncio + async def test_callbacks_emit_same_trace_id_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Full callback path: all emitted rows share one trace_id. + + Exercises the real before_run → before_agent → after_agent → + after_run callback chain via the plugin instance, then checks + every emitted BQ row has the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset span records for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span — simulates Agent Engine / custom runner. + assert not trace.get_current_span().get_span_context().is_valid + + # Run the full callback lifecycle. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + # Collect all emitted rows. + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "INVOCATION_STARTING" in event_types + assert "INVOCATION_COMPLETED" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + # Should be a 32-char hex OTel trace, not the invocation_id. + sole_trace_id = trace_ids.pop() + assert sole_trace_id != invocation_context.invocation_id + assert len(sole_trace_id) == 32 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_trace_id_continuity_with_ambient_span(self, callback_context): + """All events share one trace_id when an ambient OTel span exists.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Set up a real OTel tracer. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + with real_tracer.start_as_current_span("runner_invocation"): + ambient = trace.get_current_span() + assert ambient.get_span_context().is_valid + ambient_trace_id = format(ambient.get_span_context().trace_id, "032x") + + # ensure_invocation_span should attach the ambient span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early == ambient_trace_id + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + assert trace_id_agent == ambient_trace_id + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation (attached, not owned) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_root_span_isolated_across_turns( + self, callback_context + ): + """Each invocation gets its own root span; turns don't leak.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Turn 1 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn1 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn1 + TM.pop_span() # agent + TM.pop_span() # invocation + + # After popping, the stack should be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert not records + + # --- Turn 2 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn2 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn2 + TM.pop_span() # agent + TM.pop_span() # invocation + + # The two turns must have DIFFERENT trace_ids (different + # root spans). + assert trace_id_turn1 != trace_id_turn2 + + provider.shutdown() + + +class TestSpanIdConsistency: + """Tests that STARTING/COMPLETED event pairs share span IDs. + + Span-ID resolution contract: + - When OTel is active: BQ rows use the same trace/span/parent IDs as + Cloud Trace (ambient framework spans). STARTING and COMPLETED events + in the same lifecycle share the same span_id. + - When OTel is not active: BQ rows use the plugin's internal span + stack. STARTING gets the current top-of-stack; COMPLETED gets the + popped span. + """ + + @pytest.mark.asyncio + async def test_starting_completed_same_span_with_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """With ambient OTel, STARTING and COMPLETED get the same span_id.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # Simulate the framework's ambient spans. + with real_tracer.start_as_current_span("invocation"): + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + with real_tracer.start_as_current_span("invoke_agent"): + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # Both events must share the same span_id (the ambient + # invoke_agent span) — no plugin-synthetic override. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + assert ( + agent_starting[0]["parent_span_id"] + == agent_completed[0]["parent_span_id"] + ) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_starting_completed_use_plugin_span_without_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Without ambient OTel, COMPLETED gets the popped plugin span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # AGENT_STARTING gets the top-of-stack span; AGENT_COMPLETED + # gets the popped span via override — they should match. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_tool_error_captures_span_id( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + dummy_arrow_schema, + ): + """on_tool_error_callback uses the popped span_id (bonus fix).""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_tool = mock.create_autospec(base_tool_lib.BaseTool, instance=True) + type(mock_tool).name = mock.PropertyMock(return_value="my_tool") + tool_ctx = tool_context_lib.ToolContext( + invocation_context=invocation_context + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel — plugin span stack provides IDs. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push tool span via before_tool_callback + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + ) + # Error callback should pop the tool span and use its ID + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + error=RuntimeError("boom"), + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + tool_starting = [r for r in rows if r["event_type"] == "TOOL_STARTING"] + tool_error = [r for r in rows if r["event_type"] == "TOOL_ERROR"] + + assert len(tool_starting) == 1 + assert len(tool_error) == 1 + + # The TOOL_ERROR event must have the same span_id as + # TOOL_STARTING (both correspond to the same tool span). + assert tool_starting[0]["span_id"] == tool_error[0]["span_id"] + assert tool_error[0]["span_id"] is not None + + provider.shutdown() + + +class TestStackLeakSafety: + """Tests for stack leak safety (P2). + + Ensures the plugin's internal span stack doesn't leak records + across invocations when after_run_callback is skipped. + """ + + def test_ensure_invocation_span_clears_stale_records(self, callback_context): + """Pre-populated stack from a different invocation is cleared.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Simulate stale records from incomplete previous invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + # Mark the stale records as belonging to a different invocation. + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set( + "old-inv-stale" + ) + TM.push_span(callback_context, "stale-invocation") + TM.push_span(callback_context, "stale-agent") + + stale_records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale_records) == 2 + + # ensure_invocation_span with the *current* invocation_id should + # detect the mismatch, clear stale records, and re-init. + TM.ensure_invocation_span(callback_context) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh entry (the new invocation span). + assert len(records) == 1 + # The fresh span should NOT be one of the stale ones. + assert records[0].span_id != stale_records[0].span_id + assert records[0].span_id != stale_records[1].span_id + + provider.shutdown() + + def test_clear_stack_ends_owned_spans(self, callback_context): + """clear_stack() ends all owned spans.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + exporter = InMemorySpanExporter() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.push_span(callback_context, "span-a") + TM.push_span(callback_context, "span-b") + + records = list(bigquery_agent_analytics_plugin._span_records_ctx.get()) + assert all(r.owns_span for r in records) + + TM.clear_stack() + + # Stack must be empty after clear. + result = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert result == [] + + # Both owned spans should have been ended (exported). + exported = exporter.get_finished_spans() + assert len(exported) == 2 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_after_run_callback_clears_remaining_stack( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """after_run_callback clears any leftover stack entries.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push an agent span but DON'T pop it (simulate missing + # after_agent_callback due to exception). + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Stack now has [invocation, agent]. + + # after_run_callback should pop invocation + clear remaining. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # Stack must be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_next_invocation_clean_after_incomplete_previous( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + mock_session, + ): + """Next invocation starts clean even if previous was incomplete.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # --- Incomplete invocation 1: no after_run_callback --- + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Skip after_agent and after_run — simulates exception. + + stale = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale) >= 2 # invocation + agent + + # --- Invocation 2 with a different invocation_id --- + mock_write_client.append_rows.reset_mock() + inv_ctx_2 = InvocationContext( + agent=mock_agent, + session=mock_session, + invocation_id="inv-NEW-002", + session_service=invocation_context.session_service, + plugin_manager=invocation_context.plugin_manager, + ) + await bq_plugin_inst.before_run_callback(invocation_context=inv_ctx_2) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh invocation span. + assert len(records) == 1 + + # Cleanup + await bq_plugin_inst.after_run_callback(invocation_context=inv_ctx_2) + + provider.shutdown() + + def test_ensure_invocation_span_idempotent_same_invocation( + self, callback_context + ): + """Calling ensure_invocation_span twice in the same invocation is a no-op.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # First call: creates invocation span. + TM.ensure_invocation_span(callback_context) + records_after_first = list( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_first) == 1 + first_span_id = records_after_first[0].span_id + + # Second call (same invocation): must be a no-op. + TM.ensure_invocation_span(callback_context) + records_after_second = ( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_second) == 1 + assert records_after_second[0].span_id == first_span_id + + # Cleanup + TM.pop_span() + + provider.shutdown() + + @pytest.mark.asyncio + async def test_user_message_then_before_run_same_trace_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Regression: on_user_message → before_run must share one trace_id. + + Without the invocation-ID guard, the second ensure_invocation_span() + call would clear the stack and create a new root span with a + different trace_id, fracturing USER_MESSAGE_RECEIVED from + INVOCATION_STARTING. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + user_msg = types.Content(parts=[types.Part(text="hello")], role="user") + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=user_msg, + ) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "USER_MESSAGE_RECEIVED" in event_types + assert "INVOCATION_STARTING" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + + provider.shutdown() + + +class TestRootAgentNameAcrossInvocations: + """Regression: root_agent_name must refresh across invocations.""" + + @pytest.mark.asyncio + async def test_root_agent_name_updates_between_invocations( + self, + bq_plugin_inst, + mock_write_client, + mock_session, + dummy_arrow_schema, + ): + """Two invocations with different root agents must log correct names. + + Previously init_trace() only set _root_agent_name_ctx when it was + None, so the second invocation would inherit the first's root agent. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ) + + def _make_inv_ctx(agent_name, inv_id): + agent = mock.create_autospec( + base_agent.BaseAgent, instance=True, spec_set=True + ) + type(agent).name = mock.PropertyMock(return_value=agent_name) + type(agent).instruction = mock.PropertyMock(return_value="") + # root_agent returns itself (no parent). + agent.root_agent = agent + return InvocationContext( + agent=agent, + session=mock_session, + invocation_id=inv_id, + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Invocation 1: root agent = "RootA" --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + inv1 = _make_inv_ctx("RootA", "inv-001") + cb1 = CallbackContext(inv1) + await bq_plugin_inst.before_run_callback(invocation_context=inv1) + await bq_plugin_inst.before_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv1) + await asyncio.sleep(0.01) + + rows_inv1 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # --- Invocation 2: root agent = "RootB" --- + mock_write_client.append_rows.reset_mock() + + inv2 = _make_inv_ctx("RootB", "inv-002") + cb2 = CallbackContext(inv2) + await bq_plugin_inst.before_run_callback(invocation_context=inv2) + await bq_plugin_inst.before_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv2) + await asyncio.sleep(0.01) + + rows_inv2 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # Parse root_agent_name from the attributes JSON column. + def _get_root_names(rows): + names = set() + for r in rows: + attrs = r.get("attributes") + if attrs: + parsed = json.loads(attrs) if isinstance(attrs, str) else attrs + if "root_agent_name" in parsed: + names.add(parsed["root_agent_name"]) + return names + + names_inv1 = _get_root_names(rows_inv1) + names_inv2 = _get_root_names(rows_inv2) + + # Invocation 1 should only have "RootA". + assert names_inv1 == {"RootA"}, f"Expected {{'RootA'}}, got {names_inv1}" + # Invocation 2 must have "RootB", NOT stale "RootA". + assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}" + + provider.shutdown() + + +class TestAfterRunCleanupExceptionSafety: + """after_run_callback cleanup must execute even if _log_event fails.""" + + @pytest.mark.asyncio + async def test_cleanup_runs_when_log_event_raises( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + ): + """Stale state is cleared even when _log_event raises.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + # Run a normal before_run to initialise state. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + + # Verify state is populated. + assert bigquery_agent_analytics_plugin._span_records_ctx.get() + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is not None + ) + + # Make _log_event raise inside after_run_callback. + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("boom"), + ): + # _safe_callback swallows the exception, but cleanup in + # the finally block must still execute. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # All invocation state must be cleaned up despite the error. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] or records is None + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is None + ) + assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None + + provider.shutdown() + + +class TestStringSystemPromptTruncation: + """Tests that a string system prompt is truncated in parse().""" + + @pytest.mark.asyncio + async def test_long_string_system_prompt_is_truncated(self): + """A string system_instruction exceeding max_content_length is truncated.""" + parser = bigquery_agent_analytics_plugin.HybridContentParser( + offloader=None, + trace_id="test-trace", + span_id="test-span", + max_length=50, + ) + long_prompt = "A" * 200 + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction=long_prompt, + ), + ) + payload, _, is_truncated = await parser.parse(llm_request) + assert is_truncated + assert len(payload["system_prompt"]) < 200 + assert "TRUNCATED" in payload["system_prompt"] + + +class TestSessionStateTruncation: + """Tests that session state is truncated in _enrich_attributes.""" + + @pytest.mark.asyncio + async def test_oversized_session_state_is_truncated( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + mock_session, + invocation_context, + ): + """Session state with large values is truncated.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + max_content_length=30, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + await plugin._ensure_started() + + # Set a large session state value. + large_value = "X" * 200 + type(mock_session).state = mock.PropertyMock( + return_value={"big_key": large_value} + ) + + callback_ctx = CallbackContext(invocation_context=invocation_context) + event_data = bigquery_agent_analytics_plugin.EventData() + attrs = plugin._enrich_attributes(event_data, callback_ctx) + state = attrs["session_metadata"]["state"] + assert len(state["big_key"]) < 200 + assert "TRUNCATED" in state["big_key"] + await plugin.shutdown() + + +class TestSchemaUpgradeNestedFields: + """Tests for nested RECORD field detection in schema upgrade.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + auto_schema_upgrade=True, + ) + with mock.patch("google.cloud.bigquery.Client"): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + return plugin + + def test_nested_field_detected(self): + """A new sub-field in a RECORD triggers an upgrade.""" + plugin = self._make_plugin() + + existing_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + ], + ) + desired_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + bigquery.SchemaField("value", "STRING"), + ], + ) + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + desired_record, + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + existing_record, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + # Find the metadata field and check it has both sub-fields. + metadata_field = next( + f for f in updated_table.schema if f.name == "metadata" + ) + sub_names = {sf.name for sf in metadata_field.fields} + assert "key" in sub_names + assert "value" in sub_names + + def test_version_label_not_stamped_on_failure(self): + """A failed update_table does not persist the version label.""" + plugin = self._make_plugin() + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("new_col", "STRING"), + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin.client.update_table.side_effect = Exception("network error") + + # Should not raise. + plugin._ensure_schema_exists() + + # The label is set on the table object before update_table is + # called, but since update_table failed the label was never + # persisted remotely. On the next run the stored_version will + # still be None (from the real BQ table) so the upgrade retries. + # We verify that update_table was actually attempted. + plugin.client.update_table.assert_called_once() + + def test_nested_upgrade_preserves_policy_tags(self): + """RECORD field metadata (e.g. policy_tags) is preserved on upgrade.""" + from google.cloud.bigquery import schema as bq_schema + + plugin = self._make_plugin() + + existing_record = bigquery.SchemaField( + "metadata", + "RECORD", + policy_tags=bq_schema.PolicyTagList( + names=["projects/p/locations/us/taxonomies/t/policyTags/pt"] + ), + fields=[ + bigquery.SchemaField("key", "STRING"), + ], + ) + desired_record = bigquery.SchemaField( + "metadata", + "RECORD", + fields=[ + bigquery.SchemaField("key", "STRING"), + bigquery.SchemaField("value", "STRING"), + ], + ) + plugin._schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + desired_record, + ] + + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + existing_record, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + metadata_field = next( + f for f in updated_table.schema if f.name == "metadata" + ) + # Sub-fields were merged. + sub_names = {sf.name for sf in metadata_field.fields} + assert "key" in sub_names + assert "value" in sub_names + # policy_tags preserved from the existing field. + assert metadata_field.policy_tags is not None + assert ( + "projects/p/locations/us/taxonomies/t/policyTags/pt" + in metadata_field.policy_tags.names + ) + + +class TestMultiLoopShutdownDrainsOtherLoops: + """Tests that shutdown() drains batch processors on other loops.""" + + @pytest.mark.asyncio + async def test_other_loop_batch_processor_drained( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Shutdown drains batch_processor.shutdown on non-current loops.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await plugin._ensure_started() + + # Create a mock "other" loop with a mock batch processor. + other_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + other_loop.is_closed.return_value = False + + mock_other_bp = mock.AsyncMock() + mock_other_write_client = mock.MagicMock() + mock_other_write_client.transport = mock.AsyncMock() + + other_state = bigquery_agent_analytics_plugin._LoopState( + write_client=mock_other_write_client, + batch_processor=mock_other_bp, + ) + plugin._loop_state_by_loop[other_loop] = other_state + + # Patch run_coroutine_threadsafe to verify it's called for + # the other loop's batch_processor. Close the coroutine arg + # to avoid "coroutine was never awaited" RuntimeWarning. + mock_future = mock.MagicMock() + mock_future.result.return_value = None + + def _fake_run_coroutine_threadsafe(coro, loop): + coro.close() + return mock_future + + with mock.patch.object( + asyncio, + "run_coroutine_threadsafe", + side_effect=_fake_run_coroutine_threadsafe, + ) as mock_rcts: + await plugin.shutdown() + + # Verify run_coroutine_threadsafe was called with + # the other loop. + mock_rcts.assert_called() + call_args = mock_rcts.call_args + assert call_args[0][1] is other_loop diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index 08dfdd6f..6b12790d 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -502,6 +502,86 @@ class TestHITLConfirmationFlowWithResumableApp: == expected_parts_final ) + @pytest.mark.asyncio + async def test_pause_and_resume_on_request_confirmation_without_invocation_id( + self, + runner: testing_utils.InMemoryRunner, + agent: LlmAgent, + ): + """Tests HITL flow where all tool calls are confirmed.""" + events = runner.run("test user query") + + # Verify that the invocation is paused when tool confirmation is requested. + # The tool call returns error response, and summarization was skipped. + assert testing_utils.simplify_resumable_app_events( + copy.deepcopy(events) + ) == [ + ( + agent.name, + Part(function_call=FunctionCall(name=agent.tools[0].name, args={})), + ), + ( + agent.name, + Part( + function_call=FunctionCall( + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + args={ + "originalFunctionCall": { + "name": agent.tools[0].name, + "id": mock.ANY, + "args": {}, + }, + "toolConfirmation": { + "hint": HINT_TEXT, + "confirmed": False, + }, + }, + ) + ), + ), + ( + agent.name, + Part( + function_response=FunctionResponse( + name=agent.tools[0].name, response=TOOL_CALL_ERROR_RESPONSE + ) + ), + ), + ] + ask_for_confirmation_function_call_id = ( + events[1].content.parts[0].function_call.id + ) + invocation_id = events[1].invocation_id + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events = await runner.run_async(user_confirmation) + expected_parts_final = [ + ( + agent.name, + Part( + function_response=FunctionResponse( + name=agent.tools[0].name, + response={"result": "confirmed=True"}, + ) + ), + ), + (agent.name, "test llm response after tool call"), + (agent.name, testing_utils.END_OF_AGENT), + ] + for event in events: + assert event.invocation_id == invocation_id + assert ( + testing_utils.simplify_resumable_app_events(copy.deepcopy(events)) + == expected_parts_final + ) + class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp: """Tests the HITL confirmation flow with a resumable sequential agent app.""" diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 25530bed..5c5aa83e 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -418,16 +418,41 @@ async def test_temp_state_is_not_persisted_in_state_or_events(session_service): ) await session_service.append_event(session=session, event=event) - # Refetch session and check state and event - session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' - ) - # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + # Temp state IS available in the in-memory session (same invocation) + assert session.state.get('temp:k1') == 'v1' + assert session.state.get('sk') == 'v2' + # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert 'temp:k1' not in event.actions.state_delta + assert event.actions.state_delta.get('sk') == 'v2' + + +@pytest.mark.asyncio +async def test_temp_state_visible_across_sequential_events(session_service): + """Temp state set by one event should be readable before the next event. + + This simulates a SequentialAgent where agent-1 writes output_key='temp:out' + and agent-2 needs to read it from session.state within the same invocation. + """ + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s_seq' + ) + + # Agent-1 writes temp state + event1 = Event( + invocation_id='inv1', + author='agent1', + actions=EventActions(state_delta={'temp:output': 'result_from_a1'}), + ) + await session_service.append_event(session=session, event=event1) + + # Agent-2 should be able to read temp state from the same session object + assert session.state.get('temp:output') == 'result_from_a1' + + # But the event delta should NOT contain the temp key (not persisted) + assert 'temp:output' not in event1.actions.state_delta @pytest.mark.asyncio @@ -1153,3 +1178,92 @@ async def test_prepare_tables_idempotent_after_creation(): assert session.id == 's1' finally: await service.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'state_delta, expect_app_lock, expect_user_lock', + [ + pytest.param( + None, + False, + False, + id='no_state_delta', + ), + pytest.param( + {'session_key': 'v'}, + False, + False, + id='session_only_delta', + ), + pytest.param( + {'app:key': 'v'}, + True, + False, + id='app_delta_only', + ), + pytest.param( + {'user:key': 'v'}, + False, + True, + id='user_delta_only', + ), + pytest.param( + {'app:a': '1', 'user:b': '2', 'sk': '3'}, + True, + True, + id='all_scopes', + ), + ], +) +async def test_append_event_locks_only_scopes_with_deltas( + state_delta, expect_app_lock, expect_user_lock +): + """FOR UPDATE should only be requested for state scopes that have deltas.""" + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + + lock_requests = [] + original_fn = database_session_service._select_required_state + + async def tracking_fn(**kwargs): + lock_requests.append({ + 'model': kwargs['state_model'].__tablename__, + 'use_row_level_locking': kwargs['use_row_level_locking'], + }) + return await original_fn(**kwargs) + + try: + session = await service.create_session( + app_name='app', user_id='user', session_id='s1' + ) + + database_session_service._select_required_state = tracking_fn + lock_requests.clear() + + event_kwargs = {'invocation_id': 'inv', 'author': 'user'} + if state_delta is not None: + event_kwargs['actions'] = EventActions(state_delta=state_delta) + event = Event(**event_kwargs) + await service.append_event(session, event) + + app_req = next( + (r for r in lock_requests if r['model'] == 'app_states'), None + ) + user_req = next( + (r for r in lock_requests if r['model'] == 'user_states'), None + ) + + # SQLite doesn't support row-level locking so use_row_level_locking is + # always False. The important check is that locking is not requested + # when there is no delta (it must never be True without a delta). + if not expect_app_lock: + assert ( + app_req is None or not app_req['use_row_level_locking'] + ), 'app_states should not be locked without an app: delta' + if not expect_user_lock: + assert ( + user_req is None or not user_req['use_row_level_locking'] + ), 'user_states should not be locked without a user: delta' + finally: + database_session_service._select_required_state = original_fn + await service.close() diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 8c77f194..c095ddd9 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -27,6 +27,7 @@ from google.adk.auth import auth_schemes from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.events.event_actions import EventCompaction from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService @@ -826,3 +827,87 @@ async def test_append_event(): assert len(retrieved_session.events) == 2 event_to_append.id = retrieved_session.events[1].id assert retrieved_session.events[1] == event_to_append + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_compaction(): + """Compaction data round-trips through append_event and get_session.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session is not None + + compaction = EventCompaction( + start_timestamp=1000.0, + end_timestamp=2000.0, + compacted_content=genai_types.Content( + parts=[genai_types.Part(text='compacted summary')] + ), + ) + event_to_append = Event( + invocation_id='compaction_invocation', + author='model', + timestamp=1734005534.0, + actions=EventActions(compaction=compaction), + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert retrieved_session is not None + + appended_event = retrieved_session.events[-1] + assert appended_event.actions.compaction is not None + assert appended_event.actions.compaction.start_timestamp == 1000.0 + assert appended_event.actions.compaction.end_timestamp == 2000.0 + assert appended_event.actions.compaction.compacted_content.parts[0].text == ( + 'compacted summary' + ) + # custom_metadata should remain None when only compaction was stored + assert appended_event.custom_metadata is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_with_compaction_and_custom_metadata(): + """Both compaction and user custom_metadata survive the round-trip.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert session is not None + + compaction = EventCompaction( + start_timestamp=100.0, + end_timestamp=200.0, + compacted_content=genai_types.Content( + parts=[genai_types.Part(text='summary')] + ), + ) + event_to_append = Event( + invocation_id='compaction_and_meta_invocation', + author='model', + timestamp=1734005535.0, + actions=EventActions(compaction=compaction), + custom_metadata={'user_key': 'user_value'}, + ) + + await session_service.append_event(session, event_to_append) + + retrieved_session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + assert retrieved_session is not None + + appended_event = retrieved_session.events[-1] + # Compaction is restored + assert appended_event.actions.compaction is not None + assert appended_event.actions.compaction.start_timestamp == 100.0 + assert appended_event.actions.compaction.end_timestamp == 200.0 + # User custom_metadata is preserved without the internal _compaction key + assert appended_event.custom_metadata == {'user_key': 'user_value'} + assert '_compaction' not in (appended_event.custom_metadata or {}) diff --git a/tests/unittests/skills/test__utils.py b/tests/unittests/skills/test__utils.py new file mode 100644 index 00000000..5a65648d --- /dev/null +++ b/tests/unittests/skills/test__utils.py @@ -0,0 +1,182 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for skill utilities.""" + +from google.adk.skills import load_skill_from_dir as _load_skill_from_dir +from google.adk.skills._utils import _read_skill_properties +from google.adk.skills._utils import _validate_skill_dir +import pytest + + +def test__load_skill_from_dir(tmp_path): + """Tests loading a skill from a directory.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + + skill_md_content = """--- +name: test-skill +description: Test description +--- +Test instructions +""" + (skill_dir / "SKILL.md").write_text(skill_md_content) + + # Create references + ref_dir = skill_dir / "references" + ref_dir.mkdir() + (ref_dir / "ref1.md").write_text("ref1 content") + + # Create assets + assets_dir = skill_dir / "assets" + assets_dir.mkdir() + (assets_dir / "asset1.txt").write_text("asset1 content") + + # Create scripts + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir() + (scripts_dir / "script1.sh").write_text("echo hello") + + skill = _load_skill_from_dir(skill_dir) + + assert skill.name == "test-skill" + assert skill.description == "Test description" + assert skill.instructions == "Test instructions" + assert skill.resources.get_reference("ref1.md") == "ref1 content" + assert skill.resources.get_asset("asset1.txt") == "asset1 content" + assert skill.resources.get_script("script1.sh").src == "echo hello" + + +def test_allowed_tools_yaml_key(tmp_path): + """Tests that allowed-tools YAML key loads correctly.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +allowed-tools: "some-tool-*" +--- +Instructions here +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + skill = _load_skill_from_dir(skill_dir) + assert skill.frontmatter.allowed_tools == "some-tool-*" + + +def test_name_directory_mismatch(tmp_path): + """Tests that name-directory mismatch raises ValueError.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + with pytest.raises(ValueError, match="does not match directory"): + _load_skill_from_dir(skill_dir) + + +def test_validate_skill_dir_valid(tmp_path): + """Tests validate_skill_dir with a valid skill.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert problems == [] + + +def test_validate_skill_dir_missing_dir(tmp_path): + """Tests validate_skill_dir with missing directory.""" + problems = _validate_skill_dir(tmp_path / "nonexistent") + assert len(problems) == 1 + assert "does not exist" in problems[0] + + +def test_validate_skill_dir_missing_skill_md(tmp_path): + """Tests validate_skill_dir with missing SKILL.md.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + problems = _validate_skill_dir(skill_dir) + assert len(problems) == 1 + assert "SKILL.md not found" in problems[0] + + +def test_validate_skill_dir_name_mismatch(tmp_path): + """Tests validate_skill_dir catches name-directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert any("does not match" in p for p in problems) + + +def test_validate_skill_dir_unknown_fields(tmp_path): + """Tests validate_skill_dir detects unknown frontmatter fields.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A skill +unknown-field: something +--- +Body +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + problems = _validate_skill_dir(skill_dir) + assert any("Unknown frontmatter" in p for p in problems) + + +def test__read_skill_properties(tmp_path): + """Tests read_skill_properties basic usage.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + + skill_md = """--- +name: my-skill +description: A cool skill +license: MIT +--- +Body content +""" + (skill_dir / "SKILL.md").write_text(skill_md) + + fm = _read_skill_properties(skill_dir) + assert fm.name == "my-skill" + assert fm.description == "A cool skill" + assert fm.license == "MIT" diff --git a/tests/unittests/skills/test_models.py b/tests/unittests/skills/test_models.py index 6ecdd51f..5685e9d8 100644 --- a/tests/unittests/skills/test_models.py +++ b/tests/unittests/skills/test_models.py @@ -15,6 +15,7 @@ """Unit tests for skill models.""" from google.adk.skills import models +from pydantic import ValidationError import pytest @@ -68,3 +69,138 @@ def test_script_to_string(): """Tests Script model.""" script = models.Script(src="print('hello')") assert str(script) == "print('hello')" + + +# --- Name validation tests --- + + +def test_name_too_long(): + with pytest.raises(ValidationError, match="at most 64 characters"): + models.Frontmatter(name="a" * 65, description="desc") + + +def test_name_uppercase_rejected(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="My-Skill", description="desc") + + +def test_name_leading_hyphen(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="-my-skill", description="desc") + + +def test_name_trailing_hyphen(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my-skill-", description="desc") + + +def test_name_consecutive_hyphens(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my--skill", description="desc") + + +def test_name_invalid_chars_underscore(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="my_skill", description="desc") + + +def test_name_invalid_chars_ampersand(): + with pytest.raises(ValidationError, match="lowercase kebab-case"): + models.Frontmatter(name="skill&name", description="desc") + + +def test_name_valid_passes(): + fm = models.Frontmatter(name="my-skill-2", description="desc") + assert fm.name == "my-skill-2" + + +def test_name_single_word(): + fm = models.Frontmatter(name="skill", description="desc") + assert fm.name == "skill" + + +# --- Description validation tests --- + + +def test_description_empty(): + with pytest.raises(ValidationError, match="must not be empty"): + models.Frontmatter(name="my-skill", description="") + + +def test_description_too_long(): + with pytest.raises(ValidationError, match="at most 1024 characters"): + models.Frontmatter(name="my-skill", description="x" * 1025) + + +# --- Compatibility validation tests --- + + +def test_compatibility_too_long(): + with pytest.raises(ValidationError, match="at most 500 characters"): + models.Frontmatter( + name="my-skill", description="desc", compatibility="c" * 501 + ) + + +# --- Extra field rejected --- + + +def test_extra_field_allowed(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "unknown_field": "value", + }) + assert fm.name == "my-skill" + + +# --- allowed-tools alias --- + + +def test_allowed_tools_alias_via_model_validate(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "allowed-tools": "tool-pattern", + }) + assert fm.allowed_tools == "tool-pattern" + + +def test_allowed_tools_serialization_alias(): + fm = models.Frontmatter( + name="my-skill", description="desc", allowed_tools="tool-pattern" + ) + dumped = fm.model_dump(by_alias=True) + assert "allowed-tools" in dumped + assert dumped["allowed-tools"] == "tool-pattern" + + +def test_metadata_adk_additional_tools_list(): + fm = models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": ["tool1", "tool2"]}, + }) + assert fm.metadata["adk_additional_tools"] == ["tool1", "tool2"] + + +def test_metadata_adk_additional_tools_rejected_as_string(): + with pytest.raises( + ValidationError, match="adk_additional_tools must be a list of strings" + ): + models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": "tool1 tool2"}, + }) + + +def test_metadata_adk_additional_tools_invalid_type(): + with pytest.raises( + ValidationError, match="adk_additional_tools must be a list of strings" + ): + models.Frontmatter.model_validate({ + "name": "my-skill", + "description": "desc", + "metadata": {"adk_additional_tools": 123}, + }) diff --git a/tests/unittests/skills/test_prompt.py b/tests/unittests/skills/test_prompt.py index f5395f3c..aa48c7b8 100644 --- a/tests/unittests/skills/test_prompt.py +++ b/tests/unittests/skills/test_prompt.py @@ -42,8 +42,8 @@ class TestPrompt: def test_format_skills_as_xml_escaping(self): skills = [ - models.Frontmatter(name="skill&name", description="desc"), + models.Frontmatter(name="my-skill", description="desc"), ] xml = prompt.format_skills_as_xml(skills) - assert "skill&name" in xml + assert "my-skill" in xml assert "desc<ription>" in xml diff --git a/tests/unittests/skills/test_utils.py b/tests/unittests/skills/test_utils.py deleted file mode 100644 index d922719d..00000000 --- a/tests/unittests/skills/test_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for skill utilities.""" - -from google.adk.skills import load_skill_from_dir -import pytest - - -def test_load_skill_from_dir(tmp_path): - """Tests loading a skill from a directory.""" - skill_dir = tmp_path / "test-skill" - skill_dir.mkdir() - - skill_md_content = """--- -name: test-skill -description: Test description ---- -Test instructions -""" - (skill_dir / "SKILL.md").write_text(skill_md_content) - - # Create references - ref_dir = skill_dir / "references" - ref_dir.mkdir() - (ref_dir / "ref1.md").write_text("ref1 content") - - # Create assets - assets_dir = skill_dir / "assets" - assets_dir.mkdir() - (assets_dir / "asset1.txt").write_text("asset1 content") - - # Create scripts - scripts_dir = skill_dir / "scripts" - scripts_dir.mkdir() - (scripts_dir / "script1.sh").write_text("echo hello") - - skill = load_skill_from_dir(skill_dir) - - assert skill.name == "test-skill" - assert skill.description == "Test description" - assert skill.instructions == "Test instructions" - assert skill.resources.get_reference("ref1.md") == "ref1 content" - assert skill.resources.get_asset("asset1.txt") == "asset1 content" - assert skill.resources.get_script("script1.sh").src == "echo hello" diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index f7d4b0a3..3b7d93c4 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -97,6 +97,14 @@ async def test_tracer_start_as_current_span( def wrapped_firstiter(coro): nonlocal firstiter + # Skip check for specific async context managers in tracing.py, + # as their internal generators are not expected to be Aclosing-wrapped. + if ( + coro.__name__ == 'use_inference_span' + or coro.__name__ == '_use_native_generate_content_span' + ): + firstiter(coro) + return assert any( isinstance(referrer, Aclosing) or isinstance(indirect_referrer, Aclosing) diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index bb084676..3c061e42 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -26,25 +26,36 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.telemetry.tracing import ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS from google.adk.telemetry.tracing import trace_agent_invocation from google.adk.telemetry.tracing import trace_call_llm -from google.adk.telemetry.tracing import trace_generate_content_result +from google.adk.telemetry.tracing import trace_inference_result from google.adk.telemetry.tracing import trace_merged_tool_calls from google.adk.telemetry.tracing import trace_send_data from google.adk.telemetry.tracing import trace_tool_call -from google.adk.telemetry.tracing import use_generate_content_span +from google.adk.telemetry.tracing import use_inference_span from google.adk.tools.base_tool import BaseTool from google.genai import types +from mcp import ClientSession as McpClientSession +from mcp import ListToolsResult as McpListToolsResult +from mcp import Tool as McpTool from opentelemetry._logs import LogRecord from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_AGENT_NAME from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_CONVERSATION_ID +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_INPUT_MESSAGES from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OPERATION_NAME +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_OUTPUT_MESSAGES from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_REQUEST_MODEL from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_RESPONSE_FINISH_REASONS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM +from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_SYSTEM_INSTRUCTIONS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_INPUT_TOKENS from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_USAGE_OUTPUT_TOKENS from opentelemetry.semconv._incubating.attributes.user_attributes import USER_ID import pytest +try: + from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import GEN_AI_TOOL_DEFINITIONS +except ImportError: + GEN_AI_TOOL_DEFINITIONS = 'gen_ai.tool_definitions' + class Event: @@ -731,12 +742,12 @@ async def test_generate_content_span( ) # Act - with use_generate_content_span( + async with use_inference_span( llm_request, invocation_context, model_response_event - ) as span: - assert span is mock_span + ) as gc_span: + assert gc_span.span is mock_span - trace_generate_content_result(span, llm_response) + trace_inference_result(gc_span, llm_response) # Assert Span mock_tracer.start_as_current_span.assert_called_once_with( @@ -810,3 +821,357 @@ async def test_generate_content_span( assert choice_log is not None assert choice_log.body == expected_choice_body assert choice_log.attributes == {GEN_AI_SYSTEM: 'test_system'} + + +def _mock_callable_tool(): + """Description of some tool.""" + return 'result' + + +def _mock_mcp_client_session() -> McpClientSession: + mock_session = mock.create_autospec(spec=McpClientSession, instance=True) + + mock_tool_obj = McpTool( + name='mcp_tool', + description='Tool from session', + inputSchema={ + 'type': 'object', + 'properties': {'query': {'type': 'string'}}, + }, + ) + mock_result = mock.create_autospec(McpListToolsResult, instance=True) + mock_result.tools = [mock_tool_obj] + + mock_session.list_tools = mock.AsyncMock(return_value=mock_result) + + return mock_session + + +def _mock_mcp_tool(): + return McpTool( + name='mcp_tool', + description='A standalone mcp tool', + inputSchema={ + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + }, + ) + + +def _mock_tool_dict() -> types.ToolDict: + return types.ToolDict( + function_declarations=[ + types.FunctionDeclarationDict( + name='mock_tool', description='Description of mock tool.' + ), + ], + google_maps=types.GoogleMaps(), + ) + + +@pytest.mark.asyncio +@mock.patch('google.adk.telemetry.tracing.otel_logger') +@mock.patch('google.adk.telemetry.tracing.tracer') +@mock.patch( + 'google.adk.telemetry.tracing._guess_gemini_system_name', + return_value='test_system', +) +@pytest.mark.parametrize( + 'capture_content', + ['SPAN_AND_EVENT', 'EVENT_ONLY', 'SPAN_ONLY', 'NO_CONTENT'], +) +async def test_generate_content_span_with_experimental_semconv( + mock_guess_system_name, + mock_tracer, + mock_otel_logger, + monkeypatch, + capture_content, +): + """Test native generate_content span creation with attributes and logs with experimental semconv enabled.""" + # Arrange + monkeypatch.setenv( + 'OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT', + str(capture_content).lower(), + ) + monkeypatch.setenv( + 'OTEL_SEMCONV_STABILITY_OPT_IN', + 'gen_ai_latest_experimental', + ) + monkeypatch.setattr( + 'google.adk.telemetry.tracing._instrumented_with_opentelemetry_instrumentation_google_genai', + lambda: False, + ) + + agent = LlmAgent(name='test_agent', model='not-a-gemini-model') + invocation_context = await _create_invocation_context(agent) + + system_instruction = types.Content( + parts=[types.Part.from_text(text='You are a helpful assistant.')], + ) + + user_content1 = types.Content(role='user', parts=[types.Part(text='Hello')]) + user_content2 = types.Content(role='user', parts=[types.Part(text='World')]) + + model_content = types.Content( + role='model', parts=[types.Part(text='Response')] + ) + + tools = [ + _mock_callable_tool, + _mock_tool_dict(), + _mock_mcp_client_session(), + _mock_mcp_tool(), + ] + + llm_request = LlmRequest( + model='some-model', + contents=[user_content1, user_content2], + config=types.GenerateContentConfig( + system_instruction=system_instruction, tools=tools + ), + ) + llm_response = LlmResponse( + content=model_content, + finish_reason=types.FinishReason.STOP, + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=20, + ), + ) + + model_response_event = mock.MagicMock() + model_response_event.id = 'event-123' + + mock_span = ( + mock_tracer.start_as_current_span.return_value.__enter__.return_value + ) + + # Act + async with use_inference_span( + llm_request, + invocation_context, + model_response_event, + ) as gc_span: + assert gc_span.span is mock_span + + trace_inference_result(gc_span, llm_response) + + # Expected attributes + expected_system_instructions = [ + { + 'content': 'You are a helpful assistant.', + 'type': 'text', + }, + ] + expected_input_messages = [ + { + 'role': 'user', + 'parts': [ + {'content': 'Hello', 'type': 'text'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'content': 'World', 'type': 'text'}, + ], + }, + ] + expected_output_messages = [{ + 'role': 'assistant', + 'parts': [ + {'content': 'Response', 'type': 'text'}, + ], + 'finish_reason': 'stop', + }] + expected_tool_definitions = [ + { + 'name': '_mock_callable_tool', + 'description': 'Description of some tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mock_tool', + 'description': 'Description of mock tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'google_maps', + 'type': 'google_maps', + }, + { + 'name': 'mcp_tool', + 'description': 'Tool from session', + 'parameters': { + 'type': 'object', + 'properties': {'query': {'type': 'string'}}, + }, + 'type': 'function', + }, + { + 'name': 'mcp_tool', + 'description': 'A standalone mcp tool', + 'parameters': { + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + }, + 'type': 'function', + }, + ] + expected_tool_definitions_no_content = [ + { + 'name': '_mock_callable_tool', + 'description': 'Description of some tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mock_tool', + 'description': 'Description of mock tool.', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'google_maps', + 'type': 'google_maps', + }, + { + 'name': 'mcp_tool', + 'description': 'Tool from session', + 'parameters': None, + 'type': 'function', + }, + { + 'name': 'mcp_tool', + 'description': 'A standalone mcp tool', + 'parameters': None, + 'type': 'function', + }, + ] + expected_tool_definitions_json = ( + '[{"name":"_mock_callable_tool","description":"Description of some' + ' tool.","parameters":null,"type":"function"},{"name":"mock_tool","description":"Description' + ' of mock' + ' tool.","parameters":null,"type":"function"},{"name":"google_maps","type":"google_maps"},{"name":"mcp_tool","description":"Tool' + ' from' + ' session","parameters":{"type":"object","properties":{"query":{"type":"string"}}},"type":"function"},{"name":"mcp_tool","description":"A' + ' standalone mcp' + ' tool","parameters":{"type":"object","properties":{"id":{"type":"integer"}}},"type":"function"}]' + ) + + expected_tool_definitions_no_content_json = ( + '[{"name":"_mock_callable_tool","description":"Description of some' + ' tool.","parameters":null,"type":"function"},{"name":"mock_tool","description":"Description' + ' of mock' + ' tool.","parameters":null,"type":"function"},{"name":"google_maps","type":"google_maps"},{"name":"mcp_tool","description":"Tool' + ' from' + ' session","parameters":null,"type":"function"},{"name":"mcp_tool","description":"A' + ' standalone mcp tool","parameters":null,"type":"function"}]' + ) + # Assert Span + mock_tracer.start_as_current_span.assert_called_once_with( + 'generate_content some-model' + ) + + mock_span.set_attribute.assert_any_call( + GEN_AI_OPERATION_NAME, 'generate_content' + ) + mock_span.set_attribute.assert_any_call(GEN_AI_REQUEST_MODEL, 'some-model') + mock_span.set_attribute.assert_any_call( + GEN_AI_RESPONSE_FINISH_REASONS, ['stop'] + ) + mock_span.set_attribute.assert_any_call(GEN_AI_USAGE_INPUT_TOKENS, 10) + mock_span.set_attribute.assert_any_call(GEN_AI_USAGE_OUTPUT_TOKENS, 20) + + mock_span.set_attributes.assert_called_once_with({ + GEN_AI_AGENT_NAME: invocation_context.agent.name, + GEN_AI_CONVERSATION_ID: invocation_context.session.id, + USER_ID: invocation_context.session.user_id, + 'gcp.vertex.agent.event_id': 'event-123', + 'gcp.vertex.agent.invocation_id': invocation_context.invocation_id, + }) + + if capture_content in ['SPAN_AND_EVENT', 'SPAN_ONLY']: + mock_span.set_attribute.assert_any_call( + GEN_AI_SYSTEM_INSTRUCTIONS, + '[{"content":"You are a helpful assistant.","type":"text"}]', + ) + mock_span.set_attribute.assert_any_call( + GEN_AI_INPUT_MESSAGES, + '[{"role":"user","parts":[{"content":"Hello","type":"text"}]},{"role":"user","parts":[{"content":"World","type":"text"}]}]', + ) + mock_span.set_attribute.assert_any_call( + GEN_AI_OUTPUT_MESSAGES, + '[{"role":"assistant","parts":[{"content":"Response","type":"text"}],"finish_reason":"stop"}]', + ) + mock_span.set_attribute.assert_any_call( + GEN_AI_TOOL_DEFINITIONS, expected_tool_definitions_json + ) + else: + all_attribute_calls = mock_span.set_attribute.call_args_list + assert GEN_AI_SYSTEM_INSTRUCTIONS not in all_attribute_calls + assert GEN_AI_INPUT_MESSAGES not in all_attribute_calls + assert GEN_AI_OUTPUT_MESSAGES not in all_attribute_calls + mock_span.set_attribute.assert_any_call( + GEN_AI_TOOL_DEFINITIONS, expected_tool_definitions_no_content_json + ) + + # Assert Logs + assert mock_otel_logger.emit.call_count == 1 + + log_records: list[LogRecord] = [ + call.args[0] for call in mock_otel_logger.emit.call_args_list + ] + + operation_details_log = next( + ( + lr + for lr in log_records + if lr.event_name == 'gen_ai.client.inference.operation.details' + ), + None, + ) + + assert operation_details_log is not None + assert operation_details_log.attributes is not None + + attributes = operation_details_log.attributes + + if capture_content in ['SPAN_AND_EVENT', 'EVENT_ONLY']: + assert GEN_AI_SYSTEM_INSTRUCTIONS in attributes + assert ( + attributes[GEN_AI_SYSTEM_INSTRUCTIONS] == expected_system_instructions + ) + assert GEN_AI_INPUT_MESSAGES in attributes + assert attributes[GEN_AI_INPUT_MESSAGES] == expected_input_messages + assert GEN_AI_OUTPUT_MESSAGES in attributes + assert attributes[GEN_AI_OUTPUT_MESSAGES] == expected_output_messages + assert GEN_AI_TOOL_DEFINITIONS in attributes + assert attributes[GEN_AI_TOOL_DEFINITIONS] == expected_tool_definitions + else: + assert GEN_AI_SYSTEM_INSTRUCTIONS not in attributes + assert GEN_AI_INPUT_MESSAGES not in attributes + assert GEN_AI_OUTPUT_MESSAGES not in attributes + assert GEN_AI_TOOL_DEFINITIONS in attributes + assert ( + attributes[GEN_AI_TOOL_DEFINITIONS] + == expected_tool_definitions_no_content + ) + + assert GEN_AI_USAGE_INPUT_TOKENS in attributes + assert attributes[GEN_AI_USAGE_INPUT_TOKENS] == 10 + assert GEN_AI_USAGE_OUTPUT_TOKENS in attributes + assert attributes[GEN_AI_USAGE_OUTPUT_TOKENS] == 20 + assert 'gcp.vertex.agent.event_id' in attributes + assert attributes['gcp.vertex.agent.event_id'] == 'event-123' + assert 'gcp.vertex.agent.invocation_id' in attributes + assert ( + attributes['gcp.vertex.agent.invocation_id'] + == invocation_context.invocation_id + ) + assert GEN_AI_AGENT_NAME in attributes + assert attributes[GEN_AI_AGENT_NAME] == invocation_context.agent.name + assert GEN_AI_CONVERSATION_ID in attributes + assert attributes[GEN_AI_CONVERSATION_ID] == invocation_context.session.id diff --git a/tests/unittests/telemetry/test_sqlite_span_exporter.py b/tests/unittests/telemetry/test_sqlite_span_exporter.py new file mode 100644 index 00000000..21437175 --- /dev/null +++ b/tests/unittests/telemetry/test_sqlite_span_exporter.py @@ -0,0 +1,462 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from pathlib import Path + +from google.adk.telemetry.sqlite_span_exporter import SqliteSpanExporter +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExportResult +from opentelemetry.trace import SpanContext +from opentelemetry.trace import TraceFlags +from opentelemetry.trace import TraceState + + +def _create_span( + *, + span_id: int = 0x00000000000ABC12, + trace_id: int = 0x000000000000000000000000000DEF45, + parent_span_id: int | None = None, + name: str = "test_span", + attributes: dict | None = None, + start_time: int = 1000, + end_time: int = 2000, +) -> ReadableSpan: + """Helper to create ReadableSpan instances for testing.""" + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(), + ) + + parent = None + if parent_span_id is not None: + parent = SpanContext( + trace_id=trace_id, + span_id=parent_span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(), + ) + + return ReadableSpan( + name=name, + context=context, + parent=parent, + attributes=attributes or {}, + start_time=start_time, + end_time=end_time, + ) + + +def test_export_single_span_returns_success(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span( + name="test_operation", + attributes={"gcp.vertex.agent.session_id": "session-123"}, + ) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + assert db_path.exists() + + +def test_export_empty_list_returns_success(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + result = exporter.export([]) + + assert result == SpanExportResult.SUCCESS + + +def test_get_all_spans_for_session_returns_matching_spans(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span1 = _create_span( + span_id=0x111, + trace_id=0xAAA111, # Different trace for session-123 + attributes={"gcp.vertex.agent.session_id": "session-123"}, + name="span1", + ) + span2 = _create_span( + span_id=0x222, + trace_id=0xAAA222, # Different trace for session-123 + attributes={"gcp.vertex.agent.session_id": "session-123"}, + name="span2", + ) + span3 = _create_span( + span_id=0x333, + trace_id=0xBBB333, # Different trace for session-456 + attributes={"gcp.vertex.agent.session_id": "session-456"}, + name="span3", + ) + + exporter.export([span1, span2, span3]) + + result = exporter.get_all_spans_for_session("session-123") + + assert len(result) == 2 + names = [span.name for span in result] + assert "span1" in names + assert "span2" in names + assert "span3" not in names + + +def test_get_all_spans_for_session_includes_sibling_spans_without_session_id( + tmp_path, +): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Parent span without session_id (e.g., invocation span) + parent_span = _create_span( + span_id=0x100, + trace_id=0xAAA, + name="invocation", + attributes={}, # No session_id + ) + + # Child span with session_id + child_span = _create_span( + span_id=0x200, + trace_id=0xAAA, # Same trace + parent_span_id=0x100, + name="call_llm", + attributes={"gcp.vertex.agent.session_id": "session-789"}, + ) + + # Sibling span without session_id (should be included) + sibling_span = _create_span( + span_id=0x300, + trace_id=0xAAA, # Same trace + parent_span_id=0x100, + name="tool_call", + attributes={}, # No session_id + ) + + # Unrelated span with different trace_id (should not be included) + unrelated_span = _create_span( + span_id=0x400, + trace_id=0xBBB, # Different trace + name="unrelated", + attributes={}, + ) + + exporter.export([parent_span, child_span, sibling_span, unrelated_span]) + + result = exporter.get_all_spans_for_session("session-789") + + assert len(result) == 3 + names = [span.name for span in result] + assert "invocation" in names + assert "call_llm" in names + assert "tool_call" in names + assert "unrelated" not in names + + +def test_get_all_spans_for_unknown_session_returns_empty_list(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span( + attributes={"gcp.vertex.agent.session_id": "session-123"}, + ) + exporter.export([span]) + + result = exporter.get_all_spans_for_session("unknown-session") + + assert result == [] + + +def test_round_trip_preserves_span_attributes(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + original_attributes = { + "gcp.vertex.agent.session_id": "session-123", + "gcp.vertex.agent.invocation_id": "invocation-456", + "gen_ai.conversation.id": "conv-789", + "custom.attribute": "test_value", + "numeric.value": 42, + "boolean.value": True, + "list.value": [1, 2, 3], + "dict.value": {"nested": "data"}, + } + + original_span = _create_span( + span_id=0x12345678, + trace_id=0xABCDEF123456789, + name="test_operation", + attributes=original_attributes, + start_time=1000000, + end_time=2000000, + ) + + exporter.export([original_span]) + + retrieved_spans = exporter.get_all_spans_for_session("session-123") + + assert len(retrieved_spans) == 1 + retrieved = retrieved_spans[0] + + assert retrieved.name == "test_operation" + assert retrieved.context.span_id == 0x12345678 + assert retrieved.context.trace_id == 0xABCDEF123456789 + assert retrieved.start_time == 1000000 + assert retrieved.end_time == 2000000 + assert retrieved.attributes == original_attributes + + +def test_spans_with_parent_context_exported_correctly(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + parent_span = _create_span( + span_id=0xAAA, + trace_id=0x123, + name="parent", + attributes={"gcp.vertex.agent.session_id": "session-001"}, + ) + + child_span = _create_span( + span_id=0xBBB, + trace_id=0x123, + parent_span_id=0xAAA, + name="child", + attributes={"gcp.vertex.agent.session_id": "session-001"}, + ) + + exporter.export([parent_span, child_span]) + + retrieved_spans = exporter.get_all_spans_for_session("session-001") + + assert len(retrieved_spans) == 2 + + # Find child span in results + child = next(s for s in retrieved_spans if s.name == "child") + assert child.parent is not None + assert child.parent.span_id == 0xAAA + assert child.parent.trace_id == 0x123 + + # Find parent span in results + parent = next(s for s in retrieved_spans if s.name == "parent") + assert parent.parent is None + + +def test_shutdown_closes_connection(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create a span to ensure connection is open + span = _create_span() + exporter.export([span]) + + # Verify connection exists + assert exporter._conn is not None + + exporter.shutdown() + + # Verify connection is closed + assert exporter._conn is None + + +def test_force_flush_returns_true(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + result = exporter.force_flush() + + assert result is True + + # Also test with timeout parameter + result_with_timeout = exporter.force_flush(timeout_millis=5000) + assert result_with_timeout is True + + +def test_export_handles_spans_with_none_attributes(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + span = _create_span(attributes=None) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + + # Verify the span was stored correctly + rows = exporter._query("SELECT attributes_json FROM spans", []) + assert len(rows) == 1 + attributes_json = rows[0]["attributes_json"] + assert json.loads(attributes_json) == {} + + +def test_duplicate_span_id_replaces_previous_row(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Export first version of span + span1 = _create_span( + span_id=0x999, + name="first_version", + attributes={"version": 1, "gcp.vertex.agent.session_id": "session-dup"}, + ) + exporter.export([span1]) + + # Export second version with same span_id + span2 = _create_span( + span_id=0x999, + name="second_version", + attributes={"version": 2, "gcp.vertex.agent.session_id": "session-dup"}, + ) + exporter.export([span2]) + + # Verify only one row exists with updated data + retrieved_spans = exporter.get_all_spans_for_session("session-dup") + assert len(retrieved_spans) == 1 + assert retrieved_spans[0].name == "second_version" + assert retrieved_spans[0].attributes["version"] == 2 + + +def test_non_serializable_attributes_use_fallback(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create a non-serializable object + class NonSerializable: + pass + + attributes = { + "gcp.vertex.agent.session_id": "session-nonser", + "normal_attr": "value", + "non_serializable": NonSerializable(), + } + + span = _create_span(attributes=attributes) + + result = exporter.export([span]) + + assert result == SpanExportResult.SUCCESS + + # Verify the span was stored and non-serializable attribute has fallback + retrieved_spans = exporter.get_all_spans_for_session("session-nonser") + assert len(retrieved_spans) == 1 + assert retrieved_spans[0].attributes["normal_attr"] == "value" + assert ( + retrieved_spans[0].attributes["non_serializable"] == "" + ) + + +def test_export_multiple_spans_in_batch(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + spans = [ + _create_span( + span_id=i, + name=f"span_{i}", + attributes={"gcp.vertex.agent.session_id": "batch-session"}, + ) + for i in range(10) + ] + + result = exporter.export(spans) + + assert result == SpanExportResult.SUCCESS + + retrieved_spans = exporter.get_all_spans_for_session("batch-session") + assert len(retrieved_spans) == 10 + names = {span.name for span in retrieved_spans} + assert names == {f"span_{i}" for i in range(10)} + + +def test_export_with_alternative_session_id_attribute(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Test using gen_ai.conversation.id as fallback for session_id + span = _create_span( + attributes={"gen_ai.conversation.id": "conv-session-123"}, + ) + + exporter.export([span]) + + # Should be queryable by the conversation id + result = exporter.get_all_spans_for_session("conv-session-123") + + assert len(result) == 1 + assert result[0].attributes["gen_ai.conversation.id"] == "conv-session-123" + + +def test_deserialize_handles_invalid_json(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Manually insert a row with invalid JSON + conn = exporter._get_connection() + conn.execute( + "INSERT INTO spans (span_id, trace_id, name, attributes_json) VALUES (?," + " ?, ?, ?)", + ("abc123", "def456", "test", "not valid json"), + ) + conn.commit() + + # Try to retrieve the span - should not raise, but attributes should be empty + rows = exporter._query("SELECT * FROM spans", []) + span = exporter._row_to_readable_span(rows[0]) + + assert span.name == "test" + assert span.attributes == {} + + +def test_get_spans_ordered_by_start_time(tmp_path): + db_path = tmp_path / "test.db" + exporter = SqliteSpanExporter(db_path=str(db_path)) + + # Create spans with different start times + spans = [ + _create_span( + span_id=0x300, + start_time=3000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + _create_span( + span_id=0x100, + start_time=1000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + _create_span( + span_id=0x200, + start_time=2000, + attributes={"gcp.vertex.agent.session_id": "session-order"}, + ), + ] + + exporter.export(spans) + + result = exporter.get_all_spans_for_session("session-order") + + # Verify spans are ordered by start_time + assert len(result) == 3 + assert result[0].context.span_id == 0x100 + assert result[1].context.span_id == 0x200 + assert result[2].context.span_id == 0x300 diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index ca7eb375..cc3abc65 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -29,6 +29,7 @@ from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.cli.utils.agent_loader import AgentLoader +from google.adk.errors.session_not_found_error import SessionNotFoundError from google.adk.events.event import Event from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner @@ -243,7 +244,7 @@ async def test_session_not_found_message_includes_alignment_hint(): new_message=types.Content(role="user", parts=[]), ) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(SessionNotFoundError) as excinfo: await agen.__anext__() await agen.aclose() diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 80a97f8f..d8d5e726 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -18,9 +18,13 @@ import os from unittest import mock import google.adk +from google.adk.tools.bigquery.client import DP_USER_AGENT from google.adk.tools.bigquery.client import get_bigquery_client +from google.adk.tools.bigquery.client import get_dataplex_catalog_client +from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.auth.exceptions import DefaultCredentialsError +from google.cloud import dataplex_v1 from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -201,3 +205,74 @@ def test_bigquery_client_location_custom(): # Verify that the client has the desired project set assert client.project == "test-gcp-project" assert client.location == "us-central1" + + +# Tests for Dataplex Catalog Client +# ------------------------------------------------------------------------------ + + +# Mock the CatalogServiceClient class directly +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_default(mock_catalog_service_client): + """Test get_dataplex_catalog_client with default user agent.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + + client = get_dataplex_catalog_client(credentials=mock_creds) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + + assert kwargs["credentials"] == mock_creds + client_info = kwargs["client_info"] + assert isinstance(client_info, gapic_client_info.ClientInfo) + assert client_info.user_agent == DP_USER_AGENT + + # Ensure the function returns the mock instance + assert client == mock_catalog_service_client.return_value + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent string.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua = "catalog_ua/1.0" + expected_ua = f"{DP_USER_AGENT} {custom_ua}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent list.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list_with_none( + mock_catalog_service_client, +): + """Test get_dataplex_catalog_client with a list containing None.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9cf8c9e4..e2066292 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -44,9 +44,11 @@ class TestBigQueryCredentials: # Verify that the credentials are properly stored and attributes are extracted assert config.credentials == auth_creds - assert config.client_id is None assert config.client_secret is None - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -86,7 +88,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -128,7 +133,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py new file mode 100644 index 00000000..0ccdc9e1 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py @@ -0,0 +1,448 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +import unittest +from unittest import mock + +from absl.testing import parameterized + +# Mock google.genai and pydantic if not available, before importing google.adk modules +try: + import google.genai +except ImportError: + m = mock.MagicMock() + m.__path__ = [] + sys.modules["google.genai"] = m + sys.modules["google.genai.types"] = mock.MagicMock() + sys.modules["google.genai.errors"] = mock.MagicMock() + +try: + import pydantic +except ImportError: + m_pydantic = mock.MagicMock() + + class MockBaseModel: + pass + + m_pydantic.BaseModel = MockBaseModel + sys.modules["pydantic"] = m_pydantic + +try: + import fastapi + import fastapi.openapi.models +except ImportError: + m_fastapi = mock.MagicMock() + m_fastapi.openapi.models = mock.MagicMock() + sys.modules["fastapi"] = m_fastapi + sys.modules["fastapi.openapi"] = mock.MagicMock() + sys.modules["fastapi.openapi.models"] = mock.MagicMock() + + +from google.adk.tools.bigquery import search_tool +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + + +def _mock_creds(): + return mock.create_autospec(Credentials, instance=True) + + +def _mock_settings(app_name: str | None = "test-app"): + return BigQueryToolConfig(application_name=app_name) + + +def _mock_search_entries_response(results: list[dict[str, Any]]): + mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) + mock_results = [] + for r in results: + mock_result = mock.create_autospec( + dataplex_v1.SearchEntriesResult, instance=True + ) + # Manually attach dataplex_entry since it's not visible in dir() of the proto class + mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True) + mock_result.dataplex_entry = mock_entry + + mock_entry.name = r.get("name") + mock_entry.entry_type = r.get("entry_type") + mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") + + # Manually attach entry_source since it's not visible in dir() of the proto class + mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True) + mock_entry.entry_source = mock_source + + mock_source.display_name = r.get("display_name") + mock_source.resource = r.get("linked_resource") + mock_source.description = r.get("description") + mock_source.location = r.get("location") + mock_results.append(mock_result) + mock_response.results = mock_results + return mock_response + + +class TestSearchCatalog(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_dataplex_client = mock.create_autospec( + dataplex_v1.CatalogServiceClient, instance=True + ) + + # Patch get_dataplex_catalog_client + self.mock_get_dataplex_client = self.enter_context( + mock.patch( + "google.adk.tools.bigquery.client.get_dataplex_catalog_client", + autospec=True, + ) + ) + self.mock_get_dataplex_client.return_value = self.mock_dataplex_client + self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client + + # Patch SearchEntriesRequest + self.mock_search_request = self.enter_context( + mock.patch( + "google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True + ) + ) + + def test_search_catalog_success(self): + """Test the successful path of search_catalog.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "customer data" + project_id = "test-project" + location = "us" + + mock_api_results = [{ + "name": "entry1", + "entry_type": "TABLE", + "display_name": "Cust Table", + "linked_resource": ( + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1" + ), + "description": "Table 1", + "location": "us", + }] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Test result content"): + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 1) + self.assertEqual(result["results"][0]["name"], "entry1") + self.assertEqual(result["results"][0]["display_name"], "Cust Table") + + with self.subTest("Test mock calls"): + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=["test-app", "search_catalog"] + ) + + expected_query = ( + '(customer data) AND projectid="test-project" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/us", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + def test_search_catalog_no_project_id(self): + """Test search_catalog with missing project_id.""" + result = search_tool.search_catalog( + prompt="test", + project_id="", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("project_id must be provided", result["error_details"]) + self.mock_get_dataplex_client.assert_not_called() + + def test_search_catalog_api_error(self): + """Test search_catalog handling API exceptions.""" + self.mock_dataplex_client.search_entries.side_effect = ( + api_exceptions.BadRequest("Invalid query") + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn( + "Dataplex API Error: 400 Invalid query", result["error_details"] + ) + + def test_search_catalog_other_exception(self): + """Test search_catalog handling unexpected exceptions.""" + self.mock_get_dataplex_client.side_effect = Exception( + "Something went wrong" + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("Something went wrong", result["error_details"]) + + @parameterized.named_parameters( + ("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'), + ( + "multi_project_filter", + "p", + ["p1", "p2"], + None, + None, + '(projectid="p1" OR projectid="p2")', + ), + ("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'), + ( + "multi_type_filter", + "p", + None, + None, + ["TABLE", "DATASET"], + '(type="TABLE" OR type="DATASET")', + ), + ( + "project_and_dataset_filters", + "inventory", + ["proj1", "proj2"], + ["dsetA"], + None, + ( + '(projectid="proj1" OR projectid="proj2") AND' + ' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"' + ' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")' + ), + ), + ) + def test_search_catalog_query_construction( + self, prompt, project_ids, dataset_ids, types, expected_query_part + ): + """Test different query constructions based on filters.""" + search_tool.search_catalog( + prompt=prompt, + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + project_ids_filter=project_ids, + dataset_ids_filter=dataset_ids, + types_filter=types, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + query = kwargs["query"] + + if prompt: + assert f"({prompt})" in query + assert "system=BIGQUERY" in query + assert expected_query_part in query + + def test_search_catalog_no_app_name(self): + """Test search_catalog when settings.application_name is None.""" + creds = _mock_creds() + settings = _mock_settings(app_name=None) + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + location="us", + ) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=[None, "search_catalog"] + ) + + def test_search_catalog_multi_project_filter_semantic(self): + """Test semantic search with a multi-project filter.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "What datasets store user profiles?" + project_id = "main-project" + project_filters = ["user-data-proj", "shared-infra-proj"] + location = "global" + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + project_ids_filter=project_filters, + types_filter=["DATASET"], + ) + + expected_query = ( + f"({prompt}) AND " + '(projectid="user-data-proj" OR projectid="shared-infra-proj") AND ' + 'type="DATASET" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + def test_search_catalog_natural_language_semantic(self): + """Test natural language prompts with semantic search enabled and check output.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "Find tables about football matches" + project_id = "sports-analytics" + location = "europe-west1" + + # Mock the results that the API would return for this semantic query + mock_api_results = [ + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1" + ), + "display_name": "uk_football_premiership", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership" + ), + "description": "Stats for UK Premier League matches.", + "location": "europe-west1", + }, + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2" + ), + "display_name": "serie_a_matches", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a" + ), + "description": "Italian Serie A football results.", + "location": "europe-west1", + }, + ] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Query Construction"): + # Assert the request was made as expected + expected_query = ( + f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + with self.subTest("Response Processing"): + # Assert the output is processed correctly + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 2) + self.assertEqual( + result["results"][0]["display_name"], "uk_football_premiership" + ) + self.assertEqual(result["results"][1]["display_name"], "serie_a_matches") + self.assertIn("UK Premier League", result["results"][0]["description"]) + + def test_search_catalog_default_location(self): + """Test search_catalog fallback to global location when None is provided.""" + creds = _mock_creds() + settings = _mock_settings() + # settings.location is None by default + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/global", name_arg) + + def test_search_catalog_settings_location(self): + """Test search_catalog uses settings.location when provided.""" + creds = _mock_creds() + settings = BigQueryToolConfig(location="eu") + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/eu", name_arg) diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index f1f73aa6..0eced4b1 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 10 + assert len(tools) == 11 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", + "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py b/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py index d7debf02..46904828 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_metadata_tool.py @@ -15,69 +15,69 @@ import logging from unittest import mock +from google.adk.tools.bigtable import client from google.adk.tools.bigtable import metadata_tool from google.auth.credentials import Credentials +from google.cloud.bigtable import enums +import pytest -def test_list_instances(): - """Test list_instances function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" +@pytest.fixture +def mock_get_client(): + with mock.patch.object( + client, "get_bigtable_admin_client" ) as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client + yield mock_get_client + + +def test_list_instances(mock_get_client): + mock_instance = mock.MagicMock() + mock_instance.instance_id = "test-instance" + mock_get_client.return_value.list_instances.return_value = ( + [mock_instance], + [], + ) + + mock_instance.display_name = "Test Instance" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION + mock_instance.labels = {"env": "test"} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_instances( + project_id="test-project", credentials=creds + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + assert result == {"status": "SUCCESS", "results": [expected_result]} + + +def test_list_instances_failed_locations(mock_get_client): + with mock.patch.object(logging, "warning") as mock_warning: mock_instance = mock.MagicMock() mock_instance.instance_id = "test-instance" - mock_client.list_instances.return_value = ([mock_instance], []) + failed_locations = ["us-west1-a"] + mock_get_client.return_value.list_instances.return_value = ( + [mock_instance], + failed_locations, + ) - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_instances("test-project", creds) - assert result == {"status": "SUCCESS", "results": ["test-instance"]} - - -def test_list_instances_failed_locations(): - """Test list_instances function when some locations fail.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - with mock.patch.object(logging, "warning") as mock_warning: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_instance.instance_id = "test-instance" - failed_locations = ["us-west1-a"] - mock_client.list_instances.return_value = ( - [mock_instance], - failed_locations, - ) - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_instances("test-project", creds) - assert result == {"status": "SUCCESS", "results": ["test-instance"]} - mock_warning.assert_called_once_with( - "Failed to list instances from the following locations: %s", - failed_locations, - ) - - -def test_get_instance_info(): - """Test get_instance_info function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_instance.instance_id = "test-instance" mock_instance.display_name = "Test Instance" - mock_instance.state = "READY" - mock_instance.type_ = "PRODUCTION" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION mock_instance.labels = {"env": "test"} creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.get_instance_info( - "test-project", "test-instance", creds + result = metadata_tool.list_instances( + project_id="test-project", credentials=creds ) expected_result = { "project_id": "test-project", @@ -87,51 +87,186 @@ def test_get_instance_info(): "type": "PRODUCTION", "labels": {"env": "test"}, } - assert result == {"status": "SUCCESS", "results": expected_result} - mock_instance.reload.assert_called_once() - - -def test_list_tables(): - """Test list_tables function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_table = mock.MagicMock() - mock_table.table_id = "test-table" - mock_instance.list_tables.return_value = [mock_table] - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.list_tables("test-project", "test-instance", creds) - assert result == {"status": "SUCCESS", "results": ["test-table"]} - - -def test_get_table_info(): - """Test get_table_info function.""" - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_admin_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_instance = mock.MagicMock() - mock_client.instance.return_value = mock_instance - mock_table = mock.MagicMock() - mock_instance.table.return_value = mock_table - mock_table.table_id = "test-table" - mock_instance.instance_id = "test-instance" - mock_table.list_column_families.return_value = {"cf1": mock.MagicMock()} - - creds = mock.create_autospec(Credentials, instance=True) - result = metadata_tool.get_table_info( - "test-project", "test-instance", "test-table", creds + assert result == {"status": "SUCCESS", "results": [expected_result]} + mock_warning.assert_called_once_with( + "Failed to list instances from the following locations: %s", + failed_locations, ) - expected_result = { - "project_id": "test-project", - "instance_id": "test-instance", - "table_id": "test-table", - "column_families": ["cf1"], - } - assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_get_instance_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_instance.instance_id = "test-instance" + mock_instance.display_name = "Test Instance" + mock_instance.state = enums.Instance.State.READY + mock_instance.type_ = enums.Instance.Type.PRODUCTION + mock_instance.labels = {"env": "test"} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_instance_info( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "display_name": "Test Instance", + "state": "READY", + "type": "PRODUCTION", + "labels": {"env": "test"}, + } + assert result == {"status": "SUCCESS", "results": expected_result} + mock_instance.reload.assert_called_once() + + +def test_list_tables(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_table = mock.MagicMock() + mock_table.table_id = "test-table" + mock_table.name = ( + "projects/test-project/instances/test-instance/tables/test-table" + ) + mock_instance.list_tables.return_value = [mock_table] + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_tables( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = [{ + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "table_name": ( + "projects/test-project/instances/test-instance/tables/test-table" + ), + }] + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_get_table_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_instance.instance_id = "test-instance" + mock_get_client.return_value.instance.return_value = mock_instance + mock_table = mock.MagicMock() + mock_instance.table.return_value = mock_table + mock_table.table_id = "test-table" + mock_table.list_column_families.return_value = {"cf1": mock.MagicMock()} + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_table_info( + project_id="test-project", + instance_id="test-instance", + table_id="test-table", + credentials=creds, + ) + expected_result = { + "project_id": "test-project", + "instance_id": "test-instance", + "table_id": "test-table", + "column_families": ["cf1"], + } + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_list_clusters(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_cluster = mock.MagicMock() + mock_cluster.cluster_id = "test-cluster" + mock_cluster.name = ( + "projects/test-project/instances/test-instance/clusters/test-cluster" + ) + mock_cluster.state = enums.Cluster.State.READY + mock_cluster.serve_nodes = 3 + mock_cluster.default_storage_type = enums.StorageType.SSD + mock_cluster.location_id = "us-central1-a" + mock_instance.list_clusters.return_value = ([mock_cluster], []) + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_clusters( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + expected_result = [{ + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "cluster_name": mock_cluster.name, + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + }] + assert result == {"status": "SUCCESS", "results": expected_result} + + +def test_list_clusters_error(mock_get_client): + mock_get_client.side_effect = Exception("test-error") + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.list_clusters( + project_id="test-project", + instance_id="test-instance", + credentials=creds, + ) + assert result == { + "status": "ERROR", + "error_details": "Exception('test-error')", + } + + +def test_get_cluster_info(mock_get_client): + mock_instance = mock.MagicMock() + mock_get_client.return_value.instance.return_value = mock_instance + mock_cluster = mock.MagicMock() + mock_instance.cluster.return_value = mock_cluster + mock_cluster.cluster_id = "test-cluster" + mock_cluster.state = enums.Cluster.State.READY + mock_cluster.serve_nodes = 3 + mock_cluster.default_storage_type = enums.StorageType.SSD + mock_cluster.location_id = "us-central1-a" + mock_cluster.min_serve_nodes = 3 + mock_cluster.max_serve_nodes = 10 + mock_cluster.cpu_utilization_percent = 50 + + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_cluster_info( + project_id="test-project", + instance_id="test-instance", + cluster_id="test-cluster", + credentials=creds, + ) + expected_results = { + "project_id": "test-project", + "instance_id": "test-instance", + "cluster_id": "test-cluster", + "state": "READY", + "serve_nodes": 3, + "default_storage_type": "SSD", + "location_id": "us-central1-a", + "min_serve_nodes": 3, + "max_serve_nodes": 10, + "cpu_utilization_percent": 50, + } + assert result == {"status": "SUCCESS", "results": expected_results} + mock_cluster.reload.assert_called_once() + + +def test_get_cluster_info_error(mock_get_client): + mock_get_client.side_effect = Exception("test-error") + creds = mock.create_autospec(Credentials, instance=True) + result = metadata_tool.get_cluster_info( + project_id="test-project", + instance_id="test-instance", + cluster_id="test-cluster", + credentials=creds, + ) + assert result == { + "status": "ERROR", + "error_details": "Exception('test-error')", + } diff --git a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py index 0bd0fedc..46b65a3a 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_query_tool.py +++ b/tests/unittests/tools/bigtable/test_bigtable_query_tool.py @@ -14,130 +14,191 @@ from __future__ import annotations -from typing import Optional from unittest import mock -from google.adk.tools.base_tool import BaseTool -from google.adk.tools.bigtable import BigtableCredentialsConfig -from google.adk.tools.bigtable.bigtable_toolset import BigtableToolset +from google.adk.tools.bigtable import client from google.adk.tools.bigtable.query_tool import execute_sql from google.adk.tools.bigtable.settings import BigtableToolSettings from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials -from google.cloud import bigtable from google.cloud.bigtable.data.execute_query import ExecuteQueryIterator import pytest -def test_execute_sql_basic(): - """Test execute_sql tool basic functionality.""" +@pytest.mark.asyncio +@pytest.mark.parametrize( + ( + "query", + "settings", + "parameters", + "parameter_types", + "execute_query_side_effect", + "iterator_yield_values", + "expected_result", + ), + [ + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + None, + [{"col1": "val1", "col2": 123}], + {"status": "SUCCESS", "rows": [{"col1": "val1", "col2": 123}]}, + id="basic", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(max_query_result_rows=1), + None, + None, + None, + [{"col1": "val1"}, {"col1": "val2"}], + { + "status": "SUCCESS", + "rows": [{"col1": "val1"}], + "result_is_likely_truncated": True, + }, + id="truncated", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + Exception("Test error"), + None, + {"status": "ERROR", "error_details": "Test error"}, + id="error", + ), + pytest.param( + "SELECT * FROM my_table WHERE col1 = @param1", + BigtableToolSettings(), + {"param1": "val1"}, + {"param1": "string"}, + None, + [{"col1": "val1"}], + {"status": "SUCCESS", "rows": [{"col1": "val1"}]}, + id="with_parameters", + ), + pytest.param( + "SELECT * FROM my_table WHERE 1=0", + BigtableToolSettings(), + None, + None, + None, + [], + {"status": "SUCCESS", "rows": []}, + id="empty_results", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(max_query_result_rows=10), + None, + None, + None, + [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + { + "status": "SUCCESS", + "rows": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + }, + id="multiple_rows", + ), + pytest.param( + "SELECT * FROM my_table", + None, + None, + None, + None, + [{"id": i} for i in range(51)], + { + "status": "SUCCESS", + "rows": [{"id": i} for i in range(50)], + "result_is_likely_truncated": True, + }, + id="settings_none_uses_default", + ), + pytest.param( + "SELECT * FROM my_table", + BigtableToolSettings(), + None, + None, + None, + Exception("Iteration failed"), + {"status": "ERROR", "error_details": "Iteration failed"}, + id="iteration_error_calls_close", + ), + ], +) +async def test_execute_sql( + query, + settings, + parameters, + parameter_types, + execute_query_side_effect, + iterator_yield_values, + expected_result, +): + """Test execute_sql tool functionality.""" project = "my_project" instance_id = "my_instance" - query = "SELECT * FROM my_table" credentials = mock.create_autospec(Credentials, instance=True) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: + with mock.patch.object(client, "get_bigtable_data_client") as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client - mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) - mock_client.execute_query.return_value = mock_iterator - # Mock row data - mock_row = mock.MagicMock() - mock_row.fields = {"col1": "val1", "col2": 123} - mock_iterator.__iter__.return_value = [mock_row] + if execute_query_side_effect: + mock_client.execute_query.side_effect = execute_query_side_effect + else: + mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) + mock_client.execute_query.return_value = mock_iterator - result = execute_sql( + if isinstance(iterator_yield_values, Exception): + + def raise_error(): + yield mock.MagicMock() + raise iterator_yield_values + + mock_iterator.__iter__.side_effect = raise_error + else: + mock_rows = [] + for fields in iterator_yield_values: + mock_row = mock.MagicMock() + mock_row.fields = fields + mock_rows.append(mock_row) + mock_iterator.__iter__.return_value = mock_rows + + result = await execute_sql( project_id=project, instance_id=instance_id, credentials=credentials, query=query, - settings=BigtableToolSettings(), + settings=settings, tool_context=tool_context, + parameters=parameters, + parameter_types=parameter_types, ) - expected_rows = [{"col1": "val1", "col2": 123}] - assert result == {"status": "SUCCESS", "rows": expected_rows} - mock_client.execute_query.assert_called_once_with( - query=query, instance_id=instance_id - ) - mock_iterator.close.assert_called_once() + if expected_result["status"] == "ERROR": + assert result["status"] == "ERROR" + assert expected_result["error_details"] in result["error_details"] + else: + assert result == expected_result + + if not execute_query_side_effect: + mock_client.execute_query.assert_called_once_with( + query=query, + instance_id=instance_id, + parameters=parameters, + parameter_types=parameter_types, + ) + mock_iterator.close.assert_called_once() -def test_execute_sql_truncated(): - """Test execute_sql tool truncation functionality.""" - project = "my_project" - instance_id = "my_instance" - query = "SELECT * FROM my_table" - credentials = mock.create_autospec(Credentials, instance=True) - tool_context = mock.create_autospec(ToolContext, instance=True) - - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) - mock_client.execute_query.return_value = mock_iterator - - # Mock row data - mock_row1 = mock.MagicMock() - mock_row1.fields = {"col1": "val1"} - mock_row2 = mock.MagicMock() - mock_row2.fields = {"col1": "val2"} - mock_iterator.__iter__.return_value = [mock_row1, mock_row2] - - result = execute_sql( - project_id=project, - instance_id=instance_id, - credentials=credentials, - query=query, - settings=BigtableToolSettings(max_query_result_rows=1), - tool_context=tool_context, - ) - - expected_rows = [{"col1": "val1"}] - assert result == { - "status": "SUCCESS", - "rows": expected_rows, - "result_is_likely_truncated": True, - } - mock_client.execute_query.assert_called_once_with( - query=query, instance_id=instance_id - ) - mock_iterator.close.assert_called_once() - - -def test_execute_sql_error(): - """Test execute_sql tool error handling.""" - project = "my_project" - instance_id = "my_instance" - query = "SELECT * FROM my_table" - credentials = mock.create_autospec(Credentials, instance=True) - tool_context = mock.create_autospec(ToolContext, instance=True) - - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: - mock_client = mock.MagicMock() - mock_get_client.return_value = mock_client - mock_client.execute_query.side_effect = Exception("Test error") - - result = execute_sql( - project_id=project, - instance_id=instance_id, - credentials=credentials, - query=query, - settings=BigtableToolSettings(), - tool_context=tool_context, - ) - assert result == {"status": "ERROR", "error_details": "Test error"} - - -def test_execute_sql_row_value_circular_reference_fallback(): +@pytest.mark.asyncio +async def test_execute_sql_row_value_circular_reference_fallback(): """Test execute_sql converts circular row values to strings.""" project = "my_project" instance_id = "my_instance" @@ -145,9 +206,7 @@ def test_execute_sql_row_value_circular_reference_fallback(): credentials = mock.create_autospec(Credentials, instance=True) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch( - "google.adk.tools.bigtable.client.get_bigtable_data_client" - ) as mock_get_client: + with mock.patch.object(client, "get_bigtable_data_client") as mock_get_client: mock_client = mock.MagicMock() mock_get_client.return_value = mock_client mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True) @@ -158,7 +217,7 @@ def test_execute_sql_row_value_circular_reference_fallback(): mock_row.fields = {"col1": circular_value} mock_iterator.__iter__.return_value = [mock_row] - result = execute_sql( + result = await execute_sql( project_id=project, instance_id=instance_id, credentials=credentials, diff --git a/tests/unittests/tools/bigtable/test_bigtable_toolset.py b/tests/unittests/tools/bigtable/test_bigtable_toolset.py index 53040395..b5698cfc 100644 --- a/tests/unittests/tools/bigtable/test_bigtable_toolset.py +++ b/tests/unittests/tools/bigtable/test_bigtable_toolset.py @@ -45,7 +45,7 @@ async def test_bigtable_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 5 + assert len(tools) == 7 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -54,6 +54,8 @@ async def test_bigtable_toolset_tools_default(): "list_tables", "get_table_info", "execute_sql", + "list_clusters", + "get_cluster_info", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index c4c85e77..6d7fa0a3 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -17,6 +17,7 @@ from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from google.adk.agents.context import Context from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import HttpAuth @@ -534,7 +535,9 @@ class TestMCPTool: ) # Create service account credential - service_account = ServiceAccount(scopes=["test"]) + service_account = ServiceAccount( + scopes=["test"], use_default_credential=True + ) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=service_account, @@ -970,3 +973,45 @@ class TestMCPTool: assert factory_calls[0][0] == "test_tool" # callback_context is the tool_context itself (ToolContext extends CallbackContext) assert factory_calls[0][1] is tool_context + + @pytest.mark.asyncio + async def test_run_async_require_confirmation_callable_with_context_type( + self, + ): + """Test require_confirmation callable with Context type annotation.""" + + async def _require_confirmation_func(param1: str, ctx: Context): + return True + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + require_confirmation=_require_confirmation_func, + ) + tool_context = Mock(spec=ToolContext) + tool_context.tool_confirmation = None + tool_context.request_confirmation = Mock() + args = {"param1": "test_value", "extra_arg": 123} + + with patch.object( + tool, "_invoke_callable", new_callable=AsyncMock + ) as mock_invoke_callable: + mock_invoke_callable.return_value = True + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Verify context is passed with detected parameter name 'ctx' + expected_args_to_call = { + "param1": "test_value", + "ctx": tool_context, + } + mock_invoke_callable.assert_called_once_with( + _require_confirmation_func, expected_args_to_call + ) + + assert result == { + "error": ( + "This tool call requires confirmation, please approve or reject." + ) + } + tool_context.request_confirmation.assert_called_once() diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index 0ca99444..fb35daf6 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -25,8 +25,23 @@ from google.adk.auth.auth_schemes import AuthSchemeType from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import google.auth +from google.auth import exceptions as google_auth_exceptions import pytest +_ACCESS_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.Credentials." + "from_service_account_info" +) + +_ID_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.IDTokenCredentials." + "from_service_account_info" +) + +_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token" + @pytest.fixture def service_account_exchanger(): @@ -41,50 +56,45 @@ def auth_scheme(): return scheme -def test_exchange_credential_success( - service_account_exchanger, auth_scheme, monkeypatch +@pytest.fixture +def sa_credential(): + """A minimal valid ServiceAccountCredential for testing.""" + return ServiceAccountCredential( + type_="service_account", + project_id="test_project_id", + private_key_id="test_private_key_id", + private_key="-----BEGIN PRIVATE KEY-----...", + client_email="test@test.iam.gserviceaccount.com", + client_id="test_client_id", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs", + client_x509_cert_url=( + "https://www.googleapis.com/robot/v1/metadata/x509/test" + ), + universe_domain="googleapis.com", + ) + + +_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + + +# --- Access token exchange tests --- + + +def test_exchange_access_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test successful exchange of service account credentials.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_from_sa_info = MagicMock(return_value=mock_credentials) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) - # Mock the from_service_account_info method - mock_from_service_account_info = MagicMock(return_value=mock_credentials) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, - ) - - # Mock the refresh method - mock_credentials.refresh = MagicMock() - - # Create a valid AuthCredential with service account info auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) @@ -95,7 +105,7 @@ def test_exchange_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() mock_credentials.refresh.assert_called_once() @@ -107,7 +117,7 @@ def test_exchange_credential_success( (None, None, None), ], ) -def test_exchange_credential_use_default_credential_success( +def test_exchange_access_token_with_adc_sets_quota_project( service_account_exchanger, auth_scheme, monkeypatch, @@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success( adc_project_id, expected_quota_project_id, ): - """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" mock_credentials.quota_project_id = cred_quota_project_id @@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], + scopes=["https://www.googleapis.com/auth/bigquery"], ), ) @@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success( ) else: assert not result.http.additional_headers - # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/bigquery"] ) mock_credentials.refresh.assert_called_once() -def test_exchange_credential_missing_auth_credential( +def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = None + mock_google_auth_default = MagicMock(return_value=(mock_credentials, None)) + monkeypatch.setattr(google.auth, "default", mock_google_auth_default) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES) + + +def test_exchange_raises_when_auth_credential_is_none( service_account_exchanger, auth_scheme ): - """Test missing auth credential during exchange.""" with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, None) assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_missing_service_account_info( +def test_exchange_raises_when_service_account_is_none( service_account_exchanger, auth_scheme ): - """Test missing service account info during exchange.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, ) @@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info( assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_exchange_failure( - service_account_exchanger, auth_scheme, monkeypatch +def test_exchange_wraps_google_auth_error_as_missing_error( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test failure during service account token exchange.""" - mock_from_service_account_info = MagicMock( - side_effect=Exception("Failed to load credentials") - ) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to load credentials") ) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) + with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, auth_credential) assert "Failed to exchange service account token" in str(exc_info.value) - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() + + +def test_exchange_raises_when_explicit_credentials_have_no_scopes( + service_account_exchanger, auth_scheme, sa_credential +): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "scopes are required" in str(exc_info.value) + + +# --- ID token exchange tests --- + + +def test_exchange_id_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_id_credentials = MagicMock() + mock_id_credentials.token = "mock_id_token" + mock_from_sa_info = MagicMock(return_value=mock_id_credentials) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert result.http.additional_headers is None + mock_from_sa_info.assert_called_once() + assert ( + mock_from_sa_info.call_args[1]["target_audience"] + == "https://my-service.run.app" + ) + mock_id_credentials.refresh.assert_called_once() + + +def test_exchange_id_token_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token") + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_adc_id_token" + assert result.http.additional_headers is None + mock_fetch_id_token.assert_called_once() + assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app" + + +def test_id_token_requires_audience(): + with pytest.raises( + ValueError, match="audience is required when use_id_token is True" + ): + ServiceAccount( + use_default_credential=True, + use_id_token=True, + ) + + +def test_exchange_id_token_wraps_error_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to create ID token credentials") + ) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +def test_exchange_id_token_wraps_error_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock( + side_effect=google_auth_exceptions.DefaultCredentialsError( + "Metadata service unavailable" + ) + ) + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +# --- Model validator tests --- + + +def test_model_validator_rejects_missing_credential_without_adc(): + with pytest.raises( + ValueError, + match="service_account_credential is required", + ): + ServiceAccount( + use_default_credential=False, + scopes=_DEFAULT_SCOPES, + ) + + +def test_model_validator_allows_adc_without_explicit_credential(): + sa = ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + ) + assert sa.service_account_credential is None + assert sa.use_default_credential is True diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 81d44f0b..1131181a 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -1268,6 +1268,162 @@ class TestRestApiTool: assert result == {"result": "success"} + def test_prepare_request_params_extracts_embedded_query_params( + self, sample_auth_credential, sample_auth_scheme + ): + """Test that query params embedded in the URL path are extracted. + + ApplicationIntegrationToolset embeds query params and fragments directly + in the OpenAPI path (e.g. '...execute?triggerId=api_trigger/Name#action'). + These must be moved into the explicit query_params dict so httpx does not + strip them when it replaces the URL query string with the `params` arg. + Regression test for https://github.com/google/adk-python/issues/4555. + """ + integration_path = ( + "/v2/projects/my-proj/locations/us-central1" + "/integrations/ExecuteConnection:execute" + "?triggerId=api_trigger/ExecuteConnection" + "#POST_files" + ) + endpoint = OperationEndpoint( + base_url="https://integrations.googleapis.com", + path=integration_path, + method="POST", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + # The embedded query param must appear in params + assert request_params["params"]["triggerId"] == ( + "api_trigger/ExecuteConnection" + ) + # The URL must NOT contain the query string or fragment + assert "?" not in request_params["url"] + assert "#" not in request_params["url"] + assert request_params["url"] == ( + "https://integrations.googleapis.com" + "/v2/projects/my-proj/locations/us-central1" + "/integrations/ExecuteConnection:execute" + ) + + def test_prepare_request_params_merges_embedded_and_explicit_query_params( + self, sample_auth_credential, sample_auth_scheme + ): + """Embedded URL query params merge with explicitly defined query params.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api?embedded_key=embedded_val", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [ + ApiParameter( + original_name="explicit_key", + py_name="explicit_key", + param_location="query", + param_schema=OpenAPISchema(type="string"), + ), + ] + kwargs = {"explicit_key": "explicit_val"} + + request_params = tool._prepare_request_params(params, kwargs) + + assert request_params["params"]["embedded_key"] == "embedded_val" + assert request_params["params"]["explicit_key"] == "explicit_val" + assert "?" not in request_params["url"] + + def test_prepare_request_params_explicit_query_param_takes_precedence( + self, sample_auth_credential, sample_auth_scheme + ): + """Explicitly defined query params take precedence over embedded ones.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api?key=embedded", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [ + ApiParameter( + original_name="key", + py_name="key", + param_location="query", + param_schema=OpenAPISchema(type="string"), + ), + ] + kwargs = {"key": "explicit"} + + request_params = tool._prepare_request_params(params, kwargs) + + # Explicit value wins over the embedded one + assert request_params["params"]["key"] == "explicit" + + def test_prepare_request_params_strips_fragment_only( + self, sample_auth_credential, sample_auth_scheme + ): + """Fragment-only paths (no query string) are also cleaned.""" + endpoint = OperationEndpoint( + base_url="https://example.com", + path="/api#fragment", + method="GET", + ) + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + assert "#" not in request_params["url"] + assert request_params["url"] == "https://example.com/api" + + def test_prepare_request_params_plain_url_unchanged( + self, sample_endpoint, sample_auth_credential, sample_auth_scheme + ): + """URLs without embedded query or fragment are not modified.""" + operation = Operation(operationId="test_op") + tool = RestApiTool( + name="test_tool", + description="test", + endpoint=sample_endpoint, + operation=operation, + auth_credential=sample_auth_credential, + auth_scheme=sample_auth_scheme, + ) + + request_params = tool._prepare_request_params([], {}) + + assert request_params["url"] == "https://example.com/test" + def test_snake_to_lower_camel(): assert snake_to_lower_camel("single") == "single" diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 3b5aa26f..0a86d07c 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -145,3 +145,43 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ) ] assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + + +def test_vertex_rag_retrieval_for_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + responses = [ + 'response1', + ] + mockModel = testing_utils.MockModel.create(responses=responses) + mockModel.model = 'internal-model-v1' + + agent = Agent( + name='root_agent', + model=mockModel, + tools=[ + VertexAiRagRetrieval( + name='rag_retrieval', + description='rag_retrieval', + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ], + ) + ], + ) + runner = testing_utils.InMemoryRunner(agent) + runner.run('test1') + + assert len(mockModel.requests) == 1 + assert len(mockModel.requests[0].config.tools) == 1 + assert mockModel.requests[0].config.tools == [ + types.Tool( + retrieval=types.Retrieval( + vertex_rag_store=types.VertexRagStore( + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ] + ) + ) + ) + ] + assert 'rag_retrieval' not in mockModel.requests[0].tools_dict diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py index 4532dd56..c6a6c742 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -54,11 +54,12 @@ def mock_spanner_ids(): ), ], ) -@mock.patch.object(utils, "embed_contents") +@pytest.mark.asyncio +@mock.patch.object(utils, "embed_contents_async", autospec=True) @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_knn_success( +async def test_similarity_search_knn_success( mock_get_spanner_client, - mock_embed_contents, + mock_embed_contents_async, mock_spanner_ids, mock_credentials, embedding_option_key, @@ -77,7 +78,7 @@ def test_similarity_search_knn_success( mock_get_spanner_client.return_value = mock_spanner_client if embedding_option_key == "vertex_ai_embedding_model_name": - mock_embed_contents.return_value = [expected_embedding] + mock_embed_contents_async.return_value = [expected_embedding] # execute_sql is called once for the kNN search mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)]) else: @@ -90,7 +91,7 @@ def test_similarity_search_knn_success( iter([("result1",), ("result2",)]), ] - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -111,13 +112,14 @@ def test_similarity_search_knn_success( assert "@embedding" in sql assert call_args.kwargs == {"params": {"embedding": expected_embedding}} if embedding_option_key == "vertex_ai_embedding_model_name": - mock_embed_contents.assert_called_once_with( + mock_embed_contents_async.assert_called_once_with( embedding_option_value, ["test query"], None ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_ann_success( +async def test_similarity_search_ann_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search function with ANN success.""" @@ -139,7 +141,7 @@ def test_similarity_search_ann_success( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -164,13 +166,14 @@ def test_similarity_search_ann_success( assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_error( +async def test_similarity_search_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search function with a generic error.""" mock_get_spanner_client.side_effect = Exception("Test Exception") - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -187,11 +190,12 @@ def test_similarity_search_error( assert "Test Exception" in result["error_details"] -@mock.patch.object(utils, "embed_contents") +@pytest.mark.asyncio +@mock.patch.object(utils, "embed_contents_async") @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_circular_row_fallback_to_string( +async def test_similarity_search_circular_row_fallback_to_string( mock_get_spanner_client, - mock_embed_contents, + mock_embed_contents_async, mock_spanner_ids, mock_credentials, ): @@ -202,7 +206,7 @@ def test_similarity_search_circular_row_fallback_to_string( mock_snapshot = MagicMock() circular_row = [] circular_row.append(circular_row) - mock_embed_contents.return_value = [[0.1, 0.2, 0.3]] + mock_embed_contents_async.return_value = [[0.1, 0.2, 0.3]] mock_snapshot.execute_sql.return_value = iter([circular_row]) mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL @@ -210,7 +214,7 @@ def test_similarity_search_circular_row_fallback_to_string( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -228,8 +232,9 @@ def test_similarity_search_circular_row_fallback_to_string( assert result["rows"] == [str(circular_row)] +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_postgresql_knn_success( +async def test_similarity_search_postgresql_knn_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with PostgreSQL dialect for kNN.""" @@ -249,7 +254,7 @@ def test_similarity_search_postgresql_knn_success( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -273,8 +278,9 @@ def test_similarity_search_postgresql_knn_success( assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}} +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_postgresql_ann_unsupported( +async def test_similarity_search_postgresql_ann_unsupported( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with unsupported ANN for PostgreSQL dialect.""" @@ -286,7 +292,7 @@ def test_similarity_search_postgresql_ann_unsupported( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -311,8 +317,9 @@ def test_similarity_search_postgresql_ann_unsupported( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_gsql_missing_embedding_model_error( +async def test_similarity_search_gsql_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with missing embedding_options for GoogleSQL dialect.""" @@ -324,7 +331,7 @@ def test_similarity_search_gsql_missing_embedding_model_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -348,8 +355,9 @@ def test_similarity_search_gsql_missing_embedding_model_error( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_pg_missing_embedding_model_error( +async def test_similarity_search_pg_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with missing embedding_options for PostgreSQL dialect.""" @@ -361,7 +369,7 @@ def test_similarity_search_pg_missing_embedding_model_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -427,8 +435,9 @@ def test_similarity_search_pg_missing_embedding_model_error( ), ], ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_multiple_embedding_options_error( +async def test_similarity_search_multiple_embedding_options_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials, @@ -443,7 +452,7 @@ def test_similarity_search_multiple_embedding_options_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -461,8 +470,9 @@ def test_similarity_search_multiple_embedding_options_error( ) +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_output_dimensionality_gsql_error( +async def test_similarity_search_output_dimensionality_gsql_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name.""" @@ -474,7 +484,7 @@ def test_similarity_search_output_dimensionality_gsql_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], @@ -492,8 +502,9 @@ def test_similarity_search_output_dimensionality_gsql_error( assert "is not supported when" in result["error_details"] +@pytest.mark.asyncio @mock.patch.object(client, "get_spanner_client") -def test_similarity_search_unsupported_algorithm_error( +async def test_similarity_search_unsupported_algorithm_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): """Test similarity_search with an unsupported nearest neighbors algorithm.""" @@ -505,7 +516,7 @@ def test_similarity_search_unsupported_algorithm_error( mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client - result = search_tool.similarity_search( + result = await search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], database_id=mock_spanner_ids["database_id"], diff --git a/tests/unittests/tools/spanner/test_spanner_query_tool.py b/tests/unittests/tools/spanner/test_spanner_query_tool.py index 6c75a3ea..928c207d 100644 --- a/tests/unittests/tools/spanner/test_spanner_query_tool.py +++ b/tests/unittests/tools/spanner/test_spanner_query_tool.py @@ -191,8 +191,9 @@ async def test_execute_sql_query_result( assert tool.description == expected_description +@pytest.mark.asyncio @mock.patch.object(query_tool.utils, "execute_sql", spec_set=True) -def test_execute_sql(mock_utils_execute_sql): +async def test_execute_sql(mock_utils_execute_sql): """Test execute_sql function in query result default mode.""" mock_credentials = mock.create_autospec( Credentials, instance=True, spec_set=True @@ -202,7 +203,7 @@ def test_execute_sql(mock_utils_execute_sql): ) mock_utils_execute_sql.return_value = {"status": "SUCCESS", "rows": [[1]]} - result = query_tool.execute_sql( + result = await query_tool.execute_sql( project_id="test-project", instance_id="test-instance", database_id="test-database", diff --git a/tests/unittests/tools/test_bash_tool.py b/tests/unittests/tools/test_bash_tool.py new file mode 100644 index 00000000..e35c32b6 --- /dev/null +++ b/tests/unittests/tools/test_bash_tool.py @@ -0,0 +1,229 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools import bash_tool +from google.adk.tools import tool_context +from google.adk.tools.tool_confirmation import ToolConfirmation +import pytest + + +@pytest.fixture +def workspace(tmp_path): + """Creates a workspace mirroring the anthropics/skills PDF skill layout.""" + # Skill: pdf/ + skill_dir = tmp_path / "pdf" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: pdf\n" + "description: Use this skill whenever the user wants to do" + " anything with PDF files.\n" + "---\n# PDF Processing Guide\n\n## Overview\n" + "This guide covers PDF processing operations." + ) + scripts = skill_dir / "scripts" + scripts.mkdir() + (scripts / "extract_form_structure.py").write_text( + "import sys; print(f'extracting from {sys.argv[1]}')" + ) + (scripts / "fill_pdf_form_with_annotations.py").write_text( + "print('filling form')" + ) + references = skill_dir / "references" + references.mkdir() + (references / "REFERENCE.md").write_text("# Reference\nDetailed docs.") + # A loose file at workspace root (not inside a skill). + (tmp_path / "sample.pdf").write_bytes(b"%PDF-1.4 fake") + return tmp_path + + +@pytest.fixture +def tool_context_no_confirmation(): + """ToolContext with no confirmation (initial call).""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + ctx.tool_confirmation = None + ctx.actions = mock.MagicMock() + return ctx + + +@pytest.fixture +def tool_context_confirmed(): + """ToolContext with confirmation approved.""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + confirmation = mock.create_autospec(ToolConfirmation, instance=True) + confirmation.confirmed = True + ctx.tool_confirmation = confirmation + ctx.actions = mock.MagicMock() + return ctx + + +@pytest.fixture +def tool_context_rejected(): + """ToolContext with confirmation rejected.""" + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + confirmation = mock.create_autospec(ToolConfirmation, instance=True) + confirmation.confirmed = False + ctx.tool_confirmation = confirmation + ctx.actions = mock.MagicMock() + return ctx + + +# --- _validate_command tests --- + + +class TestValidateCommand: + + def test_empty_command(self): + policy = bash_tool.BashToolPolicy() + assert bash_tool._validate_command("", policy) is not None + assert bash_tool._validate_command(" ", policy) is not None + + def test_default_policy_allows_everything(self): + policy = bash_tool.BashToolPolicy() + assert bash_tool._validate_command("rm -rf /", policy) is None + assert bash_tool._validate_command("cat /etc/passwd", policy) is None + assert bash_tool._validate_command("sudo curl", policy) is None + + def test_restricted_policy_allows_prefixes(self): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat")) + assert bash_tool._validate_command("ls -la", policy) is None + assert bash_tool._validate_command("cat file.txt", policy) is None + + def test_restricted_policy_blocks_others(self): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat")) + assert bash_tool._validate_command("rm -rf .", policy) is not None + assert bash_tool._validate_command("tree", policy) is not None + assert "Permitted prefixes are: ls, cat" in bash_tool._validate_command( + "tree", policy + ) + + +class TestExecuteBashTool: + + @pytest.mark.asyncio + async def test_requests_confirmation( + self, workspace, tool_context_no_confirmation + ): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, + tool_context=tool_context_no_confirmation, + ) + assert "error" in result + assert "requires confirmation" in result["error"] + tool_context_no_confirmation.request_confirmation.assert_called_once() + + @pytest.mark.asyncio + async def test_rejected(self, workspace, tool_context_rejected): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, tool_context=tool_context_rejected + ) + assert result == {"error": "This tool call is rejected."} + + @pytest.mark.asyncio + async def test_executes_when_confirmed( + self, workspace, tool_context_confirmed + ): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "ls"}, + tool_context=tool_context_confirmed, + ) + assert result["returncode"] == 0 + assert "pdf" in result["stdout"] + + @pytest.mark.asyncio + async def test_cat_skill_md(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "cat pdf/SKILL.md"}, + tool_context=tool_context_confirmed, + ) + assert "PDF Processing Guide" in result["stdout"] + + @pytest.mark.asyncio + async def test_python_script(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={ + "command": "python3 pdf/scripts/extract_form_structure.py test.pdf" + }, + tool_context=tool_context_confirmed, + ) + assert "extracting from test.pdf" in result["stdout"] + assert result["returncode"] == 0 + + @pytest.mark.asyncio + async def test_blocks_disallowed_by_policy( + self, workspace, tool_context_no_confirmation + ): + policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls",)) + tool = bash_tool.ExecuteBashTool(workspace=workspace, policy=policy) + result = await tool.run_async( + args={"command": "rm -rf ."}, + tool_context=tool_context_no_confirmation, + ) + assert "error" in result + assert "Permitted prefixes are: ls" in result["error"] + tool_context_no_confirmation.request_confirmation.assert_not_called() + + @pytest.mark.asyncio + async def test_captures_stderr(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'import sys; sys.stderr.write(\"err\")'"}, + tool_context=tool_context_confirmed, + ) + assert "err" in result["stderr"] + + @pytest.mark.asyncio + async def test_nonzero_returncode(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'exit(42)'"}, + tool_context=tool_context_confirmed, + ) + assert result["returncode"] == 42 + + @pytest.mark.asyncio + async def test_timeout(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + with mock.patch( + "google.adk.tools.bash_tool.subprocess.run", + side_effect=__import__("subprocess").TimeoutExpired("cmd", 30), + ): + result = await tool.run_async( + args={"command": "python scripts/do_thing.py"}, + tool_context=tool_context_confirmed, + ) + assert "error" in result + assert "timed out" in result["error"].lower() + + @pytest.mark.asyncio + async def test_cwd_is_workspace(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async( + args={"command": "python3 -c 'import os; print(os.getcwd())'"}, + tool_context=tool_context_confirmed, + ) + assert result["stdout"].strip() == str(workspace) + + @pytest.mark.asyncio + async def test_no_command(self, workspace, tool_context_confirmed): + tool = bash_tool.ExecuteBashTool(workspace=workspace) + result = await tool.run_async(args={}, tool_context=tool_context_confirmed) + assert "error" in result + assert "required" in result["error"].lower() diff --git a/tests/unittests/tools/test_crewai_tool.py b/tests/unittests/tools/test_crewai_tool.py index 9feb094b..a0028233 100644 --- a/tests/unittests/tools/test_crewai_tool.py +++ b/tests/unittests/tools/test_crewai_tool.py @@ -21,6 +21,7 @@ pytest.importorskip( "google.adk.tools.crewai_tool", reason="Requires Python 3.10+" ) +from google.adk.agents.context import Context from google.adk.agents.invocation_context import InvocationContext from google.adk.sessions.session import Session from google.adk.tools.crewai_tool import CrewaiTool @@ -52,6 +53,14 @@ def _crewai_tool_with_context(tool_context: ToolContext, *args, **kwargs): } +def _crewai_tool_with_context_type(ctx: Context, *args, **kwargs): + """CrewAI tool with Context type annotation.""" + return { + "search_query": kwargs.get("search_query"), + "context_present": bool(ctx), + } + + class MockCrewaiBaseTool: """Mock CrewAI BaseTool for testing.""" @@ -180,3 +189,26 @@ async def test_crewai_tool_get_declaration(): # Verify that the args_schema was used to build the declaration mock_crewai_tool.args_schema.model_json_schema.assert_called_once() + + +@pytest.mark.asyncio +async def test_crewai_tool_with_context_type_annotation(mock_tool_context): + """Test CrewaiTool with Context type annotation and custom parameter name.""" + mock_crewai_tool = MockCrewaiBaseTool(_crewai_tool_with_context_type) + tool = CrewaiTool( + mock_crewai_tool, + name="context_type_tool", + description="Context type tool", + ) + + # Verify the context parameter is detected by type + assert tool._context_param_name == "ctx" + + # Test that context is properly injected + result = await tool.run_async( + args={"search_query": "test query"}, + tool_context=mock_tool_context, + ) + + assert result["search_query"] == "test query" + assert result["context_present"] diff --git a/tests/unittests/tools/test_enterprise_web_search_tool.py b/tests/unittests/tools/test_enterprise_web_search_tool.py index ed471596..7b28d858 100644 --- a/tests/unittests/tools/test_enterprise_web_search_tool.py +++ b/tests/unittests/tools/test_enterprise_web_search_tool.py @@ -76,6 +76,25 @@ async def test_process_llm_request_failure_with_non_gemini_models(): assert 'is not supported for model' in str(exc_info.value) +@pytest.mark.asyncio +async def test_process_llm_request_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = EnterpriseWebSearchTool() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + tool_context = await _create_tool_context() + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert ( + llm_request.config.tools[0].enterprise_web_search + == types.EnterpriseWebSearch() + ) + + @pytest.mark.asyncio async def test_process_llm_request_failure_with_multiple_tools_gemini_1_models(): tool = EnterpriseWebSearchTool() diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 9b1d1abd..9c76529f 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock +from google.adk.agents.context import Context from google.adk.agents.invocation_context import InvocationContext from google.adk.sessions.session import Session from google.adk.tools.function_tool import FunctionTool @@ -440,3 +441,91 @@ async def test_run_async_parameter_filtering(mock_tool_context): assert result == {"arg1": "test", "arg2": 42} # Explicitly verify that unexpected_param was filtered out and not passed to the function assert "unexpected_param" not in result + + +def test_context_param_detection_with_context_type(): + """Test that FunctionTool detects context parameter by Context type annotation.""" + + def my_tool(query: str, ctx: Context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "ctx" + assert tool._ignore_params == ["ctx", "input_stream"] + + +def test_context_param_detection_with_tool_context_type(): + """Test that FunctionTool detects context parameter by ToolContext type annotation.""" + + def my_tool(query: str, tool_context: ToolContext) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +def test_context_param_detection_with_custom_name(): + """Test that FunctionTool detects context parameter with any name if type is Context.""" + + def my_tool(query: str, my_custom_context: Context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "my_custom_context" + assert tool._ignore_params == ["my_custom_context", "input_stream"] + + +def test_context_param_detection_fallback_to_name(): + """Test that FunctionTool falls back to 'tool_context' name when no type annotation.""" + + def my_tool(query: str, tool_context) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +def test_context_param_detection_no_context(): + """Test that FunctionTool defaults to 'tool_context' when no context param exists.""" + + def my_tool(query: str, count: int) -> str: + return query + + tool = FunctionTool(my_tool) + assert tool._context_param_name == "tool_context" + assert tool._ignore_params == ["tool_context", "input_stream"] + + +@pytest.mark.asyncio +async def test_run_async_with_custom_context_param_name(mock_tool_context): + """Test that run_async correctly injects context with custom parameter name.""" + + def my_tool(query: str, ctx: Context) -> dict: + return {"query": query, "has_context": ctx is not None} + + tool = FunctionTool(my_tool) + result = await tool.run_async( + args={"query": "test"}, + tool_context=mock_tool_context, + ) + + assert result == {"query": "test", "has_context": True} + + +@pytest.mark.asyncio +async def test_run_async_with_context_type_annotation(mock_tool_context): + """Test that run_async works with Context type annotation.""" + + async def async_tool(query: str, context: Context) -> dict: + return {"query": query, "context_type": type(context).__name__} + + tool = FunctionTool(async_tool) + result = await tool.run_async( + args={"query": "hello"}, + tool_context=mock_tool_context, + ) + + assert result["query"] == "hello" + assert result["context_type"] == "Context" diff --git a/tests/unittests/tools/test_gemini_schema_util.py b/tests/unittests/tools/test_gemini_schema_util.py index d8445ab8..b7091903 100644 --- a/tests/unittests/tools/test_gemini_schema_util.py +++ b/tests/unittests/tools/test_gemini_schema_util.py @@ -648,6 +648,88 @@ class TestToGeminiSchema: assert gemini_schema.type == Type.OBJECT assert gemini_schema.properties is None + def test_to_gemini_schema_boolean_true_property(self): + """Tests that a JSON Schema boolean `true` property is handled. + + JSON Schema allows `true` as a schema meaning "accept any value". + Some MCP servers use this pattern for fields whose content is not + further constrained. + """ + openapi_schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "refId": {"type": "string"}, + "model": True, # JSON Schema boolean schema + }, + }, + } + }, + } + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + items_schema = gemini_schema.properties["items"] + assert items_schema.type == Type.ARRAY + # `model: true` should be converted to an object schema + model_schema = items_schema.items.properties["model"] + assert model_schema.type == Type.OBJECT + + def test_to_gemini_schema_boolean_false_property(self): + """Tests that a JSON Schema boolean `false` property does not raise. + + `false` means "no value is valid" in JSON Schema, which has no Gemini + equivalent. Conversion falls back to an object schema to avoid crashing; + the result is semantically imprecise but safe. + """ + openapi_schema = { + "type": "object", + "properties": { + "anything": False, # JSON Schema boolean schema (reject all) + }, + } + # Should not raise even though `false` has no Gemini equivalent. + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + assert gemini_schema.properties["anything"] is not None + + def test_to_gemini_schema_boolean_true_in_array_items_properties(self): + """Regression test: boolean `true` schema inside array item properties. + + Some MCP servers use `"field": true` in an array item's properties to + indicate an unconstrained field, which is valid JSON Schema. + """ + openapi_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "datasourceUid": {"type": "string"}, + "model": True, + "queryType": {"type": "string"}, + "refId": {"type": "string"}, + }, + }, + }, + }, + "required": ["title", "data"], + } + # Should not raise a ValidationError + gemini_schema = _to_gemini_schema(openapi_schema) + assert isinstance(gemini_schema, Schema) + assert gemini_schema.type == Type.OBJECT + data_schema = gemini_schema.properties["data"] + assert data_schema.type == Type.ARRAY + model_schema = data_schema.items.properties["model"] + assert model_schema.type == Type.OBJECT + class TestToSnakeCase: diff --git a/tests/unittests/tools/test_google_maps_grounding_tool.py b/tests/unittests/tools/test_google_maps_grounding_tool.py new file mode 100644 index 00000000..0cd2c4fa --- /dev/null +++ b/tests/unittests/tools/test_google_maps_grounding_tool.py @@ -0,0 +1,92 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.google_maps_grounding_tool import GoogleMapsGroundingTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +async def _create_tool_context() -> ToolContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=agent, + session=session, + session_service=session_service, + ) + return ToolContext(invocation_context=invocation_context) + + +class TestGoogleMapsGroundingTool: + """Tests for GoogleMapsGroundingTool.""" + + @pytest.mark.asyncio + async def test_process_llm_request_with_gemini_2_model(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='gemini-2.5-pro', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_raises_error(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='claude-3-sonnet', config=types.GenerateContentConfig() + ) + + with pytest.raises( + ValueError, + match='Google maps tool is not supported for model claude-3-sonnet', + ): + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_and_disabled_check( + self, monkeypatch + ): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index ad5d46b5..d71061b8 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -268,6 +268,27 @@ class TestGoogleSearchTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleSearchTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_set_model_response_tool.py b/tests/unittests/tools/test_set_model_response_tool.py index 75fd40e9..89da394a 100644 --- a/tests/unittests/tools/test_set_model_response_tool.py +++ b/tests/unittests/tools/test_set_model_response_tool.py @@ -14,11 +14,12 @@ """Tests for SetModelResponseTool.""" +import inspect + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.sessions.in_memory_session_service import InMemorySessionService -from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY from google.adk.tools.set_model_response_tool import SetModelResponseTool from google.adk.tools.tool_context import ToolContext from pydantic import BaseModel @@ -83,8 +84,6 @@ def test_function_signature_generation(): """Test that function signature is correctly generated from schema.""" tool = SetModelResponseTool(PersonSchema) - import inspect - sig = inspect.signature(tool.func) # Check that parameters match schema fields @@ -129,12 +128,6 @@ async def test_run_async_valid_data(): assert result['age'] == 25 assert result['city'] == 'Seattle' - # Verify data is no longer stored in session state (old behavior) - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_run_async_complex_schema(): @@ -165,12 +158,6 @@ async def test_run_async_complex_schema(): assert result['metadata'] == {'key': 'value'} assert result['is_active'] is False - # Verify data is no longer stored in session state (old behavior) - stored_response = invocation_context.session.state.get( - MODEL_JSON_RESPONSE_KEY - ) - assert stored_response is None - @pytest.mark.asyncio async def test_run_async_validation_error(): @@ -220,15 +207,12 @@ async def test_session_state_storage_key(): tool_context=tool_context, ) - # Verify response is returned directly, not stored in session state + # Verify response is returned directly assert result is not None assert result['name'] == 'Diana' assert result['age'] == 35 assert result['city'] == 'Miami' - # Verify session state is no longer used - assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state - @pytest.mark.asyncio async def test_multiple_executions_return_latest(): @@ -260,9 +244,6 @@ async def test_multiple_executions_return_latest(): assert result2['age'] == 30 assert result2['city'] == 'City2' - # Verify session state is not used - assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state - def test_function_return_value_consistency(): """Test that function return value matches run_async return value.""" @@ -273,3 +254,216 @@ def test_function_return_value_consistency(): # Both should return the same value assert direct_result == 'Response set successfully.' + + +# Tests for list[BaseModel] schema support + + +class ItemSchema(BaseModel): + """Simple item schema for list testing.""" + + id: int = Field(description='Item ID') + name: str = Field(description='Item name') + + +def test_tool_initialization_list_schema(): + """Test tool initialization with a list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + assert tool.output_schema == list[ItemSchema] + assert tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert 'Set your final response' in tool.description + assert tool.func is not None + + +def test_function_signature_generation_list_schema(): + """Test that function signature is correctly generated for list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + sig = inspect.signature(tool.func) + + # Should have a single 'items' parameter + assert 'items' in sig.parameters + assert len(sig.parameters) == 1 + + # Parameter should be keyword-only with correct annotation + assert sig.parameters['items'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['items'].annotation == list[ItemSchema] + + +def test_get_declaration_list_schema(): + """Test that tool declaration is properly generated for list schema.""" + tool = SetModelResponseTool(list[ItemSchema]) + + declaration = tool._get_declaration() + + assert declaration is not None + assert declaration.name == 'set_model_response' + assert declaration.description is not None + + +@pytest.mark.asyncio +async def test_run_async_list_schema_valid_data(): + """Test tool execution with valid list data.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with valid list data + result = await tool.run_async( + args={ + 'items': [ + {'id': 1, 'name': 'Item 1'}, + {'id': 2, 'name': 'Item 2'}, + {'id': 3, 'name': 'Item 3'}, + ] + }, + tool_context=tool_context, + ) + + # Verify the tool returns list of dicts + assert result is not None + assert isinstance(result, list) + assert len(result) == 3 + assert result[0]['id'] == 1 + assert result[0]['name'] == 'Item 1' + assert result[1]['id'] == 2 + assert result[2]['id'] == 3 + + +@pytest.mark.asyncio +async def test_run_async_list_schema_empty_list(): + """Test tool execution with empty list.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with empty list + result = await tool.run_async( + args={'items': []}, + tool_context=tool_context, + ) + + # Verify the tool returns empty list + assert result is not None + assert isinstance(result, list) + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_run_async_list_schema_validation_error(): + """Test tool execution with invalid list data raises validation error.""" + tool = SetModelResponseTool(list[ItemSchema]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with invalid data (wrong type for id) + with pytest.raises(ValidationError): + await tool.run_async( + args={ + 'items': [ + {'id': 'not_a_number', 'name': 'Item 1'}, + ] + }, + tool_context=tool_context, + ) + + +# Tests for other schema types (list[str], dict, etc.) + + +def test_tool_initialization_list_str_schema(): + """Test tool initialization with list[str] schema.""" + tool = SetModelResponseTool(list[str]) + + assert tool.output_schema == list[str] + assert not tool._is_basemodel + assert not tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert tool.func is not None + + +def test_function_signature_generation_list_str_schema(): + """Test that function signature is correctly generated for list[str] schema.""" + tool = SetModelResponseTool(list[str]) + + sig = inspect.signature(tool.func) + + # Should have a single 'response' parameter with list[str] annotation + assert 'response' in sig.parameters + assert len(sig.parameters) == 1 + assert sig.parameters['response'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['response'].annotation == list[str] + + +@pytest.mark.asyncio +async def test_run_async_list_str_schema(): + """Test tool execution with list[str] data.""" + tool = SetModelResponseTool(list[str]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with list of strings + result = await tool.run_async( + args={'response': ['apple', 'banana', 'cherry']}, + tool_context=tool_context, + ) + + # Verify the tool returns the list directly + assert result is not None + assert isinstance(result, list) + assert result == ['apple', 'banana', 'cherry'] + + +def test_tool_initialization_dict_schema(): + """Test tool initialization with dict schema.""" + tool = SetModelResponseTool(dict[str, int]) + + assert tool.output_schema == dict[str, int] + assert not tool._is_basemodel + assert not tool._is_list_of_basemodel + assert tool.name == 'set_model_response' + assert tool.func is not None + + +def test_function_signature_generation_dict_schema(): + """Test that function signature is correctly generated for dict schema.""" + tool = SetModelResponseTool(dict[str, int]) + + sig = inspect.signature(tool.func) + + # Should have a single 'response' parameter with dict[str, int] annotation + assert 'response' in sig.parameters + assert len(sig.parameters) == 1 + assert sig.parameters['response'].kind == inspect.Parameter.KEYWORD_ONLY + assert sig.parameters['response'].annotation == dict[str, int] + + +@pytest.mark.asyncio +async def test_run_async_dict_schema(): + """Test tool execution with dict data.""" + tool = SetModelResponseTool(dict[str, int]) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with dict data + result = await tool.run_async( + args={'response': {'a': 1, 'b': 2, 'c': 3}}, + tool_context=tool_context, + ) + + # Verify the tool returns the dict directly + assert result is not None + assert isinstance(result, dict) + assert result == {'a': 1, 'b': 2, 'c': 3} diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index b747d1f8..7ebf4f40 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from unittest import mock +from google.adk.code_executors.base_code_executor import BaseCodeExecutor +from google.adk.code_executors.code_execution_utils import CodeExecutionResult from google.adk.models import llm_request as llm_request_model from google.adk.skills import models from google.adk.tools import skill_toolset @@ -27,6 +30,7 @@ def mock_skill1_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill1" frontmatter.description = "Skill 1 description" + frontmatter.allowed_tools = ["test_tool"] frontmatter.model_dump.return_value = { "name": "skill1", "description": "Skill 1 description", @@ -39,10 +43,18 @@ def mock_skill1(mock_skill1_frontmatter): """Fixture for skill1.""" skill = mock.create_autospec(models.Skill, instance=True) skill.name = "skill1" + skill.description = "Skill 1 description" skill.instructions = "instructions for skill1" skill.frontmatter = mock_skill1_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -55,8 +67,25 @@ def mock_skill1(mock_skill1_frontmatter): return "asset content 1" return None + def get_script(name): + if name == "setup.sh": + return models.Script(src="echo setup") + if name == "run.py": + return models.Script(src="print('hello')") + if name == "build.rb": + return models.Script(src="puts 'hello'") + return None + skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset + skill.resources.get_script.side_effect = get_script + skill.resources.list_references.return_value = ["ref1.md"] + skill.resources.list_assets.return_value = ["asset1.txt"] + skill.resources.list_scripts.return_value = [ + "setup.sh", + "run.py", + "build.rb", + ] return skill @@ -66,6 +95,7 @@ def mock_skill2_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill2" frontmatter.description = "Skill 2 description" + frontmatter.allowed_tools = [] frontmatter.model_dump.return_value = { "name": "skill2", "description": "Skill 2 description", @@ -78,10 +108,18 @@ def mock_skill2(mock_skill2_frontmatter): """Fixture for skill2.""" skill = mock.create_autospec(models.Skill, instance=True) skill.name = "skill2" + skill.description = "Skill 2 description" skill.instructions = "instructions for skill2" skill.frontmatter = mock_skill2_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -96,13 +134,22 @@ def mock_skill2(mock_skill2_frontmatter): skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset + skill.resources.list_references.return_value = ["ref2.md"] + skill.resources.list_assets.return_value = ["asset2.txt"] + skill.resources.list_scripts.return_value = [] return skill @pytest.fixture def tool_context_instance(): """Fixture for tool context.""" - return mock.create_autospec(tool_context.ToolContext, instance=True) + ctx = mock.create_autospec(tool_context.ToolContext, instance=True) + ctx._invocation_context = mock.MagicMock() + ctx._invocation_context.agent = mock.MagicMock() + ctx._invocation_context.agent.name = "test_agent" + ctx._invocation_context.agent_states = {} + ctx.agent_name = "test_agent" + return ctx # SkillToolset tests @@ -114,23 +161,23 @@ def test_get_skill(mock_skill1, mock_skill2): def test_list_skills(mock_skill1, mock_skill2): toolset = skill_toolset.SkillToolset([mock_skill1, mock_skill2]) - frontmatters = toolset._list_skills() - assert len(frontmatters) == 2 - assert mock_skill1.frontmatter in frontmatters - assert mock_skill2.frontmatter in frontmatters + skills = toolset._list_skills() + assert len(skills) == 2 + assert mock_skill1 in skills + assert mock_skill2 in skills @pytest.mark.asyncio async def test_get_tools(mock_skill1, mock_skill2): toolset = skill_toolset.SkillToolset([mock_skill1, mock_skill2]) tools = await toolset.get_tools() - assert len(tools) == 3 + assert len(tools) == 4 assert isinstance(tools[0], skill_toolset.ListSkillsTool) assert isinstance(tools[1], skill_toolset.LoadSkillTool) assert isinstance(tools[2], skill_toolset.LoadSkillResourceTool) + assert isinstance(tools[3], skill_toolset.RunSkillScriptTool) -@pytest.mark.asyncio @pytest.mark.asyncio async def test_list_skills_tool( mock_skill1, mock_skill2, tool_context_instance @@ -203,6 +250,14 @@ async def test_load_skill_run_async( "content": "asset content 1", }, ), + ( + {"skill_name": "skill1", "path": "scripts/setup.sh"}, + { + "skill_name": "skill1", + "path": "scripts/setup.sh", + "content": "echo setup", + }, + ), ( {"skill_name": "nonexistent", "path": "references/ref1.md"}, { @@ -223,7 +278,10 @@ async def test_load_skill_run_async( ( {"skill_name": "skill1", "path": "invalid/path.txt"}, { - "error": "Path must start with 'references/' or 'assets/'.", + "error": ( + "Path must start with 'references/', 'assets/'," + " or 'scripts/'." + ), "error_code": "INVALID_RESOURCE_PATH", }, ), @@ -266,7 +324,951 @@ async def test_process_llm_request( llm_req.append_instructions.assert_called_once() args, _ = llm_req.append_instructions.call_args instructions = args[0] - assert len(instructions) == 1 - assert "" in instructions[0] - assert "skill1" in instructions[0] - assert "skill2" in instructions[0] + assert len(instructions) == 2 + assert instructions[0] == skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert "" in instructions[1] + assert "skill1" in instructions[1] + assert "skill2" in instructions[1] + + +def test_default_skill_system_instruction_warning(): + with pytest.warns( + UserWarning, match="DEFAULT_SKILL_SYSTEM_INSTRUCTION is experimental" + ): + instruction = skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert "specialized 'skills'" in instruction + + +def test_duplicate_skill_name_raises(mock_skill1): + skill_dup = mock.create_autospec(models.Skill, instance=True) + skill_dup.name = "skill1" + with pytest.raises(ValueError, match="Duplicate skill name"): + skill_toolset.SkillToolset([mock_skill1, skill_dup]) + + +@pytest.mark.asyncio +async def test_scripts_resource_not_found(mock_skill1, tool_context_instance): + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.LoadSkillResourceTool(toolset) + result = await tool.run_async( + args={"skill_name": "skill1", "path": "scripts/nonexistent.sh"}, + tool_context=tool_context_instance, + ) + assert result["error_code"] == "RESOURCE_NOT_FOUND" + + +# RunSkillScriptTool tests + + +def _make_tool_context_with_agent(agent=None): + """Creates a mock ToolContext with _invocation_context.agent.""" + ctx = mock.MagicMock(spec=tool_context.ToolContext) + ctx._invocation_context = mock.MagicMock() + ctx._invocation_context.agent = agent or mock.MagicMock() + ctx._invocation_context.agent.name = "test_agent" + ctx._invocation_context.agent_states = {} + ctx.agent_name = "test_agent" + ctx.state = {} + return ctx + + +def _make_mock_executor(stdout="", stderr=""): + """Creates a mock code executor that returns the given output.""" + executor = mock.create_autospec(BaseCodeExecutor, instance=True) + executor.execute_code.return_value = CodeExecutionResult( + stdout=stdout, stderr=stderr + ) + return executor + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "args, expected_error_code", + [ + ( + {"script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1"}, + "MISSING_SCRIPT_PATH", + ), + ( + {"skill_name": "", "script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1", "script_path": ""}, + "MISSING_SCRIPT_PATH", + ), + ], +) +async def test_execute_script_missing_params( + mock_skill1, args, expected_error_code +): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async(args=args, tool_context=ctx) + assert result["error_code"] == expected_error_code + + +@pytest.mark.asyncio +async def test_execute_script_skill_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "nonexistent", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "SKILL_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_script_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "nonexistent.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "SCRIPT_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_no_code_executor(mock_skill1): + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + # Agent without code_executor attribute + agent = mock.MagicMock(spec=[]) + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_agent_code_executor_none(mock_skill1): + """Agent has code_executor attr but it's None.""" + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = None + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_unsupported_type(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "build.rb"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +@pytest.mark.asyncio +async def test_execute_script_python_success(mock_skill1): + executor = _make_mock_executor(stdout="hello\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello\n" + assert result["stderr"] == "" + assert result["skill_name"] == "skill1" + assert result["script_path"] == "run.py" + + # Verify the code passed to executor runs the python scripts + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "_materialize_and_run()" in code_input.code + assert "import runpy" in code_input.code + assert "sys.argv = ['scripts/run.py']" in code_input.code + assert ( + "runpy.run_path('scripts/run.py', run_name='__main__')" in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_shell_success(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "setup\n" + + # Verify the code wraps in subprocess.run with JSON envelope + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "subprocess.run" in code_input.code + assert "bash" in code_input.code + assert "__shell_result__" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_python(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": {"verbose": True, "count": "3"}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert ( + "['scripts/run.py', '--verbose', 'True', '--count', '3']" + in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_shell(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "setup.sh", + "args": {"force": True}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "['bash', 'scripts/setup.sh', '--force', 'True']" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_scripts_prefix_stripping(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "scripts/setup.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["script_path"] == "scripts/setup.sh" + + +@pytest.mark.asyncio +async def test_execute_script_toolset_executor_priority(mock_skill1): + """Toolset-level executor takes priority over agent's.""" + toolset_executor = _make_mock_executor(stdout="from toolset\n") + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=toolset_executor + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from toolset\n" + toolset_executor.execute_code.assert_called_once() + agent_executor.execute_code.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_script_agent_executor_fallback(mock_skill1): + """Falls back to agent's code executor when toolset has none.""" + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from agent\n" + agent_executor.execute_code.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_script_execution_error(mock_skill1): + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("boom") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "boom" in result["error"] + assert result["error"].startswith("Failed to execute script 'run.py':") + + +@pytest.mark.asyncio +async def test_execute_script_stderr_only_sets_error_status(mock_skill1): + """stderr with no stdout should report error status.""" + executor = _make_mock_executor(stdout="", stderr="fatal error\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stderr"] == "fatal error\n" + + +@pytest.mark.asyncio +async def test_execute_script_stderr_with_stdout_sets_warning(mock_skill1): + """stderr alongside stdout should report warning status.""" + executor = _make_mock_executor(stdout="output\n", stderr="deprecation\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "output\n" + assert result["stderr"] == "deprecation\n" + + +@pytest.mark.asyncio +async def test_execute_script_execution_error_truncated(mock_skill1): + """Long exception messages are truncated to avoid wasting LLM tokens.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("x" * 300) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + # 200 chars of the message + "..." suffix + the prefix + assert result["error"].endswith("...") + assert len(result["error"]) < 300 + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_caught(mock_skill1): + """sys.exit() in a script should not terminate the process.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(1) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "exited with code 1" in result["error"] + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_zero_is_success(mock_skill1): + """sys.exit(0) is a normal termination and should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(0) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_none_is_success(mock_skill1): + """sys.exit() with no arg (None) should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(None) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_shell_includes_timeout(mock_skill1): + """Shell wrapper includes timeout in subprocess.run.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=executor, script_timeout=60 + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "timeout=60" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_extensionless_unsupported(mock_skill1): + """Files without extensions should return UNSUPPORTED_SCRIPT_TYPE.""" + # Add a script with no extension to the mock + original_side_effect = mock_skill1.resources.get_script.side_effect + + def get_script_extended(name): + if name == "noext": + return models.Script(src="print('hi')") + return original_side_effect(name) + + mock_skill1.resources.get_script.side_effect = get_script_extended + + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "noext"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +# ── Integration tests using real UnsafeLocalCodeExecutor ── + + +def _make_skill_with_script(skill_name, script_name, script): + """Creates a minimal mock Skill with a single script.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = skill_name + skill.description = f"Test skill {skill_name}" + skill.instructions = "test instructions" + fm = mock.create_autospec(models.Frontmatter, instance=True) + fm.name = skill_name + fm.description = f"Test skill {skill_name}" + skill.frontmatter = fm + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + + def get_script(name): + if name == script_name: + return script + return None + + skill.resources.get_script.side_effect = get_script + skill.resources.get_reference.return_value = None + skill.resources.get_asset.return_value = None + skill.resources.list_references.return_value = [] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = [script_name] + return skill + + +def _make_real_executor_toolset(skills, **kwargs): + """Creates a SkillToolset with a real UnsafeLocalCodeExecutor.""" + from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor + + executor = UnsafeLocalCodeExecutor() + return skill_toolset.SkillToolset(skills, code_executor=executor, **kwargs) + + +@pytest.mark.asyncio +async def test_integration_python_stdout(): + """Real executor: Python script stdout is captured.""" + script = models.Script(src="print('hello world')") + skill = _make_skill_with_script("test_skill", "hello.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "hello.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello world\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_integration_python_sys_exit_zero(): + """Real executor: sys.exit(0) is treated as success.""" + script = models.Script(src="import sys; sys.exit(0)") + skill = _make_skill_with_script("test_skill", "exit_zero.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "exit_zero.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_integration_shell_stdout_and_stderr(): + """Real executor: shell script preserves both stdout and stderr.""" + script = models.Script(src="echo output; echo warning >&2") + skill = _make_skill_with_script("test_skill", "both.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "both.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert "output" in result["stdout"] + assert "warning" in result["stderr"] + + +@pytest.mark.asyncio +async def test_integration_shell_stderr_only(): + """Real executor: shell script with only stderr reports error.""" + script = models.Script(src="echo failure >&2") + skill = _make_skill_with_script("test_skill", "err.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "err.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "failure" in result["stderr"] + + +# ── Shell JSON envelope parsing (unit tests with mock executor) ── + + +@pytest.mark.asyncio +async def test_shell_json_envelope_parsed(mock_skill1): + """Shell JSON envelope is correctly unpacked by run_async.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "hello from shell\n", + "stderr": "", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello from shell\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_nonzero_returncode(mock_skill1): + """Non-zero returncode in shell envelope sets stderr.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "", + "stderr": "", + "returncode": 2, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "Exit code 2" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_json_envelope_with_stderr(mock_skill1): + """Shell envelope with both stdout and stderr reports warning.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "data\n", + "stderr": "deprecation warning\n", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "data\n" + assert result["stderr"] == "deprecation warning\n" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_timeout(mock_skill1): + """Shell envelope from TimeoutExpired reports error status.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "partial output\n", + "stderr": "Timed out after 300s", + "returncode": -1, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stdout"] == "partial output\n" + assert "Timed out" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_non_json_stdout_passthrough(mock_skill1): + """Non-JSON shell stdout is passed through without parsing.""" + executor = _make_mock_executor(stdout="plain text output\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "plain text output\n" + + +# ── input_files packaging ── + + +@pytest.mark.asyncio +async def test_execute_script_input_files_packaged(mock_skill1): + """Verify references, assets, and scripts are packaged inside the wrapper code.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + + # input_files is no longer populated; it's serialized inside the script + assert code_input.input_files is None or len(code_input.input_files) == 0 + + # Ensure the extracted literal contains our fake files + assert "references/ref1.md" in code_input.code + assert "assets/asset1.txt" in code_input.code + assert "scripts/setup.sh" in code_input.code + assert "scripts/run.py" in code_input.code + assert "scripts/build.rb" in code_input.code + + # Verify content mappings exist in the string + assert "'references/ref1.md': 'ref content 1'" in code_input.code + assert "'assets/asset1.txt': 'asset content 1'" in code_input.code + + +# ── Integration: shell non-zero exit ── + + +@pytest.mark.asyncio +async def test_integration_shell_nonzero_exit(): + """Real executor: shell script with non-zero exit via JSON envelope.""" + script = models.Script(src="exit 42") + skill = _make_skill_with_script("test_skill", "fail.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "fail.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "42" in result["stderr"] + + +# ── Finding 1: system instruction references correct tool name ── + + +def test_system_instruction_references_run_skill_script(): + """System instruction must reference the actual tool name.""" + assert "run_skill_script" in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert ( + "execute_skill_script" + not in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + ) + + +# ── Finding 2: empty files are mounted (not silently dropped) ── + + +@pytest.mark.asyncio +async def test_execute_script_empty_files_mounted(): + """Verify empty files are included in wrapper code, not dropped.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_empty" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: "" if n == "empty.md" else None + ) + skill.resources.get_asset.side_effect = ( + lambda n: "" if n == "empty.cfg" else None + ) + skill.resources.get_script.side_effect = ( + lambda n: models.Script(src="") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["empty.md"] + skill.resources.list_assets.return_value = ["empty.cfg"] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_empty", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "'references/empty.md': ''" in code_input.code + assert "'assets/empty.cfg': ''" in code_input.code + assert "'scripts/run.py': ''" in code_input.code + + +# ── Finding 3: invalid args type returns clear error ── + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "bad_args", + [ + "not a dict", + ["a", "list"], + 42, + True, + ], +) +async def test_execute_script_invalid_args_type(mock_skill1, bad_args): + """Non-dict args should return INVALID_ARGS_TYPE, not crash.""" + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": bad_args, + }, + tool_context=ctx, + ) + assert result["error_code"] == "INVALID_ARGS_TYPE" + executor.execute_code.assert_not_called() + + +# ── Finding 4: binary file content is handled in wrapper ── + + +@pytest.mark.asyncio +async def test_execute_script_binary_content_packaged(): + """Verify binary asset content uses 'wb' mode in wrapper code.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_bin" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: b"\x00\x01\x02" if n == "data.bin" else None + ) + skill.resources.get_asset.return_value = None + skill.resources.get_script.side_effect = lambda n: ( + models.Script(src="print('ok')") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["data.bin"] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_bin", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + # Binary content should appear as bytes literal + assert "b'\\x00\\x01\\x02'" in code_input.code + # Wrapper code handles binary with 'wb' mode + assert "'wb' if isinstance(content, bytes)" in code_input.code + + +@pytest.mark.asyncio +async def test_skill_toolset_dynamic_tool_resolution(mock_skill1): + # Set up a skill with additional_tools in metadata + mock_skill1.frontmatter.metadata = { + "adk_additional_tools": ["my_custom_tool", "my_func"] + } + mock_skill1.name = "skill1" + + # Prepare additional tools + custom_tool = mock.create_autospec(skill_toolset.BaseTool, instance=True) + custom_tool.name = "my_custom_tool" + + def my_func(): + """My function description.""" + pass + + toolset = skill_toolset.SkillToolset( + [mock_skill1], + additional_tools=[custom_tool, my_func], + ) + + ctx = _make_tool_context_with_agent() + # Initial tools (only core) + tools = await toolset.get_tools(readonly_context=ctx) + assert len(tools) == 4 + + # Activate skill + load_tool = skill_toolset.LoadSkillTool(toolset) + await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx) + + # Dynamic tools should now be resolved + tools = await toolset.get_tools(readonly_context=ctx) + tool_names = {t.name for t in tools} + assert "my_custom_tool" in tool_names + assert "my_func" in tool_names + + # Check specific tool resolution details + my_func_tool = next(t for t in tools if t.name == "my_func") + assert isinstance(my_func_tool, skill_toolset.FunctionTool) + assert my_func_tool.description == "My function description." + + +@pytest.mark.asyncio +async def test_skill_toolset_resolution_error_handling(mock_skill1, caplog): + mock_skill1.frontmatter.metadata = { + "adk_additional_tools": ["nonexistent_tool"] + } + mock_skill1.name = "skill1" + toolset = skill_toolset.SkillToolset([mock_skill1]) + ctx = _make_tool_context_with_agent() + + # Activate skill + load_tool = skill_toolset.LoadSkillTool(toolset) + await load_tool.run_async(args={"name": "skill1"}, tool_context=ctx) + + with caplog.at_level(logging.WARNING): + tools = await toolset.get_tools(readonly_context=ctx) + + # Should still return basic skill tools + assert len(tools) == 4 diff --git a/tests/unittests/tools/test_url_context_tool.py b/tests/unittests/tools/test_url_context_tool.py index 53ee7e62..8fd44b59 100644 --- a/tests/unittests/tools/test_url_context_tool.py +++ b/tests/unittests/tools/test_url_context_tool.py @@ -190,6 +190,27 @@ class TestUrlContextTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = UrlContextTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].url_context is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 3ade634d..b15d3a1f 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -376,6 +376,29 @@ class TestVertexAiSearchTool: tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = VertexAiSearchTool(data_store_id='test_data_store') + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/utils/test_context_utils.py b/tests/unittests/utils/test_context_utils.py new file mode 100644 index 00000000..a5e2d656 --- /dev/null +++ b/tests/unittests/utils/test_context_utils.py @@ -0,0 +1,107 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for context_utils module.""" + +from typing import Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.context import Context +from google.adk.tools.tool_context import ToolContext +from google.adk.utils.context_utils import find_context_parameter + + +class TestFindContextParameter: + """Tests for find_context_parameter function.""" + + def test_find_context_parameter_with_context_type(self): + """Test detection of Context type annotation.""" + + def my_tool(query: str, ctx: Context) -> str: + return query + + assert find_context_parameter(my_tool) == 'ctx' + + def test_find_context_parameter_with_tool_context_type(self): + """Test detection of ToolContext type annotation.""" + + def my_tool(query: str, tool_context: ToolContext) -> str: + return query + + assert find_context_parameter(my_tool) == 'tool_context' + + def test_find_context_parameter_with_callback_context_type(self): + """Test detection of CallbackContext type annotation.""" + + def my_callback(ctx: CallbackContext) -> None: + pass + + assert find_context_parameter(my_callback) == 'ctx' + + def test_find_context_parameter_with_optional_context(self): + """Test detection of Optional[Context] type annotation.""" + + def my_tool(query: str, context: Optional[Context] = None) -> str: + return query + + assert find_context_parameter(my_tool) == 'context' + + def test_find_context_parameter_with_custom_name(self): + """Test that any parameter name works with Context type.""" + + def my_tool(query: str, my_custom_ctx: Context) -> str: + return query + + assert find_context_parameter(my_tool) == 'my_custom_ctx' + + def test_find_context_parameter_no_context(self): + """Test function without context parameter returns None.""" + + def my_tool(query: str, count: int) -> str: + return query + + assert find_context_parameter(my_tool) is None + + def test_find_context_parameter_no_annotations(self): + """Test function without type annotations returns None.""" + + def my_tool(query, ctx): + return query + + assert find_context_parameter(my_tool) is None + + def test_find_context_parameter_with_none_func(self): + """Test that None function returns None.""" + assert find_context_parameter(None) is None + + def test_find_context_parameter_returns_first_match(self): + """Test that first context parameter is returned if multiple exist.""" + + def my_tool(first_ctx: Context, second_ctx: Context) -> str: + return 'test' + + assert find_context_parameter(my_tool) == 'first_ctx' + + def test_find_context_parameter_with_mixed_params(self): + """Test context parameter detection with various other parameters.""" + + def my_tool( + query: str, + count: int, + ctx: Context, + optional_param: Optional[str] = None, + ) -> str: + return query + + assert find_context_parameter(my_tool) == 'ctx' diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index cbac37e3..2af1584b 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -18,6 +18,7 @@ from google.adk.utils.model_name_utils import extract_model_name from google.adk.utils.model_name_utils import is_gemini_1_model from google.adk.utils.model_name_utils import is_gemini_2_or_above from google.adk.utils.model_name_utils import is_gemini_model +from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled class TestExtractModelName: @@ -318,3 +319,15 @@ class TestModelNameUtilsIntegration: f'Inconsistent Gemini 2.0+ classification for {simple_model} vs' f' {path_model}' ) + + +class TestGeminiModelIdCheckFlag: + """Tests for Gemini model-id check override flag.""" + + def test_default_is_disabled(self, monkeypatch): + monkeypatch.delenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', raising=False) + assert is_gemini_model_id_check_disabled() is False + + def test_true_enables_check_bypass(self, monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + assert is_gemini_model_id_check_disabled() is True diff --git a/tests/unittests/utils/test_output_schema_utils.py b/tests/unittests/utils/test_output_schema_utils.py index fc2f6fb5..cf759c99 100644 --- a/tests/unittests/utils/test_output_schema_utils.py +++ b/tests/unittests/utils/test_output_schema_utils.py @@ -15,6 +15,7 @@ from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.utils.output_schema_utils import can_use_output_schema_with_tools import pytest @@ -37,6 +38,11 @@ import pytest (Claude(model="claude-3.7-sonnet"), "1", False), (Claude(model="claude-3.7-sonnet"), "0", False), (Claude(model="claude-3.7-sonnet"), None, False), + (LiteLlm(model="openai/gpt-4o"), "1", True), + (LiteLlm(model="openai/gpt-4o"), "0", True), + (LiteLlm(model="openai/gpt-4o"), None, True), + (LiteLlm(model="anthropic/claude-3.7-sonnet"), None, True), + (LiteLlm(model="fireworks_ai/llama-v3p1-70b"), None, True), ], ) def test_can_use_output_schema_with_tools( diff --git a/tests/unittests/utils/test_schema_utils.py b/tests/unittests/utils/test_schema_utils.py new file mode 100644 index 00000000..8f68ecdb --- /dev/null +++ b/tests/unittests/utils/test_schema_utils.py @@ -0,0 +1,146 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _schema_utils module.""" + +from google.adk.utils._schema_utils import get_list_inner_type +from google.adk.utils._schema_utils import is_basemodel_schema +from google.adk.utils._schema_utils import is_list_of_basemodel +from google.adk.utils._schema_utils import validate_schema +from pydantic import BaseModel + + +class SampleModel(BaseModel): + """Sample model for testing.""" + + name: str + value: int + + +class TestIsBasemodelSchema: + """Tests for is_basemodel_schema function.""" + + def test_basemodel_class_returns_true(self): + """Test that a BaseModel class returns True.""" + assert is_basemodel_schema(SampleModel) + + def test_list_of_basemodel_returns_false(self): + """Test that list[BaseModel] returns False.""" + assert not is_basemodel_schema(list[SampleModel]) + + def test_list_of_str_returns_false(self): + """Test that list[str] returns False.""" + assert not is_basemodel_schema(list[str]) + + def test_dict_returns_false(self): + """Test that dict types return False.""" + assert not is_basemodel_schema(dict[str, int]) + + def test_plain_str_returns_false(self): + """Test that plain str returns False.""" + assert not is_basemodel_schema(str) + + def test_plain_int_returns_false(self): + """Test that plain int returns False.""" + assert not is_basemodel_schema(int) + + +class TestIsListOfBasemodel: + """Tests for is_list_of_basemodel function.""" + + def test_list_of_basemodel_returns_true(self): + """Test that list[BaseModel] returns True.""" + assert is_list_of_basemodel(list[SampleModel]) + + def test_basemodel_class_returns_false(self): + """Test that a plain BaseModel class returns False.""" + assert not is_list_of_basemodel(SampleModel) + + def test_list_of_str_returns_false(self): + """Test that list[str] returns False.""" + assert not is_list_of_basemodel(list[str]) + + def test_list_of_int_returns_false(self): + """Test that list[int] returns False.""" + assert not is_list_of_basemodel(list[int]) + + def test_dict_returns_false(self): + """Test that dict types return False.""" + assert not is_list_of_basemodel(dict[str, int]) + + def test_plain_list_returns_false(self): + """Test that plain list (no type arg) returns False.""" + assert not is_list_of_basemodel(list) + + +class TestGetListInnerType: + """Tests for get_list_inner_type function.""" + + def test_list_of_basemodel_returns_inner_type(self): + """Test that list[BaseModel] returns the inner type.""" + assert get_list_inner_type(list[SampleModel]) is SampleModel + + def test_basemodel_class_returns_none(self): + """Test that a plain BaseModel class returns None.""" + assert get_list_inner_type(SampleModel) is None + + def test_list_of_str_returns_none(self): + """Test that list[str] returns None.""" + assert get_list_inner_type(list[str]) is None + + def test_dict_returns_none(self): + """Test that dict types return None.""" + assert get_list_inner_type(dict[str, int]) is None + + +class TestValidateSchema: + """Tests for validate_schema function.""" + + def test_basemodel_schema(self): + """Test validation with a BaseModel schema.""" + json_text = '{"name": "test", "value": 42}' + result = validate_schema(SampleModel, json_text) + assert result == {'name': 'test', 'value': 42} + + def test_basemodel_schema_excludes_none(self): + """Test that None values are excluded from the result.""" + + class ModelWithOptional(BaseModel): + name: str + optional_field: str | None = None + + json_text = '{"name": "test", "optional_field": null}' + result = validate_schema(ModelWithOptional, json_text) + assert result == {'name': 'test'} + + def test_list_of_basemodel_schema(self): + """Test validation with a list[BaseModel] schema.""" + json_text = '[{"name": "item1", "value": 1}, {"name": "item2", "value": 2}]' + result = validate_schema(list[SampleModel], json_text) + assert result == [ + {'name': 'item1', 'value': 1}, + {'name': 'item2', 'value': 2}, + ] + + def test_list_of_str_schema(self): + """Test validation with a list[str] schema.""" + json_text = '["a", "b", "c"]' + result = validate_schema(list[str], json_text) + assert result == ['a', 'b', 'c'] + + def test_dict_schema(self): + """Test validation with a dict schema.""" + json_text = '{"key1": 1, "key2": 2}' + result = validate_schema(dict[str, int], json_text) + assert result == {'key1': 1, 'key2': 2}