Implement Hash for MapObserver (#1989)

* MapObserver implements Hash

* Rename the hash utility function (in MapObserver) to hash_easy

* Use hash_slice as a helper function to impl Hash trait

* define_python_map_observer macro implements Hash trait

* Also rename hash_easy to hash_simple

* Rename hash_slice to hash_helper

* hash_helper is used to define the implementation of hash function/trait

* Factor out the Hash trait and function for runtime library structs (#1977)

* Simplify hash_simple (of trait MapObserver) (#1977)

 * Use hash_one function to make hash_simple a one-liner

* remove hash_helper

---------

Co-authored-by: Edwin Fernando <ef322@ic.ac.uk>
Co-authored-by: Addison Crump <addison.crump@cispa.de>
This commit is contained in:
edwin1729 2024-04-19 14:06:14 +01:00 committed by GitHub
parent 04cd792df2
commit c238b69498
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 172 additions and 66 deletions

View File

@ -6,7 +6,7 @@ use alloc::{
}; };
use core::{ use core::{
fmt::Debug, fmt::Debug,
hash::{BuildHasher, Hasher}, hash::{Hash, Hasher},
iter::Flatten, iter::Flatten,
marker::PhantomData, marker::PhantomData,
mem::size_of, mem::size_of,
@ -69,17 +69,6 @@ fn init_count_class_16() {
} }
} }
/// Compute the hash of a slice
fn hash_slice<T>(slice: &[T]) -> u64 {
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
hasher.write(slice::from_raw_parts(ptr, map_size));
}
hasher.finish()
}
/// Trait marker which indicates that this [`MapObserver`] is tracked for indices or novelties. /// Trait marker which indicates that this [`MapObserver`] is tracked for indices or novelties.
/// Implementors of feedbacks similar to [`crate::feedbacks::MapFeedback`] may wish to use this to /// Implementors of feedbacks similar to [`crate::feedbacks::MapFeedback`] may wish to use this to
/// ensure that edge metadata is recorded as is appropriate for the provided observer. /// ensure that edge metadata is recorded as is appropriate for the provided observer.
@ -420,12 +409,12 @@ pub mod macros {
/// ///
/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize /// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize
pub trait MapObserver: pub trait MapObserver:
HasLen + Named + Serialize + serde::de::DeserializeOwned + AsRef<Self> + AsMut<Self> HasLen + Named + Serialize + serde::de::DeserializeOwned + AsRef<Self> + AsMut<Self> + Hash
// where // where
// for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry> // for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry>
{ {
/// Type of each entry in this map /// Type of each entry in this map
type Entry: Bounded + PartialEq + Default + Copy + Debug + 'static; type Entry: Bounded + PartialEq + Default + Copy + Debug + Hash + 'static;
/// Get the value at `idx` /// Get the value at `idx`
fn get(&self, idx: usize) -> &Self::Entry; fn get(&self, idx: usize) -> &Self::Entry;
@ -439,8 +428,8 @@ pub trait MapObserver:
/// Count the set bytes in the map /// Count the set bytes in the map
fn count_bytes(&self) -> u64; fn count_bytes(&self) -> u64;
/// Compute the hash of the map /// Compute the hash of the map without needing to provide a hasher
fn hash(&self) -> u64; fn hash_simple(&self) -> u64;
/// Get the initial value for `reset()` /// Get the initial value for `reset()`
fn initial(&self) -> Self::Entry; fn initial(&self) -> Self::Entry;
@ -553,6 +542,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -604,6 +594,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -624,6 +615,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -644,6 +636,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -665,6 +658,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -685,6 +679,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -701,6 +696,24 @@ where
} }
} }
impl<'a, T, const DIFFERENTIAL: bool> Hash for StdMapObserver<'a, T, DIFFERENTIAL>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ Hash
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug,
{
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.as_slice().hash(hasher);
}
}
impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for StdMapObserver<'a, T, DIFFERENTIAL> impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for StdMapObserver<'a, T, DIFFERENTIAL>
where where
T: Default + Copy + 'static + Serialize, T: Default + Copy + 'static + Serialize,
@ -725,6 +738,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -761,8 +775,9 @@ where
self.as_slice().len() self.as_slice().len()
} }
fn hash(&self) -> u64 { #[inline]
hash_slice(self.as_slice()) fn hash_simple(&self) -> u64 {
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
} }
#[inline] #[inline]
@ -1108,6 +1123,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1128,6 +1144,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1148,6 +1165,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1168,6 +1186,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1188,6 +1207,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1204,6 +1224,23 @@ where
} }
} }
impl<'a, T, const N: usize> Hash for ConstMapObserver<'a, T, N>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ Hash
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug,
{
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.as_slice().hash(hasher);
}
}
impl<'a, T, const N: usize> AsRef<Self> for ConstMapObserver<'a, T, N> impl<'a, T, const N: usize> AsRef<Self> for ConstMapObserver<'a, T, N>
where where
T: Default + Copy + 'static + Serialize, T: Default + Copy + 'static + Serialize,
@ -1228,6 +1265,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1268,8 +1306,9 @@ where
self.as_slice().len() self.as_slice().len()
} }
fn hash(&self) -> u64 { #[inline]
hash_slice(self.as_slice()) fn hash_simple(&self) -> u64 {
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
} }
/// Reset the map /// Reset the map
@ -1429,6 +1468,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1451,6 +1491,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1473,6 +1514,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1495,6 +1537,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1517,6 +1560,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1535,6 +1579,25 @@ where
} }
} }
impl<'a, T> Hash for VariableMapObserver<'a, T>
where
T: Bounded
+ PartialEq
+ Default
+ Copy
+ Hash
+ 'static
+ Serialize
+ serde::de::DeserializeOwned
+ Debug
+ PartialEq
+ Bounded,
{
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.as_slice().hash(hasher);
}
}
impl<'a, T> AsRef<Self> for VariableMapObserver<'a, T> impl<'a, T> AsRef<Self> for VariableMapObserver<'a, T>
where where
T: Default + Copy + 'static + Serialize + PartialEq + Bounded, T: Default + Copy + 'static + Serialize + PartialEq + Bounded,
@ -1559,6 +1622,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1599,8 +1663,10 @@ where
} }
res res
} }
fn hash(&self) -> u64 {
hash_slice(self.as_slice()) #[inline]
fn hash_simple(&self) -> u64 {
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
} }
/// Reset the map /// Reset the map
@ -1640,6 +1706,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ 'static + 'static
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
@ -1660,6 +1727,7 @@ where
T: 'static T: 'static
+ Default + Default
+ Copy + Copy
+ Hash
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
+ Debug + Debug
@ -1719,7 +1787,7 @@ where
/// ///
/// [`MapObserver`]s that are not slice-backed, such as [`MultiMapObserver`], can use /// [`MapObserver`]s that are not slice-backed, such as [`MultiMapObserver`], can use
/// [`HitcountsIterableMapObserver`] instead. /// [`HitcountsIterableMapObserver`] instead.
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug, Hash)]
#[serde(bound = "M: serde::de::DeserializeOwned")] #[serde(bound = "M: serde::de::DeserializeOwned")]
pub struct HitcountsMapObserver<M> pub struct HitcountsMapObserver<M>
where where
@ -1863,8 +1931,9 @@ where
self.base.reset_map() self.base.reset_map()
} }
fn hash(&self) -> u64 { #[inline]
self.base.hash() fn hash_simple(&self) -> u64 {
self.base.hash_simple()
} }
fn to_vec(&self) -> Vec<u8> { fn to_vec(&self) -> Vec<u8> {
self.base.to_vec() self.base.to_vec()
@ -2019,7 +2088,7 @@ where
/// Map observer with hitcounts postprocessing /// Map observer with hitcounts postprocessing
/// Less optimized version for non-slice iterators. /// Less optimized version for non-slice iterators.
/// Slice-backed observers should use a [`HitcountsMapObserver`]. /// Slice-backed observers should use a [`HitcountsMapObserver`].
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug, Hash)]
#[serde(bound = "M: serde::de::DeserializeOwned")] #[serde(bound = "M: serde::de::DeserializeOwned")]
pub struct HitcountsIterableMapObserver<M> pub struct HitcountsIterableMapObserver<M>
where where
@ -2133,8 +2202,9 @@ where
self.base.reset_map() self.base.reset_map()
} }
fn hash(&self) -> u64 { #[inline]
self.base.hash() fn hash_simple(&self) -> u64 {
self.base.hash_simple()
} }
fn to_vec(&self) -> Vec<u8> { fn to_vec(&self) -> Vec<u8> {
self.base.to_vec() self.base.to_vec()
@ -2341,6 +2411,22 @@ where
} }
} }
impl<'a, T, const DIFFERENTIAL: bool> Hash for MultiMapObserver<'a, T, DIFFERENTIAL>
where
T: 'static + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
for map in &self.maps {
let slice = map.as_slice();
let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
hasher.write(slice::from_raw_parts(ptr, map_size));
}
}
}
}
impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for MultiMapObserver<'a, T, DIFFERENTIAL> impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for MultiMapObserver<'a, T, DIFFERENTIAL>
where where
T: 'static + Default + Copy + Serialize + Debug, T: 'static + Default + Copy + Serialize + Debug,
@ -2366,6 +2452,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
+ Debug, + Debug,
@ -2406,17 +2493,9 @@ where
res res
} }
fn hash(&self) -> u64 { #[inline]
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher(); fn hash_simple(&self) -> u64 {
for map in &self.maps { RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
let slice = map.as_slice();
let ptr = slice.as_ptr() as *const u8;
let map_size = slice.len() / size_of::<T>();
unsafe {
hasher.write(slice::from_raw_parts(ptr, map_size));
}
}
hasher.finish()
} }
fn reset_map(&mut self) -> Result<(), Error> { fn reset_map(&mut self) -> Result<(), Error> {
@ -2714,6 +2793,16 @@ where
} }
} }
impl<T> Hash for OwnedMapObserver<T>
where
T: 'static + Hash + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
{
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.as_slice().hash(hasher);
}
}
impl<T> AsRef<Self> for OwnedMapObserver<T> impl<T> AsRef<Self> for OwnedMapObserver<T>
where where
T: 'static + Default + Copy + Serialize, T: 'static + Default + Copy + Serialize,
@ -2739,6 +2828,7 @@ where
+ PartialEq + PartialEq
+ Default + Default
+ Copy + Copy
+ Hash
+ Serialize + Serialize
+ serde::de::DeserializeOwned + serde::de::DeserializeOwned
+ Debug, + Debug,
@ -2774,8 +2864,9 @@ where
self.as_slice().len() self.as_slice().len()
} }
fn hash(&self) -> u64 { #[inline]
hash_slice(self.as_slice()) fn hash_simple(&self) -> u64 {
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
} }
#[inline] #[inline]

