reduce memory usage of the construct_automata script (#1481)

* remove unneeded loop in `SpliceMutator::mutate`

previously we searched for the first and the last difference
between exactly the same 2 inputs 3 times in a loop

* remove unused struct fields

* avoid allocating strings for `Transition`s

* avoid allocating `String`s for `Stack`s

* avoid allocating Strings for `Element`s

* apply some clippy lints

* some more clippy lints

* simplify regex

* remove superflous if condition

* remove the Rc<_> in `Element`

* small cleanups and regex fix

* avoid allocating a vector for the culled pda

* bug fix

* bug fix

* reintroduce the Rc, but make it use the *one* alloced VecDeque this time

* slim down dependencies

* use Box<[&str]> for storted state stacks

this saves us a whopping 8 bytes ;), since we don't have to store
the capacity

* revert the changes from 9ffa715c10089f157e4e20563143a2df890c8ffe

fixes a bug

* apply clippy lint

---------

Co-authored-by: Andrea Fioraldi <andreafioraldi@gmail.com>
This commit is contained in:
lenawanel 2023-09-05 16:29:24 +02:00 committed by GitHub
parent 4c0e01c4aa
commit c791a23456
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 63 deletions

View File

@ -15,9 +15,9 @@ categories = ["development-tools::testing", "emulators", "embedded", "os", "no-s
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
libafl = { path = "../../../libafl" } libafl = { path = "../../../libafl", default-features = false }
serde_json = "1.0" serde_json = "1.0"
regex = "1" regex = "1"
postcard = { version = "1.0", features = ["alloc"], default-features = false } # no_std compatible serde serialization format postcard = { version = "1.0", features = ["alloc"], default-features = false } # no_std compatible serde serialization format
clap = { version = "4.0", features = ["derive"] } clap = { version = "4.0", features = ["derive"] }
log = "0.4.20" # log = "0.4.20"

View File

@ -1,5 +1,5 @@
use std::{ use std::{
collections::{HashMap, HashSet, VecDeque}, collections::{HashSet, VecDeque},
fs, fs,
io::{BufReader, Write}, io::{BufReader, Write},
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -49,51 +49,52 @@ fn read_grammar_from_file<P: AsRef<Path>>(path: P) -> Value {
} }
#[derive(Debug)] #[derive(Debug)]
struct Element { struct Element<'src> {
pub state: usize, pub state: usize,
pub items: Rc<VecDeque<String>>, pub items: Rc<VecDeque<&'src str>>,
} }
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
struct Transition { struct Transition<'src> {
pub source: usize, pub source: usize,
pub dest: usize, pub dest: usize,
pub ss: Vec<String>, // pub ss: Vec<String>,
pub terminal: String, pub terminal: &'src str,
pub is_regex: bool, // pub is_regex: bool,
pub stack: Rc<VecDeque<String>>, pub stack_len: usize,
} }
#[derive(Default)] #[derive(Default)]
struct Stacks { struct Stacks<'src> {
pub q: HashMap<usize, VecDeque<String>>, pub q: Vec<Rc<VecDeque<&'src str>>>,
pub s: HashMap<usize, Vec<String>>, pub s: Vec<Box<[&'src str]>>,
} }
fn tokenize(rule: &str) -> (String, Vec<String>, bool) { fn tokenize(rule: &str) -> (&str, Vec<&str>) {
let re = RE.get_or_init(|| Regex::new(r"([r])*'([\s\S]+)'([\s\S]*)").unwrap()); let re = RE.get_or_init(|| Regex::new(r"([r])*'([\s\S]+)'([\s\S]*)").unwrap());
// let re = RE.get_or_init(|| Regex::new(r"'([\s\S]+)'([\s\S]*)").unwrap());
let cap = re.captures(rule).unwrap(); let cap = re.captures(rule).unwrap();
let is_regex = cap.get(1).is_some(); // let is_regex = cap.get(1).is_some();
let terminal = cap.get(2).unwrap().as_str().to_owned(); let terminal = cap.get(2).unwrap().as_str();
let ss = cap.get(3).map_or(vec![], |m| { let ss = cap.get(3).map_or(vec![], |m| {
m.as_str() m.as_str()
.split_whitespace() .split_whitespace()
.map(ToOwned::to_owned) // .map(ToOwned::to_owned)
.collect() .collect()
}); });
if terminal == "\\n" { if terminal == "\\n" {
("\n".into(), ss, is_regex) ("\n", ss /*is_regex*/)
} else { } else {
(terminal, ss, is_regex) (terminal, ss /*is_regex*/)
} }
} }
fn prepare_transitions( fn prepare_transitions<'pda, 'src: 'pda>(
grammar: &Value, grammar: &'src Value,
pda: &mut Vec<Transition>, pda: &'pda mut Vec<Transition<'src>>,
state_stacks: &mut Stacks, state_stacks: &mut Stacks<'src>,
state_count: &mut usize, state_count: &mut usize,
worklist: &mut VecDeque<Element>, worklist: &mut VecDeque<Element<'src>>,
element: &Element, element: &Element,
stack_limit: usize, stack_limit: usize,
) { ) {
@ -102,12 +103,12 @@ fn prepare_transitions(
} }
let state = element.state; let state = element.state;
let nonterminal = &element.items[0]; let nonterminal = element.items[0];
let rules = grammar[nonterminal].as_array().unwrap(); let rules = grammar[nonterminal].as_array().unwrap();
// let mut i = 0; // let mut i = 0;
'rules_loop: for rule in rules { 'rules_loop: for rule in rules {
let rule = rule.as_str().unwrap(); let rule = rule.as_str().unwrap();
let (terminal, ss, is_regex) = tokenize(rule); let (terminal, ss /*_is_regex*/) = tokenize(rule);
let dest = *state_count; let dest = *state_count;
// log::trace!("Rule \"{}\", {} over {}", &rule, i, rules.len()); // log::trace!("Rule \"{}\", {} over {}", &rule, i, rules.len());
@ -115,33 +116,33 @@ fn prepare_transitions(
// Creating a state stack for the new state // Creating a state stack for the new state
let mut state_stack = state_stacks let mut state_stack = state_stacks
.q .q
.get(&state) .get(state.wrapping_sub(1))
.map_or(VecDeque::new(), Clone::clone); .map_or(VecDeque::new(), |state_stack| (**state_stack).clone());
if !state_stack.is_empty() {
state_stack.pop_front(); state_stack.pop_front();
for symbol in ss.into_iter().rev() {
state_stack.push_front(symbol);
} }
for symbol in ss.iter().rev() { let mut state_stack_sorted: Box<_> = state_stack.iter().copied().collect();
state_stack.push_front(symbol.clone()); state_stack_sorted.sort_unstable();
}
let mut state_stack_sorted: Vec<_> = state_stack.iter().cloned().collect();
state_stack_sorted.sort();
let mut transition = Transition { let mut transition = Transition {
source: state, source: state,
dest, dest,
ss, // ss,
terminal, terminal,
is_regex, // is_regex,
stack: Rc::new(state_stack.clone()), // stack: Rc::new(state_stack.clone()),
stack_len: state_stack.len(),
}; };
// Check if a recursive transition state being created, if so make a backward // Check if a recursive transition state being created, if so make a backward
// edge and don't add anything to the worklist // edge and don't add anything to the worklist
for (key, val) in &state_stacks.s { for (dest, stack) in state_stacks.s.iter().enumerate() {
if state_stack_sorted == *val { if state_stack_sorted == *stack {
transition.dest = *key; transition.dest = dest + 1;
// i += 1; // i += 1;
pda.push(transition.clone()); pda.push(transition);
// If a recursive transition exercised don't add the same transition as a new // If a recursive transition exercised don't add the same transition as a new
// edge, continue onto the next transitions // edge, continue onto the next transitions
@ -151,18 +152,23 @@ fn prepare_transitions(
// If the generated state has a stack size > stack_limit then that state is abandoned // If the generated state has a stack size > stack_limit then that state is abandoned
// and not added to the FSA or the worklist for further expansion // and not added to the FSA or the worklist for further expansion
if stack_limit > 0 && transition.stack.len() > stack_limit { if stack_limit > 0 && transition.stack_len > stack_limit {
// TODO add to unexpanded_rules // TODO add to unexpanded_rules
continue; continue;
} }
let state_stack = Rc::new(state_stack);
// Create transitions for the non-recursive relations and add to the worklist // Create transitions for the non-recursive relations and add to the worklist
worklist.push_back(Element { worklist.push_back(Element {
state: dest, state: dest,
items: transition.stack.clone(), items: Rc::clone(&state_stack),
}); });
state_stacks.q.insert(dest, state_stack);
state_stacks.s.insert(dest, state_stack_sorted); // since each index corresponds to `state_count - 1`
// index with `dest - 1`
state_stacks.q.push(state_stack);
state_stacks.s.push(state_stack_sorted);
pda.push(transition); pda.push(transition);
println!("worklist size: {}", worklist.len()); println!("worklist size: {}", worklist.len());
@ -209,7 +215,7 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
for final_state in &finals { for final_state in &finals {
for transition in pda { for transition in pda {
if transition.dest == *final_state && transition.stack.len() > 0 { if transition.dest == *final_state && transition.stack_len > 0 {
blocklist.insert(transition.dest); blocklist.insert(transition.dest);
} else { } else {
culled_pda.push(transition); culled_pda.push(transition);
@ -223,7 +229,9 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
let culled_finals: HashSet<usize> = finals.difference(&blocklist).copied().collect(); let culled_finals: HashSet<usize> = finals.difference(&blocklist).copied().collect();
assert!(culled_finals.len() == 1); assert!(culled_finals.len() == 1);
for transition in &culled_pda { let culled_pda_len = culled_pda.len();
for transition in culled_pda {
if blocklist.contains(&transition.dest) { if blocklist.contains(&transition.dest) {
continue; continue;
} }
@ -234,15 +242,11 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
} }
memoized[state].push(Trigger { memoized[state].push(Trigger {
dest: transition.dest, dest: transition.dest,
term: transition.terminal.clone(), term: transition.terminal.to_string(),
}); });
if num_transition % 4096 == 0 { if num_transition % 4096 == 0 {
println!( println!("processed {num_transition} transitions over {culled_pda_len}",);
"processed {} transitions over {}",
num_transition,
culled_pda.len()
);
} }
} }
@ -261,8 +265,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
*/ */
Automaton { Automaton {
init_state: initial.iter().next().copied().unwrap(), init_state: initial.into_iter().next().unwrap(),
final_state: culled_finals.iter().next().copied().unwrap(), final_state: culled_finals.into_iter().next().unwrap(),
pda: memoized, pda: memoized,
} }
} else { } else {
@ -275,7 +279,7 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
} }
memoized[state].push(Trigger { memoized[state].push(Trigger {
dest: transition.dest, dest: transition.dest,
term: transition.terminal.clone(), term: transition.terminal.to_string(),
}); });
if num_transition % 4096 == 0 { if num_transition % 4096 == 0 {
@ -288,8 +292,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton {
} }
Automaton { Automaton {
init_state: initial.iter().next().copied().unwrap(), init_state: initial.into_iter().next().unwrap(),
final_state: finals.iter().next().copied().unwrap(), final_state: finals.into_iter().next().unwrap(),
pda: memoized, pda: memoized,
} }
} }
@ -308,7 +312,7 @@ fn main() {
let mut pda = vec![]; let mut pda = vec![];
let grammar = read_grammar_from_file(grammar_file); let grammar = read_grammar_from_file(grammar_file);
let start_symbol = grammar["Start"][0].as_str().unwrap().to_owned(); let start_symbol = grammar["Start"][0].as_str().unwrap();
let mut start_vec = VecDeque::new(); let mut start_vec = VecDeque::new();
start_vec.push_back(start_symbol); start_vec.push_back(start_symbol);
worklist.push_back(Element { worklist.push_back(Element {
@ -328,8 +332,7 @@ fn main() {
); );
} }
state_stacks.q.clear(); drop(state_stacks);
state_stacks.s.clear();
let transformed = postprocess(&pda, stack_limit); let transformed = postprocess(&pda, stack_limit);
let serialized = postcard::to_allocvec(&transformed).unwrap(); let serialized = postcard::to_allocvec(&transformed).unwrap();