diff --git a/libafl/src/bolts/staterestore.rs b/libafl/src/bolts/staterestore.rs index 39e2cec551..3c262230a2 100644 --- a/libafl/src/bolts/staterestore.rs +++ b/libafl/src/bolts/staterestore.rs @@ -101,6 +101,7 @@ where ); } shmem_content.buf_len = len; + shmem_content.is_disk = true; } else { // write to shmem directly let len = serialized.len(); @@ -113,6 +114,7 @@ where ); } shmem_content.buf_len = len; + shmem_content.is_disk = false; }; Ok(()) } @@ -146,7 +148,7 @@ where where S: DeserializeOwned, { - if self.has_content() { + if !self.has_content() { return Ok(Option::None); } let state_shmem_content = self.content(); @@ -177,3 +179,42 @@ where Ok(Some(deserialized)) } } + +#[cfg(test)] +mod tests { + use crate::bolts::shmem::{ShMemProvider, StdShMemProvider}; + + use super::StateRestorer; + + #[test] + fn test_state_restore() { + const TESTMAP_SIZE: usize = 1024; + + let mut shmem_provider = StdShMemProvider::new().unwrap(); + let shmem = shmem_provider.new_map(TESTMAP_SIZE).unwrap(); + let mut state_restorer = StateRestorer::::new(shmem); + + let state = "hello world".to_string(); + + state_restorer.save(&state).unwrap(); + + assert!(state_restorer.has_content()); + let restored = state_restorer.restore::().unwrap().unwrap(); + println!("Restored {}", restored); + assert_eq!(restored, "hello world"); + + state_restorer.reset(); + + assert!(!state_restorer.has_content()); + assert!(state_restorer.restore::().unwrap().is_none()); + + let too_large = vec![4u8; TESTMAP_SIZE + 1]; + state_restorer.save(&too_large).unwrap(); + assert!(state_restorer.has_content()); + + let large_restored = state_restorer.restore::>().unwrap().unwrap(); + assert_eq!(large_restored, too_large); + assert_eq!(large_restored.len(), too_large.len()); + assert_eq!(large_restored[TESTMAP_SIZE], 4u8); + } +}