From 583c84ab4ee6c245f655d91dcb80c11179c03813 Mon Sep 17 00:00:00 2001 From: "Dongjia \"toka\" Zhang" Date: Mon, 2 Oct 2023 06:06:34 +0200 Subject: [PATCH] cmplog routines update & fix (#1592) * update * runtime * Update cmplog-routines-pass.cc (#1589) * rtm * fix * no link rt * fmt * let's change script in another pr * colon * adjust the checks * fix * more fixes * FMT --- libafl_cc/src/cmplog-routines-pass.cc | 211 +++++++++++++++++++++++++- libafl_cc/src/no-link-rt.c | 18 +++ libafl_targets/src/cmplog.c | 116 ++++++++++++-- libafl_targets/src/common.h | 16 ++ 4 files changed, 339 insertions(+), 22 deletions(-) diff --git a/libafl_cc/src/cmplog-routines-pass.cc b/libafl_cc/src/cmplog-routines-pass.cc index 437917d3f0..37cbd6026b 100644 --- a/libafl_cc/src/cmplog-routines-pass.cc +++ b/libafl_cc/src/cmplog-routines-pass.cc @@ -170,14 +170,17 @@ llvmGetPassPluginInfo() { #else char CmpLogRoutines::ID = 0; #endif +#include bool CmpLogRoutines::hookRtns(Module &M) { - std::vector calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC; - LLVMContext &C = M.getContext(); + std::vector calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC, + Memcmp, Strcmp, Strncmp; + LLVMContext &C = M.getContext(); Type *VoidTy = Type::getVoidTy(C); // PointerType *VoidPtrTy = PointerType::get(VoidTy, 0); IntegerType *Int8Ty = IntegerType::getInt8Ty(C); + IntegerType *Int64Ty = IntegerType::getInt64Ty(C); PointerType *i8PtrTy = PointerType::get(Int8Ty, 0); #if LLVM_VERSION_MAJOR < 9 @@ -269,6 +272,60 @@ bool CmpLogRoutines::hookRtns(Module &M) { FunctionCallee cmplogGccStdC = c4; #endif +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee +#else + Constant * +#endif + c5 = M.getOrInsertFunction("__cmplog_rtn_hook_n", VoidTy, i8PtrTy, + i8PtrTy, Int64Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee cmplogHookFnN = c5; +#else + Function *cmplogHookFnN = cast(c5); +#endif + +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee +#else + Constant * +#endif + c6 = M.getOrInsertFunction("__cmplog_rtn_hook_strn", VoidTy, i8PtrTy, + i8PtrTy, Int64Ty +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee cmplogHookFnStrN = c6; +#else + Function *cmplogHookFnStrN = cast(c6); +#endif + +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee +#else + Constant * +#endif + c7 = M.getOrInsertFunction("__cmplog_rtn_hook_str", VoidTy, i8PtrTy, + i8PtrTy +#if LLVM_VERSION_MAJOR < 5 + , + NULL +#endif + ); +#if LLVM_VERSION_MAJOR >= 9 + FunctionCallee cmplogHookFnStr = c7; +#else + Function *cmplogHookFnStr = cast(c7); +#endif + /* iterate over all functions, bbs and instruction and add suitable calls */ for (auto &F : M) { if (isIgnoreFunction(&F)) { continue; } @@ -283,12 +340,87 @@ bool CmpLogRoutines::hookRtns(Module &M) { if (callInst->getCallingConv() != llvm::CallingConv::C) { continue; } FunctionType *FT = Callee->getFunctionType(); + std::string FuncName = Callee->getName().str(); bool isPtrRtn = FT->getNumParams() >= 2 && !FT->getReturnType()->isVoidTy() && FT->getParamType(0) == FT->getParamType(1) && FT->getParamType(0)->isPointerTy(); + bool isPtrRtnN = FT->getNumParams() >= 3 && + !FT->getReturnType()->isVoidTy() && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0)->isPointerTy() && + FT->getParamType(2)->isIntegerTy(); + if (isPtrRtnN) { + auto intTyOp = + dyn_cast(callInst->getArgOperand(2)->getType()); + if (intTyOp) { + if (intTyOp->getBitWidth() != 32 && + intTyOp->getBitWidth() != 64) { + isPtrRtnN = false; + } + } + } + + bool isMemcmp = + (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") || + !FuncName.compare("CRYPTO_memcmp") || + !FuncName.compare("OPENSSL_memcmp") || + !FuncName.compare("memcmp_const_time") || + !FuncName.compare("memcmpct")); + isMemcmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0)->isPointerTy() && + FT->getParamType(1)->isPointerTy() && + FT->getParamType(2)->isIntegerTy(); + + bool isStrcmp = + (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") || + !FuncName.compare("xmlStrEqual") || + !FuncName.compare("g_strcmp0") || + !FuncName.compare("curl_strequal") || + !FuncName.compare("strcsequal") || + !FuncName.compare("strcasecmp") || + !FuncName.compare("stricmp") || + !FuncName.compare("ap_cstr_casecmp") || + !FuncName.compare("OPENSSL_strcasecmp") || + !FuncName.compare("xmlStrcasecmp") || + !FuncName.compare("g_strcasecmp") || + !FuncName.compare("g_ascii_strcasecmp") || + !FuncName.compare("Curl_strcasecompare") || + !FuncName.compare("Curl_safe_strcasecompare") || + !FuncName.compare("cmsstrcasecmp") || + !FuncName.compare("strstr") || + !FuncName.compare("g_strstr_len") || + !FuncName.compare("ap_strcasestr") || + !FuncName.compare("xmlStrstr") || + !FuncName.compare("xmlStrcasestr") || + !FuncName.compare("g_str_has_prefix") || + !FuncName.compare("g_str_has_suffix")); + isStrcmp &= + FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + + bool isStrncmp = (!FuncName.compare("strncmp") || + !FuncName.compare("xmlStrncmp") || + !FuncName.compare("curl_strnequal") || + !FuncName.compare("strncasecmp") || + !FuncName.compare("strnicmp") || + !FuncName.compare("ap_cstr_casecmpn") || + !FuncName.compare("OPENSSL_strncasecmp") || + !FuncName.compare("xmlStrncasecmp") || + !FuncName.compare("g_ascii_strncasecmp") || + !FuncName.compare("Curl_strncasecompare") || + !FuncName.compare("g_strncasecmp")); + isStrncmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == + IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); + bool isGccStdStringStdString = Callee->getName().find("__is_charIT_EE7__value") != std::string::npos && @@ -336,11 +468,17 @@ bool CmpLogRoutines::hookRtns(Module &M) { */ if (isGccStdStringCString || isGccStdStringStdString || - isLlvmStdStringStdString || isLlvmStdStringCString) { - isPtrRtn = false; + isLlvmStdStringStdString || isLlvmStdStringCString || isMemcmp || + isStrcmp || isStrncmp) { + isPtrRtnN = isPtrRtn = false; } + if (isPtrRtnN) { isPtrRtn = false; } + if (isPtrRtn) { calls.push_back(callInst); } + if (isMemcmp || isPtrRtnN) { Memcmp.push_back(callInst); } + if (isStrcmp) { Strcmp.push_back(callInst); } + if (isStrncmp) { Strncmp.push_back(callInst); } if (isGccStdStringStdString) { gccStdStd.push_back(callInst); } if (isGccStdStringCString) { gccStdC.push_back(callInst); } if (isLlvmStdStringStdString) { llvmStdStd.push_back(callInst); } @@ -351,9 +489,9 @@ bool CmpLogRoutines::hookRtns(Module &M) { } if (!calls.size() && !gccStdStd.size() && !gccStdC.size() && - !llvmStdStd.size() && !llvmStdC.size()) { + !llvmStdStd.size() && !llvmStdC.size() && !Memcmp.size() && + Strcmp.size() && Strncmp.size()) return false; - } for (auto &callInst : calls) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); @@ -372,6 +510,67 @@ bool CmpLogRoutines::hookRtns(Module &M) { // errs() << callInst->getCalledFunction()->getName() << "\n"; } + for (auto &callInst : Memcmp) { + Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), + *v3P = callInst->getArgOperand(2); + + IRBuilder<> IRB(callInst->getParent()); + IRB.SetInsertPoint(callInst); + + std::vector args; + Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); + Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); + Value *v3Pbitcast = IRB.CreateBitCast( + v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); + Value *v3Pcasted = + IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); + args.push_back(v1Pcasted); + args.push_back(v2Pcasted); + args.push_back(v3Pcasted); + + IRB.CreateCall(cmplogHookFnN, args); + + // errs() << callInst->getCalledFunction()->getName() << "\n"; + } + + for (auto &callInst : Strcmp) { + Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); + + IRBuilder<> IRB(callInst->getParent()); + IRB.SetInsertPoint(callInst); + std::vector args; + Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); + Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); + args.push_back(v1Pcasted); + args.push_back(v2Pcasted); + + IRB.CreateCall(cmplogHookFnStr, args); + + // errs() << callInst->getCalledFunction()->getName() << "\n"; + } + + for (auto &callInst : Strncmp) { + Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), + *v3P = callInst->getArgOperand(2); + + IRBuilder<> IRB(callInst->getParent()); + IRB.SetInsertPoint(callInst); + std::vector args; + Value *v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); + Value *v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); + Value *v3Pbitcast = IRB.CreateBitCast( + v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); + Value *v3Pcasted = + IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); + args.push_back(v1Pcasted); + args.push_back(v2Pcasted); + args.push_back(v3Pcasted); + + IRB.CreateCall(cmplogHookFnStrN, args); + + // errs() << callInst->getCalledFunction()->getName() << "\n"; + } + for (auto &callInst : gccStdStd) { Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); diff --git a/libafl_cc/src/no-link-rt.c b/libafl_cc/src/no-link-rt.c index 08c069bd45..c2a972311a 100644 --- a/libafl_cc/src/no-link-rt.c +++ b/libafl_cc/src/no-link-rt.c @@ -16,6 +16,24 @@ void __cmplog_rtn_hook(uint8_t *ptr1, uint8_t *ptr2) { (void)ptr2; } +void __cmplog_rtn_hook_n(const uint8_t *ptr1, const uint8_t *ptr2, + uint64_t len) { + (void)ptr1; + (void)ptr2; + (void)len; +} + +void __cmplog_rtn_hook_str(const uint8_t *ptr1, uint8_t *ptr2) { + (void)ptr1; + (void)ptr2; +} + +void __cmplog_rtn_hook_strn(uint8_t *ptr1, uint8_t *ptr2, uint64_t len) { + (void)ptr1; + (void)ptr2; + (void)len; +} + void __cmplog_rtn_gcc_stdstring_cstring(uint8_t *stdstring, uint8_t *cstring) { (void)stdstring; (void)cstring; diff --git a/libafl_targets/src/cmplog.c b/libafl_targets/src/cmplog.c index d86689ae5b..7ae3a46164 100644 --- a/libafl_targets/src/cmplog.c +++ b/libafl_targets/src/cmplog.c @@ -3,6 +3,7 @@ #define CMPLOG_MODULE #include "common.h" #include "cmplog.h" +#include #if defined(_WIN32) @@ -92,6 +93,7 @@ static long area_is_valid(const void *ptr, size_t len) { } } +// cmplog routines after area check void __libafl_targets_cmplog_routines_checked(uintptr_t k, const uint8_t *ptr1, const uint8_t *ptr2, size_t len) { libafl_cmplog_enabled = false; @@ -105,7 +107,8 @@ void __libafl_targets_cmplog_routines_checked(uintptr_t k, const uint8_t *ptr1, } else { hits = libafl_cmplog_map_ptr->headers[k].hits++; if (libafl_cmplog_map_ptr->headers[k].shape < len) { - libafl_cmplog_map_ptr->headers[k].shape = len; + libafl_cmplog_map_ptr->headers[k].shape = + len; // TODO; adjust len for AFL++'s cmplog protocol } } @@ -115,6 +118,7 @@ void __libafl_targets_cmplog_routines_checked(uintptr_t k, const uint8_t *ptr1, libafl_cmplog_enabled = true; } +// Very generic cmplog routines callback void __libafl_targets_cmplog_routines(uintptr_t k, const uint8_t *ptr1, const uint8_t *ptr2) { if (!libafl_cmplog_enabled) { return; } @@ -129,6 +133,7 @@ void __libafl_targets_cmplog_routines(uintptr_t k, const uint8_t *ptr1, __libafl_targets_cmplog_routines_checked(k, ptr1, ptr2, len); } +// cmplog routines but with len specified void __libafl_targets_cmplog_routines_len(uintptr_t k, const uint8_t *ptr1, const uint8_t *ptr2, size_t len) { if (!libafl_cmplog_enabled) { return; } @@ -149,6 +154,58 @@ void __cmplog_rtn_hook(const uint8_t *ptr1, const uint8_t *ptr2) { __libafl_targets_cmplog_routines(k, ptr1, ptr2); } +void __cmplog_rtn_hook_n(const uint8_t *ptr1, const uint8_t *ptr2, + uint64_t len) { + (void)(len); + __cmplog_rtn_hook(ptr1, ptr2); +} + +/* hook for string functions, eg. strcmp, strcasecmp etc. */ +void __cmplog_rtn_hook_str(const uint8_t *ptr1, uint8_t *ptr2) { + if (!libafl_cmplog_enabled) { return; } + if (unlikely(!ptr1 || !ptr2)) return; + + // these strnlen could indeed fail. but if it fails here it will sigsegv in the following hooked function call anyways + int len1 = strnlen(ptr1, 30) + 1; + int len2 = strnlen(ptr2, 30) + 1; + int l = MAX(len1, len2); + + l = MIN(l, area_is_valid(ptr1, l + 1)); // can we really access it? check + l = MIN(l, area_is_valid(ptr2, l + 1)); // can we really access it? check + + if (l < 2) return; + + intptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + + __libafl_targets_cmplog_routines_checked(k, ptr1, ptr2, l); +} + +/* hook for string with length functions, eg. strncmp, strncasecmp etc. + Note that we ignore the len parameter and take longer strings if present. */ +void __cmplog_rtn_hook_strn(uint8_t *ptr1, uint8_t *ptr2, uint64_t len) { + if (!libafl_cmplog_enabled) { return; } + if (unlikely(!ptr1 || !ptr2)) return; + + int len0 = MIN(len, 31); // cap by 31 + // these strnlen could indeed fail. but if it fails here it will sigsegv in the following hooked function call anyways + int len1 = strnlen(ptr1, len0); + int len2 = strnlen(ptr2, len0); + int l = MAX(len1, len2); + + l = MIN(l, area_is_valid(ptr1, l + 1)); // can we really access it? check + l = MIN(l, area_is_valid(ptr2, l + 1)); // can we really access it? check + + if (l < 2) return; + + intptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + + __libafl_targets_cmplog_routines_checked(k, ptr1, ptr2, l); +} + // gcc libstdc++ // _ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareEPKc static const uint8_t *get_gcc_stdstring(const uint8_t *string) { @@ -179,39 +236,66 @@ static const uint8_t *get_llvm_stdstring(const uint8_t *string) { void __cmplog_rtn_gcc_stdstring_cstring(const uint8_t *stdstring, const uint8_t *cstring) { if (!libafl_cmplog_enabled) { return; } - if (area_is_valid(stdstring, 32) <= 0) { return; } + int l1 = area_is_valid(stdstring, 32); + if (l1 <= 0) { return; } + int l2 = area_is_valid(cstring, 32); + if (l2 <= 0) { return; } - __cmplog_rtn_hook(get_gcc_stdstring(stdstring), cstring); + int len = MIN(31, MIN(l1, l2)); + + uintptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + __libafl_targets_cmplog_routines_checked(k, get_gcc_stdstring(stdstring), + cstring, len); } void __cmplog_rtn_gcc_stdstring_stdstring(const uint8_t *stdstring1, const uint8_t *stdstring2) { if (!libafl_cmplog_enabled) { return; } - if (area_is_valid(stdstring1, 32) <= 0 || - area_is_valid(stdstring2, 32) <= 0) { - return; - } + int l1 = area_is_valid(stdstring1, 32); + if (l1 <= 0) { return; } + int l2 = area_is_valid(stdstring2, 32); + if (l2 <= 0) { return; } - __cmplog_rtn_hook(get_gcc_stdstring(stdstring1), - get_gcc_stdstring(stdstring2)); + int len = MIN(31, MIN(l1, l2)); + uintptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + __libafl_targets_cmplog_routines_checked(k, get_gcc_stdstring(stdstring1), + get_gcc_stdstring(stdstring2), len); } void __cmplog_rtn_llvm_stdstring_cstring(const uint8_t *stdstring, const uint8_t *cstring) { if (!libafl_cmplog_enabled) { return; } if (area_is_valid(stdstring, 32) <= 0) { return; } + int l1 = area_is_valid(stdstring, 32); + if (l1 <= 0) { return; } + int l2 = area_is_valid(cstring, 32); + if (l2 <= 0) { return; } - __cmplog_rtn_hook(get_llvm_stdstring(stdstring), cstring); + int len = MIN(31, MIN(l1, l2)); + uintptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + __libafl_targets_cmplog_routines_checked(k, get_llvm_stdstring(stdstring), + cstring, len); } void __cmplog_rtn_llvm_stdstring_stdstring(const uint8_t *stdstring1, const uint8_t *stdstring2) { if (!libafl_cmplog_enabled) { return; } - if (area_is_valid(stdstring1, 32) <= 0 || - area_is_valid(stdstring2, 32) <= 0) { - return; - } + int l1 = area_is_valid(stdstring1, 32); + if (l1 <= 0) { return; } + int l2 = area_is_valid(stdstring2, 32); + if (l2 <= 0) { return; } - __cmplog_rtn_hook(get_llvm_stdstring(stdstring1), - get_llvm_stdstring(stdstring2)); + int len = MIN(31, MIN(l1, l2)); + + uintptr_t k = RETADDR; + k = (k >> 4) ^ (k << 8); + k &= CMPLOG_MAP_W - 1; + __libafl_targets_cmplog_routines_checked(k, get_llvm_stdstring(stdstring1), + get_llvm_stdstring(stdstring2), len); } diff --git a/libafl_targets/src/common.h b/libafl_targets/src/common.h index a13b47f0a9..0960b9719e 100644 --- a/libafl_targets/src/common.h +++ b/libafl_targets/src/common.h @@ -46,6 +46,22 @@ #define EXPORT_FN #endif +#if __GNUC__ < 6 + #ifndef likely + #define likely(_x) (_x) + #endif + #ifndef unlikely + #define unlikely(_x) (_x) + #endif +#else + #ifndef likely + #define likely(_x) __builtin_expect(!!(_x), 1) + #endif + #ifndef unlikely + #define unlikely(_x) __builtin_expect(!!(_x), 0) + #endif +#endif + #ifdef __GNUC__ #define MAX(a, b) \ ({ \