diff --git a/bindings/pylibafl/src/lib.rs b/bindings/pylibafl/src/lib.rs index c004c4c6ae..7bc8d36a00 100644 --- a/bindings/pylibafl/src/lib.rs +++ b/bindings/pylibafl/src/lib.rs @@ -1,6 +1,6 @@ +use libafl; use libafl_qemu; use libafl_sugar; -use libafl; use pyo3::prelude::*; #[pymodule] diff --git a/libafl/src/bolts/mod.rs b/libafl/src/bolts/mod.rs index 5570c0a7f9..ed2ead3cf8 100644 --- a/libafl/src/bolts/mod.rs +++ b/libafl/src/bolts/mod.rs @@ -26,7 +26,7 @@ pub mod staterestore; pub mod tuples; use alloc::string::String; -use core::time; +use core::{iter::Iterator, time}; #[cfg(feature = "std")] use std::time::{SystemTime, UNIX_EPOCH}; @@ -42,6 +42,28 @@ pub trait AsMutSlice { fn as_mut_slice(&mut self) -> &mut [T]; } +/// Create an `Iterator` from a reference +pub trait AsRefIterator<'it> { + /// The item type + type Item: 'it; + /// The iterator type + type IntoIter: Iterator; + + /// Create an interator from &self + fn as_ref_iter(&'it self) -> Self::IntoIter; +} + +/// Create an `Iterator` from a mutable reference +pub trait AsMutIterator<'it> { + /// The item type + type Item: 'it; + /// The iterator type + type IntoIter: Iterator; + + /// Create an interator from &mut self + fn as_mut_iter(&'it mut self) -> Self::IntoIter; +} + /// Has a length field pub trait HasLen { /// The length diff --git a/libafl/src/feedbacks/map.rs b/libafl/src/feedbacks/map.rs index a2280c174e..ea6450a69a 100644 --- a/libafl/src/feedbacks/map.rs +++ b/libafl/src/feedbacks/map.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use crate::{ bolts::{ tuples::{MatchName, Named}, - AsMutSlice, AsSlice, HasRefCnt, + AsMutSlice, AsRefIterator, AsSlice, HasRefCnt, }, corpus::Testcase, events::{Event, EventFirer}, @@ -41,21 +41,21 @@ pub type MaxMapOneOrFilledFeedback = MapFeedback; /// A `Reducer` function is used to aggregate values for the novelty search -pub trait Reducer: Serialize + serde::de::DeserializeOwned + 'static + Debug +pub trait Reducer: 'static + Debug where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { /// Reduce two values to one value, with the current [`Reducer`]. fn reduce(first: T, second: T) -> T; } /// A [`OrReducer`] reduces the values returning the bitwise OR with the old value -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct OrReducer {} impl Reducer for OrReducer where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd, + T: PrimInt + Default + Copy + 'static + PartialOrd, { #[inline] fn reduce(history: T, new: T) -> T { @@ -64,12 +64,12 @@ where } /// A [`AndReducer`] reduces the values returning the bitwise AND with the old value -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct AndReducer {} impl Reducer for AndReducer where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd, + T: PrimInt + Default + Copy + 'static + PartialOrd, { #[inline] fn reduce(history: T, new: T) -> T { @@ -78,12 +78,12 @@ where } /// A [`MaxReducer`] reduces int values and returns their maximum. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct MaxReducer {} impl Reducer for MaxReducer where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd, + T: PrimInt + Default + Copy + 'static + PartialOrd, { #[inline] fn reduce(first: T, second: T) -> T { @@ -96,12 +96,12 @@ where } /// A [`MinReducer`] reduces int values and returns their minimum. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct MinReducer {} impl Reducer for MinReducer where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd, + T: PrimInt + Default + Copy + 'static + PartialOrd, { #[inline] fn reduce(first: T, second: T) -> T { @@ -114,9 +114,9 @@ where } /// A `IsNovel` function is used to discriminate if a reduced value is considered novel. -pub trait IsNovel: Serialize + serde::de::DeserializeOwned + 'static + Debug +pub trait IsNovel: 'static + Debug where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { /// If a new value in the [`MapFeedback`] was found, /// this filter can decide if the result is considered novel or not. @@ -124,12 +124,12 @@ where } /// [`AllIsNovel`] consider everything a novelty. Here mostly just for debugging. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct AllIsNovel {} impl IsNovel for AllIsNovel where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { #[inline] fn is_novel(_old: T, _new: T) -> bool { @@ -152,11 +152,11 @@ fn saturating_next_power_of_two(n: T) -> T { } /// Consider as novelty if the reduced value is different from the old value. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct DifferentIsNovel {} impl IsNovel for DifferentIsNovel where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { #[inline] fn is_novel(old: T, new: T) -> bool { @@ -165,11 +165,11 @@ where } /// Only consider as novel the values which are at least the next pow2 class of the old value -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct NextPow2IsNovel {} impl IsNovel for NextPow2IsNovel where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { #[inline] fn is_novel(old: T, new: T) -> bool { @@ -185,11 +185,11 @@ where } /// A filter that only saves values which are at least the next pow2 class -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct OneOrFilledIsNovel {} impl IsNovel for OneOrFilledIsNovel where - T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned, + T: PrimInt + Default + Copy + 'static, { #[inline] fn is_novel(old: T, new: T) -> bool { @@ -342,13 +342,13 @@ where } /// The most common AFL-like feedback type -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(bound = "T: serde::de::DeserializeOwned")] +#[derive(Clone, Debug)] pub struct MapFeedback where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, R: Reducer, - O: MapObserver, + O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = T>, N: IsNovel, S: HasFeedbackStates, { @@ -369,6 +369,7 @@ where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, R: Reducer, O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = T>, N: IsNovel, I: Input, S: HasFeedbackStates + HasClientPerfMonitor + Debug, @@ -402,10 +403,8 @@ where assert!(size <= observer.len()); if self.novelties.is_some() { - for i in 0..size { + for (i, &item) in observer.as_ref_iter().enumerate() { let history = map_state.history_map[i]; - let item = *observer.get(i); - let reduced = R::reduce(history, item); if N::is_novel(history, reduced) { map_state.history_map[i] = reduced; @@ -414,10 +413,8 @@ where } } } else { - for i in 0..size { + for (i, &item) in observer.as_ref_iter().enumerate() { let history = map_state.history_map[i]; - let item = *observer.get(i); - let reduced = R::reduce(history, item); if N::is_novel(history, reduced) { map_state.history_map[i] = reduced; @@ -478,7 +475,8 @@ where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, R: Reducer, N: IsNovel, - O: MapObserver, + O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = T>, S: HasFeedbackStates, { #[inline] @@ -499,7 +497,8 @@ where + Debug, R: Reducer, N: IsNovel, - O: MapObserver, + O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = T>, S: HasFeedbackStates, { /// Create new `MapFeedback` @@ -562,7 +561,7 @@ where } /// A [`ReachabilityFeedback`] reports if a target has been reached. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct ReachabilityFeedback { name: String, target_idx: Vec, @@ -572,6 +571,7 @@ pub struct ReachabilityFeedback { impl ReachabilityFeedback where O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = usize>, { /// Creates a new [`ReachabilityFeedback`] for a [`MapObserver`]. #[must_use] @@ -598,6 +598,7 @@ impl Feedback for ReachabilityFeedback where I: Input, O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = usize>, S: HasClientPerfMonitor, { #[allow(clippy::wrong_self_convention)] @@ -615,11 +616,10 @@ where { // TODO Replace with match_name_type when stable let observer = observers.match_name::(&self.name).unwrap(); - let size = observer.usable_count(); let mut hit_target: bool = false; //check if we've hit any targets. - for i in 0..size { - if *observer.get(i) > 0 { + for (i, &elem) in observer.as_ref_iter().enumerate() { + if elem > 0 { self.target_idx.push(i); hit_target = true; } @@ -648,12 +648,14 @@ where impl Named for ReachabilityFeedback where O: MapObserver, + for<'it> O: AsRefIterator<'it, Item = usize>, { #[inline] fn name(&self) -> &str { self.name.as_str() } } + #[cfg(test)] mod tests { use crate::feedbacks::{AllIsNovel, IsNovel, NextPow2IsNovel}; @@ -711,7 +713,7 @@ pub mod pybind { } #[pyclass(unsendable, name = $max_map_feedback_py_name)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// Python class for MaxMapFeedback pub struct $max_map_feedback_struct_name { /// Rust wrapped MaxMapFeedback object @@ -719,6 +721,14 @@ pub mod pybind { MaxMapFeedback, } + impl Clone for $max_map_feedback_struct_name { + fn clone(&self) -> Self { + Self { + max_map_feedback: self.max_map_feedback.clone(), + } + } + } + #[pymethods] impl $max_map_feedback_struct_name { #[new] diff --git a/libafl/src/observers/map.rs b/libafl/src/observers/map.rs index 04afbfb91d..94a22de68a 100644 --- a/libafl/src/observers/map.rs +++ b/libafl/src/observers/map.rs @@ -9,6 +9,7 @@ use core::{ fmt::Debug, hash::Hasher, iter::Flatten, + marker::PhantomData, slice::{from_raw_parts, Iter, IterMut}, }; use intervaltree::IntervalTree; @@ -19,7 +20,7 @@ use crate::{ bolts::{ ownedref::{OwnedRefMut, OwnedSliceMut}, tuples::Named, - AsMutSlice, AsSlice, HasLen, + AsMutIterator, AsMutSlice, AsRefIterator, AsSlice, HasLen, }, executors::ExitKind, observers::Observer, @@ -38,9 +39,14 @@ fn hash_slice(slice: &[T]) -> u64 { } /// A [`MapObserver`] observes the static map, as oftentimes used for AFL-like coverage information -pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug { +/// +/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize +pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug +// where +// for<'it> &'it Self: IntoIterator +{ /// Type of each entry in this map - type Entry: PrimInt + Default + Copy + Debug; + type Entry: PrimInt + Default + Copy + Debug + 'static; /// Get the value at `idx` fn get(&self, idx: usize) -> &Self::Entry; @@ -74,7 +80,9 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned fn initial_mut(&mut self) -> &mut Self::Entry; /// Set the initial value for reset() - fn set_initial(&mut self, initial: Self::Entry); + fn set_initial(&mut self, initial: Self::Entry) { + *self.initial_mut() = initial; + } /// Reset the map #[inline] @@ -112,6 +120,64 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned } } +/// A Simple iterator calling `MapObserver::get` +#[derive(Debug)] +pub struct MapObserverSimpleIterator<'a, O> +where + O: 'a + MapObserver, +{ + index: usize, + observer: *const O, + phantom: PhantomData<&'a u8>, +} + +impl<'a, O> Iterator for MapObserverSimpleIterator<'a, O> +where + O: 'a + MapObserver, +{ + type Item = &'a O::Entry; + fn next(&mut self) -> Option { + unsafe { + if self.index >= self.observer.as_ref().unwrap().usable_count() { + None + } else { + let i = self.index; + self.index += 1; + Some(self.observer.as_ref().unwrap().get(i)) + } + } + } +} + +/// A Simple iterator calling `MapObserver::get_mut` +#[derive(Debug)] +pub struct MapObserverSimpleIteratoMut<'a, O> +where + O: 'a + MapObserver, +{ + index: usize, + observer: *mut O, + phantom: PhantomData<&'a u8>, +} + +impl<'a, O> Iterator for MapObserverSimpleIteratoMut<'a, O> +where + O: 'a + MapObserver, +{ + type Item = &'a O::Entry; + fn next(&mut self) -> Option { + unsafe { + if self.index >= self.observer.as_ref().unwrap().usable_count() { + None + } else { + let i = self.index; + self.index += 1; + Some(self.observer.as_mut().unwrap().get_mut(i)) + } + } + } +} + /// The Map Observer retrieves the state of a map, /// that will get updated by the target. /// A well-known example is the AFL-Style coverage map. @@ -130,7 +196,6 @@ where impl<'a, I, S, T> Observer for StdMapObserver<'a, T> where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, - Self: MapObserver, { #[inline] fn pre_exec(&mut self, _state: &mut S, _input: &I) -> Result<(), Error> { @@ -158,6 +223,32 @@ where } } +impl<'a, 'it, T> AsRefIterator<'it> for StdMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = Iter<'it, T>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() + } +} + +impl<'a, 'it, T> AsMutIterator<'it> for StdMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = IterMut<'it, T>; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() + } +} + impl<'a, 'it, T> IntoIterator for &'it StdMapObserver<'a, T> where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, @@ -166,7 +257,8 @@ where type IntoIter = Iter<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_slice().iter() + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() } } @@ -178,7 +270,8 @@ where type IntoIter = IterMut<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_mut_slice().iter_mut() + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() } } @@ -217,11 +310,6 @@ where &mut self.initial } - #[inline] - fn set_initial(&mut self, initial: T) { - self.initial = initial; - } - fn to_vec(&self) -> Vec { self.as_slice().to_vec() } @@ -343,6 +431,32 @@ where } } +impl<'a, 'it, T, const N: usize> AsRefIterator<'it> for ConstMapObserver<'a, T, N> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = Iter<'it, T>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() + } +} + +impl<'a, 'it, T, const N: usize> AsMutIterator<'it> for ConstMapObserver<'a, T, N> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = IterMut<'it, T>; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() + } +} + impl<'a, 'it, T, const N: usize> IntoIterator for &'it ConstMapObserver<'a, T, N> where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, @@ -351,7 +465,8 @@ where type IntoIter = Iter<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_slice().iter() + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() } } @@ -363,7 +478,8 @@ where type IntoIter = IterMut<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_mut_slice().iter_mut() + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() } } @@ -383,11 +499,6 @@ where &mut self.initial } - #[inline] - fn set_initial(&mut self, initial: T) { - self.initial = initial; - } - #[inline] fn get(&self, idx: usize) -> &T { &self.as_slice()[idx] @@ -515,6 +626,32 @@ where } } +impl<'a, 'it, T> AsRefIterator<'it> for VariableMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = Iter<'it, T>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() + } +} + +impl<'a, 'it, T> AsMutIterator<'it> for VariableMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = IterMut<'it, T>; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() + } +} + impl<'a, 'it, T> IntoIterator for &'it VariableMapObserver<'a, T> where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, @@ -523,7 +660,8 @@ where type IntoIter = Iter<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_slice().iter() + let cnt = self.usable_count(); + self.as_slice()[..cnt].iter() } } @@ -535,7 +673,8 @@ where type IntoIter = IterMut<'it, T>; fn into_iter(self) -> Self::IntoIter { - self.as_mut_slice().iter_mut() + let cnt = self.usable_count(); + self.as_mut_slice()[..cnt].iter_mut() } } @@ -555,22 +694,17 @@ where &mut self.initial } - #[inline] - fn set_initial(&mut self, initial: T) { - self.initial = initial; - } - #[inline] fn usable_count(&self) -> usize { *self.size.as_ref() } fn get(&self, idx: usize) -> &T { - &self.as_slice()[idx] + &self.map.as_slice()[idx] } fn get_mut(&mut self, idx: usize) -> &mut T { - &mut self.as_mut_slice()[idx] + &mut self.map.as_mut_slice()[idx] } fn hash(&self) -> u64 { @@ -587,7 +721,8 @@ where { #[inline] fn as_slice(&self) -> &[T] { - self.map.as_slice() + let cnt = self.usable_count(); + &self.map.as_slice()[..cnt] } } impl<'a, T> AsMutSlice for VariableMapObserver<'a, T> @@ -661,6 +796,7 @@ static COUNT_CLASS_LOOKUP: [u8; 256] = [ impl Observer for HitcountsMapObserver where M: MapObserver + Observer, + for<'it> M: AsMutIterator<'it, Item = u8>, { #[inline] fn pre_exec(&mut self, state: &mut S, input: &I) -> Result<(), Error> { @@ -669,9 +805,8 @@ where #[inline] fn post_exec(&mut self, state: &mut S, input: &I, exit_kind: &ExitKind) -> Result<(), Error> { - let cnt = self.usable_count(); - for i in 0..cnt { - *self.get_mut(i) = COUNT_CLASS_LOOKUP[*self.get(i) as usize]; + for elem in self.as_mut_iter() { + *elem = COUNT_CLASS_LOOKUP[*elem as usize]; } self.base.post_exec(state, input, exit_kind) } @@ -700,6 +835,7 @@ where impl MapObserver for HitcountsMapObserver where M: MapObserver, + for<'it> M: AsMutIterator<'it, Item = u8>, { type Entry = u8; @@ -713,11 +849,6 @@ where self.base.initial_mut() } - #[inline] - fn set_initial(&mut self, initial: u8) { - self.base.set_initial(initial); - } - #[inline] fn usable_count(&self) -> usize { self.base.usable_count() @@ -769,6 +900,57 @@ where Self { base } } } + +impl<'it, M> AsRefIterator<'it> for HitcountsMapObserver +where + M: Named + Serialize + serde::de::DeserializeOwned + AsRefIterator<'it, Item = u8>, +{ + type Item = u8; + type IntoIter = >::IntoIter; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + self.base.as_ref_iter() + } +} + +impl<'it, M> AsMutIterator<'it> for HitcountsMapObserver +where + M: Named + Serialize + serde::de::DeserializeOwned + AsMutIterator<'it, Item = u8>, +{ + type Item = u8; + type IntoIter = >::IntoIter; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + self.base.as_mut_iter() + } +} + +impl<'it, M> IntoIterator for &'it HitcountsMapObserver +where + M: Named + Serialize + serde::de::DeserializeOwned, + &'it M: IntoIterator, +{ + type Item = &'it u8; + type IntoIter = <&'it M as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.base.into_iter() + } +} + +impl<'it, M> IntoIterator for &'it mut HitcountsMapObserver +where + M: Named + Serialize + serde::de::DeserializeOwned, + &'it mut M: IntoIterator, +{ + type Item = &'it mut u8; + type IntoIter = <&'it mut M as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.base.into_iter() + } +} + /// The Multi Map Observer merge different maps into one observer #[derive(Serialize, Deserialize, Debug)] #[serde(bound = "T: serde::de::DeserializeOwned")] @@ -848,11 +1030,6 @@ where &mut self.initial } - #[inline] - fn set_initial(&mut self, initial: T) { - self.initial = initial; - } - fn count_bytes(&self) -> u64 { let initial = self.initial(); let mut res = 0; @@ -953,14 +1130,28 @@ where } } -impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T> +impl<'a, 'it, T> AsRefIterator<'it> for MultiMapObserver<'a, T> where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, + 'a: 'it, { - type Item = as Iterator>::Item; + type Item = T; + type IntoIter = Flatten>>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + self.maps.iter().flatten() + } +} + +impl<'a, 'it, T> AsMutIterator<'it> for MultiMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, + 'a: 'it, +{ + type Item = T; type IntoIter = Flatten>>; - fn into_iter(self) -> Self::IntoIter { + fn as_mut_iter(&'it mut self) -> Self::IntoIter { self.maps.iter_mut().flatten() } } @@ -977,6 +1168,18 @@ where } } +impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T> +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = as Iterator>::Item; + type IntoIter = Flatten>>; + + fn into_iter(self) -> Self::IntoIter { + self.maps.iter_mut().flatten() + } +} + /// Exact copy of `StdMapObserver` that owns its map /// Used for python bindings #[derive(Serialize, Deserialize, Debug, Clone)] @@ -1022,6 +1225,30 @@ where } } +impl<'a, 'it, T> AsRefIterator<'it> for OwnedMapObserver +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = Iter<'it, T>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + self.as_slice().iter() + } +} + +impl<'a, 'it, T> AsMutIterator<'it> for OwnedMapObserver +where + T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, +{ + type Item = T; + type IntoIter = IterMut<'it, T>; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + self.as_mut_slice().iter_mut() + } +} + impl<'it, T> IntoIterator for &'it OwnedMapObserver where T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug, @@ -1121,7 +1348,7 @@ where pub fn new(name: &'static str, map: Vec) -> Self { let initial = if map.is_empty() { T::default() } else { map[0] }; Self { - map: map, + map, name: name.to_string(), initial, } @@ -1130,11 +1357,12 @@ where /// `MapObserver` Python bindings #[cfg(feature = "python")] pub mod pybind { - use crate::bolts::{tuples::Named, HasLen}; + use crate::bolts::{tuples::Named, AsMutIterator, AsRefIterator, HasLen}; use crate::observers::{map::OwnedMapObserver, MapObserver, Observer}; use crate::Error; use pyo3::prelude::*; use serde::{Deserialize, Serialize}; + use std::slice::{Iter, IterMut}; macro_rules! define_python_map_observer { ($struct_name:ident, $py_name:tt, $struct_name_trait:ident, $py_name_trait:tt, $datatype:ty, $wrapper_name: ident) => { @@ -1183,6 +1411,32 @@ pub mod pybind { } } + impl<'it> AsRefIterator<'it> for $struct_name_trait { + type Item = $datatype; + type IntoIter = Iter<'it, $datatype>; + + fn as_ref_iter(&'it self) -> Self::IntoIter { + match &self.map_observer { + $wrapper_name::Owned(map_observer) => { + map_observer.owned_map_observer.as_ref_iter() + } + } + } + } + + impl<'it> AsMutIterator<'it> for $struct_name_trait { + type Item = $datatype; + type IntoIter = IterMut<'it, $datatype>; + + fn as_mut_iter(&'it mut self) -> Self::IntoIter { + match &mut self.map_observer { + $wrapper_name::Owned(map_observer) => { + map_observer.owned_map_observer.as_mut_iter() + } + } + } + } + impl MapObserver for $struct_name_trait { type Entry = $datatype;