diff --git a/utils/gramatron/construct_automata/Cargo.toml b/utils/gramatron/construct_automata/Cargo.toml index 2bc2f87666..60853c61d7 100644 --- a/utils/gramatron/construct_automata/Cargo.toml +++ b/utils/gramatron/construct_automata/Cargo.toml @@ -15,9 +15,9 @@ categories = ["development-tools::testing", "emulators", "embedded", "os", "no-s # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -libafl = { path = "../../../libafl" } +libafl = { path = "../../../libafl", default-features = false } serde_json = "1.0" regex = "1" postcard = { version = "1.0", features = ["alloc"], default-features = false } # no_std compatible serde serialization format clap = { version = "4.0", features = ["derive"] } -log = "0.4.20" +# log = "0.4.20" diff --git a/utils/gramatron/construct_automata/src/main.rs b/utils/gramatron/construct_automata/src/main.rs index 0c94b337fa..dd3dd403e0 100644 --- a/utils/gramatron/construct_automata/src/main.rs +++ b/utils/gramatron/construct_automata/src/main.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{HashSet, VecDeque}, fs, io::{BufReader, Write}, path::{Path, PathBuf}, @@ -49,51 +49,52 @@ fn read_grammar_from_file>(path: P) -> Value { } #[derive(Debug)] -struct Element { +struct Element<'src> { pub state: usize, - pub items: Rc>, + pub items: Rc>, } #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] -struct Transition { +struct Transition<'src> { pub source: usize, pub dest: usize, - pub ss: Vec, - pub terminal: String, - pub is_regex: bool, - pub stack: Rc>, + // pub ss: Vec, + pub terminal: &'src str, + // pub is_regex: bool, + pub stack_len: usize, } #[derive(Default)] -struct Stacks { - pub q: HashMap>, - pub s: HashMap>, +struct Stacks<'src> { + pub q: Vec>>, + pub s: Vec>, } -fn tokenize(rule: &str) -> (String, Vec, bool) { +fn tokenize(rule: &str) -> (&str, Vec<&str>) { let re = RE.get_or_init(|| Regex::new(r"([r])*'([\s\S]+)'([\s\S]*)").unwrap()); + // let re = RE.get_or_init(|| Regex::new(r"'([\s\S]+)'([\s\S]*)").unwrap()); let cap = re.captures(rule).unwrap(); - let is_regex = cap.get(1).is_some(); - let terminal = cap.get(2).unwrap().as_str().to_owned(); + // let is_regex = cap.get(1).is_some(); + let terminal = cap.get(2).unwrap().as_str(); let ss = cap.get(3).map_or(vec![], |m| { m.as_str() .split_whitespace() - .map(ToOwned::to_owned) + // .map(ToOwned::to_owned) .collect() }); if terminal == "\\n" { - ("\n".into(), ss, is_regex) + ("\n", ss /*is_regex*/) } else { - (terminal, ss, is_regex) + (terminal, ss /*is_regex*/) } } -fn prepare_transitions( - grammar: &Value, - pda: &mut Vec, - state_stacks: &mut Stacks, +fn prepare_transitions<'pda, 'src: 'pda>( + grammar: &'src Value, + pda: &'pda mut Vec>, + state_stacks: &mut Stacks<'src>, state_count: &mut usize, - worklist: &mut VecDeque, + worklist: &mut VecDeque>, element: &Element, stack_limit: usize, ) { @@ -102,12 +103,12 @@ fn prepare_transitions( } let state = element.state; - let nonterminal = &element.items[0]; + let nonterminal = element.items[0]; let rules = grammar[nonterminal].as_array().unwrap(); // let mut i = 0; 'rules_loop: for rule in rules { let rule = rule.as_str().unwrap(); - let (terminal, ss, is_regex) = tokenize(rule); + let (terminal, ss /*_is_regex*/) = tokenize(rule); let dest = *state_count; // log::trace!("Rule \"{}\", {} over {}", &rule, i, rules.len()); @@ -115,33 +116,33 @@ fn prepare_transitions( // Creating a state stack for the new state let mut state_stack = state_stacks .q - .get(&state) - .map_or(VecDeque::new(), Clone::clone); - if !state_stack.is_empty() { - state_stack.pop_front(); + .get(state.wrapping_sub(1)) + .map_or(VecDeque::new(), |state_stack| (**state_stack).clone()); + + state_stack.pop_front(); + for symbol in ss.into_iter().rev() { + state_stack.push_front(symbol); } - for symbol in ss.iter().rev() { - state_stack.push_front(symbol.clone()); - } - let mut state_stack_sorted: Vec<_> = state_stack.iter().cloned().collect(); - state_stack_sorted.sort(); + let mut state_stack_sorted: Box<_> = state_stack.iter().copied().collect(); + state_stack_sorted.sort_unstable(); let mut transition = Transition { source: state, dest, - ss, + // ss, terminal, - is_regex, - stack: Rc::new(state_stack.clone()), + // is_regex, + // stack: Rc::new(state_stack.clone()), + stack_len: state_stack.len(), }; // Check if a recursive transition state being created, if so make a backward // edge and don't add anything to the worklist - for (key, val) in &state_stacks.s { - if state_stack_sorted == *val { - transition.dest = *key; + for (dest, stack) in state_stacks.s.iter().enumerate() { + if state_stack_sorted == *stack { + transition.dest = dest + 1; // i += 1; - pda.push(transition.clone()); + pda.push(transition); // If a recursive transition exercised don't add the same transition as a new // edge, continue onto the next transitions @@ -151,18 +152,23 @@ fn prepare_transitions( // If the generated state has a stack size > stack_limit then that state is abandoned // and not added to the FSA or the worklist for further expansion - if stack_limit > 0 && transition.stack.len() > stack_limit { + if stack_limit > 0 && transition.stack_len > stack_limit { // TODO add to unexpanded_rules continue; } + let state_stack = Rc::new(state_stack); + // Create transitions for the non-recursive relations and add to the worklist worklist.push_back(Element { state: dest, - items: transition.stack.clone(), + items: Rc::clone(&state_stack), }); - state_stacks.q.insert(dest, state_stack); - state_stacks.s.insert(dest, state_stack_sorted); + + // since each index corresponds to `state_count - 1` + // index with `dest - 1` + state_stacks.q.push(state_stack); + state_stacks.s.push(state_stack_sorted); pda.push(transition); println!("worklist size: {}", worklist.len()); @@ -205,11 +211,11 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { if stack_limit > 0 { let mut culled_pda = Vec::with_capacity(pda.len()); let mut blocklist = HashSet::new(); - //let mut culled_pda_unique = HashSet::new(); + // let mut culled_pda_unique = HashSet::new(); for final_state in &finals { for transition in pda { - if transition.dest == *final_state && transition.stack.len() > 0 { + if transition.dest == *final_state && transition.stack_len > 0 { blocklist.insert(transition.dest); } else { culled_pda.push(transition); @@ -223,7 +229,9 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { let culled_finals: HashSet = finals.difference(&blocklist).copied().collect(); assert!(culled_finals.len() == 1); - for transition in &culled_pda { + let culled_pda_len = culled_pda.len(); + + for transition in culled_pda { if blocklist.contains(&transition.dest) { continue; } @@ -234,15 +242,11 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { } memoized[state].push(Trigger { dest: transition.dest, - term: transition.terminal.clone(), + term: transition.terminal.to_string(), }); if num_transition % 4096 == 0 { - println!( - "processed {} transitions over {}", - num_transition, - culled_pda.len() - ); + println!("processed {num_transition} transitions over {culled_pda_len}",); } } @@ -261,8 +265,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { */ Automaton { - init_state: initial.iter().next().copied().unwrap(), - final_state: culled_finals.iter().next().copied().unwrap(), + init_state: initial.into_iter().next().unwrap(), + final_state: culled_finals.into_iter().next().unwrap(), pda: memoized, } } else { @@ -275,7 +279,7 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { } memoized[state].push(Trigger { dest: transition.dest, - term: transition.terminal.clone(), + term: transition.terminal.to_string(), }); if num_transition % 4096 == 0 { @@ -288,8 +292,8 @@ fn postprocess(pda: &[Transition], stack_limit: usize) -> Automaton { } Automaton { - init_state: initial.iter().next().copied().unwrap(), - final_state: finals.iter().next().copied().unwrap(), + init_state: initial.into_iter().next().unwrap(), + final_state: finals.into_iter().next().unwrap(), pda: memoized, } } @@ -308,7 +312,7 @@ fn main() { let mut pda = vec![]; let grammar = read_grammar_from_file(grammar_file); - let start_symbol = grammar["Start"][0].as_str().unwrap().to_owned(); + let start_symbol = grammar["Start"][0].as_str().unwrap(); let mut start_vec = VecDeque::new(); start_vec.push_back(start_symbol); worklist.push_back(Element { @@ -328,8 +332,7 @@ fn main() { ); } - state_stacks.q.clear(); - state_stacks.s.clear(); + drop(state_stacks); let transformed = postprocess(&pda, stack_limit); let serialized = postcard::to_allocvec(&transformed).unwrap();