View File

@ -122,7 +122,7 @@ where
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))? .ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?
.as_ref(); .as_ref();
let mut hash = observer.hash() as usize; let mut hash = observer.hash_simple() as usize;
let psmeta = state.metadata_mut::<SchedulerMetadata>()?; let psmeta = state.metadata_mut::<SchedulerMetadata>()?;

View File

@ -325,7 +325,7 @@ where
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))? .ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?
.as_ref(); .as_ref();
let hash = observer.hash() as usize; let hash = observer.hash_simple() as usize;
executor executor
.observers_mut() .observers_mut()

View File

@ -400,7 +400,7 @@ where
let obs = observers let obs = observers
.match_name::<M>(self.observer_name()) .match_name::<M>(self.observer_name())
.expect("Should have been provided valid observer name."); .expect("Should have been provided valid observer name.");
Ok(obs.hash() == self.orig_hash) Ok(obs.hash_simple() == self.orig_hash)
} }
} }
@ -443,7 +443,7 @@ where
MapEqualityFeedback { MapEqualityFeedback {
name: "MapEq".to_string(), name: "MapEq".to_string(),
obs_name: self.obs_name.clone(), obs_name: self.obs_name.clone(),
orig_hash: obs.hash(), orig_hash: obs.hash_simple(),
phantom: PhantomData, phantom: PhantomData,
} }
} }

