diff --git a/libafl/src/executors/forkserver.rs b/libafl/src/executors/forkserver.rs index c3b430419e..3f5b3d2ed2 100644 --- a/libafl/src/executors/forkserver.rs +++ b/libafl/src/executors/forkserver.rs @@ -89,6 +89,9 @@ const FS_ERROR_OLD_CMPLOG: i32 = 32_u32 as i32; #[allow(clippy::cast_possible_wrap)] const FS_ERROR_OLD_CMPLOG_QEMU: i32 = 64_u32 as i32; +/// Forkserver message. We'll reuse it in a testcase. +const FAILED_TO_START_FORKSERVER_MSG: &str = "Failed to start forkserver"; + fn report_error_and_exit(status: i32) -> Result<(), Error> { /* Report on the error received via the forkserver controller and exit */ match status { @@ -488,27 +491,53 @@ impl Forkserver { } /// Read from the st pipe - pub fn read_st(&mut self) -> Result<(usize, i32), Error> { + pub fn read_st(&mut self) -> Result { let mut buf: [u8; 4] = [0_u8; 4]; - let rlen = self.st_pipe.read(&mut buf)?; - let val: i32 = i32::from_ne_bytes(buf); - Ok((rlen, val)) + if rlen == size_of::() { + Ok(i32::from_ne_bytes(buf)) + } else { + // NOTE: The underlying API does not guarantee that the read will return + // exactly four bytes, but the chance of this happening is very low. + // This is a sacrifice of correctness for performance. + Err(Error::illegal_state(format!( + "Could not read from st pipe. Expected {} bytes, got {rlen} bytes", + size_of::() + ))) + } } /// Read bytes of any length from the st pipe - pub fn read_st_size(&mut self, size: usize) -> Result<(usize, Vec), Error> { - let mut buf = vec![0; size]; - - let rlen = self.st_pipe.read(&mut buf)?; - Ok((rlen, buf)) + pub fn read_st_of_len(&mut self, size: usize) -> Result, Error> { + let mut buf = Vec::with_capacity(size); + // SAFETY: `buf` will not be returned with `Ok` unless it is filled with `size` bytes. + // So it is ok to set the length to `size` such that the length of `&mut buf` is `size` + // and the `read_exact` call will try to read `size` bytes. + #[allow( + clippy::uninit_vec, + reason = "The vec will be filled right after setting the length." + )] + unsafe { + buf.set_len(size); + } + self.st_pipe.read_exact(&mut buf)?; + Ok(buf) } /// Write to the ctl pipe - pub fn write_ctl(&mut self, val: i32) -> Result { + pub fn write_ctl(&mut self, val: i32) -> Result<(), Error> { let slen = self.ctl_pipe.write(&val.to_ne_bytes())?; - - Ok(slen) + if slen == size_of::() { + Ok(()) + } else { + // NOTE: The underlying API does not guarantee that exactly four bytes + // are written, but the chance of this happening is very low. + // This is a sacrifice of correctness for performance. + Err(Error::illegal_state(format!( + "Could not write to ctl pipe. Expected {} bytes, wrote {slen} bytes", + size_of::() + ))) + } } /// Read a message from the child process. @@ -846,11 +875,10 @@ where } }; - let (rlen, version_status) = forkserver.read_st()?; // Initial handshake, read 4-bytes hello message from the forkserver. - - if rlen != 4 { - return Err(Error::unknown("Failed to start a forkserver".to_string())); - } + // Initial handshake, read 4-bytes hello message from the forkserver. + let version_status = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!("{FAILED_TO_START_FORKSERVER_MSG}: {err:?}")) + })?; if (version_status & FS_NEW_ERROR) == FS_NEW_ERROR { report_error_and_exit(version_status & 0x0000ffff)?; @@ -882,13 +910,13 @@ where let version: u32 = status as u32 - 0x41464c00_u32; match version { 0 => { - return Err(Error::unknown("Fork server version is not assigned, this should not happen. Recompile target.")); + return Err(Error::illegal_state("Fork server version is not assigned, this should not happen. Recompile target.")); } FS_NEW_VERSION_MIN..=FS_NEW_VERSION_MAX => { // good, do nothing } _ => { - return Err(Error::unknown( + return Err(Error::illegal_state( "Fork server version is not supported. Recompile the target.", )); } @@ -896,9 +924,10 @@ where let xored_status = (status as u32 ^ 0xffffffff) as i32; - let send_len = forkserver.write_ctl(xored_status)?; - if send_len != 4 { - return Err(Error::unknown("Writing to forkserver failed.".to_string())); + if let Err(err) = forkserver.write_ctl(xored_status) { + return Err(Error::illegal_state(format!( + "Writing to forkserver failed: {err:?}" + ))); } log::info!( @@ -906,20 +935,14 @@ where version ); - let (read_len, status) = forkserver.read_st()?; - if read_len != 4 { - return Err(Error::unknown( - "Reading from forkserver failed.".to_string(), - )); - } + let status = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!("Reading from forkserver failed: {err:?}")) + })?; if status & FS_NEW_OPT_MAPSIZE == FS_NEW_OPT_MAPSIZE { - let (read_len, fsrv_map_size) = forkserver.read_st()?; - if read_len != 4 { - return Err(Error::unknown( - "Failed to read map size from forkserver".to_string(), - )); - } + let fsrv_map_size = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!("Failed to read map size from forkserver: {err:?}")) + })?; self.set_map_size(fsrv_map_size)?; } @@ -928,7 +951,7 @@ where log::info!("Using SHARED MEMORY FUZZING feature."); self.uses_shmem_testcase = true; } else { - return Err(Error::unknown( + return Err(Error::illegal_state( "Target requested sharedmem fuzzing, but you didn't prepare shmem", )); } @@ -937,12 +960,11 @@ where if status & FS_NEW_OPT_AUTODICT != 0 { // Here unlike shmem input fuzzing, we are forced to read things // hence no self.autotokens.is_some() to check if we proceed - let (read_len, autotokens_size) = forkserver.read_st()?; - if read_len != 4 { - return Err(Error::unknown( - "Failed to read autotokens size from forkserver".to_string(), - )); - } + let autotokens_size = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!( + "Failed to read autotokens size from forkserver: {err:?}", + )) + })?; let tokens_size_max = 0xffffff; @@ -952,20 +974,19 @@ where )); } log::info!("Autotokens size {autotokens_size:x}"); - let (rlen, buf) = forkserver.read_st_size(autotokens_size as usize)?; - - if rlen != autotokens_size as usize { - return Err(Error::unknown("Failed to load autotokens".to_string())); - } + let buf = forkserver + .read_st_of_len(autotokens_size as usize) + .map_err(|err| { + Error::illegal_state(format!("Failed to load autotokens: {err:?}")) + })?; if let Some(t) = &mut self.autotokens { t.parse_autodict(&buf, autotokens_size as usize); } } - let (read_len, aflx) = forkserver.read_st()?; - if read_len != 4 { - return Err(Error::unknown("Reading from forkserver failed".to_string())); - } + let aflx = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!("Reading from forkserver failed: {err:?}")) + })?; if aflx != keep { return Err(Error::unknown(format!( @@ -1015,18 +1036,16 @@ where // if send_status is not changed (Options are available but we didn't use any), then don't send the next write_ctl message. // This is important - let send_len = forkserver.write_ctl(send_status)?; - if send_len != 4 { - return Err(Error::unknown("Writing to forkserver failed.".to_string())); + if let Err(err) = forkserver.write_ctl(send_status) { + return Err(Error::illegal_state(format!( + "Writing to forkserver failed: {err:?}" + ))); } if (send_status & FS_OPT_AUTODICT) == FS_OPT_AUTODICT { - let (read_len, dict_size) = forkserver.read_st()?; - if read_len != 4 { - return Err(Error::unknown( - "Reading from forkserver failed.".to_string(), - )); - } + let dict_size = forkserver.read_st().map_err(|err| { + Error::illegal_state(format!("Reading from forkserver failed: {err:?}")) + })?; if !(2..=0xffffff).contains(&dict_size) { return Err(Error::illegal_state( @@ -1036,11 +1055,11 @@ where log::info!("Autodict size {dict_size:x}"); - let (rlen, buf) = forkserver.read_st_size(dict_size as usize)?; - - if rlen != dict_size as usize { - return Err(Error::unknown("Failed to load autodictionary".to_string())); - } + let buf = forkserver + .read_st_of_len(dict_size as usize) + .map_err(|err| { + Error::unknown(format!("Failed to load autodictionary: {err:?}")) + })?; if let Some(t) = &mut self.autotokens { t.parse_autodict(&buf, dict_size as usize); } @@ -1422,22 +1441,18 @@ where .write_buf(&input_bytes.as_slice()[..input_size])?; } - let send_len = self.forkserver.write_ctl(last_run_timed_out)?; - self.forkserver.set_last_run_timed_out(false); - - if send_len != 4 { - return Err(Error::unknown( - "Unable to request new process from fork server (OOM?)".to_string(), - )); + if let Err(err) = self.forkserver.write_ctl(last_run_timed_out) { + return Err(Error::unknown(format!( + "Unable to request new process from fork server (OOM?): {err:?}" + ))); } - let (recv_pid_len, pid) = self.forkserver.read_st()?; - if recv_pid_len != 4 { - return Err(Error::unknown( - "Unable to request new process from fork server (OOM?)".to_string(), - )); - } + let pid = self.forkserver.read_st().map_err(|err| { + Error::unknown(format!( + "Unable to request new process from fork server (OOM?): {err:?}" + )) + })?; if pid <= 0 { return Err(Error::unknown( @@ -1466,9 +1481,10 @@ where // We need to kill the child in case he has timed out, or we can't get the correct pid in the next call to self.executor.forkserver_mut().read_st()? let _ = kill(self.forkserver().child_pid(), self.forkserver.kill_signal); - let (recv_status_len, _) = self.forkserver.read_st()?; - if recv_status_len != 4 { - return Err(Error::unknown("Could not kill timed-out child".to_string())); + if let Err(err) = self.forkserver.read_st() { + return Err(Error::unknown(format!( + "Could not kill timed-out child: {err:?}" + ))); } exit_kind = ExitKind::Timeout; } @@ -1520,7 +1536,7 @@ mod tests { use serial_test::serial; use crate::{ - executors::forkserver::ForkserverExecutor, + executors::forkserver::{ForkserverExecutor, FAILED_TO_START_FORKSERVER_MSG}, observers::{ConstMapObserver, HitcountsMapObserver}, Error, }; @@ -1555,10 +1571,13 @@ mod tests { // Since /usr/bin/echo is not a instrumented binary file, the test will just check if the forkserver has failed at the initial handshake let result = match executor { Ok(_) => true, - Err(e) => match e { - Error::Unknown(s, _) => s == "Failed to start a forkserver", - _ => false, - }, + Err(e) => { + println!("Error: {e:?}"); + match e { + Error::IllegalState(s, _) => s.contains(FAILED_TO_START_FORKSERVER_MSG), + _ => false, + } + } }; assert!(result); }