StateRestorer.reset() removes old tmpfile (#242)

* StateRestorer.reset() removes old tmpfile

* checking map size on deref for extra safety

* clippy
This commit is contained in:
Dominik Maier 2021-08-04 15:13:54 +02:00 committed by GitHub
parent 3fac056b58
commit 92ba3f59f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,8 +6,10 @@ use postcard;
use serde::{de::DeserializeOwned, Serialize};
use std::{
env::temp_dir,
fs::File,
fs::{self, File},
io::{Read, Write},
path::PathBuf,
ptr::read_volatile,
};
use crate::{
@ -15,6 +17,39 @@ use crate::{
Error,
};
/// The struct stored on the shared map, containing either the data, or the filename to read contents from.
#[repr(C)]
struct StateShMemContent {
is_disk: bool,
buf_len: usize,
buf: [u8; 0],
}
impl StateShMemContent {
/// Gets the (tmp-)filename, if the contents are stored on disk.
pub fn tmpfile(&self, shmem_size: usize) -> Result<Option<PathBuf>, Error> {
Ok(if self.is_disk {
let bytes = unsafe {
slice::from_raw_parts(self.buf.as_ptr(), self.buf_len_checked(shmem_size)?)
};
let filename = postcard::from_bytes::<String>(bytes)?;
Some(temp_dir().join(&filename))
} else {
None
})
}
/// Get a length that's safe to deref from this map, or error.
pub fn buf_len_checked(&self, shmem_size: usize) -> Result<usize, Error> {
let buf_len = unsafe { read_volatile(&self.buf_len) };
if size_of::<StateShMemContent>() + buf_len > shmem_size {
Err(Error::IllegalState(format!("Stored buf_len is larger than the shared map! Shared data corrupted? Expected {} bytes max, but got {} (buf_len {})", shmem_size, size_of::<StateShMemContent>() + buf_len, buf_len)))
} else {
Ok(buf_len)
}
}
}
/// A [`StateRestorer`] saves and restores bytes to a shared map.
/// If the state gets larger than the preallocated [`ShMem`] shared map,
/// it will instead write to disk, and store the file name into the map.
@ -28,17 +63,15 @@ where
phantom: PhantomData<*const SP>,
}
#[repr(C)]
struct StateShMemContent {
is_disk: bool,
buf_len: usize,
buf: [u8; 0],
}
impl<SP> StateRestorer<SP>
where
SP: ShMemProvider,
{
/// Get the map size backing this [`StateRestorer`].
pub fn mapsize(&self) -> usize {
self.shmem.len()
}
/// Writes this [`StateRestorer`] to env variable, to be restored later
pub fn write_to_env(&self, env_name: &str) -> Result<(), Error> {
self.shmem.write_to_env(env_name)
@ -134,7 +167,12 @@ where
/// Reset this [`StateRestorer`] to an empty state.
pub fn reset(&mut self) {
let mapsize = self.mapsize();
let content_mut = self.content_mut();
if let Ok(Some(tmpfile)) = content_mut.tmpfile(mapsize) {
// Remove tmpfile and ignore result
drop(fs::remove_file(tmpfile));
}
content_mut.is_disk = false;
content_mut.buf_len = 0;
}
@ -160,6 +198,7 @@ where
}
/// Restores the contents saved in this [`StateRestorer`], if any are availiable.
/// Can only be read once.
pub fn restore<S>(&self) -> Result<Option<S>, Error>
where
S: DeserializeOwned,
@ -171,7 +210,7 @@ where
let bytes = unsafe {
slice::from_raw_parts(
state_shmem_content.buf.as_ptr(),
state_shmem_content.buf_len,
state_shmem_content.buf_len_checked(self.mapsize())?,
)
};
let mut state = bytes;
@ -198,9 +237,11 @@ where
#[cfg(test)]
mod tests {
use crate::bolts::shmem::{ShMemProvider, StdShMemProvider};
use super::StateRestorer;
use crate::bolts::{
shmem::{ShMemProvider, StdShMemProvider},
staterestore::StateRestorer,
};
#[test]
fn test_state_restore() {
@ -218,10 +259,12 @@ mod tests {
let restored = state_restorer.restore::<String>().unwrap().unwrap();
println!("Restored {}", restored);
assert_eq!(restored, "hello world");
assert!(!state_restorer.content().is_disk);
state_restorer.reset();
assert!(!state_restorer.has_content());
assert!(!state_restorer.content().is_disk);
assert!(state_restorer.restore::<String>().unwrap().is_none());
let too_large = vec![4u8; TESTMAP_SIZE + 1];
@ -232,5 +275,20 @@ mod tests {
assert_eq!(large_restored, too_large);
assert_eq!(large_restored.len(), too_large.len());
assert_eq!(large_restored[TESTMAP_SIZE], 4u8);
assert!(state_restorer.content().is_disk);
assert_ne!(state_restorer.content().buf_len, 0);
// Check if file removal works.
let state_shmem_content = state_restorer.content();
let tmpfile = state_shmem_content
.tmpfile(state_restorer.mapsize())
.unwrap()
.unwrap();
assert!(tmpfile.exists());
state_restorer.reset();
assert!(!state_restorer.has_content());
assert!(!tmpfile.exists());
}
}