Implement Hash
for MapObserver (#1989)
* MapObserver implements Hash * Rename the hash utility function (in MapObserver) to hash_easy * Use hash_slice as a helper function to impl Hash trait * define_python_map_observer macro implements Hash trait * Also rename hash_easy to hash_simple * Rename hash_slice to hash_helper * hash_helper is used to define the implementation of hash function/trait * Factor out the Hash trait and function for runtime library structs (#1977) * Simplify hash_simple (of trait MapObserver) (#1977) * Use hash_one function to make hash_simple a one-liner * remove hash_helper --------- Co-authored-by: Edwin Fernando <ef322@ic.ac.uk> Co-authored-by: Addison Crump <addison.crump@cispa.de>
This commit is contained in:
parent
04cd792df2
commit
c238b69498
@ -6,7 +6,7 @@ use alloc::{
|
||||
};
|
||||
use core::{
|
||||
fmt::Debug,
|
||||
hash::{BuildHasher, Hasher},
|
||||
hash::{Hash, Hasher},
|
||||
iter::Flatten,
|
||||
marker::PhantomData,
|
||||
mem::size_of,
|
||||
@ -69,17 +69,6 @@ fn init_count_class_16() {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the hash of a slice
|
||||
fn hash_slice<T>(slice: &[T]) -> u64 {
|
||||
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
|
||||
let ptr = slice.as_ptr() as *const u8;
|
||||
let map_size = slice.len() / size_of::<T>();
|
||||
unsafe {
|
||||
hasher.write(slice::from_raw_parts(ptr, map_size));
|
||||
}
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Trait marker which indicates that this [`MapObserver`] is tracked for indices or novelties.
|
||||
/// Implementors of feedbacks similar to [`crate::feedbacks::MapFeedback`] may wish to use this to
|
||||
/// ensure that edge metadata is recorded as is appropriate for the provided observer.
|
||||
@ -420,12 +409,12 @@ pub mod macros {
|
||||
///
|
||||
/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize
|
||||
pub trait MapObserver:
|
||||
HasLen + Named + Serialize + serde::de::DeserializeOwned + AsRef<Self> + AsMut<Self>
|
||||
HasLen + Named + Serialize + serde::de::DeserializeOwned + AsRef<Self> + AsMut<Self> + Hash
|
||||
// where
|
||||
// for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry>
|
||||
{
|
||||
/// Type of each entry in this map
|
||||
type Entry: Bounded + PartialEq + Default + Copy + Debug + 'static;
|
||||
type Entry: Bounded + PartialEq + Default + Copy + Debug + Hash + 'static;
|
||||
|
||||
/// Get the value at `idx`
|
||||
fn get(&self, idx: usize) -> &Self::Entry;
|
||||
@ -439,8 +428,8 @@ pub trait MapObserver:
|
||||
/// Count the set bytes in the map
|
||||
fn count_bytes(&self) -> u64;
|
||||
|
||||
/// Compute the hash of the map
|
||||
fn hash(&self) -> u64;
|
||||
/// Compute the hash of the map without needing to provide a hasher
|
||||
fn hash_simple(&self) -> u64;
|
||||
|
||||
/// Get the initial value for `reset()`
|
||||
fn initial(&self) -> Self::Entry;
|
||||
@ -553,6 +542,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -604,6 +594,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -624,6 +615,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -644,6 +636,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -665,6 +658,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -685,6 +679,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -701,6 +696,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const DIFFERENTIAL: bool> Hash for StdMapObserver<'a, T, DIFFERENTIAL>
|
||||
where
|
||||
T: Bounded
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug,
|
||||
{
|
||||
#[inline]
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.as_slice().hash(hasher);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for StdMapObserver<'a, T, DIFFERENTIAL>
|
||||
where
|
||||
T: Default + Copy + 'static + Serialize,
|
||||
@ -725,6 +738,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -761,8 +775,9 @@ where
|
||||
self.as_slice().len()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
hash_slice(self.as_slice())
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -1108,6 +1123,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1128,6 +1144,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1148,6 +1165,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1168,6 +1186,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1188,6 +1207,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1204,6 +1224,23 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const N: usize> Hash for ConstMapObserver<'a, T, N>
|
||||
where
|
||||
T: Bounded
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug,
|
||||
{
|
||||
#[inline]
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.as_slice().hash(hasher);
|
||||
}
|
||||
}
|
||||
impl<'a, T, const N: usize> AsRef<Self> for ConstMapObserver<'a, T, N>
|
||||
where
|
||||
T: Default + Copy + 'static + Serialize,
|
||||
@ -1228,6 +1265,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1268,8 +1306,9 @@ where
|
||||
self.as_slice().len()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
hash_slice(self.as_slice())
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
/// Reset the map
|
||||
@ -1429,6 +1468,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1451,6 +1491,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1473,6 +1514,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1495,6 +1537,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1517,6 +1560,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1535,6 +1579,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Hash for VariableMapObserver<'a, T>
|
||||
where
|
||||
T: Bounded
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug
|
||||
+ PartialEq
|
||||
+ Bounded,
|
||||
{
|
||||
#[inline]
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.as_slice().hash(hasher);
|
||||
}
|
||||
}
|
||||
impl<'a, T> AsRef<Self> for VariableMapObserver<'a, T>
|
||||
where
|
||||
T: Default + Copy + 'static + Serialize + PartialEq + Bounded,
|
||||
@ -1559,6 +1622,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1599,8 +1663,10 @@ where
|
||||
}
|
||||
res
|
||||
}
|
||||
fn hash(&self) -> u64 {
|
||||
hash_slice(self.as_slice())
|
||||
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
/// Reset the map
|
||||
@ -1640,6 +1706,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ 'static
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
@ -1660,6 +1727,7 @@ where
|
||||
T: 'static
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug
|
||||
@ -1719,7 +1787,7 @@ where
|
||||
///
|
||||
/// [`MapObserver`]s that are not slice-backed, such as [`MultiMapObserver`], can use
|
||||
/// [`HitcountsIterableMapObserver`] instead.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
|
||||
#[serde(bound = "M: serde::de::DeserializeOwned")]
|
||||
pub struct HitcountsMapObserver<M>
|
||||
where
|
||||
@ -1863,8 +1931,9 @@ where
|
||||
self.base.reset_map()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
self.base.hash()
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
self.base.hash_simple()
|
||||
}
|
||||
fn to_vec(&self) -> Vec<u8> {
|
||||
self.base.to_vec()
|
||||
@ -2019,7 +2088,7 @@ where
|
||||
/// Map observer with hitcounts postprocessing
|
||||
/// Less optimized version for non-slice iterators.
|
||||
/// Slice-backed observers should use a [`HitcountsMapObserver`].
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Hash)]
|
||||
#[serde(bound = "M: serde::de::DeserializeOwned")]
|
||||
pub struct HitcountsIterableMapObserver<M>
|
||||
where
|
||||
@ -2133,8 +2202,9 @@ where
|
||||
self.base.reset_map()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
self.base.hash()
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
self.base.hash_simple()
|
||||
}
|
||||
fn to_vec(&self) -> Vec<u8> {
|
||||
self.base.to_vec()
|
||||
@ -2341,6 +2411,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const DIFFERENTIAL: bool> Hash for MultiMapObserver<'a, T, DIFFERENTIAL>
|
||||
where
|
||||
T: 'static + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
for map in &self.maps {
|
||||
let slice = map.as_slice();
|
||||
let ptr = slice.as_ptr() as *const u8;
|
||||
let map_size = slice.len() / size_of::<T>();
|
||||
unsafe {
|
||||
hasher.write(slice::from_raw_parts(ptr, map_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const DIFFERENTIAL: bool> AsRef<Self> for MultiMapObserver<'a, T, DIFFERENTIAL>
|
||||
where
|
||||
T: 'static + Default + Copy + Serialize + Debug,
|
||||
@ -2366,6 +2452,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug,
|
||||
@ -2406,17 +2493,9 @@ where
|
||||
res
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
|
||||
for map in &self.maps {
|
||||
let slice = map.as_slice();
|
||||
let ptr = slice.as_ptr() as *const u8;
|
||||
let map_size = slice.len() / size_of::<T>();
|
||||
unsafe {
|
||||
hasher.write(slice::from_raw_parts(ptr, map_size));
|
||||
}
|
||||
}
|
||||
hasher.finish()
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
fn reset_map(&mut self) -> Result<(), Error> {
|
||||
@ -2714,6 +2793,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Hash for OwnedMapObserver<T>
|
||||
where
|
||||
T: 'static + Hash + Default + Copy + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
#[inline]
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.as_slice().hash(hasher);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsRef<Self> for OwnedMapObserver<T>
|
||||
where
|
||||
T: 'static + Default + Copy + Serialize,
|
||||
@ -2739,6 +2828,7 @@ where
|
||||
+ PartialEq
|
||||
+ Default
|
||||
+ Copy
|
||||
+ Hash
|
||||
+ Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Debug,
|
||||
@ -2774,8 +2864,9 @@ where
|
||||
self.as_slice().len()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
hash_slice(self.as_slice())
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -122,7 +122,7 @@ where
|
||||
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?
|
||||
.as_ref();
|
||||
|
||||
let mut hash = observer.hash() as usize;
|
||||
let mut hash = observer.hash_simple() as usize;
|
||||
|
||||
let psmeta = state.metadata_mut::<SchedulerMetadata>()?;
|
||||
|
||||
|
@ -325,7 +325,7 @@ where
|
||||
.ok_or_else(|| Error::key_not_found("MapObserver not found".to_string()))?
|
||||
.as_ref();
|
||||
|
||||
let hash = observer.hash() as usize;
|
||||
let hash = observer.hash_simple() as usize;
|
||||
|
||||
executor
|
||||
.observers_mut()
|
||||
|
@ -400,7 +400,7 @@ where
|
||||
let obs = observers
|
||||
.match_name::<M>(self.observer_name())
|
||||
.expect("Should have been provided valid observer name.");
|
||||
Ok(obs.hash() == self.orig_hash)
|
||||
Ok(obs.hash_simple() == self.orig_hash)
|
||||
}
|
||||
}
|
||||
|
||||
@ -443,7 +443,7 @@ where
|
||||
MapEqualityFeedback {
|
||||
name: "MapEq".to_string(),
|
||||
obs_name: self.obs_name.clone(),
|
||||
orig_hash: obs.hash(),
|
||||
orig_hash: obs.hash_simple(),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -82,6 +82,23 @@ impl<M, O> Named for MappedEdgeMapObserver<M, O> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, O> Hash for MappedEdgeMapObserver<M, O>
|
||||
where
|
||||
M: MapObserver + for<'it> AsIter<'it, Item = M::Entry>,
|
||||
O: ValueObserver,
|
||||
{
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
let initial = self.inner.initial();
|
||||
for e in self.inner.as_iter() {
|
||||
if *e == initial {
|
||||
self.value_observer.default_value().hash(hasher);
|
||||
} else {
|
||||
self.value_observer.value().hash(hasher);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, O> MapObserver for MappedEdgeMapObserver<M, O>
|
||||
where
|
||||
M: MapObserver + for<'it> AsIter<'it, Item = M::Entry>,
|
||||
@ -110,16 +127,9 @@ where
|
||||
self.inner.count_bytes()
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
fn hash_simple(&self) -> u64 {
|
||||
let mut hasher = AHasher::default();
|
||||
let initial = self.inner.initial();
|
||||
for e in self.inner.as_iter() {
|
||||
if *e == initial {
|
||||
self.value_observer.default_value().hash(&mut hasher);
|
||||
} else {
|
||||
self.value_observer.value().hash(&mut hasher);
|
||||
}
|
||||
}
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ mod observers {
|
||||
};
|
||||
use core::{
|
||||
fmt::Debug,
|
||||
hash::{BuildHasher, Hasher},
|
||||
hash::{Hash, Hasher},
|
||||
iter::Flatten,
|
||||
ptr::{addr_of, addr_of_mut},
|
||||
slice::{from_raw_parts, Iter, IterMut},
|
||||
@ -159,6 +159,19 @@ mod observers {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIFFERENTIAL: bool> Hash for CountersMultiMapObserver<DIFFERENTIAL> {
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
for map in unsafe { &*addr_of!(COUNTERS_MAPS) } {
|
||||
let slice = map.as_slice();
|
||||
let ptr = slice.as_ptr();
|
||||
let map_size = slice.len() / core::mem::size_of::<u8>();
|
||||
unsafe {
|
||||
hasher.write(from_raw_parts(ptr, map_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DIFFERENTIAL: bool> AsRef<Self> for CountersMultiMapObserver<DIFFERENTIAL> {
|
||||
fn as_ref(&self) -> &Self {
|
||||
self
|
||||
@ -208,17 +221,9 @@ mod observers {
|
||||
res
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
let mut hasher = RandomState::with_seeds(0, 0, 0, 0).build_hasher();
|
||||
for map in unsafe { &*addr_of!(COUNTERS_MAPS) } {
|
||||
let slice = map.as_slice();
|
||||
let ptr = slice.as_ptr();
|
||||
let map_size = slice.len() / core::mem::size_of::<u8>();
|
||||
unsafe {
|
||||
hasher.write(from_raw_parts(ptr, map_size));
|
||||
}
|
||||
}
|
||||
hasher.finish()
|
||||
#[inline]
|
||||
fn hash_simple(&self) -> u64 {
|
||||
RandomState::with_seeds(0, 0, 0, 0).hash_one(self)
|
||||
}
|
||||
|
||||
fn reset_map(&mut self) -> Result<(), Error> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user