You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: remove hardcoded google-cloud-aiplatform version in agent engine requirements
This fixes e.g. `--trace_to_cloud flag` Co-authored-by: Max Ind <maxind@google.com> PiperOrigin-RevId: 834190152
This commit is contained in:
committed by
Copybara-Service
parent
0cc3d6d6d5
commit
e15e19da05
@@ -776,11 +776,7 @@ def to_agent_engine(
|
||||
if not os.path.exists(requirements_txt_path):
|
||||
click.echo(f'Creating {requirements_txt_path}...')
|
||||
with open(requirements_txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(
|
||||
'google-cloud-aiplatform[adk,agent_engines] @ '
|
||||
'git+https://github.com/googleapis/python-aiplatform.git@'
|
||||
'bf1851e59cb34e63b509a2a610e72691e1c4ca28'
|
||||
)
|
||||
f.write('google-cloud-aiplatform[adk,agent_engines]')
|
||||
click.echo(f'Created {requirements_txt_path}')
|
||||
agent_config['requirements_file'] = agent_config.get(
|
||||
'requirements',
|
||||
|
||||
@@ -25,22 +25,17 @@ import sys
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import runtime_checkable
|
||||
from typing import TextIO
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import create_mcp_http_client
|
||||
from mcp.client.streamable_http import McpHttpClientFactory
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except ImportError as e:
|
||||
|
||||
@@ -89,11 +84,6 @@ class SseConnectionParams(BaseModel):
|
||||
sse_read_timeout: float = 60 * 5.0
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CheckableMcpHttpClientFactory(McpHttpClientFactory, Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class StreamableHTTPConnectionParams(BaseModel):
|
||||
"""Parameters for the MCP Streamable HTTP connection.
|
||||
|
||||
@@ -109,18 +99,13 @@ class StreamableHTTPConnectionParams(BaseModel):
|
||||
Streamable HTTP server.
|
||||
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
|
||||
when the connection is closed.
|
||||
httpx_client_factory: Factory function to create a custom HTTPX client. If
|
||||
not provided, a default factory will be used.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5.0
|
||||
sse_read_timeout: float = 60 * 5.0
|
||||
terminate_on_close: bool = True
|
||||
httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client
|
||||
|
||||
|
||||
def retry_on_closed_resource(func):
|
||||
@@ -301,7 +286,6 @@ class MCPSessionManager:
|
||||
seconds=self._connection_params.sse_read_timeout
|
||||
),
|
||||
terminate_on_close=self._connection_params.terminate_on_close,
|
||||
httpx_client_factory=self._connection_params.httpx_client_factory,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -146,59 +146,6 @@ class TestMCPSessionManager:
|
||||
|
||||
assert manager._connection_params == http_params
|
||||
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client")
|
||||
def test_init_with_streamable_http_custom_httpx_factory(
|
||||
self, mock_streamablehttp_client
|
||||
):
|
||||
"""Test that streamablehttp_client is called with custom httpx_client_factory."""
|
||||
from datetime import timedelta
|
||||
|
||||
custom_httpx_factory = Mock()
|
||||
|
||||
http_params = StreamableHTTPConnectionParams(
|
||||
url="https://example.com/mcp",
|
||||
timeout=15.0,
|
||||
httpx_client_factory=custom_httpx_factory,
|
||||
)
|
||||
manager = MCPSessionManager(http_params)
|
||||
|
||||
manager._create_client()
|
||||
|
||||
mock_streamablehttp_client.assert_called_once_with(
|
||||
url="https://example.com/mcp",
|
||||
headers=None,
|
||||
timeout=timedelta(seconds=15.0),
|
||||
sse_read_timeout=timedelta(seconds=300.0),
|
||||
terminate_on_close=True,
|
||||
httpx_client_factory=custom_httpx_factory,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client")
|
||||
async def test_init_with_streamable_http_default_httpx_factory(
|
||||
self, mock_streamablehttp_client
|
||||
):
|
||||
"""Test that streamablehttp_client is called with custom httpx_client_factory."""
|
||||
from datetime import timedelta
|
||||
|
||||
from mcp.client.streamable_http import create_mcp_http_client
|
||||
|
||||
http_params = StreamableHTTPConnectionParams(
|
||||
url="https://example.com/mcp", timeout=15.0
|
||||
)
|
||||
manager = MCPSessionManager(http_params)
|
||||
|
||||
manager._create_client()
|
||||
|
||||
mock_streamablehttp_client.assert_called_once_with(
|
||||
url="https://example.com/mcp",
|
||||
headers=None,
|
||||
timeout=timedelta(seconds=15.0),
|
||||
sse_read_timeout=timedelta(seconds=300.0),
|
||||
terminate_on_close=True,
|
||||
httpx_client_factory=create_mcp_http_client,
|
||||
)
|
||||
|
||||
def test_generate_session_key_stdio(self):
|
||||
"""Test session key generation for stdio connections."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
|
||||
Reference in New Issue
Block a user