Bug 525181, part 2: Implement protocol state machines in C++. r=bent

This commit is contained in:
Chris Jones 2010-07-15 14:27:43 -05:00
parent bb1ac58a3c
commit dce4511b26
4 changed files with 152 additions and 7 deletions

View File

@ -72,6 +72,24 @@ struct ActorHandle
int mId;
};
// Used internally to represent a "trigger" that might cause a state
// transition. Triggers are normalized across parent+child to Send
// and Recv (instead of child-in, child-out, parent-in, parent-out) so
// that they can share the same state machine implementation. To
// further normalize, |Send| is used for 'call', |Recv| for 'answer'.
struct Trigger
{
enum Action { Send, Recv };
Trigger(Action action, int32 msg) :
mAction(action),
mMsg(msg)
{}
Action mAction;
int32 mMsg;
};
template<class ListenerT>
class /*NS_INTERFACE_CLASS*/ IProtocolManager
{

View File

@ -753,6 +753,7 @@ class StmtSwitch(Block):
assert not isinstance(case, str)
assert (isinstance(block, StmtBreak)
or isinstance(block, StmtReturn)
or isinstance(block, StmtSwitch)
or (hasattr(block, 'stmts')
and (isinstance(block.stmts[-1], StmtBreak)
or isinstance(block.stmts[-1], StmtReturn))))
@ -760,6 +761,10 @@ class StmtSwitch(Block):
self.addstmt(block)
self.nr_cases += 1
def addfallthrough(self, case):
self.addstmt(case)
self.nr_cases += 1
class StmtBreak(Node):
def __init__(self):
Node.__init__(self)

View File

@ -154,6 +154,9 @@ def _startState(proto=None, fq=False):
else: pfx = proto.name() +'::'
return ExprVar(pfx +'__Start')
def _deleteId():
return ExprVar('Msg___delete____ID')
def _lookupListener(idexpr):
return ExprCall(ExprVar('Lookup'), args=[ idexpr ])
@ -1396,6 +1399,8 @@ child actors.'''
msgenum.addId(self.protocol.name +'End')
ns.addstmts([ StmtDecl(Decl(msgenum, '')), Whitespace.NL ])
ns.addstmts([ self.genTransitionFunc(), Whitespace.NL ])
typedefs = self.protocol.decl.cxxtypedefs
for md in p.messageDecls:
ns.addstmts([
@ -1411,6 +1416,122 @@ child actors.'''
ns.addstmts([ Whitespace.NL, Whitespace.NL ])
def genTransitionFunc(self):
ptype = self.protocol.decl.type
usesend, sendvar = set(), ExprVar('__Send')
userecv, recvvar = set(), ExprVar('__Recv')
def sameTrigger(trigger, actionexpr):
if trigger is ipdl.ast.SEND or trigger is ipdl.ast.CALL:
usesend.add('yes')
return ExprBinary(sendvar, '==', actionexpr)
else:
userecv.add('yes')
return ExprBinary(recvvar, '==',
actionexpr)
def stateEnum(s):
if s is ipdl.ast.State.DEAD:
return _deadState()
else:
return ExprVar(s.decl.cxxname)
# bool Transition(State from, Trigger trigger, State* next)
fromvar = ExprVar('from')
triggervar = ExprVar('trigger')
nextvar = ExprVar('next')
msgexpr = ExprSelect(triggervar, '.', 'mMsg')
actionexpr = ExprSelect(triggervar, '.', 'mAction')
transitionfunc = MethodDefn(MethodDecl(
'Transition',
params=[ Decl(Type('State'), fromvar.name),
Decl(Type('mozilla::ipc::Trigger'), triggervar.name),
Decl(Type('State', ptr=1), nextvar.name) ],
ret=Type.BOOL,
inline=1))
fromswitch = StmtSwitch(fromvar)
for ts in self.protocol.transitionStmts:
msgswitch = StmtSwitch(msgexpr)
msgToTransitions = { }
for t in ts.transitions:
msgid = t.msg._md.msgId()
ifsametrigger = StmtIf(sameTrigger(t.trigger, actionexpr))
# FIXME multi-out states
for nextstate in t.toStates: break
ifsametrigger.addifstmts([
StmtExpr(ExprAssn(ExprDeref(nextvar),
stateEnum(nextstate))),
StmtReturn(ExprLiteral.TRUE)
])
transitions = msgToTransitions.get(msgid, [ ])
transitions.append(ifsametrigger)
msgToTransitions[msgid] = transitions
for msgid, transitions in msgToTransitions.iteritems():
block = Block()
block.addstmts(transitions +[ StmtBreak() ])
msgswitch.addcase(CaseLabel(msgid), block)
msgblock = Block()
msgblock.addstmts([
msgswitch,
StmtBreak()
])
fromswitch.addcase(CaseLabel(ts.state.decl.cxxname), msgblock)
# special cases for Null and Error
nullerrorblock = Block()
if ptype.hasDelete:
ifdelete = StmtIf(ExprBinary(_deleteId(), '==', msgexpr))
ifdelete.addifstmts([
StmtExpr(ExprAssn(ExprDeref(nextvar), _deadState())),
StmtReturn(ExprLiteral.TRUE) ])
nullerrorblock.addstmt(ifdelete)
nullerrorblock.addstmt(
StmtReturn(ExprBinary(_nullState(), '==', fromvar)))
fromswitch.addfallthrough(CaseLabel(_nullState().name))
fromswitch.addcase(CaseLabel(_errorState().name), nullerrorblock)
# special case for Dead
deadblock = Block()
deadblock.addstmts([
_runtimeAbort('__delete__()d actor'),
StmtReturn(ExprLiteral.FALSE) ])
fromswitch.addcase(CaseLabel(_deadState().name), deadblock)
unreachedblock = Block()
unreachedblock.addstmts([
_runtimeAbort('corrupted actor state'),
StmtReturn(ExprLiteral.FALSE) ])
fromswitch.addcase(DefaultLabel(), unreachedblock)
if usesend:
transitionfunc.addstmt(
StmtDecl(Decl(Type('int32', const=1), sendvar.name),
init=ExprVar('mozilla::ipc::Trigger::Send')))
if userecv:
transitionfunc.addstmt(
StmtDecl(Decl(Type('int32', const=1), recvvar.name),
init=ExprVar('mozilla::ipc::Trigger::Recv')))
if usesend or userecv:
transitionfunc.addstmt(Whitespace.NL)
transitionfunc.addstmts([
fromswitch,
# all --> Error transitions break to here
StmtExpr(ExprAssn(ExprDeref(nextvar), _errorState())),
StmtReturn(ExprLiteral.FALSE)
])
return transitionfunc
##--------------------------------------------------
def _generateMessageClass(clsname, msgid, typedefs, prettyName):

View File

@ -298,6 +298,7 @@ class ProtocolType(IPDLType):
self.managers = set() # ProtocolType
self.manages = [ ]
self.stateless = stateless
self.hasDelete = False
def isProtocol(self): return True
def name(self):
@ -801,13 +802,12 @@ class GatherDecls(TcheckVisitor):
msg.accept(self)
del self.currentProtocolDecl
if not p.decl.type.isToplevel():
dtordecl = self.symtab.lookup(_DELETE_MSG)
if not dtordecl:
self.error(
p.loc,
"destructor declaration `%s(...)' required for managed protocol `%s'",
_DELETE_MSG, p.name)
p.decl.type.hasDelete = (not not self.symtab.lookup(_DELETE_MSG))
if not (p.decl.type.hasDelete or p.decl.type.isToplevel()):
self.error(
p.loc,
"destructor declaration `%s(...)' required for managed protocol `%s'",
_DELETE_MSG, p.name)
for managed in p.managesStmts:
mgdname = managed.name
@ -1038,6 +1038,7 @@ class GatherDecls(TcheckVisitor):
type=msgtype,
progname=msgname)
md.protocolDecl = self.currentProtocolDecl
md.decl._md = md
def visitTransitionStmt(self, ts):