Introducing Launcher::overcommit, improving CI formatting (#2670)

* introducing Launcher::overcommit

* removing unnecessary cfg restrictions and clippy allows

* improving warning for wrong clang-format version

* installing black in the format CI

* Enforcing python formatting in CI

* extending formatting using black on all python files

* printing diff on black failure

* preferring python's black over system black

* moving to LLVM 19 for formatting
This commit is contained in:
Valentin Huber 2024-11-09 19:13:51 +01:00 committed by GitHub
parent 8617fa6603
commit e32b3eae93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 345 additions and 299 deletions

View File

@ -198,6 +198,8 @@ jobs:
run: rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt run: rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
with: { shared-key: "ubuntu" } with: { shared-key: "ubuntu" }
- name: Installing black
run: python3 -m pip install black
- name: Format Check - name: Format Check
run: ./scripts/fmt_all.sh check run: ./scripts/fmt_all.sh check

View File

@ -3,5 +3,7 @@ import ctypes
import platform import platform
print("Starting to fuzz from python!") print("Starting to fuzz from python!")
fuzzer = sugar.InMemoryBytesCoverageSugar(input_dirs=["./in"], output_dir="out", broker_port=1337, cores=[0,1]) fuzzer = sugar.InMemoryBytesCoverageSugar(
input_dirs=["./in"], output_dir="out", broker_port=1337, cores=[0, 1]
)
fuzzer.run(lambda b: print("foo")) fuzzer.run(lambda b: print("foo"))

View File

@ -4,31 +4,32 @@ from pylibafl import sugar, qemu
import lief import lief
MAX_SIZE = 0x100 MAX_SIZE = 0x100
BINARY_PATH = './a.out' BINARY_PATH = "./a.out"
emu = qemu.Qemu(['qemu-x86_64', BINARY_PATH], []) emu = qemu.Qemu(["qemu-x86_64", BINARY_PATH], [])
elf = lief.parse(BINARY_PATH) elf = lief.parse(BINARY_PATH)
test_one_input = elf.get_function_address("LLVMFuzzerTestOneInput") test_one_input = elf.get_function_address("LLVMFuzzerTestOneInput")
if elf.is_pie: if elf.is_pie:
test_one_input += emu.load_addr() test_one_input += emu.load_addr()
print('LLVMFuzzerTestOneInput @ 0x%x' % test_one_input) print("LLVMFuzzerTestOneInput @ 0x%x" % test_one_input)
emu.set_breakpoint(test_one_input) emu.set_breakpoint(test_one_input)
emu.run() emu.run()
sp = emu.read_reg(qemu.regs.Rsp) sp = emu.read_reg(qemu.regs.Rsp)
print('SP = 0x%x' % sp) print("SP = 0x%x" % sp)
retaddr = int.from_bytes(emu.read_mem(sp, 8), 'little') retaddr = int.from_bytes(emu.read_mem(sp, 8), "little")
print('RET = 0x%x' % retaddr) print("RET = 0x%x" % retaddr)
inp = emu.map_private(0, MAX_SIZE, qemu.mmap.ReadWrite) inp = emu.map_private(0, MAX_SIZE, qemu.mmap.ReadWrite)
assert(inp > 0) assert inp > 0
emu.remove_breakpoint(test_one_input) emu.remove_breakpoint(test_one_input)
emu.set_breakpoint(retaddr) emu.set_breakpoint(retaddr)
def harness(b): def harness(b):
if len(b) > MAX_SIZE: if len(b) > MAX_SIZE:
b = b[:MAX_SIZE] b = b[:MAX_SIZE]
@ -39,5 +40,6 @@ def harness(b):
emu.write_reg(qemu.regs.Rip, test_one_input) emu.write_reg(qemu.regs.Rip, test_one_input)
emu.run() emu.run()
fuzz = sugar.QemuBytesCoverageSugar(['./in'], './out', 3456, [0,1,2,3])
fuzz = sugar.QemuBytesCoverageSugar(["./in"], "./out", 3456, [0, 1, 2, 3])
fuzz.run(emu, harness) fuzz.run(emu, harness)

View File

@ -4,16 +4,17 @@ import os
import json import json
import sys import sys
def concatenate_json_files(input_dir): def concatenate_json_files(input_dir):
json_files = [] json_files = []
for root, dirs, files in os.walk(input_dir): for root, dirs, files in os.walk(input_dir):
for file in files: for file in files:
if file.endswith('.json'): if file.endswith(".json"):
json_files.append(os.path.join(root, file)) json_files.append(os.path.join(root, file))
data = dict() data = dict()
for json_file in json_files: for json_file in json_files:
with open(json_file, 'r') as file: with open(json_file, "r") as file:
if os.stat(json_file).st_size == 0: if os.stat(json_file).st_size == 0:
# skip empty file else json.load() fails # skip empty file else json.load() fails
continue continue
@ -21,13 +22,14 @@ def concatenate_json_files(input_dir):
print(type(json_data), file) print(type(json_data), file)
data = data | json_data data = data | json_data
output_file = os.path.join(os.getcwd(), 'concatenated.json') output_file = os.path.join(os.getcwd(), "concatenated.json")
with open(output_file, 'w') as file: with open(output_file, "w") as file:
json.dump([data], file) json.dump([data], file)
print(f"JSON files concatenated successfully! Output file: {output_file}") print(f"JSON files concatenated successfully! Output file: {output_file}")
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python script.py <directory_path>") print("Usage: python script.py <directory_path>")
sys.exit(1) sys.exit(1)

View File

@ -108,24 +108,27 @@ pub struct Launcher<'a, CF, MT, SP> {
broker_port: u16, broker_port: u16,
/// The list of cores to run on /// The list of cores to run on
cores: &'a Cores, cores: &'a Cores,
/// The number of clients to spawn on each core
#[builder(default = 1)]
overcommit: usize,
/// A file name to write all client output to /// A file name to write all client output to
#[cfg(all(unix, feature = "std"))] #[cfg(unix)]
#[builder(default = None)] #[builder(default = None)]
stdout_file: Option<&'a str>, stdout_file: Option<&'a str>,
/// The time in milliseconds to delay between child launches /// The time in milliseconds to delay between child launches
#[builder(default = 10)] #[builder(default = 10)]
launch_delay: u64, launch_delay: u64,
/// The actual, opened, `stdout_file` - so that we keep it open until the end /// The actual, opened, `stdout_file` - so that we keep it open until the end
#[cfg(all(unix, feature = "std", feature = "fork"))] #[cfg(all(unix, feature = "fork"))]
#[builder(setter(skip), default = None)] #[builder(setter(skip), default = None)]
opened_stdout_file: Option<File>, opened_stdout_file: Option<File>,
/// A file name to write all client stderr output to. If not specified, output is sent to /// A file name to write all client stderr output to. If not specified, output is sent to
/// `stdout_file`. /// `stdout_file`.
#[cfg(all(unix, feature = "std"))] #[cfg(unix)]
#[builder(default = None)] #[builder(default = None)]
stderr_file: Option<&'a str>, stderr_file: Option<&'a str>,
/// The actual, opened, `stdout_file` - so that we keep it open until the end /// The actual, opened, `stdout_file` - so that we keep it open until the end
#[cfg(all(unix, feature = "std", feature = "fork"))] #[cfg(all(unix, feature = "fork"))]
#[builder(setter(skip), default = None)] #[builder(setter(skip), default = None)]
opened_stderr_file: Option<File>, opened_stderr_file: Option<File>,
/// The `ip:port` address of another broker to connect our new broker to for multi-machine /// The `ip:port` address of another broker to connect our new broker to for multi-machine
@ -172,17 +175,10 @@ where
SP: ShMemProvider, SP: ShMemProvider,
{ {
/// Launch the broker and the clients and fuzz /// Launch the broker and the clients and fuzz
#[cfg(all(unix, feature = "std", feature = "fork"))] #[cfg(all(
pub fn launch<S>(&mut self) -> Result<(), Error> feature = "std",
where any(windows, not(feature = "fork"), all(unix, feature = "fork"))
S: State + HasExecutions, ))]
CF: FnOnce(Option<S>, LlmpRestartingEventManager<(), S, SP>, CoreId) -> Result<(), Error>,
{
Self::launch_with_hooks(self, tuple_list!())
}
/// Launch the broker and the clients and fuzz
#[cfg(all(feature = "std", any(windows, not(feature = "fork"))))]
#[allow(unused_mut, clippy::match_wild_err_arm)] #[allow(unused_mut, clippy::match_wild_err_arm)]
pub fn launch<S>(&mut self) -> Result<(), Error> pub fn launch<S>(&mut self) -> Result<(), Error>
where where
@ -200,9 +196,8 @@ where
SP: ShMemProvider, SP: ShMemProvider,
{ {
/// Launch the broker and the clients and fuzz with a user-supplied hook /// Launch the broker and the clients and fuzz with a user-supplied hook
#[cfg(all(unix, feature = "std", feature = "fork"))] #[cfg(all(unix, feature = "fork"))]
#[allow(clippy::similar_names)] #[allow(clippy::similar_names, clippy::too_many_lines)]
#[allow(clippy::too_many_lines)]
pub fn launch_with_hooks<EMH, S>(&mut self, hooks: EMH) -> Result<(), Error> pub fn launch_with_hooks<EMH, S>(&mut self, hooks: EMH) -> Result<(), Error>
where where
S: State + HasExecutions, S: State + HasExecutions,
@ -221,8 +216,7 @@ where
)); ));
} }
let core_ids = get_core_ids().unwrap(); let core_ids = get_core_ids()?;
let num_cores = core_ids.len();
let mut handles = vec![]; let mut handles = vec![];
log::info!("spawning on cores: {:?}", self.cores); log::info!("spawning on cores: {:?}", self.cores);
@ -234,66 +228,63 @@ where
.stderr_file .stderr_file
.map(|filename| File::create(filename).unwrap()); .map(|filename| File::create(filename).unwrap());
#[cfg(feature = "std")]
let debug_output = std::env::var(LIBAFL_DEBUG_OUTPUT).is_ok(); let debug_output = std::env::var(LIBAFL_DEBUG_OUTPUT).is_ok();
// Spawn clients // Spawn clients
let mut index = 0_u64; let mut index = 0_u64;
for (id, bind_to) in core_ids.iter().enumerate().take(num_cores) { for (id, bind_to) in core_ids.iter().enumerate() {
if self.cores.ids.iter().any(|&x| x == id.into()) { if self.cores.ids.iter().any(|&x| x == id.into()) {
index += 1; for _ in 0..self.overcommit {
self.shmem_provider.pre_fork()?; index += 1;
// # Safety self.shmem_provider.pre_fork()?;
// Fork is safe in general, apart from potential side effects to the OS and other threads // # Safety
match unsafe { fork() }? { // Fork is safe in general, apart from potential side effects to the OS and other threads
ForkResult::Parent(child) => { match unsafe { fork() }? {
self.shmem_provider.post_fork(false)?; ForkResult::Parent(child) => {
handles.push(child.pid); self.shmem_provider.post_fork(false)?;
#[cfg(feature = "std")] handles.push(child.pid);
log::info!("child spawned and bound to core {id}"); log::info!("child spawned and bound to core {id}");
} }
ForkResult::Child => { ForkResult::Child => {
// # Safety // # Safety
// A call to `getpid` is safe. // A call to `getpid` is safe.
log::info!("{:?} PostFork", unsafe { libc::getpid() }); log::info!("{:?} PostFork", unsafe { libc::getpid() });
self.shmem_provider.post_fork(true)?; self.shmem_provider.post_fork(true)?;
#[cfg(feature = "std")] std::thread::sleep(Duration::from_millis(index * self.launch_delay));
std::thread::sleep(Duration::from_millis(index * self.launch_delay));
#[cfg(feature = "std")] if !debug_output {
if !debug_output { if let Some(file) = &self.opened_stdout_file {
if let Some(file) = &self.opened_stdout_file { dup2(file.as_raw_fd(), libc::STDOUT_FILENO)?;
dup2(file.as_raw_fd(), libc::STDOUT_FILENO)?; if let Some(stderr) = &self.opened_stderr_file {
if let Some(stderr) = &self.opened_stderr_file { dup2(stderr.as_raw_fd(), libc::STDERR_FILENO)?;
dup2(stderr.as_raw_fd(), libc::STDERR_FILENO)?; } else {
} else { dup2(file.as_raw_fd(), libc::STDERR_FILENO)?;
dup2(file.as_raw_fd(), libc::STDERR_FILENO)?; }
} }
} }
// Fuzzer client. keeps retrying the connection to broker till the broker starts
let builder = RestartingMgr::<EMH, MT, S, SP>::builder()
.shmem_provider(self.shmem_provider.clone())
.broker_port(self.broker_port)
.kind(ManagerKind::Client {
cpu_core: Some(*bind_to),
})
.configuration(self.configuration)
.serialize_state(self.serialize_state)
.hooks(hooks);
let builder = builder.time_ref(self.time_ref.clone());
let (state, mgr) = builder.build().launch()?;
return (self.run_client.take().unwrap())(state, mgr, *bind_to);
} }
};
// Fuzzer client. keeps retrying the connection to broker till the broker starts }
let builder = RestartingMgr::<EMH, MT, S, SP>::builder()
.shmem_provider(self.shmem_provider.clone())
.broker_port(self.broker_port)
.kind(ManagerKind::Client {
cpu_core: Some(*bind_to),
})
.configuration(self.configuration)
.serialize_state(self.serialize_state)
.hooks(hooks);
let builder = builder.time_ref(self.time_ref.clone());
let (state, mgr) = builder.build().launch()?;
return (self.run_client.take().unwrap())(state, mgr, *bind_to);
}
};
} }
} }
if self.spawn_broker { if self.spawn_broker {
#[cfg(feature = "std")]
log::info!("I am broker!!."); log::info!("I am broker!!.");
// TODO we don't want always a broker here, think about using different laucher process to spawn different configurations // TODO we don't want always a broker here, think about using different laucher process to spawn different configurations
@ -337,7 +328,7 @@ where
} }
/// Launch the broker and the clients and fuzz /// Launch the broker and the clients and fuzz
#[cfg(all(feature = "std", any(windows, not(feature = "fork"))))] #[cfg(any(windows, not(feature = "fork")))]
#[allow(unused_mut, clippy::match_wild_err_arm, clippy::too_many_lines)] #[allow(unused_mut, clippy::match_wild_err_arm, clippy::too_many_lines)]
pub fn launch_with_hooks<EMH, S>(&mut self, hooks: EMH) -> Result<(), Error> pub fn launch_with_hooks<EMH, S>(&mut self, hooks: EMH) -> Result<(), Error>
where where
@ -381,7 +372,7 @@ where
log::info!("spawning on cores: {:?}", self.cores); log::info!("spawning on cores: {:?}", self.cores);
let debug_output = std::env::var("LIBAFL_DEBUG_OUTPUT").is_ok(); let debug_output = std::env::var("LIBAFL_DEBUG_OUTPUT").is_ok();
#[cfg(all(feature = "std", unix))] #[cfg(unix)]
{ {
// Set own stdout and stderr as set by the user // Set own stdout and stderr as set by the user
if !debug_output { if !debug_output {
@ -404,32 +395,34 @@ where
//spawn clients //spawn clients
for (id, _) in core_ids.iter().enumerate().take(num_cores) { for (id, _) in core_ids.iter().enumerate().take(num_cores) {
if self.cores.ids.iter().any(|&x| x == id.into()) { if self.cores.ids.iter().any(|&x| x == id.into()) {
// Forward own stdio to child processes, if requested by user for _ in 0..self.overcommit {
let (mut stdout, mut stderr) = (Stdio::null(), Stdio::null()); // Forward own stdio to child processes, if requested by user
#[cfg(all(feature = "std", unix))] let (mut stdout, mut stderr) = (Stdio::null(), Stdio::null());
{ #[cfg(unix)]
if self.stdout_file.is_some() || self.stderr_file.is_some() { {
stdout = Stdio::inherit(); if self.stdout_file.is_some() || self.stderr_file.is_some() {
stderr = Stdio::inherit(); stdout = Stdio::inherit();
}; stderr = Stdio::inherit();
};
}
std::thread::sleep(Duration::from_millis(
id as u64 * self.launch_delay,
));
std::env::set_var(_AFL_LAUNCHER_CLIENT, id.to_string());
let mut child = startable_self()?;
let child = (if debug_output {
&mut child
} else {
child.stdout(stdout);
child.stderr(stderr)
})
.spawn()?;
handles.push(child);
} }
#[cfg(feature = "std")]
std::thread::sleep(Duration::from_millis(id as u64 * self.launch_delay));
std::env::set_var(_AFL_LAUNCHER_CLIENT, id.to_string());
let mut child = startable_self()?;
let child = (if debug_output {
&mut child
} else {
child.stdout(stdout);
child.stderr(stderr)
})
.spawn()?;
handles.push(child);
} }
} }
handles handles
} }
Err(_) => panic!("Env variables are broken, received non-unicode!"), Err(_) => panic!("Env variables are broken, received non-unicode!"),
@ -444,7 +437,6 @@ where
} }
if self.spawn_broker { if self.spawn_broker {
#[cfg(feature = "std")]
log::info!("I am broker!!."); log::info!("I am broker!!.");
let builder = RestartingMgr::<EMH, MT, S, SP>::builder() let builder = RestartingMgr::<EMH, MT, S, SP>::builder()
@ -620,8 +612,7 @@ where
/// Launch a Centralized-based fuzzer. /// Launch a Centralized-based fuzzer.
/// - `main_inner_mgr_builder` will be called to build the inner manager of the main node. /// - `main_inner_mgr_builder` will be called to build the inner manager of the main node.
/// - `secondary_inner_mgr_builder` will be called to build the inner manager of the secondary nodes. /// - `secondary_inner_mgr_builder` will be called to build the inner manager of the secondary nodes.
#[allow(clippy::similar_names)] #[allow(clippy::similar_names, clippy::too_many_lines)]
#[allow(clippy::too_many_lines)]
pub fn launch_generic<EM, EMB, S>( pub fn launch_generic<EM, EMB, S>(
&mut self, &mut self,
main_inner_mgr_builder: EMB, main_inner_mgr_builder: EMB,

View File

@ -11,16 +11,21 @@ else
cargo run --manifest-path "$LIBAFL_DIR/utils/libafl_fmt/Cargo.toml" --release -- --verbose || exit 1 cargo run --manifest-path "$LIBAFL_DIR/utils/libafl_fmt/Cargo.toml" --release -- --verbose || exit 1
fi fi
if command -v black > /dev/null; then if python3 -m black --version > /dev/null; then
echo "[*] Formatting python files" BLACK_COMMAND="python3 -m black"
if ! black "$SCRIPT_DIR" elif command -v black > /dev/null; then
then BLACK_COMMAND="black"
echo "Python format failed." fi
exit 1
fi
if [ -n "$BLACK_COMMAND" ]; then
echo "[*] Formatting python files"
if [ "$1" = "check" ]; then
$BLACK_COMMAND --check --diff "$LIBAFL_DIR" || exit 1
else
$BLACK_COMMAND "$LIBAFL_DIR" || exit 1
fi
else else
echo "Warning: python black not found. Formatting skipped for python." echo -e "\n\033[1;33mWarning\033[0m: python black not found. Formatting skipped for python.\n"
fi fi
if [ "$1" != "check" ]; then if [ "$1" != "check" ]; then

View File

@ -7,7 +7,7 @@ import sys
cfg = dict() cfg = dict()
if 'CFG_OUTPUT_PATH' not in os.environ: if "CFG_OUTPUT_PATH" not in os.environ:
sys.exit("CFG_OUTPUT_PATH not set") sys.exit("CFG_OUTPUT_PATH not set")
input_path = os.environ["CFG_OUTPUT_PATH"] input_path = os.environ["CFG_OUTPUT_PATH"]
@ -31,7 +31,7 @@ for mname, module in cfg.items():
fnname2SG = dict() fnname2SG = dict()
# First, add all the intra-procedural edges # First, add all the intra-procedural edges
for (fname, v) in module['edges'].items(): for fname, v in module["edges"].items():
if fname not in fname2id: if fname not in fname2id:
GG.add_node(f_ids, label=fname) GG.add_node(f_ids, label=fname)
@ -41,8 +41,7 @@ for mname, module in cfg.items():
sz = len(v) sz = len(v)
for idx in range(node_ids, node_ids + sz): for idx in range(node_ids, node_ids + sz):
G.add_node(idx) G.add_node(idx)
G.nodes[idx]['label'] = mname + ' ' + \ G.nodes[idx]["label"] = mname + " " + fname + " " + str(idx - node_ids)
fname + ' ' + str(idx - node_ids)
node_id_list = list(range(node_ids, node_ids + sz)) node_id_list = list(range(node_ids, node_ids + sz))
node_ids += sz node_ids += sz
SG = G.subgraph(node_id_list) SG = G.subgraph(node_id_list)
@ -52,14 +51,14 @@ for mname, module in cfg.items():
G.add_edge(node_id_list[src], node_id_list[item]) G.add_edge(node_id_list[src], node_id_list[item])
# Next, build inter-procedural edges # Next, build inter-procedural edges
for (fname, calls) in module['calls'].items(): for fname, calls in module["calls"].items():
for (idx, target_fns) in calls.items(): for idx, target_fns in calls.items():
# G.nodes isn't sorted # G.nodes isn't sorted
src = sorted(fnname2SG[fname].nodes())[0] + int(idx) src = sorted(fnname2SG[fname].nodes())[0] + int(idx)
for target_fn in target_fns: for target_fn in target_fns:
if target_fn in fnname2SG: if target_fn in fnname2SG:
offset = module['entries'][target_fn] offset = module["entries"][target_fn]
dst = sorted(fnname2SG[target_fn].nodes)[0] + offset dst = sorted(fnname2SG[target_fn].nodes)[0] + offset

View File

@ -8,6 +8,7 @@ import sys
import json import json
import re import re
from collections import defaultdict from collections import defaultdict
# import pygraphviz as pgv # import pygraphviz as pgv
gram_data = None gram_data = None
@ -24,20 +25,20 @@ stack_limit = None
# Holds the set of unexpanded rules owing to the user-passed stack constraint limit # Holds the set of unexpanded rules owing to the user-passed stack constraint limit
unexpanded_rules = set() unexpanded_rules = set()
def main(grammar, limit): def main(grammar, limit):
global worklist, gram_data, stack_limit global worklist, gram_data, stack_limit
current = '0' current = "0"
stack_limit = limit stack_limit = limit
if stack_limit: if stack_limit:
print ('[X] Operating in bounded stack mode') print("[X] Operating in bounded stack mode")
with open(grammar, 'r') as fd: with open(grammar, "r") as fd:
gram_data = json.load(fd) gram_data = json.load(fd)
start_symbol = gram_data["Start"][0] start_symbol = gram_data["Start"][0]
worklist.append([current, [start_symbol]]) worklist.append([current, [start_symbol]])
# print (grammar) # print (grammar)
filename = (grammar.split('/')[-1]).split('.')[0] filename = (grammar.split("/")[-1]).split(".")[0]
while worklist: while worklist:
# Take an element from the worklist # Take an element from the worklist
@ -46,45 +47,54 @@ def main(grammar, limit):
element = worklist.pop(0) element = worklist.pop(0)
prep_transitions(element) prep_transitions(element)
pda_file = filename + '_transition.json' pda_file = filename + "_transition.json"
graph_file = filename + '.png' graph_file = filename + ".png"
# print ('XXXXXXXXXXXXXXXX') # print ('XXXXXXXXXXXXXXXX')
# print ('PDA file:%s Png graph file:%s' % (pda_file, graph_file)) # print ('PDA file:%s Png graph file:%s' % (pda_file, graph_file))
# XXX Commented out because visualization of current version of PHP causes segfault # XXX Commented out because visualization of current version of PHP causes segfault
# Create the graph and dump the transitions to a file # Create the graph and dump the transitions to a file
# create_graph(filename) # create_graph(filename)
transformed = postprocess() transformed = postprocess()
with open(filename + '_automata.json', 'w+') as fd: with open(filename + "_automata.json", "w+") as fd:
json.dump(transformed, fd) json.dump(transformed, fd)
with open(filename + '_transition.json', 'w+') as fd: with open(filename + "_transition.json", "w+") as fd:
json.dump(pda, fd) json.dump(pda, fd)
if not unexpanded_rules: if not unexpanded_rules:
print ('[X] No unexpanded rules, absolute FSA formed') print("[X] No unexpanded rules, absolute FSA formed")
exit(0) exit(0)
else: else:
print ('[X] Certain rules were not expanded due to stack size limit. Inexact approximation has been created and the disallowed rules have been put in {}_disallowed.json'.format(filename)) print(
print ('[X] Number of unexpanded rules:', len(unexpanded_rules)) "[X] Certain rules were not expanded due to stack size limit. Inexact approximation has been created and the disallowed rules have been put in {}_disallowed.json".format(
with open(filename + '_disallowed.json', 'w+') as fd: filename
)
)
print("[X] Number of unexpanded rules:", len(unexpanded_rules))
with open(filename + "_disallowed.json", "w+") as fd:
json.dump(list(unexpanded_rules), fd) json.dump(list(unexpanded_rules), fd)
def create_graph(filename): def create_graph(filename):
''' """
Creates a DOT representation of the PDA Creates a DOT representation of the PDA
''' """
global pda global pda
G = pgv.AGraph(strict = False, directed = True) G = pgv.AGraph(strict=False, directed=True)
for transition in pda: for transition in pda:
print ('Transition:', transition) print("Transition:", transition)
G.add_edge(transition['source'], transition['dest'], G.add_edge(
label = 'Term:{}'.format(transition['terminal'])) transition["source"],
G.layout(prog = 'dot') transition["dest"],
print ('Do it up 2') label="Term:{}".format(transition["terminal"]),
G.draw(filename + '.png') )
G.layout(prog="dot")
print("Do it up 2")
G.draw(filename + ".png")
def prep_transitions(element): def prep_transitions(element):
''' """
Generates transitions Generates transitions
''' """
global gram_data, state_count, pda, worklist, state_stacks, stack_limit, unexpanded_rules global gram_data, state_count, pda, worklist, state_stacks, stack_limit, unexpanded_rules
state = element[0] state = element[0]
try: try:
@ -95,18 +105,18 @@ def prep_transitions(element):
rules = gram_data[nonterminal] rules = gram_data[nonterminal]
count = 1 count = 1
for rule in rules: for rule in rules:
isRecursive = False isRecursive = False
# print ('Current state:', state) # print ('Current state:', state)
terminal, ss, termIsRegex = tokenize(rule) terminal, ss, termIsRegex = tokenize(rule)
transition = get_template() transition = get_template()
transition['trigger'] = '_'.join([state, str(count)]) transition["trigger"] = "_".join([state, str(count)])
transition['source'] = state transition["source"] = state
transition['dest'] = str(state_count) transition["dest"] = str(state_count)
transition['ss'] = ss transition["ss"] = ss
transition['terminal'] = terminal transition["terminal"] = terminal
transition['rule'] = "{} -> {}".format(nonterminal, rule ) transition["rule"] = "{} -> {}".format(nonterminal, rule)
if termIsRegex: if termIsRegex:
transition['termIsRegex'] = True transition["termIsRegex"] = True
# Creating a state stack for the new state # Creating a state stack for the new state
try: try:
@ -118,7 +128,7 @@ def prep_transitions(element):
if ss: if ss:
for symbol in ss[::-1]: for symbol in ss[::-1]:
state_stack.insert(0, symbol) state_stack.insert(0, symbol)
transition['stack'] = state_stack transition["stack"] = state_stack
# Check if a recursive transition state being created, if so make a backward # Check if a recursive transition state being created, if so make a backward
# edge and don't add anything to the worklist # edge and don't add anything to the worklist
@ -128,7 +138,7 @@ def prep_transitions(element):
# print ('Stack:', sorted(stack)) # print ('Stack:', sorted(stack))
# print ('State stack:', sorted(state_stack)) # print ('State stack:', sorted(state_stack))
if sorted(stack) == sorted(state_stack): if sorted(stack) == sorted(state_stack):
transition['dest'] = state_element transition["dest"] = state_element
# print ('Recursive:', transition) # print ('Recursive:', transition)
pda.append(transition) pda.append(transition)
count += 1 count += 1
@ -142,24 +152,25 @@ def prep_transitions(element):
# If the generated state has a stack size > stack_limit then that state is abandoned # If the generated state has a stack size > stack_limit then that state is abandoned
# and not added to the FSA or the worklist for further expansion # and not added to the FSA or the worklist for further expansion
if stack_limit: if stack_limit:
if (len(transition['stack']) > stack_limit): if len(transition["stack"]) > stack_limit:
unexpanded_rules.add(transition['rule']) unexpanded_rules.add(transition["rule"])
continue continue
# Create transitions for the non-recursive relations and add to the worklist # Create transitions for the non-recursive relations and add to the worklist
# print ('Normal:', transition) # print ('Normal:', transition)
# print ('State2:', state) # print ('State2:', state)
pda.append(transition) pda.append(transition)
worklist.append([transition['dest'], transition['stack']]) worklist.append([transition["dest"], transition["stack"]])
state_stacks[transition['dest']] = state_stack state_stacks[transition["dest"]] = state_stack
state_count += 1 state_count += 1
count += 1 count += 1
def tokenize(rule): def tokenize(rule):
''' """
Gets the terminal and the corresponding stack symbols from a rule in GNF form Gets the terminal and the corresponding stack symbols from a rule in GNF form
''' """
pattern = re.compile("([r])*\'([\s\S]+)\'([\s\S]*)") pattern = re.compile("([r])*'([\s\S]+)'([\s\S]*)")
terminal = None terminal = None
ss = None ss = None
termIsRegex = False termIsRegex = False
@ -176,34 +187,35 @@ def tokenize(rule):
return terminal, ss, termIsRegex return terminal, ss, termIsRegex
def get_template(): def get_template():
transition_template = { transition_template = {
'trigger':None, "trigger": None,
'source': None, "source": None,
'dest': None, "dest": None,
'termIsRegex': False, "termIsRegex": False,
'terminal' : None, "terminal": None,
'stack': [] "stack": [],
} }
return transition_template return transition_template
def postprocess1(): def postprocess1():
''' """
Creates a representation to be passed on to the C-module Creates a representation to be passed on to the C-module
''' """
global pda global pda
final_struct = {} final_struct = {}
# Supporting data structures for if stack limit is imposed # Supporting data structures for if stack limit is imposed
culled_pda = [] culled_pda = []
culled_final = [] culled_final = []
num_transitions = 0 # Keep track of number of transitions num_transitions = 0 # Keep track of number of transitions
states, final, initial = _get_states() states, final, initial = _get_states()
memoized = [[]] * len(states) memoized = [[]] * len(states)
print (initial) print(initial)
assert len(initial) == 1, 'More than one init state found' assert len(initial) == 1, "More than one init state found"
# Cull transitions to states which were not expanded owing to the stack limit # Cull transitions to states which were not expanded owing to the stack limit
if stack_limit: if stack_limit:
@ -211,7 +223,9 @@ def postprocess1():
blocklist = [] blocklist = []
for final_state in final: for final_state in final:
for transition in pda: for transition in pda:
if (transition["dest"] == final_state) and (len(transition["stack"]) > 0): if (transition["dest"] == final_state) and (
len(transition["stack"]) > 0
):
blocklist.append(transition["dest"]) blocklist.append(transition["dest"])
continue continue
else: else:
@ -219,55 +233,57 @@ def postprocess1():
culled_final = [state for state in final if state not in blocklist] culled_final = [state for state in final if state not in blocklist]
assert len(culled_final) == 1, 'More than one final state found' assert len(culled_final) == 1, "More than one final state found"
for transition in culled_pda: for transition in culled_pda:
state = transition["source"] state = transition["source"]
if transition["dest"] in blocklist: if transition["dest"] in blocklist:
continue continue
num_transitions += 1 num_transitions += 1
memoized[int(state)].append((transition["trigger"], memoized[int(state)].append(
int(transition["dest"]), transition["terminal"])) (transition["trigger"], int(transition["dest"]), transition["terminal"])
)
final_struct["init_state"] = int(initial) final_struct["init_state"] = int(initial)
final_struct["final_state"] = int(culled_final[0]) final_struct["final_state"] = int(culled_final[0])
# The reason we do this is because when states are culled, the indexing is # The reason we do this is because when states are culled, the indexing is
# still relative to the actual number of states hence we keep numstates recorded # still relative to the actual number of states hence we keep numstates recorded
# as the original number of states # as the original number of states
print ('[X] Actual Number of states:', len(memoized)) print("[X] Actual Number of states:", len(memoized))
print ('[X] Number of transitions:', num_transitions) print("[X] Number of transitions:", num_transitions)
print ('[X] Original Number of states:', len(states)) print("[X] Original Number of states:", len(states))
final_struct["pda"] = memoized final_struct["pda"] = memoized
return final_struct return final_struct
# Running FSA construction in exact approximation mode and postprocessing it like so # Running FSA construction in exact approximation mode and postprocessing it like so
for transition in pda: for transition in pda:
state = transition["source"] state = transition["source"]
memoized[int(state)].append((transition["trigger"], memoized[int(state)].append(
int(transition["dest"]), transition["terminal"])) (transition["trigger"], int(transition["dest"]), transition["terminal"])
)
final_struct["init_state"] = int(initial) final_struct["init_state"] = int(initial)
final_struct["final_state"] = int(final[0]) final_struct["final_state"] = int(final[0])
print ('[X] Actual Number of states:', len(memoized)) print("[X] Actual Number of states:", len(memoized))
final_struct["pda"] = memoized final_struct["pda"] = memoized
return final_struct return final_struct
def postprocess(): def postprocess():
''' """
Creates a representation to be passed on to the C-module Creates a representation to be passed on to the C-module
''' """
global pda global pda
final_struct = {} final_struct = {}
memoized = defaultdict(list) memoized = defaultdict(list)
# Supporting data structures for if stack limit is imposed # Supporting data structures for if stack limit is imposed
culled_pda = [] culled_pda = []
culled_final = [] culled_final = []
num_transitions = 0 # Keep track of number of transitions num_transitions = 0 # Keep track of number of transitions
states, final, initial = _get_states() states, final, initial = _get_states()
print (initial) print(initial)
assert len(initial) == 1, 'More than one init state found' assert len(initial) == 1, "More than one init state found"
# Cull transitions to states which were not expanded owing to the stack limit # Cull transitions to states which were not expanded owing to the stack limit
if stack_limit: if stack_limit:
@ -275,7 +291,9 @@ def postprocess():
blocklist = [] blocklist = []
for final_state in final: for final_state in final:
for transition in pda: for transition in pda:
if (transition["dest"] == final_state) and (len(transition["stack"]) > 0): if (transition["dest"] == final_state) and (
len(transition["stack"]) > 0
):
blocklist.append(transition["dest"]) blocklist.append(transition["dest"])
continue continue
else: else:
@ -283,40 +301,40 @@ def postprocess():
culled_final = [state for state in final if state not in blocklist] culled_final = [state for state in final if state not in blocklist]
assert len(culled_final) == 1, 'More than one final state found' assert len(culled_final) == 1, "More than one final state found"
for transition in culled_pda: for transition in culled_pda:
state = transition["source"] state = transition["source"]
if transition["dest"] in blocklist: if transition["dest"] in blocklist:
continue continue
num_transitions += 1 num_transitions += 1
memoized[int(state)].append([transition["trigger"], int(transition["dest"]), memoized[int(state)].append(
transition["terminal"]]) [transition["trigger"], int(transition["dest"]), transition["terminal"]]
)
final_struct["init_state"] = int(initial) final_struct["init_state"] = int(initial)
final_struct["final_state"] = int(culled_final[0]) final_struct["final_state"] = int(culled_final[0])
# The reason we do this is because when states are culled, the indexing is # The reason we do this is because when states are culled, the indexing is
# still relative to the actual number of states hence we keep numstates recorded # still relative to the actual number of states hence we keep numstates recorded
# as the original number of states # as the original number of states
print ('[X] Actual Number of states:', len(memoized.keys())) print("[X] Actual Number of states:", len(memoized.keys()))
print ('[X] Number of transitions:', num_transitions) print("[X] Number of transitions:", num_transitions)
print ('[X] Original Number of states:', len(states)) print("[X] Original Number of states:", len(states))
#final_struct["numstates"] = len(states) # final_struct["numstates"] = len(states)
memoized_list = [[]]*len(states) memoized_list = [[]] * len(states)
else: else:
# Running FSA construction in exact approximation mode and postprocessing it like so # Running FSA construction in exact approximation mode and postprocessing it like so
for transition in pda: for transition in pda:
state = transition["source"] state = transition["source"]
memoized[int(state)].append([transition["trigger"], int(transition["dest"]), memoized[int(state)].append(
transition["terminal"]]) [transition["trigger"], int(transition["dest"]), transition["terminal"]]
)
final_struct["init_state"] = int(initial) final_struct["init_state"] = int(initial)
final_struct["final_state"] = int(final[0]) final_struct["final_state"] = int(final[0])
print ('[X] Actual Number of states:', len(memoized.keys())) print("[X] Actual Number of states:", len(memoized.keys()))
#final_struct["numstates"] = len(memoized.keys()) # final_struct["numstates"] = len(memoized.keys())
memoized_list = [[]]*len(memoized.keys()) memoized_list = [[]] * len(memoized.keys())
for k in memoized.keys(): for k in memoized.keys():
memoized_list[k] = memoized[k] memoized_list[k] = memoized[k]
@ -333,19 +351,23 @@ def _get_states():
dest.add(transition["dest"]) dest.add(transition["dest"])
source_copy = source.copy() source_copy = source.copy()
source_copy.update(dest) source_copy.update(dest)
return list(source_copy), list(dest.difference(source)), str(''.join(list(source.difference(dest)))) return (
list(source_copy),
list(dest.difference(source)),
str("".join(list(source.difference(dest)))),
)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description = 'Script to convert GNF grammar to PDA')
parser = argparse.ArgumentParser(description="Script to convert GNF grammar to PDA")
parser.add_argument("--gf", type=str, help="Location of GNF grammar")
parser.add_argument( parser.add_argument(
'--gf', "--limit",
type = str, type=int,
help = 'Location of GNF grammar') default=None,
parser.add_argument( help="Specify the upper bound for the stack size",
'--limit', )
type = int,
default = None,
help = 'Specify the upper bound for the stack size')
args = parser.parse_args() args = parser.parse_args()
main(args.gf, args.limit) main(args.gf, args.limit)

View File

@ -16,17 +16,18 @@ DEBUG = False
NONTERMINALSET = [] NONTERMINALSET = []
COUNT = 1 COUNT = 1
def convert_to_gnf(grammar, start): def convert_to_gnf(grammar, start):
if DEBUG: if DEBUG:
with open('debug_preprocess.json', 'w+') as fd: with open("debug_preprocess.json", "w+") as fd:
json.dump(grammar, fd) json.dump(grammar, fd)
grammar = remove_unit(grammar) # eliminates unit productions grammar = remove_unit(grammar) # eliminates unit productions
if DEBUG: if DEBUG:
with open('debug_unit.json', 'w+') as fd: with open("debug_unit.json", "w+") as fd:
json.dump(grammar, fd) json.dump(grammar, fd)
grammar = remove_mixed(grammar) # eliminate terminals existing with non-terminals grammar = remove_mixed(grammar) # eliminate terminals existing with non-terminals
if DEBUG: if DEBUG:
with open('debug_mixed.json', 'w+') as fd: with open("debug_mixed.json", "w+") as fd:
json.dump(grammar, fd) json.dump(grammar, fd)
grammar = gnf(grammar) grammar = gnf(grammar)
@ -35,12 +36,13 @@ def convert_to_gnf(grammar, start):
# with open('debug_gnf_reachable.json', 'w+') as fd: # with open('debug_gnf_reachable.json', 'w+') as fd:
# json.dump(reachable_grammar, fd) # json.dump(reachable_grammar, fd)
if DEBUG: if DEBUG:
with open('debug_gnf.json', 'w+') as fd: with open("debug_gnf.json", "w+") as fd:
json.dump(grammar, fd) json.dump(grammar, fd)
grammar["Start"] = [start] grammar["Start"] = [start]
return grammar return grammar
def remove_left_recursion(grammar): def remove_left_recursion(grammar):
# Remove the left recursion in the grammar rules. # Remove the left recursion in the grammar rules.
# This algorithm is adopted from # This algorithm is adopted from
@ -69,10 +71,10 @@ def remove_left_recursion(grammar):
r.append(new_rule) r.append(new_rule)
left_recursion = [r[1:] + [new_rule] for r in left_recursion] left_recursion = [r[1:] + [new_rule] for r in left_recursion]
left_recursion.append(["' '"]) left_recursion.append(["' '"])
new_grammar[lhs] = [' '.join(rule) for rule in others] new_grammar[lhs] = [" ".join(rule) for rule in others]
new_grammar[new_rule] = [' '.join(rule) for rule in left_recursion] new_grammar[new_rule] = [" ".join(rule) for rule in left_recursion]
else: else:
new_grammar[lhs] = [' '.join(rule) for rule in others] new_grammar[lhs] = [" ".join(rule) for rule in others]
no_left_recursion = True no_left_recursion = True
for lhs, rules in old_grammar.items(): for lhs, rules in old_grammar.items():
for rule in rules: for rule in rules:
@ -88,10 +90,11 @@ def remove_left_recursion(grammar):
new_grammar = defaultdict(list) new_grammar = defaultdict(list)
return new_grammar return new_grammar
def get_reachable(grammar, start): def get_reachable(grammar, start):
''' """
Returns a grammar without dead rules Returns a grammar without dead rules
''' """
reachable_nt = set() reachable_nt = set()
worklist = list() worklist = list()
processed = set() processed = set()
@ -113,9 +116,10 @@ def get_reachable(grammar, start):
def gettokens(rule): def gettokens(rule):
pattern = re.compile("([^\s\"\']+)|\"([^\"]*)\"|\'([^\']*)\'") pattern = re.compile("([^\s\"']+)|\"([^\"]*)\"|'([^']*)'")
return [matched.group(0) for matched in pattern.finditer(rule)] return [matched.group(0) for matched in pattern.finditer(rule)]
def gnf(grammar): def gnf(grammar):
old_grammar = copy.deepcopy(grammar) old_grammar = copy.deepcopy(grammar)
new_grammar = defaultdict(list) new_grammar = defaultdict(list)
@ -129,7 +133,7 @@ def gnf(grammar):
new_grammar[lhs].append(rule) new_grammar[lhs].append(rule)
continue continue
startoken = tokens[0] startoken = tokens[0]
assert(startoken != lhs) assert startoken != lhs
endrule = tokens[1:] endrule = tokens[1:]
if not isTerminal(startoken): if not isTerminal(startoken):
newrules = [] newrules = []
@ -139,7 +143,7 @@ def gnf(grammar):
temprule.insert(0, extension) temprule.insert(0, extension)
newrules.append(temprule) newrules.append(temprule)
for newnew in newrules: for newnew in newrules:
new_grammar[lhs].append(' '.join(newnew)) new_grammar[lhs].append(" ".join(newnew))
else: else:
new_grammar[lhs].append(rule) new_grammar[lhs].append(rule)
isgnf = True isgnf = True
@ -163,7 +167,7 @@ def process_antlr4_grammar(data):
productions = [] productions = []
production = [] production = []
for line in data: for line in data:
if line != '\n': if line != "\n":
production.append(line) production.append(line)
else: else:
productions.append(production) productions.append(production)
@ -172,16 +176,17 @@ def process_antlr4_grammar(data):
for production in productions: for production in productions:
rules = [] rules = []
init = production[0] init = production[0]
nonterminal = init.split(':')[0] nonterminal = init.split(":")[0]
rules.append(strip_chars(init.split(':')[1]).strip('| ')) rules.append(strip_chars(init.split(":")[1]).strip("| "))
for production_rule in production[1:]: for production_rule in production[1:]:
rules.append(strip_chars(production_rule.split('|')[0])) rules.append(strip_chars(production_rule.split("|")[0]))
final_rule_set[nonterminal] = rules final_rule_set[nonterminal] = rules
# for line in data: # for line in data:
# if line != '\n': # if line != '\n':
# production.append(line) # production.append(line)
return final_rule_set return final_rule_set
def remove_unit(grammar): def remove_unit(grammar):
nounitproductions = False nounitproductions = False
old_grammar = copy.deepcopy(grammar) old_grammar = copy.deepcopy(grammar)
@ -213,19 +218,21 @@ def remove_unit(grammar):
new_grammar = defaultdict(list) new_grammar = defaultdict(list)
return new_grammar return new_grammar
def isTerminal(rule): def isTerminal(rule):
# pattern = re.compile("([r]*\'[\s\S]+\')") # pattern = re.compile("([r]*\'[\s\S]+\')")
pattern = re.compile("\'(.*?)\'") pattern = re.compile("'(.*?)'")
match = pattern.match(rule) match = pattern.match(rule)
if match: if match:
return True return True
else: else:
return False return False
def remove_mixed(grammar): def remove_mixed(grammar):
''' """
Remove rules where there are terminals mixed in with non-terminals Remove rules where there are terminals mixed in with non-terminals
''' """
new_grammar = defaultdict(list) new_grammar = defaultdict(list)
for lhs, rules in grammar.items(): for lhs, rules in grammar.items():
for rhs in rules: for rhs in rules:
@ -248,17 +255,20 @@ def remove_mixed(grammar):
regen_rule.append(new_nonterm) regen_rule.append(new_nonterm)
else: else:
regen_rule.append(token) regen_rule.append(token)
new_grammar[lhs].append(' '.join(regen_rule)) new_grammar[lhs].append(" ".join(regen_rule))
return new_grammar return new_grammar
def strip_chars(rule): def strip_chars(rule):
return rule.strip('\n\t ') return rule.strip("\n\t ")
def get_nonterminal(): def get_nonterminal():
global COUNT global COUNT
COUNT += 1 COUNT += 1
return f"GeneratedTermVar{COUNT}" return f"GeneratedTermVar{COUNT}"
def terminal_exist(token, grammar): def terminal_exist(token, grammar):
for nonterminal, rules in grammar.items(): for nonterminal, rules in grammar.items():
if token in rules and len(token) == 1: if token in rules and len(token) == 1:
@ -269,42 +279,37 @@ def terminal_exist(token, grammar):
def main(grammar_file, out, start): def main(grammar_file, out, start):
grammar = None grammar = None
# If grammar file is a preprocessed NT file, then skip preprocessing # If grammar file is a preprocessed NT file, then skip preprocessing
if '.json' in grammar_file: if ".json" in grammar_file:
with open(grammar_file, 'r') as fd: with open(grammar_file, "r") as fd:
grammar = json.load(fd) grammar = json.load(fd)
elif '.g4' in grammar_file: elif ".g4" in grammar_file:
with open(grammar_file, 'r') as fd: with open(grammar_file, "r") as fd:
data = fd.readlines() data = fd.readlines()
grammar = process_antlr4_grammar(data) grammar = process_antlr4_grammar(data)
else: else:
raise('Unknwown file format passed. Accepts (.g4/.json)') raise ("Unknwown file format passed. Accepts (.g4/.json)")
grammar = convert_to_gnf(grammar, start) grammar = convert_to_gnf(grammar, start)
with open(out, 'w+') as fd: with open(out, "w+") as fd:
json.dump(grammar, fd) json.dump(grammar, fd)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description = 'Script to convert grammar to GNF form')
parser = argparse.ArgumentParser(
description="Script to convert grammar to GNF form"
)
parser.add_argument( parser.add_argument(
'--gf', "--gf", type=str, required=True, help="Location of grammar file"
type = str, )
required = True,
help = 'Location of grammar file')
parser.add_argument( parser.add_argument(
'--out', "--out", type=str, required=True, help="Location of output file"
type = str, )
required = True, parser.add_argument("--start", type=str, required=True, help="Start token")
help = 'Location of output file')
parser.add_argument( parser.add_argument(
'--start', "--debug", action="store_true", help="Write intermediate states to debug files"
type = str, )
required = True,
help = 'Start token')
parser.add_argument(
'--debug',
action='store_true',
help = 'Write intermediate states to debug files')
args = parser.parse_args() args = parser.parse_args()
DEBUG = args.debug DEBUG = args.debug

View File

@ -20,3 +20,4 @@ tokio = { version = "1.38", features = [
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
exitcode = "1.1" exitcode = "1.1"
which = "6.0" which = "6.0"
colored = "2.1.0"

View File

@ -78,12 +78,13 @@ use std::{
}; };
use clap::Parser; use clap::Parser;
use colored::Colorize;
use regex::RegexSet; use regex::RegexSet;
use tokio::{process::Command, task::JoinSet}; use tokio::{process::Command, task::JoinSet};
use walkdir::{DirEntry, WalkDir}; use walkdir::{DirEntry, WalkDir};
use which::which; use which::which;
const REF_LLVM_VERSION: u32 = 18; const REF_LLVM_VERSION: u32 = 19;
fn is_workspace_toml(path: &Path) -> bool { fn is_workspace_toml(path: &Path) -> bool {
for line in read_to_string(path).unwrap().lines() { for line in read_to_string(path).unwrap().lines() {
@ -249,20 +250,29 @@ async fn main() -> io::Result<()> {
tokio_joinset.spawn(run_cargo_fmt(project, cli.check, cli.verbose)); tokio_joinset.spawn(run_cargo_fmt(project, cli.check, cli.verbose));
} }
let ref_clang_format = format!("clang-format-{REF_LLVM_VERSION}"); let reference_clang_format = format!("clang-format-{REF_LLVM_VERSION}");
let unspecified_clang_format = "clang-format";
let (clang, warning) = if which(&reference_clang_format).is_ok() {
(Some(reference_clang_format.as_str()), None)
} else if which(unspecified_clang_format).is_ok() {
let version = Command::new(unspecified_clang_format)
.arg("--version")
.output()
.await?
.stdout;
let (clang, warning) = if which(ref_clang_format.clone()).is_ok() {
// can't use 18 for ci.
(Some(ref_clang_format), None)
} else if which("clang-format").is_ok() {
( (
Some("clang-format".to_string()), Some(unspecified_clang_format),
Some("using clang-format, could provide a different result from clang-format-17"), Some(format!(
"using {}, could provide a different result from clang-format-17",
from_utf8(&version).unwrap().replace('\n', "")
)),
) )
} else { } else {
( (
None, None,
Some("clang-format not found. Skipping C formatting..."), Some("clang-format not found. Skipping C formatting...".to_string()),
) )
}; };
// println!("Using {:#?} to format...", clang); // println!("Using {:#?} to format...", clang);
@ -277,7 +287,12 @@ async fn main() -> io::Result<()> {
.collect(); .collect();
for c_file in c_files_to_fmt { for c_file in c_files_to_fmt {
tokio_joinset.spawn(run_clang_fmt(c_file, clang.clone(), cli.check, cli.verbose)); tokio_joinset.spawn(run_clang_fmt(
c_file,
clang.to_string(),
cli.check,
cli.verbose,
));
} }
} }
@ -292,7 +307,7 @@ async fn main() -> io::Result<()> {
} }
if let Some(warning) = warning { if let Some(warning) = warning {
println!("Warning: {warning}"); println!("\n{}: {}\n", "Warning".yellow().bold(), warning);
} }
if cli.check { if cli.check {