Skip to content

[BOLT][NFCI] Simplify DataAggregator using traces #143289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

aaupov
Copy link
Contributor

@aaupov aaupov commented Jun 7, 2025

Consistently apply traces as defined in #127125 for branch profile
aggregation. This combines branches and fall-through records into one.

With large input binaries/profiles, the speed up in aggregation time
(-time-aggr, wall time):

  • perf.data, pre-BOLT input: 154.5528s -> 144.0767s
  • pre-aggregated data, pre-BOLT input: 15.1026s -> 9.0711s
  • pre-aggregated data, BOLTed input: 15.4871s -> 10.0077s

Test Plan: NFC

aaupov added 2 commits June 7, 2025 15:58
Created using spr 1.3.4
Copy link

github-actions bot commented Jun 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

aaupov added 4 commits June 7, 2025 21:11
Created using spr 1.3.4

[skip ci]
Created using spr 1.3.4

[skip ci]
Created using spr 1.3.4
@aaupov aaupov changed the base branch from users/aaupov/spr/main.boltnfci-simplify-dataaggregator-using-traces-1 to main June 9, 2025 05:52
@aaupov aaupov marked this pull request as ready for review June 9, 2025 05:52
@llvmbot llvmbot added the BOLT label Jun 9, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 9, 2025

@llvm/pr-subscribers-bolt

Author: Amir Ayupov (aaupov)

Changes

Consistently apply traces as defined in #127125 for branch profile
aggregation. This combines branches and fall-through records into one.

Test Plan: NFC


Full diff: https://github.com/llvm/llvm-project/pull/143289.diff

2 Files Affected:

  • (modified) bolt/include/bolt/Profile/DataAggregator.h (+33-14)
  • (modified) bolt/lib/Profile/DataAggregator.cpp (+53-107)
diff --git a/bolt/include/bolt/Profile/DataAggregator.h b/bolt/include/bolt/Profile/DataAggregator.h
index 3f07a6dc03a4f..1e115b0231055 100644
--- a/bolt/include/bolt/Profile/DataAggregator.h
+++ b/bolt/include/bolt/Profile/DataAggregator.h
@@ -99,26 +99,31 @@ class DataAggregator : public DataReader {
     uint64_t Addr;
   };
 
+  /// Container for the unit of branch data.
+  /// Backwards compatible with legacy use for branches and fall-throughs:
+  /// - if \p Branch is FT_ONLY or FT_EXTERNAL_ORIGIN, the trace only contains
+  ///   fall-through data,
+  /// - if \p To is EXTERNAL, the trace only contains branch data.
   struct Trace {
+    static constexpr const uint64_t EXTERNAL = 0ULL;
+    static constexpr const uint64_t FT_ONLY = -1ULL;
+    static constexpr const uint64_t FT_EXTERNAL_ORIGIN = -2ULL;
+
+    uint64_t Branch;
     uint64_t From;
     uint64_t To;
-    Trace(uint64_t From, uint64_t To) : From(From), To(To) {}
     bool operator==(const Trace &Other) const {
-      return From == Other.From && To == Other.To;
+      return Branch == Other.Branch && From == Other.From && To == Other.To;
     }
   };
+  friend raw_ostream &operator<<(raw_ostream &OS, const Trace &);
 
   struct TraceHash {
     size_t operator()(const Trace &L) const {
-      return std::hash<uint64_t>()(L.From << 32 | L.To);
+      return llvm::hash_combine(L.Branch, L.From, L.To);
     }
   };
 
