Walk the map observer using as_ref_iter() in the map feedback (#535)
* Walk the map observer using into_iter() in the map feedback * fmt * map observers as iterators * perf * IntoMutIterator and IntoRefIterator * Clone * clippy
This commit is contained in:
parent
2dcdaaa89f
commit
479f9471ff
@ -1,6 +1,6 @@
|
||||
use libafl;
|
||||
use libafl_qemu;
|
||||
use libafl_sugar;
|
||||
use libafl;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[pymodule]
|
||||
|
@ -26,7 +26,7 @@ pub mod staterestore;
|
||||
pub mod tuples;
|
||||
|
||||
use alloc::string::String;
|
||||
use core::time;
|
||||
use core::{iter::Iterator, time};
|
||||
#[cfg(feature = "std")]
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
@ -42,6 +42,28 @@ pub trait AsMutSlice<T> {
|
||||
fn as_mut_slice(&mut self) -> &mut [T];
|
||||
}
|
||||
|
||||
/// Create an `Iterator` from a reference
|
||||
pub trait AsRefIterator<'it> {
|
||||
/// The item type
|
||||
type Item: 'it;
|
||||
/// The iterator type
|
||||
type IntoIter: Iterator<Item = &'it Self::Item>;
|
||||
|
||||
/// Create an interator from &self
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter;
|
||||
}
|
||||
|
||||
/// Create an `Iterator` from a mutable reference
|
||||
pub trait AsMutIterator<'it> {
|
||||
/// The item type
|
||||
type Item: 'it;
|
||||
/// The iterator type
|
||||
type IntoIter: Iterator<Item = &'it mut Self::Item>;
|
||||
|
||||
/// Create an interator from &mut self
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter;
|
||||
}
|
||||
|
||||
/// Has a length field
|
||||
pub trait HasLen {
|
||||
/// The length
|
||||
|
@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
bolts::{
|
||||
tuples::{MatchName, Named},
|
||||
AsMutSlice, AsSlice, HasRefCnt,
|
||||
AsMutSlice, AsRefIterator, AsSlice, HasRefCnt,
|
||||
},
|
||||
corpus::Testcase,
|
||||
events::{Event, EventFirer},
|
||||
@ -41,21 +41,21 @@ pub type MaxMapOneOrFilledFeedback<I, O, S, T> =
|
||||
MapFeedback<I, OneOrFilledIsNovel, O, MaxReducer, S, T>;
|
||||
|
||||
/// A `Reducer` function is used to aggregate values for the novelty search
|
||||
pub trait Reducer<T>: Serialize + serde::de::DeserializeOwned + 'static + Debug
|
||||
pub trait Reducer<T>: 'static + Debug
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
/// Reduce two values to one value, with the current [`Reducer`].
|
||||
fn reduce(first: T, second: T) -> T;
|
||||
}
|
||||
|
||||
/// A [`OrReducer`] reduces the values returning the bitwise OR with the old value
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OrReducer {}
|
||||
|
||||
impl<T> Reducer<T> for OrReducer
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
|
||||
T: PrimInt + Default + Copy + 'static + PartialOrd,
|
||||
{
|
||||
#[inline]
|
||||
fn reduce(history: T, new: T) -> T {
|
||||
@ -64,12 +64,12 @@ where
|
||||
}
|
||||
|
||||
/// A [`AndReducer`] reduces the values returning the bitwise AND with the old value
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AndReducer {}
|
||||
|
||||
impl<T> Reducer<T> for AndReducer
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
|
||||
T: PrimInt + Default + Copy + 'static + PartialOrd,
|
||||
{
|
||||
#[inline]
|
||||
fn reduce(history: T, new: T) -> T {
|
||||
@ -78,12 +78,12 @@ where
|
||||
}
|
||||
|
||||
/// A [`MaxReducer`] reduces int values and returns their maximum.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MaxReducer {}
|
||||
|
||||
impl<T> Reducer<T> for MaxReducer
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
|
||||
T: PrimInt + Default + Copy + 'static + PartialOrd,
|
||||
{
|
||||
#[inline]
|
||||
fn reduce(first: T, second: T) -> T {
|
||||
@ -96,12 +96,12 @@ where
|
||||
}
|
||||
|
||||
/// A [`MinReducer`] reduces int values and returns their minimum.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MinReducer {}
|
||||
|
||||
impl<T> Reducer<T> for MinReducer
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + PartialOrd,
|
||||
T: PrimInt + Default + Copy + 'static + PartialOrd,
|
||||
{
|
||||
#[inline]
|
||||
fn reduce(first: T, second: T) -> T {
|
||||
@ -114,9 +114,9 @@ where
|
||||
}
|
||||
|
||||
/// A `IsNovel` function is used to discriminate if a reduced value is considered novel.
|
||||
pub trait IsNovel<T>: Serialize + serde::de::DeserializeOwned + 'static + Debug
|
||||
pub trait IsNovel<T>: 'static + Debug
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
/// If a new value in the [`MapFeedback`] was found,
|
||||
/// this filter can decide if the result is considered novel or not.
|
||||
@ -124,12 +124,12 @@ where
|
||||
}
|
||||
|
||||
/// [`AllIsNovel`] consider everything a novelty. Here mostly just for debugging.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AllIsNovel {}
|
||||
|
||||
impl<T> IsNovel<T> for AllIsNovel
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
#[inline]
|
||||
fn is_novel(_old: T, _new: T) -> bool {
|
||||
@ -152,11 +152,11 @@ fn saturating_next_power_of_two<T: PrimInt>(n: T) -> T {
|
||||
}
|
||||
|
||||
/// Consider as novelty if the reduced value is different from the old value.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DifferentIsNovel {}
|
||||
impl<T> IsNovel<T> for DifferentIsNovel
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
#[inline]
|
||||
fn is_novel(old: T, new: T) -> bool {
|
||||
@ -165,11 +165,11 @@ where
|
||||
}
|
||||
|
||||
/// Only consider as novel the values which are at least the next pow2 class of the old value
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NextPow2IsNovel {}
|
||||
impl<T> IsNovel<T> for NextPow2IsNovel
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
#[inline]
|
||||
fn is_novel(old: T, new: T) -> bool {
|
||||
@ -185,11 +185,11 @@ where
|
||||
}
|
||||
|
||||
/// A filter that only saves values which are at least the next pow2 class
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OneOrFilledIsNovel {}
|
||||
impl<T> IsNovel<T> for OneOrFilledIsNovel
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned,
|
||||
T: PrimInt + Default + Copy + 'static,
|
||||
{
|
||||
#[inline]
|
||||
fn is_novel(old: T, new: T) -> bool {
|
||||
@ -342,13 +342,13 @@ where
|
||||
}
|
||||
|
||||
/// The most common AFL-like feedback type
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(bound = "T: serde::de::DeserializeOwned")]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MapFeedback<I, N, O, R, S, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
R: Reducer<T>,
|
||||
O: MapObserver,
|
||||
O: MapObserver<Entry = T>,
|
||||
for<'it> O: AsRefIterator<'it, Item = T>,
|
||||
N: IsNovel<T>,
|
||||
S: HasFeedbackStates,
|
||||
{
|
||||
@ -369,6 +369,7 @@ where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
R: Reducer<T>,
|
||||
O: MapObserver<Entry = T>,
|
||||
for<'it> O: AsRefIterator<'it, Item = T>,
|
||||
N: IsNovel<T>,
|
||||
I: Input,
|
||||
S: HasFeedbackStates + HasClientPerfMonitor + Debug,
|
||||
@ -402,10 +403,8 @@ where
|
||||
assert!(size <= observer.len());
|
||||
|
||||
if self.novelties.is_some() {
|
||||
for i in 0..size {
|
||||
for (i, &item) in observer.as_ref_iter().enumerate() {
|
||||
let history = map_state.history_map[i];
|
||||
let item = *observer.get(i);
|
||||
|
||||
let reduced = R::reduce(history, item);
|
||||
if N::is_novel(history, reduced) {
|
||||
map_state.history_map[i] = reduced;
|
||||
@ -414,10 +413,8 @@ where
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in 0..size {
|
||||
for (i, &item) in observer.as_ref_iter().enumerate() {
|
||||
let history = map_state.history_map[i];
|
||||
let item = *observer.get(i);
|
||||
|
||||
let reduced = R::reduce(history, item);
|
||||
if N::is_novel(history, reduced) {
|
||||
map_state.history_map[i] = reduced;
|
||||
@ -478,7 +475,8 @@ where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
R: Reducer<T>,
|
||||
N: IsNovel<T>,
|
||||
O: MapObserver,
|
||||
O: MapObserver<Entry = T>,
|
||||
for<'it> O: AsRefIterator<'it, Item = T>,
|
||||
S: HasFeedbackStates,
|
||||
{
|
||||
#[inline]
|
||||
@ -499,7 +497,8 @@ where
|
||||
+ Debug,
|
||||
R: Reducer<T>,
|
||||
N: IsNovel<T>,
|
||||
O: MapObserver,
|
||||
O: MapObserver<Entry = T>,
|
||||
for<'it> O: AsRefIterator<'it, Item = T>,
|
||||
S: HasFeedbackStates,
|
||||
{
|
||||
/// Create new `MapFeedback`
|
||||
@ -562,7 +561,7 @@ where
|
||||
}
|
||||
|
||||
/// A [`ReachabilityFeedback`] reports if a target has been reached.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ReachabilityFeedback<O> {
|
||||
name: String,
|
||||
target_idx: Vec<usize>,
|
||||
@ -572,6 +571,7 @@ pub struct ReachabilityFeedback<O> {
|
||||
impl<O> ReachabilityFeedback<O>
|
||||
where
|
||||
O: MapObserver<Entry = usize>,
|
||||
for<'it> O: AsRefIterator<'it, Item = usize>,
|
||||
{
|
||||
/// Creates a new [`ReachabilityFeedback`] for a [`MapObserver`].
|
||||
#[must_use]
|
||||
@ -598,6 +598,7 @@ impl<I, O, S> Feedback<I, S> for ReachabilityFeedback<O>
|
||||
where
|
||||
I: Input,
|
||||
O: MapObserver<Entry = usize>,
|
||||
for<'it> O: AsRefIterator<'it, Item = usize>,
|
||||
S: HasClientPerfMonitor,
|
||||
{
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
@ -615,11 +616,10 @@ where
|
||||
{
|
||||
// TODO Replace with match_name_type when stable
|
||||
let observer = observers.match_name::<O>(&self.name).unwrap();
|
||||
let size = observer.usable_count();
|
||||
let mut hit_target: bool = false;
|
||||
//check if we've hit any targets.
|
||||
for i in 0..size {
|
||||
if *observer.get(i) > 0 {
|
||||
for (i, &elem) in observer.as_ref_iter().enumerate() {
|
||||
if elem > 0 {
|
||||
self.target_idx.push(i);
|
||||
hit_target = true;
|
||||
}
|
||||
@ -648,12 +648,14 @@ where
|
||||
impl<O> Named for ReachabilityFeedback<O>
|
||||
where
|
||||
O: MapObserver<Entry = usize>,
|
||||
for<'it> O: AsRefIterator<'it, Item = usize>,
|
||||
{
|
||||
#[inline]
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::feedbacks::{AllIsNovel, IsNovel, NextPow2IsNovel};
|
||||
@ -711,7 +713,7 @@ pub mod pybind {
|
||||
}
|
||||
|
||||
#[pyclass(unsendable, name = $max_map_feedback_py_name)]
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
/// Python class for MaxMapFeedback
|
||||
pub struct $max_map_feedback_struct_name {
|
||||
/// Rust wrapped MaxMapFeedback object
|
||||
@ -719,6 +721,14 @@ pub mod pybind {
|
||||
MaxMapFeedback<BytesInput, $map_observer_name, $std_state_name, $datatype>,
|
||||
}
|
||||
|
||||
impl Clone for $max_map_feedback_struct_name {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
max_map_feedback: self.max_map_feedback.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl $max_map_feedback_struct_name {
|
||||
#[new]
|
||||
|
@ -9,6 +9,7 @@ use core::{
|
||||
fmt::Debug,
|
||||
hash::Hasher,
|
||||
iter::Flatten,
|
||||
marker::PhantomData,
|
||||
slice::{from_raw_parts, Iter, IterMut},
|
||||
};
|
||||
use intervaltree::IntervalTree;
|
||||
@ -19,7 +20,7 @@ use crate::{
|
||||
bolts::{
|
||||
ownedref::{OwnedRefMut, OwnedSliceMut},
|
||||
tuples::Named,
|
||||
AsMutSlice, AsSlice, HasLen,
|
||||
AsMutIterator, AsMutSlice, AsRefIterator, AsSlice, HasLen,
|
||||
},
|
||||
executors::ExitKind,
|
||||
observers::Observer,
|
||||
@ -38,9 +39,14 @@ fn hash_slice<T: PrimInt>(slice: &[T]) -> u64 {
|
||||
}
|
||||
|
||||
/// A [`MapObserver`] observes the static map, as oftentimes used for AFL-like coverage information
|
||||
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug {
|
||||
///
|
||||
/// TODO: enforce `iter() -> AssociatedTypeIter` when generic associated types stabilize
|
||||
pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned + Debug
|
||||
// where
|
||||
// for<'it> &'it Self: IntoIterator<Item = &'it Self::Entry>
|
||||
{
|
||||
/// Type of each entry in this map
|
||||
type Entry: PrimInt + Default + Copy + Debug;
|
||||
type Entry: PrimInt + Default + Copy + Debug + 'static;
|
||||
|
||||
/// Get the value at `idx`
|
||||
fn get(&self, idx: usize) -> &Self::Entry;
|
||||
@ -74,7 +80,9 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
|
||||
fn initial_mut(&mut self) -> &mut Self::Entry;
|
||||
|
||||
/// Set the initial value for reset()
|
||||
fn set_initial(&mut self, initial: Self::Entry);
|
||||
fn set_initial(&mut self, initial: Self::Entry) {
|
||||
*self.initial_mut() = initial;
|
||||
}
|
||||
|
||||
/// Reset the map
|
||||
#[inline]
|
||||
@ -112,6 +120,64 @@ pub trait MapObserver: HasLen + Named + Serialize + serde::de::DeserializeOwned
|
||||
}
|
||||
}
|
||||
|
||||
/// A Simple iterator calling `MapObserver::get`
|
||||
#[derive(Debug)]
|
||||
pub struct MapObserverSimpleIterator<'a, O>
|
||||
where
|
||||
O: 'a + MapObserver,
|
||||
{
|
||||
index: usize,
|
||||
observer: *const O,
|
||||
phantom: PhantomData<&'a u8>,
|
||||
}
|
||||
|
||||
impl<'a, O> Iterator for MapObserverSimpleIterator<'a, O>
|
||||
where
|
||||
O: 'a + MapObserver,
|
||||
{
|
||||
type Item = &'a O::Entry;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
unsafe {
|
||||
if self.index >= self.observer.as_ref().unwrap().usable_count() {
|
||||
None
|
||||
} else {
|
||||
let i = self.index;
|
||||
self.index += 1;
|
||||
Some(self.observer.as_ref().unwrap().get(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Simple iterator calling `MapObserver::get_mut`
|
||||
#[derive(Debug)]
|
||||
pub struct MapObserverSimpleIteratoMut<'a, O>
|
||||
where
|
||||
O: 'a + MapObserver,
|
||||
{
|
||||
index: usize,
|
||||
observer: *mut O,
|
||||
phantom: PhantomData<&'a u8>,
|
||||
}
|
||||
|
||||
impl<'a, O> Iterator for MapObserverSimpleIteratoMut<'a, O>
|
||||
where
|
||||
O: 'a + MapObserver,
|
||||
{
|
||||
type Item = &'a O::Entry;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
unsafe {
|
||||
if self.index >= self.observer.as_ref().unwrap().usable_count() {
|
||||
None
|
||||
} else {
|
||||
let i = self.index;
|
||||
self.index += 1;
|
||||
Some(self.observer.as_mut().unwrap().get_mut(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The Map Observer retrieves the state of a map,
|
||||
/// that will get updated by the target.
|
||||
/// A well-known example is the AFL-Style coverage map.
|
||||
@ -130,7 +196,6 @@ where
|
||||
impl<'a, I, S, T> Observer<I, S> for StdMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
Self: MapObserver,
|
||||
{
|
||||
#[inline]
|
||||
fn pre_exec(&mut self, _state: &mut S, _input: &I) -> Result<(), Error> {
|
||||
@ -158,6 +223,32 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsRefIterator<'it> for StdMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsMutIterator<'it> for StdMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> IntoIterator for &'it StdMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
@ -166,7 +257,8 @@ where
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_slice().iter()
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,7 +270,8 @@ where
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_mut_slice().iter_mut()
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,11 +310,6 @@ where
|
||||
&mut self.initial
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn set_initial(&mut self, initial: T) {
|
||||
self.initial = initial;
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<T> {
|
||||
self.as_slice().to_vec()
|
||||
}
|
||||
@ -343,6 +431,32 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T, const N: usize> AsRefIterator<'it> for ConstMapObserver<'a, T, N>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T, const N: usize> AsMutIterator<'it> for ConstMapObserver<'a, T, N>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T, const N: usize> IntoIterator for &'it ConstMapObserver<'a, T, N>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
@ -351,7 +465,8 @@ where
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_slice().iter()
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
@ -363,7 +478,8 @@ where
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_mut_slice().iter_mut()
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
@ -383,11 +499,6 @@ where
|
||||
&mut self.initial
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn set_initial(&mut self, initial: T) {
|
||||
self.initial = initial;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get(&self, idx: usize) -> &T {
|
||||
&self.as_slice()[idx]
|
||||
@ -515,6 +626,32 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsRefIterator<'it> for VariableMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsMutIterator<'it> for VariableMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> IntoIterator for &'it VariableMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
@ -523,7 +660,8 @@ where
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_slice().iter()
|
||||
let cnt = self.usable_count();
|
||||
self.as_slice()[..cnt].iter()
|
||||
}
|
||||
}
|
||||
|
||||
@ -535,7 +673,8 @@ where
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.as_mut_slice().iter_mut()
|
||||
let cnt = self.usable_count();
|
||||
self.as_mut_slice()[..cnt].iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
@ -555,22 +694,17 @@ where
|
||||
&mut self.initial
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn set_initial(&mut self, initial: T) {
|
||||
self.initial = initial;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn usable_count(&self) -> usize {
|
||||
*self.size.as_ref()
|
||||
}
|
||||
|
||||
fn get(&self, idx: usize) -> &T {
|
||||
&self.as_slice()[idx]
|
||||
&self.map.as_slice()[idx]
|
||||
}
|
||||
|
||||
fn get_mut(&mut self, idx: usize) -> &mut T {
|
||||
&mut self.as_mut_slice()[idx]
|
||||
&mut self.map.as_mut_slice()[idx]
|
||||
}
|
||||
|
||||
fn hash(&self) -> u64 {
|
||||
@ -587,7 +721,8 @@ where
|
||||
{
|
||||
#[inline]
|
||||
fn as_slice(&self) -> &[T] {
|
||||
self.map.as_slice()
|
||||
let cnt = self.usable_count();
|
||||
&self.map.as_slice()[..cnt]
|
||||
}
|
||||
}
|
||||
impl<'a, T> AsMutSlice<T> for VariableMapObserver<'a, T>
|
||||
@ -661,6 +796,7 @@ static COUNT_CLASS_LOOKUP: [u8; 256] = [
|
||||
impl<I, S, M> Observer<I, S> for HitcountsMapObserver<M>
|
||||
where
|
||||
M: MapObserver<Entry = u8> + Observer<I, S>,
|
||||
for<'it> M: AsMutIterator<'it, Item = u8>,
|
||||
{
|
||||
#[inline]
|
||||
fn pre_exec(&mut self, state: &mut S, input: &I) -> Result<(), Error> {
|
||||
@ -669,9 +805,8 @@ where
|
||||
|
||||
#[inline]
|
||||
fn post_exec(&mut self, state: &mut S, input: &I, exit_kind: &ExitKind) -> Result<(), Error> {
|
||||
let cnt = self.usable_count();
|
||||
for i in 0..cnt {
|
||||
*self.get_mut(i) = COUNT_CLASS_LOOKUP[*self.get(i) as usize];
|
||||
for elem in self.as_mut_iter() {
|
||||
*elem = COUNT_CLASS_LOOKUP[*elem as usize];
|
||||
}
|
||||
self.base.post_exec(state, input, exit_kind)
|
||||
}
|
||||
@ -700,6 +835,7 @@ where
|
||||
impl<M> MapObserver for HitcountsMapObserver<M>
|
||||
where
|
||||
M: MapObserver<Entry = u8>,
|
||||
for<'it> M: AsMutIterator<'it, Item = u8>,
|
||||
{
|
||||
type Entry = u8;
|
||||
|
||||
@ -713,11 +849,6 @@ where
|
||||
self.base.initial_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn set_initial(&mut self, initial: u8) {
|
||||
self.base.set_initial(initial);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn usable_count(&self) -> usize {
|
||||
self.base.usable_count()
|
||||
@ -769,6 +900,57 @@ where
|
||||
Self { base }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it, M> AsRefIterator<'it> for HitcountsMapObserver<M>
|
||||
where
|
||||
M: Named + Serialize + serde::de::DeserializeOwned + AsRefIterator<'it, Item = u8>,
|
||||
{
|
||||
type Item = u8;
|
||||
type IntoIter = <M as AsRefIterator<'it>>::IntoIter;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
self.base.as_ref_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it, M> AsMutIterator<'it> for HitcountsMapObserver<M>
|
||||
where
|
||||
M: Named + Serialize + serde::de::DeserializeOwned + AsMutIterator<'it, Item = u8>,
|
||||
{
|
||||
type Item = u8;
|
||||
type IntoIter = <M as AsMutIterator<'it>>::IntoIter;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
self.base.as_mut_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it, M> IntoIterator for &'it HitcountsMapObserver<M>
|
||||
where
|
||||
M: Named + Serialize + serde::de::DeserializeOwned,
|
||||
&'it M: IntoIterator<Item = &'it u8>,
|
||||
{
|
||||
type Item = &'it u8;
|
||||
type IntoIter = <&'it M as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.base.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it, M> IntoIterator for &'it mut HitcountsMapObserver<M>
|
||||
where
|
||||
M: Named + Serialize + serde::de::DeserializeOwned,
|
||||
&'it mut M: IntoIterator<Item = &'it mut u8>,
|
||||
{
|
||||
type Item = &'it mut u8;
|
||||
type IntoIter = <&'it mut M as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.base.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// The Multi Map Observer merge different maps into one observer
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(bound = "T: serde::de::DeserializeOwned")]
|
||||
@ -848,11 +1030,6 @@ where
|
||||
&mut self.initial
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn set_initial(&mut self, initial: T) {
|
||||
self.initial = initial;
|
||||
}
|
||||
|
||||
fn count_bytes(&self) -> u64 {
|
||||
let initial = self.initial();
|
||||
let mut res = 0;
|
||||
@ -953,14 +1130,28 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T>
|
||||
impl<'a, 'it, T> AsRefIterator<'it> for MultiMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
'a: 'it,
|
||||
{
|
||||
type Item = <IterMut<'it, T> as Iterator>::Item;
|
||||
type Item = T;
|
||||
type IntoIter = Flatten<Iter<'it, OwnedSliceMut<'a, T>>>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
self.maps.iter().flatten()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsMutIterator<'it> for MultiMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
'a: 'it,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = Flatten<IterMut<'it, OwnedSliceMut<'a, T>>>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
self.maps.iter_mut().flatten()
|
||||
}
|
||||
}
|
||||
@ -977,6 +1168,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> IntoIterator for &'it mut MultiMapObserver<'a, T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = <IterMut<'it, T> as Iterator>::Item;
|
||||
type IntoIter = Flatten<IterMut<'it, OwnedSliceMut<'a, T>>>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.maps.iter_mut().flatten()
|
||||
}
|
||||
}
|
||||
|
||||
/// Exact copy of `StdMapObserver` that owns its map
|
||||
/// Used for python bindings
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
@ -1022,6 +1225,30 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsRefIterator<'it> for OwnedMapObserver<T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = Iter<'it, T>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
self.as_slice().iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'it, T> AsMutIterator<'it> for OwnedMapObserver<T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
{
|
||||
type Item = T;
|
||||
type IntoIter = IterMut<'it, T>;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
self.as_mut_slice().iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it, T> IntoIterator for &'it OwnedMapObserver<T>
|
||||
where
|
||||
T: PrimInt + Default + Copy + 'static + Serialize + serde::de::DeserializeOwned + Debug,
|
||||
@ -1121,7 +1348,7 @@ where
|
||||
pub fn new(name: &'static str, map: Vec<T>) -> Self {
|
||||
let initial = if map.is_empty() { T::default() } else { map[0] };
|
||||
Self {
|
||||
map: map,
|
||||
map,
|
||||
name: name.to_string(),
|
||||
initial,
|
||||
}
|
||||
@ -1130,11 +1357,12 @@ where
|
||||
/// `MapObserver` Python bindings
|
||||
#[cfg(feature = "python")]
|
||||
pub mod pybind {
|
||||
use crate::bolts::{tuples::Named, HasLen};
|
||||
use crate::bolts::{tuples::Named, AsMutIterator, AsRefIterator, HasLen};
|
||||
use crate::observers::{map::OwnedMapObserver, MapObserver, Observer};
|
||||
use crate::Error;
|
||||
use pyo3::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
macro_rules! define_python_map_observer {
|
||||
($struct_name:ident, $py_name:tt, $struct_name_trait:ident, $py_name_trait:tt, $datatype:ty, $wrapper_name: ident) => {
|
||||
@ -1183,6 +1411,32 @@ pub mod pybind {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it> AsRefIterator<'it> for $struct_name_trait {
|
||||
type Item = $datatype;
|
||||
type IntoIter = Iter<'it, $datatype>;
|
||||
|
||||
fn as_ref_iter(&'it self) -> Self::IntoIter {
|
||||
match &self.map_observer {
|
||||
$wrapper_name::Owned(map_observer) => {
|
||||
map_observer.owned_map_observer.as_ref_iter()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'it> AsMutIterator<'it> for $struct_name_trait {
|
||||
type Item = $datatype;
|
||||
type IntoIter = IterMut<'it, $datatype>;
|
||||
|
||||
fn as_mut_iter(&'it mut self) -> Self::IntoIter {
|
||||
match &mut self.map_observer {
|
||||
$wrapper_name::Owned(map_observer) => {
|
||||
map_observer.owned_map_observer.as_mut_iter()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MapObserver for $struct_name_trait {
|
||||
type Entry = $datatype;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user