View File

@ -82,6 +82,23 @@ impl<M, O> Named for MappedEdgeMapObserver<M, O> {
} }
} }
impl<M, O> Hash for MappedEdgeMapObserver<M, O>
where
M: MapObserver + for<'it> AsIter<'it, Item = M::Entry>,
O: ValueObserver,
{
fn hash<H: Hasher>(&self, hasher: &mut H) {
let initial = self.inner.initial();
for e in self.inner.as_iter() {
if *e == initial {
self.value_observer.default_value().hash(hasher);
} else {
self.value_observer.value().hash(hasher);
}
}
}
}
impl<M, O> MapObserver for MappedEdgeMapObserver<M, O> impl<M, O> MapObserver for MappedEdgeMapObserver<M, O>
where where
M: MapObserver + for<'it> AsIter<'it, Item = M::Entry>, M: MapObserver + for<'it> AsIter<'it, Item = M::Entry>,
@ -110,16 +127,9 @@ where
self.inner.count_bytes() self.inner.count_bytes()
} }
fn hash(&self) -> u64 { fn hash_simple(&self) -> u64 {
let mut hasher = AHasher::default(); let mut hasher = AHasher::default();
let initial = self.inner.initial(); self.hash(&mut hasher);
for e in self.inner.as_iter() {
if *e == initial {
self.value_observer.default_value().hash(&mut hasher);
} else {
self.value_observer.value().hash(&mut hasher);
}
}
hasher.finish() hasher.finish()
} }

