//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a general divergence analysis for loop vectorization // and GPU programs. It determines which branches and values in a loop or GPU // program are divergent. It can help branch optimizations such as jump // threading and loop unswitching to make better decisions. // // GPU programs typically use the SIMD execution model, where multiple threads // in the same execution group have to execute in lock-step. Therefore, if the // code contains divergent branches (i.e., threads in a group do not agree on // which path of the branch to take), the group of threads has to execute all // the paths from that branch with different subsets of threads enabled until // they re-converge. // // Due to this execution model, some optimizations such as jump // threading and loop unswitching can interfere with thread re-convergence. // Therefore, an analysis that computes which branches in a GPU program are // divergent can help the compiler to selectively run these optimizations. // // This implementation is derived from the Vectorization Analysis of the // Region Vectorizer (RV). That implementation in turn is based on the approach // described in // // Improving Performance of OpenCL on CPUs // Ralf Karrenberg and Sebastian Hack // CC '12 // // This DivergenceAnalysis implementation is generic in the sense that it does // not itself identify original sources of divergence. // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence // (e.g., special variables that hold the thread ID or the iteration variable). // // The generic implementation propagates divergence to variables that are data // or sync dependent on a source of divergence. // // While data dependency is a well-known concept, the notion of sync dependency // is worth more explanation. Sync dependence characterizes the control flow // aspect of the propagation of branch divergence. For example, // // %cond = icmp slt i32 %tid, 10 // br i1 %cond, label %then, label %else // then: // br label %merge // else: // br label %merge // merge: // %a = phi i32 [ 0, %then ], [ 1, %else ] // // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid // because %tid is not on its use-def chains, %a is sync dependent on %tid // because the branch "br i1 %cond" depends on %tid and affects which value %a // is assigned to. // // The sync dependence detection (which branch induces divergence in which join // points) is implemented in the SyncDependenceAnalysis. // // The current DivergenceAnalysis implementation has the following limitations: // 1. intra-procedural. It conservatively considers the arguments of a // non-kernel-entry function and the return value of a function call as // divergent. // 2. memory as black box. It conservatively considers values loaded from // generic or local address as divergent. This can be improved by leveraging // pointer analysis and/or by modelling non-escaping memory objects in SSA // as done in RV. // //===----------------------------------------------------------------------===// #include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; #define DEBUG_TYPE "divergence-analysis" // class DivergenceAnalysis DivergenceAnalysis::DivergenceAnalysis( const Function &F, const Loop *RegionLoop, const DominatorTree &DT, const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), IsLCSSAForm(IsLCSSAForm) {} bool DivergenceAnalysis::markDivergent(const Value &DivVal) { if (isAlwaysUniform(DivVal)) return false; assert(isa(DivVal) || isa(DivVal)); assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); return DivergentValues.insert(&DivVal).second; } void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { UniformOverrides.insert(&UniVal); } bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, const Value &Val) const { const auto *Inst = dyn_cast(&Val); if (!Inst) return false; // check whether any divergent loop carrying Val terminates before control // proceeds to ObservingBlock for (const auto *Loop = LI.getLoopFor(Inst->getParent()); Loop != RegionLoop && !Loop->contains(&ObservingBlock); Loop = Loop->getParentLoop()) { if (DivergentLoops.contains(Loop)) return true; } return false; } bool DivergenceAnalysis::inRegion(const Instruction &I) const { return I.getParent() && inRegion(*I.getParent()); } bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); } void DivergenceAnalysis::pushUsers(const Value &V) { const auto *I = dyn_cast(&V); if (I && I->isTerminator()) { analyzeControlDivergence(*I); return; } for (const auto *User : V.users()) { const auto *UserInst = dyn_cast(User); if (!UserInst) continue; // only compute divergent inside loop if (!inRegion(*UserInst)) continue; // All users of divergent values are immediate divergent if (markDivergent(*UserInst)) Worklist.push_back(UserInst); } } static const Instruction *getIfCarriedInstruction(const Use &U, const Loop &DivLoop) { const auto *I = dyn_cast(&U); if (!I) return nullptr; if (!DivLoop.contains(I)) return nullptr; return I; } void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I, const Loop &OuterDivLoop) { if (isAlwaysUniform(I)) return; if (isDivergent(I)) return; LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); assert((isa(I) || !IsLCSSAForm) && "In LCSSA form all users of loop-exiting defs are Phi nodes."); for (const Use &Op : I.operands()) { const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); if (!OpInst) continue; if (markDivergent(I)) pushUsers(I); return; } } // marks all users of loop-carried values of the loop headed by LoopHeader as // divergent void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit, const Loop &OuterDivLoop) { // All users are in immediate exit blocks if (IsLCSSAForm) { for (const auto &Phi : DivExit.phis()) { analyzeTemporalDivergence(Phi, OuterDivLoop); } return; } // For non-LCSSA we have to follow all live out edges wherever they may lead. const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); SmallVector TaintStack; TaintStack.push_back(&DivExit); // Otherwise potential users of loop-carried values could be anywhere in the // dominance region of DivLoop (including its fringes for phi nodes) DenseSet Visited; Visited.insert(&DivExit); do { auto *UserBlock = TaintStack.pop_back_val(); // don't spread divergence beyond the region if (!inRegion(*UserBlock)) continue; assert(!OuterDivLoop.contains(UserBlock) && "irreducible control flow detected"); // phi nodes at the fringes of the dominance region if (!DT.dominates(&LoopHeader, UserBlock)) { // all PHI nodes of UserBlock become divergent for (auto &Phi : UserBlock->phis()) { analyzeTemporalDivergence(Phi, OuterDivLoop); } continue; } // Taint outside users of values carried by OuterDivLoop. for (auto &I : *UserBlock) { analyzeTemporalDivergence(I, OuterDivLoop); } // visit all blocks in the dominance region for (auto *SuccBlock : successors(UserBlock)) { if (!Visited.insert(SuccBlock).second) { continue; } TaintStack.push_back(SuccBlock); } } while (!TaintStack.empty()); } void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit, const Loop &InnerDivLoop) { LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); // Find outer-most loop that does not contain \p DivExit const Loop *DivLoop = &InnerDivLoop; const Loop *OuterDivLoop = DivLoop; const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); const unsigned LoopExitDepth = ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { DivergentLoops.insert(DivLoop); // all crossed loops are divergent OuterDivLoop = DivLoop; DivLoop = DivLoop->getParentLoop(); } LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() << "\n"); analyzeLoopExitDivergence(DivExit, *OuterDivLoop); } // this is a divergent join point - mark all phi nodes as divergent and push // them onto the stack. void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() << "\n"); // ignore divergence outside the region if (!inRegion(JoinBlock)) { return; } // push non-divergent phi nodes in JoinBlock to the worklist for (const auto &Phi : JoinBlock.phis()) { if (isDivergent(Phi)) continue; // FIXME Theoretically ,the 'undef' value could be replaced by any other // value causing spurious divergence. if (Phi.hasConstantOrUndefValue()) continue; if (markDivergent(Phi)) Worklist.push_back(&Phi); } } void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) { LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() << "\n"); // Don't propagate divergence from unreachable blocks. if (!DT.isReachableFromEntry(Term.getParent())) return; const auto *BranchLoop = LI.getLoopFor(Term.getParent()); const auto &DivDesc = SDA.getJoinBlocks(Term); // Iterate over all blocks now reachable by a disjoint path join for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { taintAndPushPhiNodes(*JoinBlock); } assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); } } void DivergenceAnalysis::compute() { // Initialize worklist. auto DivValuesCopy = DivergentValues; for (const auto *DivVal : DivValuesCopy) { assert(isDivergent(*DivVal) && "Worklist invariant violated!"); pushUsers(*DivVal); } // All values on the Worklist are divergent. // Their users may not have been updated yed. while (!Worklist.empty()) { const Instruction &I = *Worklist.back(); Worklist.pop_back(); // propagate value divergence to users assert(isDivergent(I) && "Worklist invariant violated!"); pushUsers(I); } } bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { return UniformOverrides.contains(&V); } bool DivergenceAnalysis::isDivergent(const Value &V) const { return DivergentValues.contains(&V); } bool DivergenceAnalysis::isDivergentUse(const Use &U) const { Value &V = *U.get(); Instruction &I = *cast(U.getUser()); return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); } void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { if (DivergentValues.empty()) return; // iterate instructions using instructions() to ensure a deterministic order. for (auto &I : instructions(F)) { if (isDivergent(I)) OS << "DIVERGENT:" << I << '\n'; } } // class GPUDivergenceAnalysis GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI, const TargetTransformInfo &TTI) : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) { for (auto &I : instructions(F)) { if (TTI.isSourceOfDivergence(&I)) { DA.markDivergent(I); } else if (TTI.isAlwaysUniform(&I)) { DA.addUniformOverride(I); } } for (auto &Arg : F.args()) { if (TTI.isSourceOfDivergence(&Arg)) { DA.markDivergent(Arg); } } DA.compute(); } bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { return DA.isDivergent(val); } bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const { return DA.isDivergentUse(use); } void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; DA.print(OS, mod); OS << "}\n"; }