You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: allow Vertex AI Client initialization with API Key
Co-authored-by: Ankur Sharma <ankusharma@google.com> PiperOrigin-RevId: 861335865
This commit is contained in:
committed by
Copybara-Service
parent
553e376718
commit
43d6075ea7
@@ -159,18 +159,25 @@ class _VertexAiEvalFacade(Evaluator):
|
||||
"""
|
||||
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
||||
location = os.environ.get("GOOGLE_CLOUD_LOCATION", None)
|
||||
api_key = os.environ.get("GOOGLE_API_KEY", None)
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
|
||||
if not location:
|
||||
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
|
||||
from ..dependencies.vertexai import vertexai
|
||||
|
||||
from vertexai import Client
|
||||
from vertexai import types as vertexai_types
|
||||
|
||||
client = Client(project=project_id, location=location)
|
||||
if api_key:
|
||||
client = vertexai.Client(api_key=api_key)
|
||||
elif project_id or location:
|
||||
if not project_id:
|
||||
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
|
||||
if not location:
|
||||
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
|
||||
client = vertexai.Client(project=project_id, location=location)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either API Key or Google cloud Project id and location should be"
|
||||
" specified."
|
||||
)
|
||||
|
||||
return client.evals.evaluate(
|
||||
dataset=vertexai_types.EvaluationDataset(eval_dataset_df=dataset),
|
||||
dataset=vertexai.types.EvaluationDataset(eval_dataset_df=dataset),
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
"""Tests for the Response Evaluator."""
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
from google.adk.dependencies.vertexai import vertexai
|
||||
@@ -23,6 +24,7 @@ from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.evaluator import EvalStatus
|
||||
from google.adk.evaluation.vertex_ai_eval_facade import _VertexAiEvalFacade
|
||||
from google.genai import types as genai_types
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
vertexai_types = vertexai.types
|
||||
@@ -246,3 +248,89 @@ class TestVertexAiEvalFacade:
|
||||
)
|
||||
assert evaluation_result.overall_eval_status == EvalStatus.FAILED
|
||||
assert mock_perform_eval.call_count == num_invocations
|
||||
|
||||
def test_perform_eval_with_api_key(self, mocker):
|
||||
mocker.patch.dict(
|
||||
os.environ, {"GOOGLE_API_KEY": "test_api_key"}, clear=True
|
||||
)
|
||||
mock_client_cls = mocker.patch(
|
||||
"google.adk.dependencies.vertexai.vertexai.Client"
|
||||
)
|
||||
mock_client_instance = mock_client_cls.return_value
|
||||
dummy_dataset = pd.DataFrame(
|
||||
[{"prompt": "p", "reference": "r", "response": "r"}]
|
||||
)
|
||||
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
|
||||
|
||||
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
|
||||
|
||||
mock_client_cls.assert_called_once_with(api_key="test_api_key")
|
||||
mock_client_instance.evals.evaluate.assert_called_once()
|
||||
|
||||
def test_perform_eval_with_project_and_location(self, mocker):
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GOOGLE_CLOUD_PROJECT": "test_project",
|
||||
"GOOGLE_CLOUD_LOCATION": "test_location",
|
||||
},
|
||||
clear=True,
|
||||
)
|
||||
mock_client_cls = mocker.patch(
|
||||
"google.adk.dependencies.vertexai.vertexai.Client"
|
||||
)
|
||||
mock_client_instance = mock_client_cls.return_value
|
||||
dummy_dataset = pd.DataFrame(
|
||||
[{"prompt": "p", "reference": "r", "response": "r"}]
|
||||
)
|
||||
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
|
||||
|
||||
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
|
||||
|
||||
mock_client_cls.assert_called_once_with(
|
||||
project="test_project", location="test_location"
|
||||
)
|
||||
mock_client_instance.evals.evaluate.assert_called_once()
|
||||
|
||||
def test_perform_eval_with_project_only_raises_error(self, mocker):
|
||||
mocker.patch.dict(
|
||||
os.environ, {"GOOGLE_CLOUD_PROJECT": "test_project"}, clear=True
|
||||
)
|
||||
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
|
||||
dummy_dataset = pd.DataFrame(
|
||||
[{"prompt": "p", "reference": "r", "response": "r"}]
|
||||
)
|
||||
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing location."):
|
||||
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
|
||||
|
||||
def test_perform_eval_with_location_only_raises_error(self, mocker):
|
||||
mocker.patch.dict(
|
||||
os.environ, {"GOOGLE_CLOUD_LOCATION": "test_location"}, clear=True
|
||||
)
|
||||
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
|
||||
dummy_dataset = pd.DataFrame(
|
||||
[{"prompt": "p", "reference": "r", "response": "r"}]
|
||||
)
|
||||
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing project id."):
|
||||
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
|
||||
|
||||
def test_perform_eval_with_no_env_vars_raises_error(self, mocker):
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
|
||||
dummy_dataset = pd.DataFrame(
|
||||
[{"prompt": "p", "reference": "r", "response": "r"}]
|
||||
)
|
||||
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either API Key or Google cloud Project id and location should be"
|
||||
" specified."
|
||||
),
|
||||
):
|
||||
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
|
||||
|
||||
Reference in New Issue
Block a user