ShMem server race-condition fix for #276 (#278)

* This fixes a potential race condition when the parent dies before the child connects after a fork (#276)

* fix docs

* trying to fix shmem server forking

* removed bug where decreasing map count to 0 would not be reallocatable

* ignored clippy warning, refactoring
This commit is contained in:
Dominik Maier 2021-09-07 00:03:37 +02:00 committed by GitHub
parent e7ed5be9a2
commit b71704b14d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 29 deletions

View File

@ -105,6 +105,9 @@ where
/// Send a request to the server, and wait for a response /// Send a request to the server, and wait for a response
#[allow(clippy::similar_names)] // id and fd #[allow(clippy::similar_names)] // id and fd
fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> { fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> {
//let bt = Backtrace::new();
//println!("Sending {:?} with bt:\n{:?}", request, bt);
let body = postcard::to_allocvec(&request)?; let body = postcard::to_allocvec(&request)?;
let header = (body.len() as u32).to_be_bytes(); let header = (body.len() as u32).to_be_bytes();
@ -165,7 +168,7 @@ where
id: -1, id: -1,
service, service,
}; };
let (id, _) = res.send_receive(ServedShMemRequest::Hello(None))?; let (id, _) = res.send_receive(ServedShMemRequest::Hello())?;
res.id = id; res.id = id;
Ok(res) Ok(res)
} }
@ -196,16 +199,24 @@ where
}) })
} }
fn pre_fork(&mut self) -> Result<(), Error> {
self.send_receive(ServedShMemRequest::PreFork())?;
Ok(())
}
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> { fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
if is_child { if is_child {
// After fork, only the parent keeps the join handle. // After fork, only the parent keeps the join handle.
if let ShMemService::Started { bg_thread, .. } = &mut self.service { if let ShMemService::Started { bg_thread, .. } = &mut self.service {
bg_thread.borrow_mut().lock().unwrap().join_handle = None; bg_thread.borrow_mut().lock().unwrap().join_handle = None;
} }
//fn connect(&mut self) -> Result<Self, Error> {
//self.stream = UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?,
// After fork, the child needs to reconnect as to not share the fds with the parent. // After fork, the child needs to reconnect as to not share the fds with the parent.
self.stream = self.stream =
UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?; UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?;
let (id, _) = self.send_receive(ServedShMemRequest::Hello(Some(self.id)))?; let (id, _) = self.send_receive(ServedShMemRequest::PostForkChildHello(self.id))?;
self.id = id; self.id = id;
} }
Ok(()) Ok(())
@ -234,7 +245,11 @@ pub enum ServedShMemRequest {
Deregister(i32), Deregister(i32),
/// A message that tells us hello, and optionally which other client we were created from, we /// A message that tells us hello, and optionally which other client we were created from, we
/// return a client id. /// return a client id.
Hello(Option<i32>), Hello(),
/// A client tells us that it's about to fork. Already clone all of the maps now so that they will be available by the time the child sends a [`ServedShMemRequest::PostForkChildHello`] request.
PreFork(),
/// The client's child re-registers with us after it forked.
PostForkChildHello(i32),
/// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it? /// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it?
Exit, Exit,
} }
@ -417,12 +432,15 @@ where
} }
/// The struct for the worker, handling incoming requests for [`ShMem`]. /// The struct for the worker, handling incoming requests for [`ShMem`].
#[allow(clippy::type_complexity)]
struct ServedShMemServiceWorker<SP> struct ServedShMemServiceWorker<SP>
where where
SP: ShMemProvider, SP: ShMemProvider,
{ {
provider: SP, provider: SP,
clients: HashMap<RawFd, SharedShMemClient<SP::Mem>>, clients: HashMap<RawFd, SharedShMemClient<SP::Mem>>,
/// Maps from a pre-fork (parent) client id to its cloned maps.
forking_clients: HashMap<RawFd, HashMap<i32, Vec<Rc<RefCell<SP::Mem>>>>>,
all_maps: HashMap<i32, Weak<RefCell<SP::Mem>>>, all_maps: HashMap<i32, Weak<RefCell<SP::Mem>>>,
} }
@ -436,28 +454,53 @@ where
provider: SP::new()?, provider: SP::new()?,
clients: HashMap::new(), clients: HashMap::new(),
all_maps: HashMap::new(), all_maps: HashMap::new(),
forking_clients: HashMap::new(),
}) })
} }
fn upgrade_map_with_id(&mut self, description_id: i32) -> Rc<RefCell<SP::Mem>> {
self.all_maps
.get_mut(&description_id)
.unwrap()
.clone()
.upgrade()
.unwrap()
}
/// Read and handle the client request, send the answer over unix fd. /// Read and handle the client request, send the answer over unix fd.
fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> { fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> {
let request = self.read_request(client_id)?; let request = self.read_request(client_id)?;
//println!("got ashmem client: {}, request:{:?}", client_id, request); // println!("got ashmem client: {}, request:{:?}", client_id, request);
// Handle the client request // Handle the client request
let response = match request { let response = match request {
ServedShMemRequest::Hello(other_id) => { ServedShMemRequest::Hello() => Ok(ServedShMemResponse::Id(client_id)),
if let Some(other_id) = other_id { ServedShMemRequest::PreFork() => {
if other_id != client_id { // We clone the provider already, waiting for it to reconnect [`PostFork`].
// remove temporarily // That wa, even if the parent dies before the child sends its `PostFork`, we should be good.
let other_client = self.clients.remove(&other_id); // See issue https://github.com/AFLplusplus/LibAFL/issues/276
let client = self.clients.get_mut(&client_id).unwrap(); //let forking_client = self.clients[&client_id].maps.clone();
for (id, map) in other_client.as_ref().unwrap().maps.iter() { self.forking_clients
client.maps.insert(*id, map.clone()); .insert(client_id, self.clients[&client_id].maps.clone());
} // Technically, no need to send the client_id here but it keeps the code easier.
self.clients.insert(other_id, other_client.unwrap());
} /*
}; // remove temporarily
let client = self.clients.remove(&client_id);
let mut forking_maps = HashMap::new();
for (id, map) in client.as_ref().unwrap().maps.iter() {
forking_maps.insert(*id, map.clone());
}
self.forking_clients.insert(client_id, forking_maps);
self.clients.insert(client_id, client.unwrap());
*/
Ok(ServedShMemResponse::Id(client_id))
}
ServedShMemRequest::PostForkChildHello(other_id) => {
let client = self.clients.get_mut(&client_id).unwrap();
client.maps = self.forking_clients.remove(&other_id).unwrap();
Ok(ServedShMemResponse::Id(client_id)) Ok(ServedShMemResponse::Id(client_id))
} }
ServedShMemRequest::NewMap(map_size) => { ServedShMemRequest::NewMap(map_size) => {
@ -472,25 +515,25 @@ where
let client = self.clients.get_mut(&client_id).unwrap(); let client = self.clients.get_mut(&client_id).unwrap();
let description_id: i32 = description.id.into(); let description_id: i32 = description.id.into();
if client.maps.contains_key(&description_id) { if client.maps.contains_key(&description_id) {
// Using let else here as self needs to be accessed in the else branch.
#[allow(clippy::option_if_let_else)]
Ok(ServedShMemResponse::Mapping( Ok(ServedShMemResponse::Mapping(
client if let Some(map) = client
.maps .maps
.get_mut(&description_id) .get_mut(&description_id)
.as_mut() .as_mut()
.unwrap() .unwrap()
.first() .first()
.as_mut() .as_mut()
.unwrap() {
.clone(), map.clone()
} else {
self.upgrade_map_with_id(description_id)
},
)) ))
} else { } else {
Ok(ServedShMemResponse::Mapping( Ok(ServedShMemResponse::Mapping(
self.all_maps self.upgrade_map_with_id(description_id),
.get_mut(&description_id)
.unwrap()
.clone()
.upgrade()
.unwrap(),
)) ))
} }
} }
@ -511,7 +554,7 @@ where
return Err(Error::ShuttingDown); return Err(Error::ShuttingDown);
} }
}; };
//println!("send ashmem client: {}, response: {:?}", client_id, &response); // println!("send ashmem client: {}, response: {:?}", client_id, &response);
response response
} }
@ -623,7 +666,7 @@ where
} }
}; };
// println!("Recieved connection from {:?}", addr); println!("Recieved connection from {:?}", _addr);
let pollfd = PollFd::new( let pollfd = PollFd::new(
stream.as_raw_fd(), stream.as_raw_fd(),
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND, PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,

View File

@ -367,8 +367,8 @@ where
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> { fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
if is_child { if is_child {
self.await_parent_done()?; self.await_parent_done()?;
let child_shmem = self.internal.borrow_mut().clone(); //let child_shmem = self.internal.borrow_mut().clone();
self.internal = Rc::new(RefCell::new(child_shmem)); //self.internal = Rc::new(RefCell::new(child_shmem));
} }
self.internal.borrow_mut().post_fork(is_child)?; self.internal.borrow_mut().post_fork(is_child)?;
if is_child { if is_child {