-  struct FTInfo {
-    uint64_t InternCount{0};
-    uint64_t ExternCount{0};
-  };
-
   struct TakenBranchInfo {
     uint64_t TakenCount{0};
     uint64_t MispredCount{0};
@@ -126,8 +131,8 @@ class DataAggregator : public DataReader {
 
   /// Intermediate storage for profile data. We save the results of parsing
   /// and use them later for processing and assigning profile.
-  std::unordered_map<Trace, TakenBranchInfo, TraceHash> BranchLBRs;
-  std::unordered_map<Trace, FTInfo, TraceHash> FallthroughLBRs;
+  std::unordered_map<Trace, TakenBranchInfo, TraceHash> TraceMap;
+  std::vector<std::pair<Trace, TakenBranchInfo>> Traces;
   std::unordered_map<uint64_t, uint64_t> BasicSamples;
   std::vector<PerfMemSample> MemSamples;
 
@@ -200,8 +205,8 @@ class DataAggregator : public DataReader {
   /// Return a vector of offsets corresponding to a trace in a function
   /// if the trace is valid, std::nullopt otherwise.
   std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
-  getFallthroughsInTrace(BinaryFunction &BF, const LBREntry &First,
-                         const LBREntry &Second, uint64_t Count = 1) const;
+  getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
+                         uint64_t Count) const;
 
   /// Record external entry into the function \p BF.
   ///
@@ -265,8 +270,7 @@ class DataAggregator : public DataReader {
   bool doBranch(uint64_t From, uint64_t To, uint64_t Count, uint64_t Mispreds);
 
   /// Register a trace between two LBR entries supplied in execution order.
-  bool doTrace(const LBREntry &First, const LBREntry &Second,
-               uint64_t Count = 1);
+  bool doTrace(const Trace &Trace, uint64_t Count);
 
   /// Parser helpers
   /// Return false if we exhausted our parser buffer and finished parsing
@@ -516,6 +520,21 @@ inline raw_ostream &operator<<(raw_ostream &OS,
   OS << formatv("{0:x} -> {1:x}/{2}", L.From, L.To, L.Mispred ? 'M' : 'P');
   return OS;
 }
+
+inline raw_ostream &operator<<(raw_ostream &OS,
+                               const DataAggregator::Trace &T) {
+  switch (T.Branch) {
+  case DataAggregator::Trace::FT_ONLY:
+  case DataAggregator::Trace::FT_EXTERNAL_ORIGIN:
+    break;
+  default:
+    OS << Twine::utohexstr(T.Branch) << " -> ";
+  }
+  OS << Twine::utohexstr(T.From);
+  if (T.To)
+    OS << " ... " << Twine::utohexstr(T.To);
+  return OS;
+}
 } // namespace bolt
 } // namespace llvm
 
diff --git a/bolt/lib/Profile/DataAggregator.cpp b/bolt/lib/Profile/DataAggregator.cpp
index 4022212bcf1b6..bd7e550569140 100644
--- a/bolt/lib/Profile/DataAggregator.cpp
+++ b/bolt/lib/Profile/DataAggregator.cpp
@@ -589,8 +589,7 @@ void DataAggregator::processProfile(BinaryContext &BC) {
     llvm::stable_sort(MemEvents.second.Data);
 
   // Release intermediate storage.
-  clear(BranchLBRs);
-  clear(FallthroughLBRs);
+  clear(Traces);
   clear(BasicSamples);
   clear(MemSamples);
 }
@@ -771,37 +770,18 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
   return doInterBranch(FromFunc, ToFunc, From, To, Count, Mispreds);
 }
 
-bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
-                             uint64_t Count) {
-  BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(First.To);
-  BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(Second.From);
+bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
+  const uint64_t Branch = Trace.Branch, From = Trace.From, To = Trace.To;
+  BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(From);
+  BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(To);
   if (!FromFunc || !ToFunc) {
-    LLVM_DEBUG({
-      dbgs() << "Out of range trace starting in ";
-      if (FromFunc)
-        dbgs() << formatv("{0} @ {1:x}", *FromFunc,
-                          First.To - FromFunc->getAddress());
-      else
-        dbgs() << Twine::utohexstr(First.To);
-      dbgs() << " and ending in ";
-      if (ToFunc)
-        dbgs() << formatv("{0} @ {1:x}", *ToFunc,
-                          Second.From - ToFunc->getAddress());
-      else
-        dbgs() << Twine::utohexstr(Second.From);
-      dbgs() << '\n';
-    });
+    LLVM_DEBUG(dbgs() << "Out of range trace " << Trace << '\n');
     NumLongRangeTraces += Count;
     return false;
   }
   if (FromFunc != ToFunc) {
     NumInvalidTraces += Count;
-    LLVM_DEBUG({
-      dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
-             << formatv(" @ {0:x}", First.To - FromFunc->getAddress())
-             << " and ending in " << ToFunc->getPrintName()
-             << formatv(" @ {0:x}\n", Second.From - ToFunc->getAddress());
-    });
+    LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
     return false;
   }
 
@@ -809,28 +789,21 @@ bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
   BinaryFunction *ParentFunc = getBATParentFunction(*FromFunc);
   if (!ParentFunc)
     ParentFunc = FromFunc;
-  ParentFunc->SampleCountInBytes += Count * (Second.From - First.To);
+  ParentFunc->SampleCountInBytes += Count * (To - From);
 
   const uint64_t FuncAddress = FromFunc->getAddress();
   std::optional<BoltAddressTranslation::FallthroughListTy> FTs =
       BAT && BAT->isBATFunction(FuncAddress)
-          ? BAT->getFallthroughsInTrace(FuncAddress, First.To, Second.From)
-          : getFallthroughsInTrace(*FromFunc, First, Second, Count);
+          ? BAT->getFallthroughsInTrace(FuncAddress, From, To)
+          : getFallthroughsInTrace(*FromFunc, Trace, Count);
   if (!FTs) {
-    LLVM_DEBUG(
-        dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
-               << " @ " << Twine::utohexstr(First.To - FromFunc->getAddress())
-               << " and ending in " << ToFunc->getPrintName() << " @ "
-               << ToFunc->getPrintName() << " @ "
-               << Twine::utohexstr(Second.From - ToFunc->getAddress()) << '\n');
+    LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
     NumInvalidTraces += Count;
     return false;
   }
 
   LLVM_DEBUG(dbgs() << "Processing " << FTs->size() << " fallthroughs for "
-                    << FromFunc->getPrintName() << ":"
-                    << Twine::utohexstr(First.To) << " to "
-                    << Twine::utohexstr(Second.From) << ".\n");
+                    << FromFunc->getPrintName() << ":" << Trace << '\n');
   for (auto [From, To] : *FTs) {
     if (BAT) {
       From = BAT->translate(FromFunc->getAddress(), From, /*IsBranchSrc=*/true);
@@ -843,17 +816,15 @@ bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
 }
 
 std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
-DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
-                                       const LBREntry &FirstLBR,
-                                       const LBREntry &SecondLBR,
+DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
                                        uint64_t Count) const {
   SmallVector<std::pair<uint64_t, uint64_t>, 16> Branches;
 
   BinaryContext &BC = BF.getBinaryContext();
 
   // Offsets of the trace within this function.
-  const uint64_t From = FirstLBR.To - BF.getAddress();
-  const uint64_t To = SecondLBR.From - BF.getAddress();
+  const uint64_t From = Trace.From - BF.getAddress();
+  const uint64_t To = Trace.To - BF.getAddress();
 
   if (From > To)
     return std::nullopt;
@@ -880,8 +851,9 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
 
   // Adjust FromBB if the first LBR is a return from the last instruction in
   // the previous block (that instruction should be a call).
-  if (From == FromBB->getOffset() && !BF.containsAddress(FirstLBR.From) &&
-      !FromBB->isEntryPoint() && !FromBB->isLandingPad()) {
+  if (Trace.Branch != Trace::FT_ONLY && !BF.containsAddress(Trace.Branch) &&
+      From == FromBB->getOffset() && !FromBB->isEntryPoint() &&
+      !FromBB->isLandingPad()) {
     const BinaryBasicBlock *PrevBB =
         BF.getLayout().getBlock(FromBB->getIndex() - 1);
     if (PrevBB->getSuccessor(FromBB->getLabel())) {
@@ -889,10 +861,9 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
       if (Instr && BC.MIB->isCall(*Instr))
         FromBB = PrevBB;
       else
-        LLVM_DEBUG(dbgs() << "invalid incoming LBR (no call): " << FirstLBR
-                          << '\n');
+        LLVM_DEBUG(dbgs() << "invalid trace (no call): " << Trace << '\n');
     } else {
-      LLVM_DEBUG(dbgs() << "invalid incoming LBR: " << FirstLBR << '\n');
+      LLVM_DEBUG(dbgs() << "invalid trace: " << Trace << '\n');
     }
   }
 
@@ -911,9 +882,7 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
 
     // Check for bad LBRs.
     if (!BB->getSuccessor(NextBB->getLabel())) {
-      LLVM_DEBUG(dbgs() << "no fall-through for the trace:\n"
-                        << "  " << FirstLBR << '\n'
-                        << "  " << SecondLBR << '\n');
+      LLVM_DEBUG(dbgs() << "no fall-through for the trace: " << Trace << '\n');
       return std::nullopt;
     }
 
@@ -1307,29 +1276,24 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
   if (ToFunc)
     ToFunc->setHasProfileAvailable();
 
-  Trace Trace(FromOffset, ToOffset);
-  // Taken trace
-  if (Type == TRACE || Type == BRANCH) {
-    TakenBranchInfo &Info = BranchLBRs[Trace];
-    Info.TakenCount += Count;
-    Info.MispredCount += Mispreds;
-
-    NumTotalSamples += Count;
+  if (Type == FT || Type == FT_EXTERNAL_ORIGIN) {
+    Addr[2] = Location(Addr[1]->Offset);
+    Addr[1] = Location(Addr[0]->Offset);
+    Addr[0] = Location(Type == FT ? Trace::FT_ONLY : Trace::FT_EXTERNAL_ORIGIN);
   }
-  // Construct fallthrough part of the trace
-  if (Type == TRACE) {
-    const uint64_t TraceFtEndOffset = Addr[2]->Offset;
-    Trace.From = ToOffset;
-    Trace.To = TraceFtEndOffset;
-    Type = FromFunc == ToFunc ? FT : FT_EXTERNAL_ORIGIN;
+
+  if (Type == BRANCH) {
+    Addr[2] = Location(Trace::EXTERNAL);
   }
-  // Add fallthrough trace
-  if (Type != BRANCH) {
-    FTInfo &Info = FallthroughLBRs[Trace];
-    (Type == FT ? Info.InternCount : Info.ExternCount) += Count;
 
+  Trace T{Addr[0]->Offset, Addr[1]->Offset, Addr[2]->Offset};
+  TakenBranchInfo TI{(uint64_t)Count, (uint64_t)Mispreds};
+
+  Traces.emplace_back(T, TI);
+
+  if (Addr[2]->Offset)
     NumTraces += Count;
-  }
+  NumTotalSamples += Count;
 
   return std::error_code();
 }
@@ -1377,12 +1341,9 @@ std::error_code DataAggregator::printLBRHeatMap() {
   // Register basic samples and perf LBR addresses not covered by fallthroughs.
   for (const auto &[PC, Hits] : BasicSamples)
     HM.registerAddress(PC, Hits);
-  for (const auto &LBR : FallthroughLBRs) {
-    const Trace &Trace = LBR.first;
-    const FTInfo &Info = LBR.second;
-    HM.registerAddressRange(Trace.From, Trace.To,
-                            Info.InternCount + Info.ExternCount);
-  }
+  for (const auto &[Trace, Info] : Traces)
+    if (Trace.To)
+      HM.registerAddressRange(Trace.From, Trace.To, Info.TakenCount);
 
   if (HM.getNumInvalidRanges())
     outs() << "HEATMAP: invalid traces: " << HM.getNumInvalidRanges() << '\n';
@@ -1428,22 +1389,14 @@ void DataAggregator::parseLBRSample(const PerfBranchSample &Sample,
     // chronological order)
     if (NeedsSkylakeFix && NumEntry <= 2)
       continue;
+    uint64_t TraceTo = Trace::EXTERNAL;
     if (NextLBR) {
-      // Record fall-through trace.
-      const uint64_t TraceFrom = LBR.To;
-      const uint64_t TraceTo = NextLBR->From;
-      const BinaryFunction *TraceBF =
-          getBinaryFunctionContainingAddress(TraceFrom);
-      FTInfo &Info = FallthroughLBRs[Trace(TraceFrom, TraceTo)];
-      if (TraceBF && TraceBF->containsAddress(LBR.From))
-        ++Info.InternCount;
-      else
-        ++Info.ExternCount;
+      TraceTo = NextLBR->From;
       ++NumTraces;
     }
     NextLBR = &LBR;
 
-    TakenBranchInfo &Info = BranchLBRs[Trace(LBR.From, LBR.To)];
+    TakenBranchInfo &Info = TraceMap[Trace{LBR.From, LBR.To, TraceTo}];
     ++Info.TakenCount;
     Info.MispredCount += LBR.Mispred;
   }
@@ -1554,10 +1507,14 @@ std::error_code DataAggregator::parseBranchEvents() {
     parseLBRSample(Sample, NeedsSkylakeFix);
   }
 
-  for (const Trace &Trace : llvm::make_first_range(BranchLBRs))
-    for (const uint64_t Addr : {Trace.From, Trace.To})
+  Traces.reserve(TraceMap.size());
+  for (const auto &[Trace, Info] : TraceMap) {
+    Traces.emplace_back(Trace, Info);
+    for (const uint64_t Addr : {Trace.Branch, Trace.From})
       if (BinaryFunction *BF = getBinaryFunctionContainingAddress(Addr))
         BF->setHasProfileAvailable();
+  }
+  clear(TraceMap);
 
   outs() << "PERF2BOLT: read " << NumSamples << " samples and " << NumEntries
          << " LBR entries\n";
@@ -1582,23 +1539,12 @@ void DataAggregator::processBranchEvents() {
   NamedRegionTimer T("processBranch", "Processing branch events",
                      TimerGroupName, TimerGroupDesc, opts::TimeAggregator);
 
-  for (const auto &AggrLBR : FallthroughLBRs) {
-    const Trace &Loc = AggrLBR.first;
-    const FTInfo &Info = AggrLBR.second;
-    LBREntry First{Loc.From, Loc.From, false};
-    LBREntry Second{Loc.To, Loc.To, false};
-    if (Info.InternCount)
-      doTrace(First, Second, Info.InternCount);
-    if (Info.ExternCount) {
-      First.From = 0;
-      doTrace(First, Second, Info.ExternCount);
-    }
-  }
-
-  for (const auto &AggrLBR : BranchLBRs) {
-    const Trace &Loc = AggrLBR.first;
-    const TakenBranchInfo &Info = AggrLBR.second;
-    doBranch(Loc.From, Loc.To, Info.TakenCount, Info.MispredCount);
+  for (const auto &[Trace, Info] : Traces) {
+    if (Trace.To)
+      doTrace(Trace, Info.TakenCount);
+    if (Trace.Branch != Trace::FT_ONLY &&
+        Trace.Branch != Trace::FT_EXTERNAL_ORIGIN)
+      doBranch(Trace.Branch, Trace.From, Info.TakenCount, Info.MispredCount);
   }
   printBranchSamplesDiagnostics();
 }

aaupov added 4 commits June 9, 2025 15:57
Created using spr 1.3.4
Created using spr 1.3.4
Created using spr 1.3.4
Created using spr 1.3.4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants