ShMem server race-condition fix for #276 (#278)

* This fixes a potential race condition when the parent dies before the child connects after a fork (#276)

* fix docs

* trying to fix shmem server forking

* removed bug where decreasing map count to 0 would not be reallocatable

* ignored clippy warning, refactoring
This commit is contained in:
Dominik Maier 2021-09-07 00:03:37 +02:00 committed by GitHub
parent e7ed5be9a2
commit b71704b14d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 29 deletions

View File

@ -105,6 +105,9 @@ where
/// Send a request to the server, and wait for a response
#[allow(clippy::similar_names)] // id and fd
fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> {
//let bt = Backtrace::new();
//println!("Sending {:?} with bt:\n{:?}", request, bt);
let body = postcard::to_allocvec(&request)?;
let header = (body.len() as u32).to_be_bytes();
@ -165,7 +168,7 @@ where
id: -1,
service,
};
let (id, _) = res.send_receive(ServedShMemRequest::Hello(None))?;
let (id, _) = res.send_receive(ServedShMemRequest::Hello())?;
res.id = id;
Ok(res)
}
@ -196,16 +199,24 @@ where
})
}
fn pre_fork(&mut self) -> Result<(), Error> {
self.send_receive(ServedShMemRequest::PreFork())?;
Ok(())
}
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
if is_child {
// After fork, only the parent keeps the join handle.
if let ShMemService::Started { bg_thread, .. } = &mut self.service {
bg_thread.borrow_mut().lock().unwrap().join_handle = None;
}
//fn connect(&mut self) -> Result<Self, Error> {
//self.stream = UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?,
// After fork, the child needs to reconnect as to not share the fds with the parent.
self.stream =
UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?;
let (id, _) = self.send_receive(ServedShMemRequest::Hello(Some(self.id)))?;
let (id, _) = self.send_receive(ServedShMemRequest::PostForkChildHello(self.id))?;
self.id = id;
}
Ok(())
@ -234,7 +245,11 @@ pub enum ServedShMemRequest {
Deregister(i32),
/// A message that tells us hello, and optionally which other client we were created from, we
/// return a client id.
Hello(Option<i32>),
Hello(),
/// A client tells us that it's about to fork. Already clone all of the maps now so that they will be available by the time the child sends a [`ServedShMemRequest::PostForkChildHello`] request.
PreFork(),
/// The client's child re-registers with us after it forked.
PostForkChildHello(i32),
/// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it?
Exit,
}
@ -417,12 +432,15 @@ where
}
/// The struct for the worker, handling incoming requests for [`ShMem`].
#[allow(clippy::type_complexity)]
struct ServedShMemServiceWorker<SP>
where
SP: ShMemProvider,
{
provider: SP,
clients: HashMap<RawFd, SharedShMemClient<SP::Mem>>,
/// Maps from a pre-fork (parent) client id to its cloned maps.
forking_clients: HashMap<RawFd, HashMap<i32, Vec<Rc<RefCell<SP::Mem>>>>>,
all_maps: HashMap<i32, Weak<RefCell<SP::Mem>>>,
}
@ -436,28 +454,53 @@ where
provider: SP::new()?,
clients: HashMap::new(),
all_maps: HashMap::new(),
forking_clients: HashMap::new(),
})
}
fn upgrade_map_with_id(&mut self, description_id: i32) -> Rc<RefCell<SP::Mem>> {
self.all_maps
.get_mut(&description_id)
.unwrap()
.clone()
.upgrade()
.unwrap()
}
/// Read and handle the client request, send the answer over unix fd.
fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> {
let request = self.read_request(client_id)?;
//println!("got ashmem client: {}, request:{:?}", client_id, request);
// println!("got ashmem client: {}, request:{:?}", client_id, request);
// Handle the client request
let response = match request {
ServedShMemRequest::Hello(other_id) => {
if let Some(other_id) = other_id {
if other_id != client_id {
ServedShMemRequest::Hello() => Ok(ServedShMemResponse::Id(client_id)),
ServedShMemRequest::PreFork() => {
// We clone the provider already, waiting for it to reconnect [`PostFork`].
// That wa, even if the parent dies before the child sends its `PostFork`, we should be good.
// See issue https://github.com/AFLplusplus/LibAFL/issues/276
//let forking_client = self.clients[&client_id].maps.clone();
self.forking_clients
.insert(client_id, self.clients[&client_id].maps.clone());
// Technically, no need to send the client_id here but it keeps the code easier.
/*
// remove temporarily
let other_client = self.clients.remove(&other_id);
let client = self.clients.remove(&client_id);
let mut forking_maps = HashMap::new();
for (id, map) in client.as_ref().unwrap().maps.iter() {
forking_maps.insert(*id, map.clone());
}
self.forking_clients.insert(client_id, forking_maps);
self.clients.insert(client_id, client.unwrap());
*/
Ok(ServedShMemResponse::Id(client_id))
}
ServedShMemRequest::PostForkChildHello(other_id) => {
let client = self.clients.get_mut(&client_id).unwrap();
for (id, map) in other_client.as_ref().unwrap().maps.iter() {
client.maps.insert(*id, map.clone());
}
self.clients.insert(other_id, other_client.unwrap());
}
};
client.maps = self.forking_clients.remove(&other_id).unwrap();
Ok(ServedShMemResponse::Id(client_id))
}
ServedShMemRequest::NewMap(map_size) => {
@ -472,25 +515,25 @@ where
let client = self.clients.get_mut(&client_id).unwrap();
let description_id: i32 = description.id.into();
if client.maps.contains_key(&description_id) {
// Using let else here as self needs to be accessed in the else branch.
#[allow(clippy::option_if_let_else)]
Ok(ServedShMemResponse::Mapping(
client
if let Some(map) = client
.maps
.get_mut(&description_id)
.as_mut()
.unwrap()
.first()
.as_mut()
.unwrap()
.clone(),
{
map.clone()
} else {
self.upgrade_map_with_id(description_id)
},
))
} else {
Ok(ServedShMemResponse::Mapping(
self.all_maps
.get_mut(&description_id)
.unwrap()
.clone()
.upgrade()
.unwrap(),
self.upgrade_map_with_id(description_id),
))
}
}
@ -511,7 +554,7 @@ where
return Err(Error::ShuttingDown);
}
};
//println!("send ashmem client: {}, response: {:?}", client_id, &response);
// println!("send ashmem client: {}, response: {:?}", client_id, &response);
response
}
@ -623,7 +666,7 @@ where
}
};
// println!("Recieved connection from {:?}", addr);
println!("Recieved connection from {:?}", _addr);
let pollfd = PollFd::new(
stream.as_raw_fd(),
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,

View File

@ -367,8 +367,8 @@ where
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
if is_child {
self.await_parent_done()?;
let child_shmem = self.internal.borrow_mut().clone();
self.internal = Rc::new(RefCell::new(child_shmem));
//let child_shmem = self.internal.borrow_mut().clone();
//self.internal = Rc::new(RefCell::new(child_shmem));
}
self.internal.borrow_mut().post_fork(is_child)?;
if is_child {