Skip to content

[ADT] Simplify SparseBitVectorIterator. NFCI. #143885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Jun 12, 2025

The old implementation admitted to being "a lot uglier than it would be, in order to be efficient". This new implementation aims to gain efficiency through simplicity.

The old implementation admitted to being "a lot uglier than it would be,
in order to be efficient". The new implementation aims to gain
efficiency through simplicity.
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-llvm-adt

Author: Jay Foad (jayfoad)

Changes

The old implementation admitted to being "a lot uglier than it would be, in order to be efficient". This new implementation aims to gain efficiency through simplicity.


Full diff: https://github.com/llvm/llvm-project/pull/143885.diff

1 Files Affected:

  • (modified) llvm/include/llvm/ADT/SparseBitVector.h (+24-104)
diff --git a/llvm/include/llvm/ADT/SparseBitVector.h b/llvm/include/llvm/ADT/SparseBitVector.h
index 7151af6146e6e..d3ac388871d9d 100644
--- a/llvm/include/llvm/ADT/SparseBitVector.h
+++ b/llvm/include/llvm/ADT/SparseBitVector.h
@@ -145,19 +145,14 @@ template <unsigned ElementSize = 128> struct SparseBitVectorElement {
 
   /// find_next - Returns the index of the next set bit starting from the
   /// "Curr" bit. Returns -1 if the next set bit is not found.
-  int find_next(unsigned Curr) const {
-    if (Curr >= BITS_PER_ELEMENT)
-      return -1;
+  int find_next(int Curr) const {
+    assert(Curr >= 0 && Curr < BITS_PER_ELEMENT);
 
     unsigned WordPos = Curr / BITWORD_SIZE;
     unsigned BitPos = Curr % BITWORD_SIZE;
-    BitWord Copy = Bits[WordPos];
-    assert(WordPos <= BITWORDS_PER_ELEMENT
-           && "Word Position outside of element");
 
     // Mask off previous bits.
-    Copy &= ~0UL << BitPos;
-
+    BitWord Copy = Bits[WordPos] & ~1UL << BitPos;
     if (Copy != 0)
       return WordPos * BITWORD_SIZE + llvm::countr_zero(Copy);
 
@@ -314,101 +309,34 @@ class SparseBitVector {
     return FindLowerBoundImpl(ElementIndex);
   }
 
-  // Iterator to walk set bits in the bitmap.  This iterator is a lot uglier
-  // than it would be, in order to be efficient.
+  // Iterator to walk set bits in the bitvector.
   class SparseBitVectorIterator {
   private:
-    bool AtEnd;
-
-    const SparseBitVector<ElementSize> *BitVector = nullptr;
+    // Current bit number within the current element, or -1 if we are at the
+    // end.
+    int BitPos = -1;
 
-    // Current element inside of bitmap.
+    // Iterators to the current element and the end of the bitvector. These are
+    // only valid when BitPos >= 0.
     ElementListConstIter Iter;
-
-    // Current bit number inside of our bitmap.
-    unsigned BitNumber;
-
-    // Current word number inside of our element.
-    unsigned WordNumber;
-
-    // Current bits from the element.
-    typename SparseBitVectorElement<ElementSize>::BitWord Bits;
-
-    // Move our iterator to the first non-zero bit in the bitmap.
-    void AdvanceToFirstNonZero() {
-      if (AtEnd)
-        return;
-      if (BitVector->Elements.empty()) {
-        AtEnd = true;
-        return;
-      }
-      Iter = BitVector->Elements.begin();
-      BitNumber = Iter->index() * ElementSize;
-      unsigned BitPos = Iter->find_first();
-      BitNumber += BitPos;
-      WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
-      Bits = Iter->word(WordNumber);
-      Bits >>= BitPos % BITWORD_SIZE;
-    }
-
-    // Move our iterator to the next non-zero bit.
-    void AdvanceToNextNonZero() {
-      if (AtEnd)
-        return;
-
-      while (Bits && !(Bits & 1)) {
-        Bits >>= 1;
-        BitNumber += 1;
-      }
-
-      // See if we ran out of Bits in this word.
-      if (!Bits) {
-        int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize) ;
-        // If we ran out of set bits in this element, move to next element.
-        if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
-          ++Iter;
-          WordNumber = 0;
-
-          // We may run out of elements in the bitmap.
-          if (Iter == BitVector->Elements.end()) {
-            AtEnd = true;
-            return;
-          }
-          // Set up for next non-zero word in bitmap.
-          BitNumber = Iter->index() * ElementSize;
-          NextSetBitNumber = Iter->find_first();
-          BitNumber += NextSetBitNumber;
-          WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
-          Bits = Iter->word(WordNumber);
-          Bits >>= NextSetBitNumber % BITWORD_SIZE;
-        } else {
-          WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
-          Bits = Iter->word(WordNumber);
-          Bits >>= NextSetBitNumber % BITWORD_SIZE;
-          BitNumber = Iter->index() * ElementSize;
-          BitNumber += NextSetBitNumber;
-        }
-      }
-    }
+    ElementListConstIter End;
 
   public:
     SparseBitVectorIterator() = default;
 
-    SparseBitVectorIterator(const SparseBitVector<ElementSize> *RHS,
-                            bool end = false):BitVector(RHS) {
-      Iter = BitVector->Elements.begin();
-      BitNumber = 0;
-      Bits = 0;
-      WordNumber = ~0;
-      AtEnd = end;
-      AdvanceToFirstNonZero();
+    SparseBitVectorIterator(const ElementList &Elements) {
+      if (Elements.empty())
+        return;
+      Iter = Elements.begin();
+      End = Elements.end();
+      BitPos = Iter->find_first();
     }
 
     // Preincrement.
     inline SparseBitVectorIterator& operator++() {
-      ++BitNumber;
-      Bits >>= 1;
-      AdvanceToNextNonZero();
+      BitPos = Iter->find_next(BitPos);
+      if (BitPos < 0 && ++Iter != End)
+        BitPos = Iter->find_first();
       return *this;
     }
 
@@ -421,16 +349,12 @@ class SparseBitVector {
 
     // Return the current set bit number.
     unsigned operator*() const {
-      return BitNumber;
+      assert(BitPos >= 0);
+      return Iter->index() * ElementSize + BitPos;
     }
 
     bool operator==(const SparseBitVectorIterator &RHS) const {
-      // If they are both at the end, ignore the rest of the fields.
-      if (AtEnd && RHS.AtEnd)
-        return true;
-      // Otherwise they are the same if they have the same bit number and
-      // bitmap.
-      return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
+      return BitPos == RHS.BitPos;
     }
 
     bool operator!=(const SparseBitVectorIterator &RHS) const {
@@ -807,13 +731,9 @@ class SparseBitVector {
     return BitCount;
   }
 
-  iterator begin() const {
-    return iterator(this);
-  }
+  iterator begin() const { return iterator(Elements); }
 
-  iterator end() const {
-    return iterator(this, true);
-  }
+  iterator end() const { return iterator(); }
 };
 
 // Convenience functions to allow Or and And without dereferencing in the user

@jayfoad
Copy link
Contributor Author

jayfoad commented Jun 12, 2025

No impact on overall compile time: https://llvm-compile-time-tracker.com/compare.php?from=756e7cfd86c7f2bf20aaa1a3f87b5aa72ec128b4&to=b5e202767931f2e331d34380117cd5b4225b937c&stat=instructions:u

Do we have any standard way of writing microbenchmarks for things like this?


// Mask off previous bits.
Copy &= ~0UL << BitPos;

BitWord Copy = Bits[WordPos] & ~1UL << BitPos;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ~1UL instead of ~0UL is a bug fix. Previously this whole find_next method was unused.

@nikic
Copy link
Contributor

nikic commented Jun 12, 2025

Do we have any standard way of writing microbenchmarks for things like this?

We have llvm/benchmarks using google/benchmark.

@jayfoad
Copy link
Contributor Author

jayfoad commented Jun 12, 2025

We have llvm/benchmarks using google/benchmark.

Thanks! I added a benchmark. I get these numbers before:

---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_SparseBitVectorIterator/10          15.4 ns         15.4 ns     45277610
BM_SparseBitVectorIterator/100          229 ns          229 ns      3056142
BM_SparseBitVectorIterator/1000        1821 ns         1821 ns       384125
BM_SparseBitVectorIterator/10000       2975 ns         2975 ns       223294

And after:

---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_SparseBitVectorIterator/10          9.86 ns         9.86 ns     70871286
BM_SparseBitVectorIterator/100          100 ns          100 ns      7088720
BM_SparseBitVectorIterator/1000        1356 ns         1356 ns       512892
BM_SparseBitVectorIterator/10000      12426 ns        12425 ns        56256

This shows that the new implementation is faster for sparse SparseBitVectors but 4x slower when they get almost full.

@kuhar
Copy link
Member

kuhar commented Jun 12, 2025

This shows that the new implementation is faster for sparse SparseBitVectors but 4x slower when they get almost full.

Do we know how full the existing SparseBitVectors are in LLVM on average?

@jayfoad
Copy link
Contributor Author

jayfoad commented Jun 12, 2025

This shows that the new implementation is faster for sparse SparseBitVectors but 4x slower when they get almost full.

Do we know how full the existing SparseBitVectors are in LLVM on average?

I don't know, and I'm not sure I have the time or energy to gather that kind of data.

Interestingly the 4x slowdown seems to be entirely due to using llvm::countr_zero to find the next bit set in a word, instead of a loop like:

      while (Bits && !(Bits & 1)) {
        Bits >>= 1;
        BitNumber += 1;
      }

I don't understand this. I have checked that countr_zero codegens to a single tzcnt instruction which is supposed to be pretty fast. I'm running my benchmark on an AMD Ryzen 9 9950X 16-Core Processor.

@kuhar
Copy link
Member

kuhar commented Jun 12, 2025

I don't know, and I'm not sure I have the time or energy to gather that kind of data.

I usually measure these things by adding an fprintf(stderr, ...); in the destructor, running on any bitcode laying around and then copying to a spreadsheet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants