You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add custom instructions support to LlmBackedUserSimulator
Details: - Allows users to provide custom instructions for the LLM-backed user simulator via the `custom_instructions` field in `LlmBackedUserSimulatorConfig`. - The custom instructions must include placeholders for the stop signal, conversation plan, and conversation history. A pydantic validator ensures these placeholders are present. - If no custom instructions are provided, the current default template is used. Co-authored-by: Keyur Joshi <keyurj@google.com> PiperOrigin-RevId: 850471448
This commit is contained in:
committed by
Copybara-Service
parent
0918b647df
commit
a364388d97
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from ...events.event import Event
|
||||
@@ -40,7 +41,7 @@ logger = logging.getLogger("google_adk." + __name__)
|
||||
_AUTHOR_USER = "user"
|
||||
_STOP_SIGNAL = "</finished>"
|
||||
|
||||
_USER_AGENT_INSTRUCTIONS_TEMPLATE = """You are a Simulated User designed to test an AI Agent.
|
||||
_DEFAULT_USER_AGENT_INSTRUCTIONS = """You are a Simulated User designed to test an AI Agent.
|
||||
|
||||
Your single most important job is to react logically to the Agent's last message.
|
||||
The Conversation Plan is your canonical grounding, not a script; your response MUST be dictated by what the Agent just said.
|
||||
@@ -126,6 +127,38 @@ prompt is also counted as an invocation.
|
||||
(Not recommended) If you don't want a limit, you can set the value to -1.""",
|
||||
)
|
||||
|
||||
custom_instructions: Optional[str] = Field(
|
||||
default=None,
|
||||
description="""Custom instructions for the LlmBackedUserSimulator. The
|
||||
instructions must contain the following formatting placeholders:
|
||||
* {stop_signal} : text to be generated when the user simulator decides that the
|
||||
conversation is over.
|
||||
* {conversation_plan} : the overall plan for the conversation that the user
|
||||
simulator must follow.
|
||||
* {conversation_history} : the conversation between the user and the agent so
|
||||
far.""",
|
||||
)
|
||||
|
||||
@field_validator("custom_instructions")
|
||||
@classmethod
|
||||
def validate_custom_instructions(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
if not all(
|
||||
placeholder in value
|
||||
for placeholder in [
|
||||
"{stop_signal}",
|
||||
"{conversation_plan}",
|
||||
"{conversation_history}",
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
"custom_instructions must contain each of the following formatting"
|
||||
" placeholders:"
|
||||
" {stop_signal}, {conversation_plan}, {conversation_history}"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
@experimental
|
||||
class LlmBackedUserSimulator(UserSimulator):
|
||||
@@ -147,6 +180,11 @@ class LlmBackedUserSimulator(UserSimulator):
|
||||
llm_registry = LLMRegistry()
|
||||
llm_class = llm_registry.resolve(self._config.model)
|
||||
self._llm = llm_class(model=self._config.model)
|
||||
self._instructions = (
|
||||
self._config.custom_instructions
|
||||
if self._config.custom_instructions
|
||||
else _DEFAULT_USER_AGENT_INSTRUCTIONS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _summarize_conversation(
|
||||
@@ -183,7 +221,7 @@ class LlmBackedUserSimulator(UserSimulator):
|
||||
# first invocation - send the static starting prompt
|
||||
return self._conversation_scenario.starting_prompt
|
||||
|
||||
user_agent_instructions = _USER_AGENT_INSTRUCTIONS_TEMPLATE.format(
|
||||
user_agent_instructions = self._instructions.format(
|
||||
stop_signal=_STOP_SIGNAL,
|
||||
conversation_plan=self._conversation_scenario.conversation_plan,
|
||||
conversation_history=rewritten_dialogue,
|
||||
|
||||
@@ -20,6 +20,7 @@ from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBacked
|
||||
from google.adk.evaluation.simulation.user_simulator import Status
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
|
||||
_INPUT_EVENTS = [
|
||||
@@ -88,6 +89,20 @@ user: I need to book a flight.
|
||||
helpful_assistant: Sure, what is your departure date and destination?"""
|
||||
|
||||
|
||||
def test_llm_backed_user_simulator_config_validation():
|
||||
"""Tests for LlmBackedUserSimulatorConfig."""
|
||||
config = LlmBackedUserSimulatorConfig(custom_instructions=None)
|
||||
assert config.custom_instructions is None
|
||||
valid_instructions = (
|
||||
"{stop_signal} {conversation_plan} {conversation_history}"
|
||||
)
|
||||
config = LlmBackedUserSimulatorConfig(custom_instructions=valid_instructions)
|
||||
assert config.custom_instructions == valid_instructions
|
||||
invalid_instructions = "Instructions with missing formatting placeholders"
|
||||
with pytest.raises(ValidationError):
|
||||
LlmBackedUserSimulatorConfig(custom_instructions=invalid_instructions)
|
||||
|
||||
|
||||
class TestHelperMethods:
|
||||
"""Test cases for LlmBackedUserSimulator helper methods."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user