Walk the map observer using as_ref_iter() in the map feedback (#535)

* Walk the map observer using into_iter() in the map feedback

* fmt

* map observers as iterators

* perf

* IntoMutIterator and IntoRefIterator

* Clone

* clippy
This commit is contained in:
Andrea Fioraldi 2022-02-14 18:12:19 +01:00 committed by GitHub
parent 2dcdaaa89f
commit 479f9471ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 372 additions and 86 deletions

View File

@ -1,6 +1,6 @@
use libafl;
use libafl_qemu;
use libafl_sugar;
use libafl;
use pyo3::prelude::*;
#[pymodule]

View File

@ -26,7 +26,7 @@ pub mod staterestore;
pub mod tuples;
use alloc::string::String;
use core::time;
use core::{iter::Iterator, time};
#[cfg(feature = "std")]
use std::time::{SystemTime, UNIX_EPOCH};
@ -42,6 +42,28 @@ pub trait AsMutSlice<T> {
fn as_mut_slice(&mut self) -> &mut [T];
}
/// Create an `Iterator` from a reference
pub trait AsRefIterator<'it> {
/// The item type
type Item: 'it;
/// The iterator type
type IntoIter: Iterator<Item = &'it Self::Item>;
/// Create an interator from &self
fn as_ref_iter(&'it self) -> Self::IntoIter;
}
/// Create an `Iterator` from a mutable reference
pub trait AsMutIterator<'it> {
/// The item type
type Item: 'it;
/// The iterator type
type IntoIter: Iterator<Item = &'it mut Self::Item>;
/// Create an interator from &mut self
fn as_mut_iter(&'it mut self) -> Self::IntoIter;
}
/// Has a length field
pub trait HasLen {
/// The length

View File

@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use crate::{
bolts::{
tuples::{MatchName, Named},
AsMutSlice, AsSlice, HasRefCnt,
AsMutSlice, AsRefIterator, AsSlice, HasRefCnt,
},
corpus::Testcase,
events::{Event, EventFirer},
@ -41,21 +41,21 @@ pub type MaxMapOneOrFilledFeedback<I, O, S, T> =
MapFeedback<I, OneOrFilledIsNovel, O, MaxReducer, S, T>;
/// A `Reducer` function is used to aggregate values for the novelty search
pub trait Reducer<T>: Serialize + serde::de::DeserializeOwned + 'static + Debug
pub trait Reducer<T>: 'static + Debug
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
/// Reduce two values to one value, with the current [`Reducer`].
fn reduce(first: T, second: T) -> T;
}
/// A [`OrReducer`] reduces the values returning the bitwise OR with the old value
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct OrReducer {}
impl<T> Reducer<T> for OrReducer
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
T: PrimInt + Default + Copy + 'static + PartialOrd,
{
#[inline]
fn reduce(history: T, new: T) -> T {
@ -64,12 +64,12 @@ where
}
/// A [`AndReducer`] reduces the values returning the bitwise AND with the old value
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct AndReducer {}
impl<T> Reducer<T> for AndReducer
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
T: PrimInt + Default + Copy + 'static + PartialOrd,
{
#[inline]
fn reduce(history: T, new: T) -> T {
@ -78,12 +78,12 @@ where
}
/// A [`MaxReducer`] reduces int values and returns their maximum.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct MaxReducer {}
impl<T> Reducer<T> for MaxReducer
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
T: PrimInt + Default + Copy + 'static + PartialOrd,
{
#[inline]
fn reduce(first: T, second: T) -> T {
@ -96,12 +96,12 @@ where
}
/// A [`MinReducer`] reduces int values and returns their minimum.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct MinReducer {}
impl<T> Reducer<T> for MinReducer
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
T: PrimInt + Default + Copy + 'static + PartialOrd,
{
#[inline]
fn reduce(first: T, second: T) -> T {
@ -114,9 +114,9 @@ where
}
/// A `IsNovel` function is used to discriminate if a reduced value is considered novel.
pub trait IsNovel<T>: Serialize + serde::de::DeserializeOwned + 'static + Debug
pub trait IsNovel<T>: 'static + Debug
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
/// If a new value in the [`MapFeedback`] was found,
/// this filter can decide if the result is considered novel or not.
@ -124,12 +124,12 @@ where
}
/// [`AllIsNovel`] consider everything a novelty. Here mostly just for debugging.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct AllIsNovel {}
impl<T> IsNovel<T> for AllIsNovel
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
#[inline]
fn is_novel(_old: T, _new: T) -> bool {
@ -152,11 +152,11 @@ fn saturating_next_power_of_two<T: PrimInt>(n: T) -> T {
}
/// Consider as novelty if the reduced value is different from the old value.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct DifferentIsNovel {}
impl<T> IsNovel<T> for DifferentIsNovel
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
#[inline]
fn is_novel(old: T, new: T) -> bool {
@ -165,11 +165,11 @@ where
}
/// Only consider as novel the values which are at least the next pow2 class of the old value
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct NextPow2IsNovel {}
impl<T> IsNovel<T> for NextPow2IsNovel
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
#[inline]
fn is_novel(old: T, new: T) -> bool {
@ -185,11 +185,11 @@ where
}
/// A filter that only saves values which are at least the next pow2 class
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct OneOrFilledIsNovel {}
impl<T> IsNovel<T> for OneOrFilledIsNovel
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
T: PrimInt + Default + Copy + 'static,
{
#[inline]
fn is_novel(old: T, new: T) -> bool {
@ -342,13 +342,13 @@ where
}
/// The most common AFL-like feedback type
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(bound = "T: serde::de::DeserializeOwned")]
#[derive(Clone, Debug)]
pub struct MapFeedback<I, N, O, R, S, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
R: Reducer<T>,
O: MapObserver,
O: MapObserver<Entry = T>,
for<'it> O: AsRefIterator<'it, Item = T>,
N: IsNovel<T>,
S: HasFeedbackStates,
{
@ -369,6 +369,7 @@ where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
R: Reducer<T>,
O: MapObserver<Entry = T>,
for<'it> O: AsRefIterator<'it, Item = T>,
N: IsNovel<T>,
I: Input,
S: HasFeedbackStates + HasClientPerfMonitor + Debug,
@ -402,10 +403,8 @@ where
assert!(size <= observer.len());
if self.novelties.is_some() {
for i in 0..size {
for (i, &item) in observer.as_ref_iter().enumerate() {
let history = map_state.history_map[i];
let item = *observer.get(i);
let reduced = R::reduce(history, item);
if N::is_novel(history, reduced) {
map_state.history_map[i] = reduced;
@ -414,10 +413,8 @@ where
}
}
} else {
for i in 0..size {
for (i, &item) in observer.as_ref_iter().enumerate() {
let history = map_state.history_map[i];
let item = *observer.get(i);
let reduced = R::reduce(history, item);
if N::is_novel(history, reduced) {
map_state.history_map[i] = reduced;
@ -478,7 +475,8 @@ where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
R: Reducer<T>,
N: IsNovel<T>,
O: MapObserver,
O: MapObserver<Entry = T>,
for<'it> O: AsRefIterator<'it, Item = T>,
S: HasFeedbackStates,
{
#[inline]
@ -499,7 +497,8 @@ where
+ Debug,
R: Reducer<T>,
N: IsNovel<T>,
O: MapObserver,
O: MapObserver<Entry = T>,
for<'it> O: AsRefIterator<'it, Item = T>,
S: HasFeedbackStates,
{
/// Create new `MapFeedback`
@ -562,7 +561,7 @@ where
}
/// A [`ReachabilityFeedback`] reports if a target has been reached.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct ReachabilityFeedback<O> {
name: String,
target_idx: Vec<usize>,
@ -572,6 +571,7 @@ pub struct ReachabilityFeedback<O> {
impl<O> ReachabilityFeedback<O>
where
O: MapObserver<Entry = usize>,
for<'it> O: AsRefIterator<'it, Item = usize>,
{
/// Creates a new [`ReachabilityFeedback`] for a [`MapObserver`].
#[must_use]
@ -598,6 +598,7 @@ impl<I, O, S> Feedback<I, S> for ReachabilityFeedback<O>
where
I: Input,
O: MapObserver<Entry = usize>,
for<'it> O: AsRefIterator<'it, Item = usize>,
S: HasClientPerfMonitor,
{
#[allow(clippy::wrong_self_convention)]
@ -615,11 +616,10 @@ where
{
// TODO Replace with match_name_type when stable
let observer = observers.match_name::<O>(&self.name).unwrap();
let size = observer.usable_count();
let mut hit_target: bool = false;
//check if we've hit any targets.
for i in 0..size {
if *observer.get(i) > 0 {
for (i, &elem) in observer.as_ref_iter().enumerate() {
if elem > 0 {
self.target_idx.push(i);
hit_target = true;
}
@ -648,12 +648,14 @@ where
impl<O> Named for ReachabilityFeedback<O>
where
O: MapObserver<Entry = usize>,
for<'it> O: AsRefIterator<'it, Item = usize>,
{
#[inline]
fn name(&self) -> &str {
self.name.as_str()
}
}
#[cfg(test)]
mod tests {
use crate::feedbacks::{AllIsNovel, IsNovel, NextPow2IsNovel};
@ -711,7 +713,7 @@ pub mod pybind {
}
#[pyclass(unsendable, name = $max_map_feedback_py_name)]
#[derive(Clone, Debug)]
#[derive(Debug)]
/// Python class for MaxMapFeedback
pub struct $max_map_feedback_struct_name {
/// Rust wrapped MaxMapFeedback object
@ -719,6 +721,14 @@ pub mod pybind {
MaxMapFeedback<BytesInput, $map_observer_name, $std_state_name, $datatype>,
}
impl Clone for $max_map_feedback_struct_name {
fn clone(&self) -> Self {
Self {
max_map_feedback: self.max_map_feedback.clone(),
}
}
}
#[pymethods]
impl $max_map_feedback_struct_name {
#[new]

View File

@ -9,6 +9,7 @@ use core::{
fmt::Debug,
hash::Hasher,
iter::Flatten,
marker::PhantomData,
slice::{from_raw_parts, Iter, IterMut},
};
use intervaltree::IntervalTree;
@ -19,7 +20,7 @@ use crate::{
bolts::{
ownedref::{OwnedRefMut, OwnedSliceMut},
tuples::Named,
AsMutSlice, AsSlice, HasLen,
AsMutIterator, AsMutSlice, AsRefIterator, AsSlice, HasLen,
},
executors::ExitKind,
observers::Observer,
@ -38,9 +39,14 @@ fn hash_slice<T: PrimInt>(slice: &[T]) -> u64 {
}
/// A [`MapObserver`] observes the static map, as oftentimes used for AFL-like coverage information
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug {
///
/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug
// where
// for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry>
{
/// Type of each entry in this map
type Entry: PrimInt + Default + Copy + Debug;
type Entry: PrimInt + Default + Copy + Debug + 'static;
/// Get the value at `idx`
fn get(&self, idx: usize) -> &Self::Entry;
@ -74,7 +80,9 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
fn initial_mut(&mut self) -> &mut Self::Entry;
/// Set the initial value for reset()
fn set_initial(&mut self, initial: Self::Entry);
fn set_initial(&mut self, initial: Self::Entry) {
*self.initial_mut() = initial;
}
/// Reset the map
#[inline]
@ -112,6 +120,64 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
}
}
/// A Simple iterator calling `MapObserver::get`
#[derive(Debug)]
pub struct MapObserverSimpleIterator<'a, O>
where
O: 'a + MapObserver,
{
index: usize,
observer: *const O,
phantom: PhantomData<&'a u8>,
}
impl<'a, O> Iterator for MapObserverSimpleIterator<'a, O>
where
O: 'a + MapObserver,
{
type Item = &'a O::Entry;
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.index >= self.observer.as_ref().unwrap().usable_count() {
None
} else {
let i = self.index;
self.index += 1;
Some(self.observer.as_ref().unwrap().get(i))
}
}
}
}
/// A Simple iterator calling `MapObserver::get_mut`
#[derive(Debug)]
pub struct MapObserverSimpleIteratoMut<'a, O>
where
O: 'a + MapObserver,
{
index: usize,
observer: *mut O,
phantom: PhantomData<&'a u8>,
}
impl<'a, O> Iterator for MapObserverSimpleIteratoMut<'a, O>
where
O: 'a + MapObserver,
{
type Item = &'a O::Entry;
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.index >= self.observer.as_ref().unwrap().usable_count() {
None
} else {
let i = self.index;
self.index += 1;
Some(self.observer.as_mut().unwrap().get_mut(i))
}
}
}
}
/// The Map Observer retrieves the state of a map,
/// that will get updated by the target.
/// A well-known example is the AFL-Style coverage map.
@ -130,7 +196,6 @@ where
impl<'a, I, S, T> Observer<I, S> for StdMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
Self: MapObserver,
{
#[inline]
fn pre_exec(&mut self, _state: &mut S, _input: &I) -> Result<(), Error> {
@ -158,6 +223,32 @@ where
}
}
impl<'a, 'it, T> AsRefIterator<'it> for StdMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = Iter<'it, T>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
impl<'a, 'it, T> AsMutIterator<'it> for StdMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = IterMut<'it, T>;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
impl<'a, 'it, T> IntoIterator for &'it StdMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
@ -166,7 +257,8 @@ where
type IntoIter = Iter<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_slice().iter()
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
@ -178,7 +270,8 @@ where
type IntoIter = IterMut<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_mut_slice().iter_mut()
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
@ -217,11 +310,6 @@ where
&mut self.initial
}
#[inline]
fn set_initial(&mut self, initial: T) {
self.initial = initial;
}
fn to_vec(&self) -> Vec<T> {
self.as_slice().to_vec()
}
@ -343,6 +431,32 @@ where
}
}
impl<'a, 'it, T, const N: usize> AsRefIterator<'it> for ConstMapObserver<'a, T, N>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = Iter<'it, T>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
impl<'a, 'it, T, const N: usize> AsMutIterator<'it> for ConstMapObserver<'a, T, N>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = IterMut<'it, T>;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
impl<'a, 'it, T, const N: usize> IntoIterator for &'it ConstMapObserver<'a, T, N>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
@ -351,7 +465,8 @@ where
type IntoIter = Iter<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_slice().iter()
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
@ -363,7 +478,8 @@ where
type IntoIter = IterMut<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_mut_slice().iter_mut()
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
@ -383,11 +499,6 @@ where
&mut self.initial
}
#[inline]
fn set_initial(&mut self, initial: T) {
self.initial = initial;
}
#[inline]
fn get(&self, idx: usize) -> &T {
&self.as_slice()[idx]
@ -515,6 +626,32 @@ where
}
}
impl<'a, 'it, T> AsRefIterator<'it> for VariableMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = Iter<'it, T>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
impl<'a, 'it, T> AsMutIterator<'it> for VariableMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = IterMut<'it, T>;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
impl<'a, 'it, T> IntoIterator for &'it VariableMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
@ -523,7 +660,8 @@ where
type IntoIter = Iter<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_slice().iter()
let cnt = self.usable_count();
self.as_slice()[..cnt].iter()
}
}
@ -535,7 +673,8 @@ where
type IntoIter = IterMut<'it, T>;
fn into_iter(self) -> Self::IntoIter {
self.as_mut_slice().iter_mut()
let cnt = self.usable_count();
self.as_mut_slice()[..cnt].iter_mut()
}
}
@ -555,22 +694,17 @@ where
&mut self.initial
}
#[inline]
fn set_initial(&mut self, initial: T) {
self.initial = initial;
}
#[inline]
fn usable_count(&self) -> usize {
*self.size.as_ref()
}
fn get(&self, idx: usize) -> &T {
&self.as_slice()[idx]
&self.map.as_slice()[idx]
}
fn get_mut(&mut self, idx: usize) -> &mut T {
&mut self.as_mut_slice()[idx]
&mut self.map.as_mut_slice()[idx]
}
fn hash(&self) -> u64 {
@ -587,7 +721,8 @@ where
{
#[inline]
fn as_slice(&self) -> &[T] {
self.map.as_slice()
let cnt = self.usable_count();
&self.map.as_slice()[..cnt]
}
}
impl<'a, T> AsMutSlice<T> for VariableMapObserver<'a, T>
@ -661,6 +796,7 @@ static COUNT_CLASS_LOOKUP: [u8; 256] = [
impl<I, S, M> Observer<I, S> for HitcountsMapObserver<M>
where
M: MapObserver<Entry = u8> + Observer<I, S>,
for<'it> M: AsMutIterator<'it, Item = u8>,
{
#[inline]
fn pre_exec(&mut self, state: &mut S, input: &I) -> Result<(), Error> {
@ -669,9 +805,8 @@ where
#[inline]
fn post_exec(&mut self, state: &mut S, input: &I, exit_kind: &ExitKind) -> Result<(), Error> {
let cnt = self.usable_count();
for i in 0..cnt {
*self.get_mut(i) = COUNT_CLASS_LOOKUP[*self.get(i) as usize];
for elem in self.as_mut_iter() {
*elem = COUNT_CLASS_LOOKUP[*elem as usize];
}
self.base.post_exec(state, input, exit_kind)
}
@ -700,6 +835,7 @@ where
impl<M> MapObserver for HitcountsMapObserver<M>
where
M: MapObserver<Entry = u8>,
for<'it> M: AsMutIterator<'it, Item = u8>,
{
type Entry = u8;
@ -713,11 +849,6 @@ where
self.base.initial_mut()
}
#[inline]
fn set_initial(&mut self, initial: u8) {
self.base.set_initial(initial);
}
#[inline]
fn usable_count(&self) -> usize {
self.base.usable_count()
@ -769,6 +900,57 @@ where
Self { base }
}
}
impl<'it, M> AsRefIterator<'it> for HitcountsMapObserver<M>
where
M: Named + Serialize + serde::de::DeserializeOwned + AsRefIterator<'it, Item = u8>,
{
type Item = u8;
type IntoIter = <M as AsRefIterator<'it>>::IntoIter;
fn as_ref_iter(&'it self) -> Self::IntoIter {
self.base.as_ref_iter()
}
}
impl<'it, M> AsMutIterator<'it> for HitcountsMapObserver<M>
where
M: Named + Serialize + serde::de::DeserializeOwned + AsMutIterator<'it, Item = u8>,
{
type Item = u8;
type IntoIter = <M as AsMutIterator<'it>>::IntoIter;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
self.base.as_mut_iter()
}
}
impl<'it, M> IntoIterator for &'it HitcountsMapObserver<M>
where
M: Named + Serialize + serde::de::DeserializeOwned,
&'it M: IntoIterator<Item = &'it u8>,
{
type Item = &'it u8;
type IntoIter = <&'it M as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.base.into_iter()
}
}
impl<'it, M> IntoIterator for &'it mut HitcountsMapObserver<M>
where
M: Named + Serialize + serde::de::DeserializeOwned,
&'it mut M: IntoIterator<Item = &'it mut u8>,
{
type Item = &'it mut u8;
type IntoIter = <&'it mut M as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.base.into_iter()
}
}
/// The Multi Map Observer merge different maps into one observer
#[derive(Serialize, Deserialize, Debug)]
#[serde(bound = "T: serde::de::DeserializeOwned")]
@ -848,11 +1030,6 @@ where
&mut self.initial
}
#[inline]
fn set_initial(&mut self, initial: T) {
self.initial = initial;
}
fn count_bytes(&self) -> u64 {
let initial = self.initial();
let mut res = 0;
@ -953,14 +1130,28 @@ where
}
}
impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T>
impl<'a, 'it, T> AsRefIterator<'it> for MultiMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
'a: 'it,
{
type Item = <IterMut<'it, T> as Iterator>::Item;
type Item = T;
type IntoIter = Flatten<Iter<'it, OwnedSliceMut<'a, T>>>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
self.maps.iter().flatten()
}
}
impl<'a, 'it, T> AsMutIterator<'it> for MultiMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
'a: 'it,
{
type Item = T;
type IntoIter = Flatten<IterMut<'it, OwnedSliceMut<'a, T>>>;
fn into_iter(self) -> Self::IntoIter {
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
self.maps.iter_mut().flatten()
}
}
@ -977,6 +1168,18 @@ where
}
}
impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = <IterMut<'it, T> as Iterator>::Item;
type IntoIter = Flatten<IterMut<'it, OwnedSliceMut<'a, T>>>;
fn into_iter(self) -> Self::IntoIter {
self.maps.iter_mut().flatten()
}
}
/// Exact copy of `StdMapObserver` that owns its map
/// Used for python bindings
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -1022,6 +1225,30 @@ where
}
}
impl<'a, 'it, T> AsRefIterator<'it> for OwnedMapObserver<T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = Iter<'it, T>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
self.as_slice().iter()
}
}
impl<'a, 'it, T> AsMutIterator<'it> for OwnedMapObserver<T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
{
type Item = T;
type IntoIter = IterMut<'it, T>;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
self.as_mut_slice().iter_mut()
}
}
impl<'it, T> IntoIterator for &'it OwnedMapObserver<T>
where
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
@ -1121,7 +1348,7 @@ where
pub fn new(name: &'static str, map: Vec<T>) -> Self {
let initial = if map.is_empty() { T::default() } else { map[0] };
Self {
map: map,
map,
name: name.to_string(),
initial,
}
@ -1130,11 +1357,12 @@ where
/// `MapObserver` Python bindings
#[cfg(feature = "python")]
pub mod pybind {
use crate::bolts::{tuples::Named, HasLen};
use crate::bolts::{tuples::Named, AsMutIterator, AsRefIterator, HasLen};
use crate::observers::{map::OwnedMapObserver, MapObserver, Observer};
use crate::Error;
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
use std::slice::{Iter, IterMut};
macro_rules! define_python_map_observer {
($struct_name:ident, $py_name:tt, $struct_name_trait:ident, $py_name_trait:tt, $datatype:ty, $wrapper_name: ident) => {
@ -1183,6 +1411,32 @@ pub mod pybind {
}
}
impl<'it> AsRefIterator<'it> for $struct_name_trait {
type Item = $datatype;
type IntoIter = Iter<'it, $datatype>;
fn as_ref_iter(&'it self) -> Self::IntoIter {
match &self.map_observer {
$wrapper_name::Owned(map_observer) => {
map_observer.owned_map_observer.as_ref_iter()
}
}
}
}
impl<'it> AsMutIterator<'it> for $struct_name_trait {
type Item = $datatype;
type IntoIter = IterMut<'it, $datatype>;
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
match &mut self.map_observer {
$wrapper_name::Owned(map_observer) => {
map_observer.owned_map_observer.as_mut_iter()
}
}
}
}
impl MapObserver for $struct_name_trait {
type Entry = $datatype;