View File

@ -65,7 +65,7 @@ mod observers {
}; };
use core::{ use core::{
fmt::Debug, fmt::Debug,
hash::{BuildHasher, Hasher}, hash::{Hash, Hasher},
iter::Flatten, iter::Flatten,
ptr::{addr_of, addr_of_mut}, ptr::{addr_of, addr_of_mut},
slice::{from_raw_parts, Iter, IterMut}, slice::{from_raw_parts, Iter, IterMut},
@ -159,6 +159,19 @@ mod observers {
} }
} }
impl<const DIFFERENTIAL: bool> Hash for CountersMultiMapObserver<DIFFERENTIAL> {
fn hash<H: Hasher>(&self, hasher: &mut H) {
for map in unsafe { &*addr_of!(COUNTERS_MAPS) } {
let slice = map.as_slice();
let ptr = slice.as_ptr();
let map_size = slice.len() / core::mem::size_of::<u8>();
unsafe {
hasher.write(from_raw_parts(ptr, map_size));
}
}
}
}
impl<const DIFFERENTIAL: bool> AsRef<Self> for CountersMultiMapObserver<DIFFERENTIAL> { impl<const DIFFERENTIAL: bool> AsRef<Self> for CountersMultiMapObserver<DIFFERENTIAL> {
fn as_ref(&self) -> &Self { fn as_ref(&self) -> &Self {
self self
@ -208,17 +221,9 @@ mod observers {
res res
} }
fn hash(&self) -> u64 { #[inline]
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher(); fn hash_simple(&self) -> u64 {
for map in unsafe { &*addr_of!(COUNTERS_MAPS) } { RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
let slice = map.as_slice();
let ptr = slice.as_ptr();
let map_size = slice.len() / core::mem::size_of::<u8>();
unsafe {
hasher.write(from_raw_parts(ptr, map_size));
}
}
hasher.finish()
} }
fn reset_map(&mut self) -> Result<(), Error> { fn reset_map(&mut self) -> Result<(), Error> {