Cmplog New Pass Manager & LLVM 14 Fixes (#626)

* wip

* more

* match aflpp

* llvm14

* fix

* more llvm14

* check llvm version in libafl_cc

* safe access

* more

* fmt

* no windows

* no windows
This commit is contained in:
Dongjia Zhang 2022-05-17 15:45:48 +09:00 committed by GitHub
parent 2ead2c398e
commit afb32fb351
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 144 additions and 21 deletions

View File

@ -112,6 +112,21 @@ fn main() {
let llvm_config = find_llvm_config(); let llvm_config = find_llvm_config();
// Get LLVM version.
let llvm_version = match llvm_config.split('-').collect::<Vec<&str>>().get(2) {
Some(ver) => ver.parse::<usize>().ok(),
None => None,
};
match llvm_version {
Some(ver) => {
if ver >= 14 {
custom_flags.push("-DUSE_NEW_PM".to_string());
}
}
None => (),
}
if let Ok(output) = Command::new(&llvm_config).args(&["--bindir"]).output() { if let Ok(output) = Command::new(&llvm_config).args(&["--bindir"]).output() {
let llvm_bindir = Path::new( let llvm_bindir = Path::new(
str::from_utf8(&output.stdout) str::from_utf8(&output.stdout)
@ -132,11 +147,15 @@ fn main() {
/// The size of the accounting maps /// The size of the accounting maps
pub const ACCOUNTING_MAP_SIZE: usize = {}; pub const ACCOUNTING_MAP_SIZE: usize = {};
/// The llvm version used to build llvm passes
pub const LIBAFL_CC_LLVM_VERSION: Option<usize> = {:?};
", ",
llvm_bindir.join("clang"), llvm_bindir.join("clang"),
llvm_bindir.join("clang++"), llvm_bindir.join("clang++"),
edges_map_size, edges_map_size,
acc_map_size acc_map_size,
llvm_version,
) )
.expect("Could not write file"); .expect("Could not write file");
@ -206,7 +225,6 @@ fn main() {
.args(&cxxflags) .args(&cxxflags)
.args(&custom_flags) .args(&custom_flags)
.arg(src_dir.join("autotokens-pass.cc")) .arg(src_dir.join("autotokens-pass.cc"))
//.arg("-DUSE_NEW_PM")
.args(&ldflags) .args(&ldflags)
.args(&["-fPIC", "-shared", "-o"]) .args(&["-fPIC", "-shared", "-o"])
.arg(out_dir.join(format!("autotokens-pass.{}", dll_extension()))) .arg(out_dir.join(format!("autotokens-pass.{}", dll_extension())))

View File

@ -140,7 +140,9 @@ llvmGetPassPluginInfo() {
/* lambda to insert our pass into the pass pipeline. */ /* lambda to insert our pass into the pass pipeline. */
[](PassBuilder &PB) { [](PassBuilder &PB) {
#if 1 #if 1
#if LLVM_VERSION_MAJOR <= 13
using OptimizationLevel = typename PassBuilder::OptimizationLevel; using OptimizationLevel = typename PassBuilder::OptimizationLevel;
#endif
PB.registerOptimizerLastEPCallback( PB.registerOptimizerLastEPCallback(
[](ModulePassManager &MPM, OptimizationLevel OL) { [](ModulePassManager &MPM, OptimizationLevel OL) {
MPM.addPass(AFLCoverage()); MPM.addPass(AFLCoverage());
@ -433,7 +435,11 @@ bool AFLCoverage::runOnModule(Module &M) {
if (instrument_ctx && &BB == &F.getEntryBlock()) { if (instrument_ctx && &BB == &F.getEntryBlock()) {
#ifdef HAVE_VECTOR_INTRINSICS #ifdef HAVE_VECTOR_INTRINSICS
if (CtxK) { if (CtxK) {
PrevCaller = IRB.CreateLoad(AFLPrevCaller); PrevCaller = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
PrevCallerTy,
#endif
AFLPrevCaller);
PrevCaller->setMetadata(M.getMDKindID("nosanitize"), PrevCaller->setMetadata(M.getMDKindID("nosanitize"),
MDNode::get(C, None)); MDNode::get(C, None));
PrevCtx = PrevCtx =
@ -445,7 +451,12 @@ bool AFLCoverage::runOnModule(Module &M) {
// load the context ID of the previous function and write to to a // load the context ID of the previous function and write to to a
// local variable on the stack // local variable on the stack
LoadInst *PrevCtxLoad = IRB.CreateLoad(AFLContext); LoadInst *PrevCtxLoad = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
IRB.getInt32Ty(),
#endif
AFLContext
);
PrevCtxLoad->setMetadata(M.getMDKindID("nosanitize"), PrevCtxLoad->setMetadata(M.getMDKindID("nosanitize"),
MDNode::get(C, None)); MDNode::get(C, None));
PrevCtx = PrevCtxLoad; PrevCtx = PrevCtxLoad;
@ -573,7 +584,11 @@ bool AFLCoverage::runOnModule(Module &M) {
/* Load prev_loc */ /* Load prev_loc */
LoadInst *PrevLoc = IRB.CreateLoad(AFLPrevLoc); LoadInst *PrevLoc = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
PrevLocTy,
#endif
AFLPrevLoc);
PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *PrevLocTrans; Value *PrevLocTrans;
@ -597,20 +612,31 @@ bool AFLCoverage::runOnModule(Module &M) {
/* Load SHM pointer */ /* Load SHM pointer */
LoadInst *MapPtr = IRB.CreateLoad(AFLMapPtr); LoadInst *MapPtr = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
PointerType::get(Int8Ty, 0),
#endif
AFLMapPtr);
MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *MapPtrIdx; Value *MapPtrIdx;
#ifdef HAVE_VECTOR_INTRINSICS #ifdef HAVE_VECTOR_INTRINSICS
if (Ngram) if (Ngram)
MapPtrIdx = IRB.CreateGEP( MapPtrIdx = IRB.CreateGEP(
#if LLVM_VERSION_MAJOR >= 14
Int8Ty,
#endif
MapPtr, MapPtr,
IRB.CreateZExt( IRB.CreateZExt(
IRB.CreateXor(PrevLocTrans, IRB.CreateZExt(CurLoc, Int32Ty)), IRB.CreateXor(PrevLocTrans, IRB.CreateZExt(CurLoc, Int32Ty)),
Int32Ty)); Int32Ty));
else else
#endif #endif
MapPtrIdx = IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocTrans, CurLoc)); MapPtrIdx = IRB.CreateGEP(
#if LLVM_VERSION_MAJOR >= 14
Int8Ty,
#endif
MapPtr, IRB.CreateXor(PrevLocTrans, CurLoc));
/* Update bitmap */ /* Update bitmap */
@ -643,7 +669,11 @@ bool AFLCoverage::runOnModule(Module &M) {
*/ */
} else { } else {
LoadInst *Counter = IRB.CreateLoad(MapPtrIdx); LoadInst *Counter = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
IRB.getInt8Ty(),
#endif
MapPtrIdx);
Counter->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); Counter->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *Incr = IRB.CreateAdd(Counter, One); Value *Incr = IRB.CreateAdd(Counter, One);

View File

@ -349,6 +349,14 @@ impl ClangWrapper {
/// Create a new Clang Wrapper /// Create a new Clang Wrapper
#[must_use] #[must_use]
pub fn new() -> Self { pub fn new() -> Self {
#[cfg(unix)]
let use_new_pm = match LIBAFL_CC_LLVM_VERSION {
Some(ver) => ver >= 14,
None => false,
};
#[cfg(not(unix))]
let use_new_pm = false;
Self { Self {
optimize: true, optimize: true,
wrapped_cc: CLANG_PATH.into(), wrapped_cc: CLANG_PATH.into(),
@ -361,7 +369,7 @@ impl ClangWrapper {
bit_mode: 0, bit_mode: 0,
need_libafl_arg: false, need_libafl_arg: false,
has_libafl_arg: false, has_libafl_arg: false,
use_new_pm: false, use_new_pm,
parse_args_called: false, parse_args_called: false,
base_args: vec![], base_args: vec![],
cc_args: vec![], cc_args: vec![],

View File

@ -25,9 +25,16 @@
#include <sys/time.h> #include <sys/time.h>
#include "llvm/Config/llvm-config.h" #include "llvm/Config/llvm-config.h"
#if USE_NEW_PM
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/IR/PassManager.h"
#else
#include "llvm/IR/LegacyPassManager.h"
#endif
#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Statistic.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@ -111,22 +118,33 @@ bool isIgnoreFunction(const llvm::Function *F) {
return false; return false;
} }
#if USE_NEW_PM
class CmpLogRoutines : public PassInfoMixin<CmpLogRoutines> {
public:
CmpLogRoutines() {
#else
class CmpLogRoutines : public ModulePass { class CmpLogRoutines : public ModulePass {
public: public:
static char ID; static char ID;
CmpLogRoutines() : ModulePass(ID) { CmpLogRoutines() : ModulePass(ID) {
#endif
} }
bool runOnModule(Module &M) override; #if USE_NEW_PM
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
#if LLVM_VERSION_MAJOR < 4
const char *getPassName() const override {
#else #else
bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR < 4
const char *getPassName() const override {
#else
StringRef getPassName() const override { StringRef getPassName() const override {
#endif #endif
return "cmplog routines"; return "cmplog routines";
} }
#endif
private: private:
bool hookRtns(Module &M); bool hookRtns(Module &M);
@ -134,7 +152,23 @@ class CmpLogRoutines : public ModulePass {
} // namespace } // namespace
#if USE_NEW_PM
extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo() {
return {LLVM_PLUGIN_API_VERSION, "CmpLogRoutines", "v0.1",
[](PassBuilder &PB) {
#if LLVM_VERSION_MAJOR <= 13
using OptimizationLevel = typename PassBuilder::OptimizationLevel;
#endif
PB.registerOptimizerLastEPCallback(
[](ModulePassManager &MPM, OptimizationLevel OL) {
MPM.addPass(CmpLogRoutines());
});
}};
}
#else
char CmpLogRoutines::ID = 0; char CmpLogRoutines::ID = 0;
#endif
bool CmpLogRoutines::hookRtns(Module &M) { bool CmpLogRoutines::hookRtns(Module &M) {
std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC; std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC;
@ -407,13 +441,27 @@ bool CmpLogRoutines::hookRtns(Module &M) {
return true; return true;
} }
#if USE_NEW_PM
PreservedAnalyses CmpLogRoutines::run(Module &M, ModuleAnalysisManager &MAM) {
#else
bool CmpLogRoutines::runOnModule(Module &M) { bool CmpLogRoutines::runOnModule(Module &M) {
#endif
hookRtns(M); hookRtns(M);
#if USE_NEW_PM
auto PA = PreservedAnalyses::all();
#endif
verifyModule(M); verifyModule(M);
#if USE_NEW_PM
return PA;
#else
return true; return true;
#endif
} }
#if USE_NEW_PM
#else
static void registerCmpLogRoutinesPass(const PassManagerBuilder &, static void registerCmpLogRoutinesPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) { legacy::PassManagerBase &PM) {
auto p = new CmpLogRoutines(); auto p = new CmpLogRoutines();
@ -426,8 +474,10 @@ static RegisterStandardPasses RegisterCmpLogRoutinesPass(
static RegisterStandardPasses RegisterCmpLogRoutinesPass0( static RegisterStandardPasses RegisterCmpLogRoutinesPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogRoutinesPass); PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogRoutinesPass);
#if LLVM_VERSION_MAJOR >= 11 #if LLVM_VERSION_MAJOR >= 11
static RegisterStandardPasses RegisterCmpLogRoutinesPassLTO( static RegisterStandardPasses RegisterCmpLogRoutinesPassLTO(
PassManagerBuilder::EP_FullLinkTimeOptimizationLast, PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
registerCmpLogRoutinesPass); registerCmpLogRoutinesPass);
#endif
#endif #endif

View File

@ -140,7 +140,9 @@ llvmGetPassPluginInfo() {
/* lambda to insert our pass into the pass pipeline. */ /* lambda to insert our pass into the pass pipeline. */
[](PassBuilder &PB) { [](PassBuilder &PB) {
#if 1 #if 1
#if LLVM_VERSION_MAJOR <= 13
using OptimizationLevel = typename PassBuilder::OptimizationLevel; using OptimizationLevel = typename PassBuilder::OptimizationLevel;
#endif
PB.registerOptimizerLastEPCallback( PB.registerOptimizerLastEPCallback(
[](ModulePassManager &MPM, OptimizationLevel OL) { [](ModulePassManager &MPM, OptimizationLevel OL) {
MPM.addPass(AFLCoverage()); MPM.addPass(AFLCoverage());
@ -254,20 +256,35 @@ bool AFLCoverage::runOnModule(Module &M) {
/* Load prev_loc */ /* Load prev_loc */
LoadInst *PrevLoc = IRB.CreateLoad(AFLPrevLoc); LoadInst *PrevLoc = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
Int32Ty,
#endif
AFLPrevLoc);
PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
/* Load SHM pointer */ /* Load SHM pointer */
LoadInst *MemReadPtr = IRB.CreateLoad(AFLMemOpPtr); LoadInst *MemReadPtr = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
PointerType::get(Int32Ty, 0),
#endif
AFLMemOpPtr);
MemReadPtr->setMetadata(M.getMDKindID("nosanitize"), MemReadPtr->setMetadata(M.getMDKindID("nosanitize"),
MDNode::get(C, None)); MDNode::get(C, None));
Value *MemReadPtrIdx = Value *MemReadPtrIdx = IRB.CreateGEP(
IRB.CreateGEP(MemReadPtr, IRB.CreateXor(PrevLoc, CurLoc)); #if LLVM_VERSION_MAJOR >= 14
Int32Ty,
#endif
MemReadPtr, IRB.CreateXor(PrevLoc, CurLoc));
/* Update bitmap */ /* Update bitmap */
LoadInst *MemReadCount = IRB.CreateLoad(MemReadPtrIdx); LoadInst *MemReadCount = IRB.CreateLoad(
#if LLVM_VERSION_MAJOR >= 14
Int32Ty,
#endif
MemReadPtrIdx);
MemReadCount->setMetadata(M.getMDKindID("nosanitize"), MemReadCount->setMetadata(M.getMDKindID("nosanitize"),
MDNode::get(C, None)); MDNode::get(C, None));
Value *MemReadIncr = Value *MemReadIncr =