* This fixes a potential race condition when the parent dies before the child connects after a fork (#276) * fix docs * trying to fix shmem server forking * removed bug where decreasing map count to 0 would not be reallocatable * ignored clippy warning, refactoring
This commit is contained in:
parent
e7ed5be9a2
commit
b71704b14d
@ -105,6 +105,9 @@ where
|
|||||||
/// Send a request to the server, and wait for a response
|
/// Send a request to the server, and wait for a response
|
||||||
#[allow(clippy::similar_names)] // id and fd
|
#[allow(clippy::similar_names)] // id and fd
|
||||||
fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> {
|
fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> {
|
||||||
|
//let bt = Backtrace::new();
|
||||||
|
//println!("Sending {:?} with bt:\n{:?}", request, bt);
|
||||||
|
|
||||||
let body = postcard::to_allocvec(&request)?;
|
let body = postcard::to_allocvec(&request)?;
|
||||||
|
|
||||||
let header = (body.len() as u32).to_be_bytes();
|
let header = (body.len() as u32).to_be_bytes();
|
||||||
@ -165,7 +168,7 @@ where
|
|||||||
id: -1,
|
id: -1,
|
||||||
service,
|
service,
|
||||||
};
|
};
|
||||||
let (id, _) = res.send_receive(ServedShMemRequest::Hello(None))?;
|
let (id, _) = res.send_receive(ServedShMemRequest::Hello())?;
|
||||||
res.id = id;
|
res.id = id;
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
@ -196,16 +199,24 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn pre_fork(&mut self) -> Result<(), Error> {
|
||||||
|
self.send_receive(ServedShMemRequest::PreFork())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
|
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
|
||||||
if is_child {
|
if is_child {
|
||||||
// After fork, only the parent keeps the join handle.
|
// After fork, only the parent keeps the join handle.
|
||||||
if let ShMemService::Started { bg_thread, .. } = &mut self.service {
|
if let ShMemService::Started { bg_thread, .. } = &mut self.service {
|
||||||
bg_thread.borrow_mut().lock().unwrap().join_handle = None;
|
bg_thread.borrow_mut().lock().unwrap().join_handle = None;
|
||||||
}
|
}
|
||||||
|
//fn connect(&mut self) -> Result<Self, Error> {
|
||||||
|
//self.stream = UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?,
|
||||||
|
|
||||||
// After fork, the child needs to reconnect as to not share the fds with the parent.
|
// After fork, the child needs to reconnect as to not share the fds with the parent.
|
||||||
self.stream =
|
self.stream =
|
||||||
UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?;
|
UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?;
|
||||||
let (id, _) = self.send_receive(ServedShMemRequest::Hello(Some(self.id)))?;
|
let (id, _) = self.send_receive(ServedShMemRequest::PostForkChildHello(self.id))?;
|
||||||
self.id = id;
|
self.id = id;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -234,7 +245,11 @@ pub enum ServedShMemRequest {
|
|||||||
Deregister(i32),
|
Deregister(i32),
|
||||||
/// A message that tells us hello, and optionally which other client we were created from, we
|
/// A message that tells us hello, and optionally which other client we were created from, we
|
||||||
/// return a client id.
|
/// return a client id.
|
||||||
Hello(Option<i32>),
|
Hello(),
|
||||||
|
/// A client tells us that it's about to fork. Already clone all of the maps now so that they will be available by the time the child sends a [`ServedShMemRequest::PostForkChildHello`] request.
|
||||||
|
PreFork(),
|
||||||
|
/// The client's child re-registers with us after it forked.
|
||||||
|
PostForkChildHello(i32),
|
||||||
/// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it?
|
/// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it?
|
||||||
Exit,
|
Exit,
|
||||||
}
|
}
|
||||||
@ -417,12 +432,15 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The struct for the worker, handling incoming requests for [`ShMem`].
|
/// The struct for the worker, handling incoming requests for [`ShMem`].
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
struct ServedShMemServiceWorker<SP>
|
struct ServedShMemServiceWorker<SP>
|
||||||
where
|
where
|
||||||
SP: ShMemProvider,
|
SP: ShMemProvider,
|
||||||
{
|
{
|
||||||
provider: SP,
|
provider: SP,
|
||||||
clients: HashMap<RawFd, SharedShMemClient<SP::Mem>>,
|
clients: HashMap<RawFd, SharedShMemClient<SP::Mem>>,
|
||||||
|
/// Maps from a pre-fork (parent) client id to its cloned maps.
|
||||||
|
forking_clients: HashMap<RawFd, HashMap<i32, Vec<Rc<RefCell<SP::Mem>>>>>,
|
||||||
all_maps: HashMap<i32, Weak<RefCell<SP::Mem>>>,
|
all_maps: HashMap<i32, Weak<RefCell<SP::Mem>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -436,28 +454,53 @@ where
|
|||||||
provider: SP::new()?,
|
provider: SP::new()?,
|
||||||
clients: HashMap::new(),
|
clients: HashMap::new(),
|
||||||
all_maps: HashMap::new(),
|
all_maps: HashMap::new(),
|
||||||
|
forking_clients: HashMap::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn upgrade_map_with_id(&mut self, description_id: i32) -> Rc<RefCell<SP::Mem>> {
|
||||||
|
self.all_maps
|
||||||
|
.get_mut(&description_id)
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.upgrade()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Read and handle the client request, send the answer over unix fd.
|
/// Read and handle the client request, send the answer over unix fd.
|
||||||
fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> {
|
fn handle_request(&mut self, client_id: RawFd) -> Result<ServedShMemResponse<SP>, Error> {
|
||||||
let request = self.read_request(client_id)?;
|
let request = self.read_request(client_id)?;
|
||||||
|
|
||||||
//println!("got ashmem client: {}, request:{:?}", client_id, request);
|
// println!("got ashmem client: {}, request:{:?}", client_id, request);
|
||||||
|
|
||||||
// Handle the client request
|
// Handle the client request
|
||||||
let response = match request {
|
let response = match request {
|
||||||
ServedShMemRequest::Hello(other_id) => {
|
ServedShMemRequest::Hello() => Ok(ServedShMemResponse::Id(client_id)),
|
||||||
if let Some(other_id) = other_id {
|
ServedShMemRequest::PreFork() => {
|
||||||
if other_id != client_id {
|
// We clone the provider already, waiting for it to reconnect [`PostFork`].
|
||||||
|
// That wa, even if the parent dies before the child sends its `PostFork`, we should be good.
|
||||||
|
// See issue https://github.com/AFLplusplus/LibAFL/issues/276
|
||||||
|
//let forking_client = self.clients[&client_id].maps.clone();
|
||||||
|
self.forking_clients
|
||||||
|
.insert(client_id, self.clients[&client_id].maps.clone());
|
||||||
|
// Technically, no need to send the client_id here but it keeps the code easier.
|
||||||
|
|
||||||
|
/*
|
||||||
// remove temporarily
|
// remove temporarily
|
||||||
let other_client = self.clients.remove(&other_id);
|
let client = self.clients.remove(&client_id);
|
||||||
|
let mut forking_maps = HashMap::new();
|
||||||
|
for (id, map) in client.as_ref().unwrap().maps.iter() {
|
||||||
|
forking_maps.insert(*id, map.clone());
|
||||||
|
}
|
||||||
|
self.forking_clients.insert(client_id, forking_maps);
|
||||||
|
self.clients.insert(client_id, client.unwrap());
|
||||||
|
*/
|
||||||
|
|
||||||
|
Ok(ServedShMemResponse::Id(client_id))
|
||||||
|
}
|
||||||
|
ServedShMemRequest::PostForkChildHello(other_id) => {
|
||||||
let client = self.clients.get_mut(&client_id).unwrap();
|
let client = self.clients.get_mut(&client_id).unwrap();
|
||||||
for (id, map) in other_client.as_ref().unwrap().maps.iter() {
|
client.maps = self.forking_clients.remove(&other_id).unwrap();
|
||||||
client.maps.insert(*id, map.clone());
|
|
||||||
}
|
|
||||||
self.clients.insert(other_id, other_client.unwrap());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(ServedShMemResponse::Id(client_id))
|
Ok(ServedShMemResponse::Id(client_id))
|
||||||
}
|
}
|
||||||
ServedShMemRequest::NewMap(map_size) => {
|
ServedShMemRequest::NewMap(map_size) => {
|
||||||
@ -472,25 +515,25 @@ where
|
|||||||
let client = self.clients.get_mut(&client_id).unwrap();
|
let client = self.clients.get_mut(&client_id).unwrap();
|
||||||
let description_id: i32 = description.id.into();
|
let description_id: i32 = description.id.into();
|
||||||
if client.maps.contains_key(&description_id) {
|
if client.maps.contains_key(&description_id) {
|
||||||
|
// Using let else here as self needs to be accessed in the else branch.
|
||||||
|
#[allow(clippy::option_if_let_else)]
|
||||||
Ok(ServedShMemResponse::Mapping(
|
Ok(ServedShMemResponse::Mapping(
|
||||||
client
|
if let Some(map) = client
|
||||||
.maps
|
.maps
|
||||||
.get_mut(&description_id)
|
.get_mut(&description_id)
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.first()
|
.first()
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.unwrap()
|
{
|
||||||
.clone(),
|
map.clone()
|
||||||
|
} else {
|
||||||
|
self.upgrade_map_with_id(description_id)
|
||||||
|
},
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok(ServedShMemResponse::Mapping(
|
Ok(ServedShMemResponse::Mapping(
|
||||||
self.all_maps
|
self.upgrade_map_with_id(description_id),
|
||||||
.get_mut(&description_id)
|
|
||||||
.unwrap()
|
|
||||||
.clone()
|
|
||||||
.upgrade()
|
|
||||||
.unwrap(),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -511,7 +554,7 @@ where
|
|||||||
return Err(Error::ShuttingDown);
|
return Err(Error::ShuttingDown);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
//println!("send ashmem client: {}, response: {:?}", client_id, &response);
|
// println!("send ashmem client: {}, response: {:?}", client_id, &response);
|
||||||
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
@ -623,7 +666,7 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// println!("Recieved connection from {:?}", addr);
|
println!("Recieved connection from {:?}", _addr);
|
||||||
let pollfd = PollFd::new(
|
let pollfd = PollFd::new(
|
||||||
stream.as_raw_fd(),
|
stream.as_raw_fd(),
|
||||||
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,
|
PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND,
|
||||||
|
@ -367,8 +367,8 @@ where
|
|||||||
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
|
fn post_fork(&mut self, is_child: bool) -> Result<(), Error> {
|
||||||
if is_child {
|
if is_child {
|
||||||
self.await_parent_done()?;
|
self.await_parent_done()?;
|
||||||
let child_shmem = self.internal.borrow_mut().clone();
|
//let child_shmem = self.internal.borrow_mut().clone();
|
||||||
self.internal = Rc::new(RefCell::new(child_shmem));
|
//self.internal = Rc::new(RefCell::new(child_shmem));
|
||||||
}
|
}
|
||||||
self.internal.borrow_mut().post_fork(is_child)?;
|
self.internal.borrow_mut().post_fork(is_child)?;
|
||||||
if is_child {
|
if is_child {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user