diff --git a/libafl/src/bolts/staterestore.rs b/libafl/src/bolts/staterestore.rs index 350d98dc53..10562f7ca2 100644 --- a/libafl/src/bolts/staterestore.rs +++ b/libafl/src/bolts/staterestore.rs @@ -6,8 +6,10 @@ use postcard; use serde::{de::DeserializeOwned, Serialize}; use std::{ env::temp_dir, - fs::File, + fs::{self, File}, io::{Read, Write}, + path::PathBuf, + ptr::read_volatile, }; use crate::{ @@ -15,6 +17,39 @@ use crate::{ Error, }; +/// The struct stored on the shared map, containing either the data, or the filename to read contents from. +#[repr(C)] +struct StateShMemContent { + is_disk: bool, + buf_len: usize, + buf: [u8; 0], +} + +impl StateShMemContent { + /// Gets the (tmp-)filename, if the contents are stored on disk. + pub fn tmpfile(&self, shmem_size: usize) -> Result, Error> { + Ok(if self.is_disk { + let bytes = unsafe { + slice::from_raw_parts(self.buf.as_ptr(), self.buf_len_checked(shmem_size)?) + }; + let filename = postcard::from_bytes::(bytes)?; + Some(temp_dir().join(&filename)) + } else { + None + }) + } + + /// Get a length that's safe to deref from this map, or error. + pub fn buf_len_checked(&self, shmem_size: usize) -> Result { + let buf_len = unsafe { read_volatile(&self.buf_len) }; + if size_of::() + buf_len > shmem_size { + Err(Error::IllegalState(format!("Stored buf_len is larger than the shared map! Shared data corrupted? Expected {} bytes max, but got {} (buf_len {})", shmem_size, size_of::() + buf_len, buf_len))) + } else { + Ok(buf_len) + } + } +} + /// A [`StateRestorer`] saves and restores bytes to a shared map. /// If the state gets larger than the preallocated [`ShMem`] shared map, /// it will instead write to disk, and store the file name into the map. @@ -28,17 +63,15 @@ where phantom: PhantomData<*const SP>, } -#[repr(C)] -struct StateShMemContent { - is_disk: bool, - buf_len: usize, - buf: [u8; 0], -} - impl StateRestorer where SP: ShMemProvider, { + /// Get the map size backing this [`StateRestorer`]. + pub fn mapsize(&self) -> usize { + self.shmem.len() + } + /// Writes this [`StateRestorer`] to env variable, to be restored later pub fn write_to_env(&self, env_name: &str) -> Result<(), Error> { self.shmem.write_to_env(env_name) @@ -134,7 +167,12 @@ where /// Reset this [`StateRestorer`] to an empty state. pub fn reset(&mut self) { + let mapsize = self.mapsize(); let content_mut = self.content_mut(); + if let Ok(Some(tmpfile)) = content_mut.tmpfile(mapsize) { + // Remove tmpfile and ignore result + drop(fs::remove_file(tmpfile)); + } content_mut.is_disk = false; content_mut.buf_len = 0; } @@ -160,6 +198,7 @@ where } /// Restores the contents saved in this [`StateRestorer`], if any are availiable. + /// Can only be read once. pub fn restore(&self) -> Result, Error> where S: DeserializeOwned, @@ -171,7 +210,7 @@ where let bytes = unsafe { slice::from_raw_parts( state_shmem_content.buf.as_ptr(), - state_shmem_content.buf_len, + state_shmem_content.buf_len_checked(self.mapsize())?, ) }; let mut state = bytes; @@ -198,9 +237,11 @@ where #[cfg(test)] mod tests { - use crate::bolts::shmem::{ShMemProvider, StdShMemProvider}; - use super::StateRestorer; + use crate::bolts::{ + shmem::{ShMemProvider, StdShMemProvider}, + staterestore::StateRestorer, + }; #[test] fn test_state_restore() { @@ -218,10 +259,12 @@ mod tests { let restored = state_restorer.restore::().unwrap().unwrap(); println!("Restored {}", restored); assert_eq!(restored, "hello world"); + assert!(!state_restorer.content().is_disk); state_restorer.reset(); assert!(!state_restorer.has_content()); + assert!(!state_restorer.content().is_disk); assert!(state_restorer.restore::().unwrap().is_none()); let too_large = vec![4u8; TESTMAP_SIZE + 1]; @@ -232,5 +275,20 @@ mod tests { assert_eq!(large_restored, too_large); assert_eq!(large_restored.len(), too_large.len()); assert_eq!(large_restored[TESTMAP_SIZE], 4u8); + + assert!(state_restorer.content().is_disk); + assert_ne!(state_restorer.content().buf_len, 0); + + // Check if file removal works. + let state_shmem_content = state_restorer.content(); + let tmpfile = state_shmem_content + .tmpfile(state_restorer.mapsize()) + .unwrap() + .unwrap(); + assert!(tmpfile.exists()); + + state_restorer.reset(); + assert!(!state_restorer.has_content()); + assert!(!tmpfile.exists()); } }