Bug 1205164 - Detect message races in Mach Shmem implementation. r=blassey

This commit is contained in:
Gian-Carlo Pascutto 2015-09-25 12:30:46 +02:00
parent ae2483ab67
commit e259122ad0

View File

@ -127,6 +127,11 @@ struct ListeningThread {
: mThread(thread), mPorts(ports) {}
};
struct SharePortsReply {
uint64_t serial;
mach_port_t port;
};
std::map<pid_t, ListeningThread> gThreads;
static void *
@ -277,8 +282,13 @@ void
HandleSharePortsMessage(MachReceiveMessage* rmsg, MemoryPorts* ports)
{
mach_port_t port = rmsg->GetTranslatedPort(0);
uint64_t* serial = reinterpret_cast<uint64_t*>(rmsg->GetData());
MachSendMessage msg(kReturnIdMsg);
msg.SetData(&port, sizeof(port));
// Construct the reply message, echoing the serial, and adding the port
SharePortsReply replydata;
replydata.port = port;
replydata.serial = *serial;
msg.SetData(&replydata, sizeof(SharePortsReply));
kern_return_t err = ports->mSender->SendMessage(msg, kTimeout);
if (KERN_SUCCESS != err) {
LOG_ERROR("SendMessage failed 0x%x %s\n", err, mach_error_string(err));
@ -584,6 +594,12 @@ SharedMemoryBasic::ShareToProcess(base::ProcessId pid,
}
StaticMutexAutoLock smal(gMutex);
// Serially number the messages, to check whether
// the reply we get was meant for us.
static uint64_t serial = 0;
uint64_t my_serial = serial;
serial++;
MemoryPorts* ports = GetMemoryPortsForPid(pid);
if (!ports) {
LOG_ERROR("Unable to get ports for process.\n");
@ -591,6 +607,7 @@ SharedMemoryBasic::ShareToProcess(base::ProcessId pid,
}
MachSendMessage smsg(kSharePortsMsg);
smsg.AddDescriptor(MachMsgPortDescriptor(mPort, MACH_MSG_TYPE_COPY_SEND));
smsg.SetData(&my_serial, sizeof(uint64_t));
kern_return_t err = ports->mSender->SendMessage(smsg, kTimeout);
if (err != KERN_SUCCESS) {
LOG_ERROR("sending port failed %s %x\n", mach_error_string(err), err);
@ -602,12 +619,18 @@ SharedMemoryBasic::ShareToProcess(base::ProcessId pid,
LOG_ERROR("didn't get an id %s %x\n", mach_error_string(err), err);
return false;
}
if (msg.GetDataLength() != sizeof(mach_port_t)) {
if (msg.GetDataLength() != sizeof(SharePortsReply)) {
LOG_ERROR("Improperly formatted reply\n");
return false;
}
mach_port_t *id = reinterpret_cast<mach_port_t*>(msg.GetData());
*aNewHandle = *id;
SharePortsReply* msg_data = reinterpret_cast<SharePortsReply*>(msg.GetData());
mach_port_t id = msg_data->port;
uint64_t serial_check = msg_data->serial;
if (serial_check != my_serial) {
LOG_ERROR("Serials do not match up: %d vs %d", serial_check, my_serial);
return false;
}
*aNewHandle = id;
return true;
}