You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Set agent_engine_id in the service constructor, also use the agent_engine_id field instead of overriding app_name in FastAPI endpoint
PiperOrigin-RevId: 770427903
This commit is contained in:
committed by
Copybara-Service
parent
a7ea374dfb
commit
fc65873d7c
@@ -276,7 +276,6 @@ def get_fast_api_app(
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
if session_service_uri:
|
||||
if session_service_uri.startswith("agentengine://"):
|
||||
# Create vertex session service
|
||||
@@ -285,8 +284,9 @@ def get_fast_api_app(
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
session_service = VertexAiSessionService(
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
project=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
location=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
agent_engine_id=agent_engine_id,
|
||||
)
|
||||
else:
|
||||
session_service = DatabaseSessionService(db_url=session_service_uri)
|
||||
@@ -357,8 +357,6 @@ def get_fast_api_app(
|
||||
async def get_session(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -371,8 +369,6 @@ def get_fast_api_app(
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
list_sessions_response = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
@@ -393,8 +389,6 @@ def get_fast_api_app(
|
||||
session_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
if (
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
@@ -419,8 +413,6 @@ def get_fast_api_app(
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
logger.info("New session created")
|
||||
return await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
@@ -660,8 +652,6 @@ def get_fast_api_app(
|
||||
|
||||
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
||||
async def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
await session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -677,7 +667,6 @@ def get_fast_api_app(
|
||||
artifact_name: str,
|
||||
version: Optional[int] = Query(None),
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = await artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -700,7 +689,6 @@ def get_fast_api_app(
|
||||
artifact_name: str,
|
||||
version_id: int,
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = await artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -719,7 +707,6 @@ def get_fast_api_app(
|
||||
async def list_artifact_names(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return await artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -731,7 +718,6 @@ def get_fast_api_app(
|
||||
async def list_artifact_versions(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
) -> list[int]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return await artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -745,7 +731,6 @@ def get_fast_api_app(
|
||||
async def delete_artifact(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
):
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
await artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -755,10 +740,8 @@ def get_fast_api_app(
|
||||
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
@@ -776,11 +759,9 @@ def get_fast_api_app(
|
||||
|
||||
@app.post("/run_sse")
|
||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
@@ -818,8 +799,6 @@ def get_fast_api_app(
|
||||
async def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -875,8 +854,6 @@ def get_fast_api_app(
|
||||
) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
@@ -940,7 +917,7 @@ def get_fast_api_app(
|
||||
return runner_dict[app_name]
|
||||
root_agent = agent_loader.load_agent(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Optional
|
||||
import urllib.parse
|
||||
|
||||
from dateutil import parser
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from google import genai
|
||||
@@ -40,15 +39,27 @@ logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class VertexAiSessionService(BaseSessionService):
|
||||
"""Connects to the managed Vertex AI Session Service."""
|
||||
"""Connects to the Vertex AI Agent Engine Session Service using GenAI API client.
|
||||
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/sessions/overview
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str = None,
|
||||
location: str = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
agent_engine_id: Optional[str] = None,
|
||||
):
|
||||
self.project = project
|
||||
self.location = location
|
||||
"""Initializes the VertexAiSessionService.
|
||||
|
||||
Args:
|
||||
project: The project id of the project to use.
|
||||
location: The location of the project to use.
|
||||
agent_engine_id: The resource ID of the agent engine to use.
|
||||
"""
|
||||
self._project = project
|
||||
self._location = location
|
||||
self._agent_engine_id = agent_engine_id
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
@@ -64,14 +75,13 @@ class VertexAiSessionService(BaseSessionService):
|
||||
'User-provided Session id is not supported for'
|
||||
' VertexAISessionService.'
|
||||
)
|
||||
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
session_json_dict = {'user_id': user_id}
|
||||
if state:
|
||||
session_json_dict['session_state'] = state
|
||||
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
api_response = await api_client.async_request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
|
||||
@@ -130,10 +140,10 @@ class VertexAiSessionService(BaseSessionService):
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
# Get session resource
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
get_session_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
@@ -203,14 +213,14 @@ class VertexAiSessionService(BaseSessionService):
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
|
||||
if user_id:
|
||||
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
|
||||
path = path + f'?filter=user_id={parsed_user_id}'
|
||||
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=path,
|
||||
@@ -236,8 +246,9 @@ class VertexAiSessionService(BaseSessionService):
|
||||
async def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
|
||||
api_client = self._get_api_client()
|
||||
|
||||
try:
|
||||
await api_client.async_request(
|
||||
http_method='DELETE',
|
||||
@@ -253,8 +264,8 @@ class VertexAiSessionService(BaseSessionService):
|
||||
# Update the in-memory session.
|
||||
await super().append_event(session=session, event=event)
|
||||
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)
|
||||
api_client = self._get_api_client()
|
||||
await api_client.async_request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
|
||||
@@ -262,15 +273,34 @@ class VertexAiSessionService(BaseSessionService):
|
||||
)
|
||||
return event
|
||||
|
||||
def _get_reasoning_engine_id(self, app_name: str):
|
||||
if self._agent_engine_id:
|
||||
return self._agent_engine_id
|
||||
|
||||
def _get_api_client(project: str, location: str):
|
||||
"""Instantiates an API client for the given project and location.
|
||||
if app_name.isdigit():
|
||||
return app_name
|
||||
|
||||
It needs to be instantiated inside each request so that the event loop
|
||||
management.
|
||||
"""
|
||||
client = genai.Client(vertexai=True, project=project, location=location)
|
||||
return client._api_client
|
||||
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
|
||||
match = re.fullmatch(pattern, app_name)
|
||||
|
||||
if not bool(match):
|
||||
raise ValueError(
|
||||
f'App name {app_name} is not valid. It should either be the full'
|
||||
' ReasoningEngine resource name, or the reasoning engine id.'
|
||||
)
|
||||
|
||||
return match.groups()[-1]
|
||||
|
||||
def _get_api_client(self):
|
||||
"""Instantiates an API client for the given project and location.
|
||||
|
||||
It needs to be instantiated inside each request so that the event loop
|
||||
management can be properly propagated.
|
||||
"""
|
||||
client = genai.Client(
|
||||
vertexai=True, project=self._project, location=self._location
|
||||
)
|
||||
return client._api_client
|
||||
|
||||
|
||||
def _convert_event_to_json(event: Event) -> Dict[str, Any]:
|
||||
@@ -366,19 +396,3 @@ def _from_api_event(api_event: Dict[str, Any]) -> Event:
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def _parse_reasoning_engine_id(app_name: str):
|
||||
if app_name.isdigit():
|
||||
return app_name
|
||||
|
||||
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
|
||||
match = re.fullmatch(pattern, app_name)
|
||||
|
||||
if not bool(match):
|
||||
raise ValueError(
|
||||
f'App name {app_name} is not valid. It should either be the full'
|
||||
' ReasoningEngine resource name, or the reasoning engine id.'
|
||||
)
|
||||
|
||||
return match.groups()[-1]
|
||||
|
||||
@@ -245,8 +245,14 @@ class MockApiClient:
|
||||
raise ValueError(f'Unsupported http method: {http_method}')
|
||||
|
||||
|
||||
def mock_vertex_ai_session_service():
|
||||
def mock_vertex_ai_session_service(agent_engine_id: Optional[str] = None):
|
||||
"""Creates a mock Vertex AI Session service for testing."""
|
||||
if agent_engine_id:
|
||||
return VertexAiSessionService(
|
||||
project='test-project',
|
||||
location='test-location',
|
||||
agent_engine_id=agent_engine_id,
|
||||
)
|
||||
return VertexAiSessionService(
|
||||
project='test-project', location='test-location'
|
||||
)
|
||||
@@ -265,7 +271,7 @@ def mock_get_api_client():
|
||||
'2': (MOCK_EVENT_JSON_2, 'my_token'),
|
||||
}
|
||||
with mock.patch(
|
||||
'google.adk.sessions.vertex_ai_session_service._get_api_client',
|
||||
'google.adk.sessions.vertex_ai_session_service.VertexAiSessionService._get_api_client',
|
||||
return_value=api_client,
|
||||
):
|
||||
yield
|
||||
@@ -273,8 +279,12 @@ def mock_get_api_client():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_get_empty_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
@pytest.mark.parametrize('agent_engine_id', [None, '123'])
|
||||
async def test_get_empty_session(agent_engine_id):
|
||||
if agent_engine_id:
|
||||
session_service = mock_vertex_ai_session_service(agent_engine_id)
|
||||
else:
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='0'
|
||||
|
||||
Reference in New Issue
Block a user