diff --git a/libafl/src/bolts/os/unix_shmem_server.rs b/libafl/src/bolts/os/unix_shmem_server.rs index 10504adffe..7b8b28d78b 100644 --- a/libafl/src/bolts/os/unix_shmem_server.rs +++ b/libafl/src/bolts/os/unix_shmem_server.rs @@ -105,6 +105,9 @@ where /// Send a request to the server, and wait for a response #[allow(clippy::similar_names)] // id and fd fn send_receive(&mut self, request: ServedShMemRequest) -> Result<(i32, i32), Error> { + //let bt = Backtrace::new(); + //println!("Sending {:?} with bt:\n{:?}", request, bt); + let body = postcard::to_allocvec(&request)?; let header = (body.len() as u32).to_be_bytes(); @@ -165,7 +168,7 @@ where id: -1, service, }; - let (id, _) = res.send_receive(ServedShMemRequest::Hello(None))?; + let (id, _) = res.send_receive(ServedShMemRequest::Hello())?; res.id = id; Ok(res) } @@ -196,16 +199,24 @@ where }) } + fn pre_fork(&mut self) -> Result<(), Error> { + self.send_receive(ServedShMemRequest::PreFork())?; + Ok(()) + } + fn post_fork(&mut self, is_child: bool) -> Result<(), Error> { if is_child { // After fork, only the parent keeps the join handle. if let ShMemService::Started { bg_thread, .. } = &mut self.service { bg_thread.borrow_mut().lock().unwrap().join_handle = None; } + //fn connect(&mut self) -> Result { + //self.stream = UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?, + // After fork, the child needs to reconnect as to not share the fds with the parent. self.stream = UnixStream::connect_to_unix_addr(&UnixSocketAddr::new(UNIX_SERVER_NAME)?)?; - let (id, _) = self.send_receive(ServedShMemRequest::Hello(Some(self.id)))?; + let (id, _) = self.send_receive(ServedShMemRequest::PostForkChildHello(self.id))?; self.id = id; } Ok(()) @@ -234,7 +245,11 @@ pub enum ServedShMemRequest { Deregister(i32), /// A message that tells us hello, and optionally which other client we were created from, we /// return a client id. - Hello(Option), + Hello(), + /// A client tells us that it's about to fork. Already clone all of the maps now so that they will be available by the time the child sends a [`ServedShMemRequest::PostForkChildHello`] request. + PreFork(), + /// The client's child re-registers with us after it forked. + PostForkChildHello(i32), /// The ShMem Service should exit. This is sually sent internally on `drop`, but feel free to do whatever with it? Exit, } @@ -417,12 +432,15 @@ where } /// The struct for the worker, handling incoming requests for [`ShMem`]. +#[allow(clippy::type_complexity)] struct ServedShMemServiceWorker where SP: ShMemProvider, { provider: SP, clients: HashMap>, + /// Maps from a pre-fork (parent) client id to its cloned maps. + forking_clients: HashMap>>>>, all_maps: HashMap>>, } @@ -436,28 +454,53 @@ where provider: SP::new()?, clients: HashMap::new(), all_maps: HashMap::new(), + forking_clients: HashMap::new(), }) } + fn upgrade_map_with_id(&mut self, description_id: i32) -> Rc> { + self.all_maps + .get_mut(&description_id) + .unwrap() + .clone() + .upgrade() + .unwrap() + } + /// Read and handle the client request, send the answer over unix fd. fn handle_request(&mut self, client_id: RawFd) -> Result, Error> { let request = self.read_request(client_id)?; - //println!("got ashmem client: {}, request:{:?}", client_id, request); + // println!("got ashmem client: {}, request:{:?}", client_id, request); + // Handle the client request let response = match request { - ServedShMemRequest::Hello(other_id) => { - if let Some(other_id) = other_id { - if other_id != client_id { - // remove temporarily - let other_client = self.clients.remove(&other_id); - let client = self.clients.get_mut(&client_id).unwrap(); - for (id, map) in other_client.as_ref().unwrap().maps.iter() { - client.maps.insert(*id, map.clone()); - } - self.clients.insert(other_id, other_client.unwrap()); - } - }; + ServedShMemRequest::Hello() => Ok(ServedShMemResponse::Id(client_id)), + ServedShMemRequest::PreFork() => { + // We clone the provider already, waiting for it to reconnect [`PostFork`]. + // That wa, even if the parent dies before the child sends its `PostFork`, we should be good. + // See issue https://github.com/AFLplusplus/LibAFL/issues/276 + //let forking_client = self.clients[&client_id].maps.clone(); + self.forking_clients + .insert(client_id, self.clients[&client_id].maps.clone()); + // Technically, no need to send the client_id here but it keeps the code easier. + + /* + // remove temporarily + let client = self.clients.remove(&client_id); + let mut forking_maps = HashMap::new(); + for (id, map) in client.as_ref().unwrap().maps.iter() { + forking_maps.insert(*id, map.clone()); + } + self.forking_clients.insert(client_id, forking_maps); + self.clients.insert(client_id, client.unwrap()); + */ + + Ok(ServedShMemResponse::Id(client_id)) + } + ServedShMemRequest::PostForkChildHello(other_id) => { + let client = self.clients.get_mut(&client_id).unwrap(); + client.maps = self.forking_clients.remove(&other_id).unwrap(); Ok(ServedShMemResponse::Id(client_id)) } ServedShMemRequest::NewMap(map_size) => { @@ -472,25 +515,25 @@ where let client = self.clients.get_mut(&client_id).unwrap(); let description_id: i32 = description.id.into(); if client.maps.contains_key(&description_id) { + // Using let else here as self needs to be accessed in the else branch. + #[allow(clippy::option_if_let_else)] Ok(ServedShMemResponse::Mapping( - client + if let Some(map) = client .maps .get_mut(&description_id) .as_mut() .unwrap() .first() .as_mut() - .unwrap() - .clone(), + { + map.clone() + } else { + self.upgrade_map_with_id(description_id) + }, )) } else { Ok(ServedShMemResponse::Mapping( - self.all_maps - .get_mut(&description_id) - .unwrap() - .clone() - .upgrade() - .unwrap(), + self.upgrade_map_with_id(description_id), )) } } @@ -511,7 +554,7 @@ where return Err(Error::ShuttingDown); } }; - //println!("send ashmem client: {}, response: {:?}", client_id, &response); + // println!("send ashmem client: {}, response: {:?}", client_id, &response); response } @@ -623,7 +666,7 @@ where } }; - // println!("Recieved connection from {:?}", addr); + println!("Recieved connection from {:?}", _addr); let pollfd = PollFd::new( stream.as_raw_fd(), PollFlags::POLLIN | PollFlags::POLLRDNORM | PollFlags::POLLRDBAND, diff --git a/libafl/src/bolts/shmem.rs b/libafl/src/bolts/shmem.rs index a4e9e75f35..3d08779228 100644 --- a/libafl/src/bolts/shmem.rs +++ b/libafl/src/bolts/shmem.rs @@ -367,8 +367,8 @@ where fn post_fork(&mut self, is_child: bool) -> Result<(), Error> { if is_child { self.await_parent_done()?; - let child_shmem = self.internal.borrow_mut().clone(); - self.internal = Rc::new(RefCell::new(child_shmem)); + //let child_shmem = self.internal.borrow_mut().clone(); + //self.internal = Rc::new(RefCell::new(child_shmem)); } self.internal.borrow_mut().post_fork(is_child)?; if is_child {