Move ShMem persisting flag to a new constructor (#2649)

* moving shmem persisting to take an owned value, adding test

* clean code updates

* adding imports conditionally

* fixing tests

* moving persistent mmap shmem to custom constructor

* excluding miri properly

* fixing formatting
This commit is contained in:
Valentin Huber 2024-11-03 03:13:10 +01:00 committed by GitHub
parent 89cff63702
commit d4fbe1754f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -838,85 +838,6 @@ pub mod unix_shmem {
pub fn filename_path(&self) -> &Option<[u8; MAX_MMAP_FILENAME_LEN]> { pub fn filename_path(&self) -> &Option<[u8; MAX_MMAP_FILENAME_LEN]> {
&self.filename_path &self.filename_path
} }
/// If called, the shared memory will also be available in subprocesses.
///
/// You likely want to pass the [`crate::shmem::ShMemDescription`] and reopen the shared memory in the child process using [`crate::shmem::ShMemProvider::shmem_from_description`].
///
/// # Errors
///
/// This function will return an error if the appropriate flags could not be extracted or set.
pub fn persist_for_child_processes(&self) -> Result<&Self, Error> {
// # Safety
// No user-provided potentially unsafe parameters.
// FFI Calls.
unsafe {
let flags = fcntl(self.shm_fd, libc::F_GETFD);
if flags == -1 {
return Err(Error::os_error(
io::Error::last_os_error(),
"Failed to retrieve FD flags",
));
}
if fcntl(self.shm_fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) == -1 {
return Err(Error::os_error(
io::Error::last_os_error(),
"Failed to set FD flags",
));
}
}
Ok(self)
}
}
/// A [`ShMemProvider`] which uses [`shm_open`] and [`mmap`] to provide shared memory mappings.
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct MmapShMemProvider {}
unsafe impl Send for MmapShMemProvider {}
#[cfg(unix)]
impl Default for MmapShMemProvider {
fn default() -> Self {
Self::new().unwrap()
}
}
/// Implement [`ShMemProvider`] for [`MmapShMemProvider`].
#[cfg(unix)]
impl ShMemProvider for MmapShMemProvider {
type ShMem = MmapShMem;
fn new() -> Result<Self, Error> {
Ok(Self {})
}
fn new_shmem(&mut self, map_size: usize) -> Result<Self::ShMem, Error> {
let mut rand = StdRand::with_seed(crate::rands::random_seed());
let id = rand.next() as u32;
MmapShMem::new(map_size, id)
}
fn shmem_from_id_and_size(
&mut self,
id: ShMemId,
size: usize,
) -> Result<Self::ShMem, Error> {
MmapShMem::shmem_from_id_and_size(id, size)
}
fn release_shmem(&mut self, shmem: &mut Self::ShMem) {
let fd = CStr::from_bytes_until_nul(shmem.id().as_array())
.unwrap()
.to_str()
.unwrap()
.parse()
.unwrap();
unsafe { close(fd) };
}
} }
impl ShMem for MmapShMem { impl ShMem for MmapShMem {
@ -974,6 +895,97 @@ pub mod unix_shmem {
} }
} }
/// A [`ShMemProvider`] which uses [`shm_open`] and [`mmap`] to provide shared memory mappings.
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct MmapShMemProvider {}
impl MmapShMemProvider {
/// Creates a new shared memory mapping, which is available in other processes.
///
/// Only available on UNIX systems at the moment.
///
/// You likely want to pass the [`crate::shmem::ShMemDescription`] of the returned [`ShMem`]
/// and reopen the shared memory in the child process using [`crate::shmem::ShMemProvider::shmem_from_description`].
///
/// # Errors
///
/// This function will return an error if the appropriate flags could not be extracted or set.
#[cfg(any(unix, doc))]
pub fn new_shmem_persistent(
&mut self,
map_size: usize,
) -> Result<<Self as ShMemProvider>::ShMem, Error> {
let shmem = self.new_shmem(map_size)?;
let fd = shmem.shm_fd;
// # Safety
// No user-provided potentially unsafe parameters.
// FFI Calls.
unsafe {
let flags = fcntl(fd, libc::F_GETFD);
if flags == -1 {
return Err(Error::os_error(
io::Error::last_os_error(),
"Failed to retrieve FD flags",
));
}
if fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) == -1 {
return Err(Error::os_error(
io::Error::last_os_error(),
"Failed to set FD flags",
));
}
}
Ok(shmem)
}
}
unsafe impl Send for MmapShMemProvider {}
#[cfg(unix)]
impl Default for MmapShMemProvider {
fn default() -> Self {
Self::new().unwrap()
}
}
/// Implement [`ShMemProvider`] for [`MmapShMemProvider`].
#[cfg(unix)]
impl ShMemProvider for MmapShMemProvider {
type ShMem = MmapShMem;
fn new() -> Result<Self, Error> {
Ok(Self {})
}
fn new_shmem(&mut self, map_size: usize) -> Result<Self::ShMem, Error> {
let mut rand = StdRand::with_seed(crate::rands::random_seed());
let id = rand.next() as u32;
MmapShMem::new(map_size, id)
}
fn shmem_from_id_and_size(
&mut self,
id: ShMemId,
size: usize,
) -> Result<Self::ShMem, Error> {
MmapShMem::shmem_from_id_and_size(id, size)
}
fn release_shmem(&mut self, shmem: &mut Self::ShMem) {
let fd = CStr::from_bytes_until_nul(shmem.id().as_array())
.unwrap()
.to_str()
.unwrap()
.parse()
.unwrap();
unsafe { close(fd) };
}
}
/// The default sharedmap impl for unix using shmctl & shmget /// The default sharedmap impl for unix using shmctl & shmget
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct CommonUnixShMem { pub struct CommonUnixShMem {
@ -1622,16 +1634,42 @@ mod tests {
use crate::{ use crate::{
shmem::{ShMemProvider, StdShMemProvider}, shmem::{ShMemProvider, StdShMemProvider},
AsSlice, AsSliceMut, AsSlice, AsSliceMut, Error,
}; };
#[test] #[test]
#[serial] #[serial]
#[cfg_attr(miri, ignore)] #[cfg_attr(miri, ignore)]
fn test_shmem_service() { fn test_shmem_service() -> Result<(), Error> {
let mut provider = StdShMemProvider::new().unwrap(); let mut provider = StdShMemProvider::new()?;
let mut map = provider.new_shmem(1024).unwrap(); let mut map = provider.new_shmem(1024)?;
map.as_slice_mut()[0] = 1; map.as_slice_mut()[0] = 1;
assert!(map.as_slice()[0] == 1); assert_eq!(1, map.as_slice()[0]);
Ok(())
}
#[test]
#[cfg(all(unix, not(miri)))]
#[cfg_attr(miri, ignore)]
fn test_persist_shmem() -> Result<(), Error> {
use std::thread;
use crate::shmem::{MmapShMemProvider, ShMem as _};
let mut provider = MmapShMemProvider::new()?;
let mut shmem = provider.new_shmem_persistent(1)?;
shmem.fill(0);
let description = shmem.description();
let handle = thread::spawn(move || -> Result<(), Error> {
let mut provider = MmapShMemProvider::new()?;
let mut shmem = provider.shmem_from_description(description)?;
shmem.as_slice_mut()[0] = 1;
Ok(())
});
handle.join().unwrap()?;
assert_eq!(1, shmem.as_slice()[0]);
Ok(())
} }
} }