StateRestorer.reset() removes old tmpfile (#242)
* StateRestorer.reset() removes old tmpfile * checking map size on deref for extra safety * clippy
This commit is contained in:
parent
3fac056b58
commit
92ba3f59f9
@ -6,8 +6,10 @@ use postcard;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{
|
||||
env::temp_dir,
|
||||
fs::File,
|
||||
fs::{self, File},
|
||||
io::{Read, Write},
|
||||
path::PathBuf,
|
||||
ptr::read_volatile,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@ -15,6 +17,39 @@ use crate::{
|
||||
Error,
|
||||
};
|
||||
|
||||
/// The struct stored on the shared map, containing either the data, or the filename to read contents from.
|
||||
#[repr(C)]
|
||||
struct StateShMemContent {
|
||||
is_disk: bool,
|
||||
buf_len: usize,
|
||||
buf: [u8; 0],
|
||||
}
|
||||
|
||||
impl StateShMemContent {
|
||||
/// Gets the (tmp-)filename, if the contents are stored on disk.
|
||||
pub fn tmpfile(&self, shmem_size: usize) -> Result<Option<PathBuf>, Error> {
|
||||
Ok(if self.is_disk {
|
||||
let bytes = unsafe {
|
||||
slice::from_raw_parts(self.buf.as_ptr(), self.buf_len_checked(shmem_size)?)
|
||||
};
|
||||
let filename = postcard::from_bytes::<String>(bytes)?;
|
||||
Some(temp_dir().join(&filename))
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a length that's safe to deref from this map, or error.
|
||||
pub fn buf_len_checked(&self, shmem_size: usize) -> Result<usize, Error> {
|
||||
let buf_len = unsafe { read_volatile(&self.buf_len) };
|
||||
if size_of::<StateShMemContent>() + buf_len > shmem_size {
|
||||
Err(Error::IllegalState(format!("Stored buf_len is larger than the shared map! Shared data corrupted? Expected {} bytes max, but got {} (buf_len {})", shmem_size, size_of::<StateShMemContent>() + buf_len, buf_len)))
|
||||
} else {
|
||||
Ok(buf_len)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`StateRestorer`] saves and restores bytes to a shared map.
|
||||
/// If the state gets larger than the preallocated [`ShMem`] shared map,
|
||||
/// it will instead write to disk, and store the file name into the map.
|
||||
@ -28,17 +63,15 @@ where
|
||||
phantom: PhantomData<*const SP>,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct StateShMemContent {
|
||||
is_disk: bool,
|
||||
buf_len: usize,
|
||||
buf: [u8; 0],
|
||||
}
|
||||
|
||||
impl<SP> StateRestorer<SP>
|
||||
where
|
||||
SP: ShMemProvider,
|
||||
{
|
||||
/// Get the map size backing this [`StateRestorer`].
|
||||
pub fn mapsize(&self) -> usize {
|
||||
self.shmem.len()
|
||||
}
|
||||
|
||||
/// Writes this [`StateRestorer`] to env variable, to be restored later
|
||||
pub fn write_to_env(&self, env_name: &str) -> Result<(), Error> {
|
||||
self.shmem.write_to_env(env_name)
|
||||
@ -134,7 +167,12 @@ where
|
||||
|
||||
/// Reset this [`StateRestorer`] to an empty state.
|
||||
pub fn reset(&mut self) {
|
||||
let mapsize = self.mapsize();
|
||||
let content_mut = self.content_mut();
|
||||
if let Ok(Some(tmpfile)) = content_mut.tmpfile(mapsize) {
|
||||
// Remove tmpfile and ignore result
|
||||
drop(fs::remove_file(tmpfile));
|
||||
}
|
||||
content_mut.is_disk = false;
|
||||
content_mut.buf_len = 0;
|
||||
}
|
||||
@ -160,6 +198,7 @@ where
|
||||
}
|
||||
|
||||
/// Restores the contents saved in this [`StateRestorer`], if any are availiable.
|
||||
/// Can only be read once.
|
||||
pub fn restore<S>(&self) -> Result<Option<S>, Error>
|
||||
where
|
||||
S: DeserializeOwned,
|
||||
@ -171,7 +210,7 @@ where
|
||||
let bytes = unsafe {
|
||||
slice::from_raw_parts(
|
||||
state_shmem_content.buf.as_ptr(),
|
||||
state_shmem_content.buf_len,
|
||||
state_shmem_content.buf_len_checked(self.mapsize())?,
|
||||
)
|
||||
};
|
||||
let mut state = bytes;
|
||||
@ -198,9 +237,11 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bolts::shmem::{ShMemProvider, StdShMemProvider};
|
||||
|
||||
use super::StateRestorer;
|
||||
use crate::bolts::{
|
||||
shmem::{ShMemProvider, StdShMemProvider},
|
||||
staterestore::StateRestorer,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_state_restore() {
|
||||
@ -218,10 +259,12 @@ mod tests {
|
||||
let restored = state_restorer.restore::<String>().unwrap().unwrap();
|
||||
println!("Restored {}", restored);
|
||||
assert_eq!(restored, "hello world");
|
||||
assert!(!state_restorer.content().is_disk);
|
||||
|
||||
state_restorer.reset();
|
||||
|
||||
assert!(!state_restorer.has_content());
|
||||
assert!(!state_restorer.content().is_disk);
|
||||
assert!(state_restorer.restore::<String>().unwrap().is_none());
|
||||
|
||||
let too_large = vec![4u8; TESTMAP_SIZE + 1];
|
||||
@ -232,5 +275,20 @@ mod tests {
|
||||
assert_eq!(large_restored, too_large);
|
||||
assert_eq!(large_restored.len(), too_large.len());
|
||||
assert_eq!(large_restored[TESTMAP_SIZE], 4u8);
|
||||
|
||||
assert!(state_restorer.content().is_disk);
|
||||
assert_ne!(state_restorer.content().buf_len, 0);
|
||||
|
||||
// Check if file removal works.
|
||||
let state_shmem_content = state_restorer.content();
|
||||
let tmpfile = state_shmem_content
|
||||
.tmpfile(state_restorer.mapsize())
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(tmpfile.exists());
|
||||
|
||||
state_restorer.reset();
|
||||
assert!(!state_restorer.has_content());
|
||||
assert!(!tmpfile.exists());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user