Skip to content

Commit 8906bad

Browse files
committed
Fix and refactor indexOfDiff
1 parent dcf6ff4 commit 8906bad

File tree

1 file changed

+89
-87
lines changed

1 file changed

+89
-87
lines changed

lib/std/mem.zig

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -787,13 +787,13 @@ fn eqlBytes(a: []const u8, b: []const u8) bool {
787787
/// Compares two slices and returns the index of the first inequality.
788788
/// Returns null if the slices are equal.
789789
pub fn indexOfDiff(comptime T: type, a: []const T, b: []const T) ?usize {
790-
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T) and eqlBytes_allowed)
790+
if (!@inComptime() and @sizeOf(T) != 0 and std.meta.hasUniqueRepresentation(T))
791791
return if (indexOfDiffBytes(sliceAsBytes(a), sliceAsBytes(b))) |index| index / @sizeOf(T) else null;
792792

793793
const shortest = @min(a.len, b.len);
794794
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;
795-
var index: usize = 0;
796-
while (index < shortest) : (index += 1) if (a[index] != b[index]) return index;
795+
796+
for (0..shortest) |index| if (a[index] != b[index]) return index;
797797
return if (a.len == b.len) null else shortest;
798798
}
799799

@@ -803,8 +803,10 @@ test indexOfDiff {
803803
try testing.expectEqual(indexOfDiff(u8, "one", "one two"), 3);
804804
try testing.expectEqual(indexOfDiff(u8, "one twx", "one two"), 6);
805805
try testing.expectEqual(indexOfDiff(u8, "xne", "one"), 0);
806-
try testing.expectEqual(indexOfDiff(u8, "one two three four", "one two three"), 13);
807-
try testing.expectEqual(indexOfDiff(u8, "one two three four five six", "one two three four five"), 23);
806+
try testing.expectEqual(indexOfDiff(u16, &.{ 0x4e00, 0x4e8c, 0x4e09, 0x56db }, &.{ 0x4e00, 0x4e8c, 0x4e09 }), 3);
807+
try testing.expectEqual(indexOfDiff(u16, &.{ 0x96f6, 0x4e8c, 0x4e09, 0x56db }, &.{ 0x4e00, 0x4e8c, 0x4e09, 0x56db }), 0);
808+
try testing.expectEqual(indexOfDiff(f64, &.{ 0x8000000000000000, 0x0000000000000000 }, &.{ 0x0000000000000000, 0x0000000000000000 }), 0);
809+
try testing.expectEqual(indexOfDiff(u64, &.{ 0xaaaaaaaaaaaaaaaa, 0xaaaaaaaaaaaabbbb }, &.{ 0xaaaaaaaaaaaaaaaa, 0xaaaaaaaaaaaacccc }), 1);
808810
comptime {
809811
try testing.expectEqual(indexOfDiff(type, &.{ bool, f32 }, &.{ bool, f32 }), null);
810812
try testing.expectEqual(indexOfDiff(type, &.{ bool, f32 }, &.{ f32, bool }), 0);
@@ -815,106 +817,106 @@ test indexOfDiff {
815817
}
816818
try testing.expectEqual(indexOfDiff(void, &.{ {}, {} }, &.{ {}, {} }), null);
817819
try testing.expectEqual(indexOfDiff(void, &.{{}}, &.{ {}, {} }), 1);
818-
try testing.expectEqual(indexOfDiff(f64, &.{ 3.14, 2.71, 1.60 }, &.{ 3.14, 2.71, 1.60 }), null);
819-
try testing.expectEqual(indexOfDiff(u128, &.{ 1, 2, 3 }, &.{ 1, 2, 4 }), 2);
820820
}
821821

822822
/// std.mem.indexOfDiff heavily optimized for slices of bytes.
823823
fn indexOfDiffBytes(a: []const u8, b: []const u8) ?usize {
824-
comptime assert(eqlBytes_allowed);
825-
826824
const shortest = @min(a.len, b.len);
827825
if (a.ptr == b.ptr) return if (a.len == b.len) null else shortest;
828826

829-
if (shortest < 16) {
830-
if (shortest < @sizeOf(usize)) {
827+
const swar_thr = @sizeOf(usize) * 2;
828+
const max_vec_size = std.simd.suggestVectorLength(u8) orelse 0;
829+
const unroll_factor = 4;
830+
// Context used to generate corresponding scanning strategies (SWAR/SIMD) at compile time
831+
const Ctx = struct {
832+
fn Scan(vec_size: comptime_int) type {
833+
return if (vec_size != 0) struct { // SIMD path
834+
const size = vec_size;
835+
const Chunk = @Vector(size, u8);
836+
const Mask = @Type(.{ .int = .{ .bits = size, .signedness = .unsigned } });
837+
inline fn load(src: []const u8) Chunk {
838+
return @bitCast(src[0..size].*);
839+
}
840+
inline fn toMask(lhs: Chunk, rhs: Chunk) Mask {
841+
return @bitCast(lhs != rhs);
842+
}
843+
inline fn hasDiff(mask: Mask) bool {
844+
return mask != 0;
845+
}
846+
inline fn firstDiff(mask: Mask) usize {
847+
return @ctz(mask);
848+
}
849+
} else struct { // SWAR path
850+
const size = @sizeOf(usize);
851+
const Chunk = usize;
852+
const Mask = usize;
853+
inline fn load(src: []const u8) Chunk {
854+
return @bitCast(src[0..size].*);
855+
}
856+
inline fn toMask(lhs: Chunk, rhs: Chunk) Mask {
857+
return lhs ^ rhs;
858+
}
859+
inline fn hasDiff(mask: Mask) bool {
860+
return mask != 0;
861+
}
862+
inline fn firstDiff(mask: Mask) usize {
863+
// Endian-aware
864+
return (if (native_endian == .little) @ctz(mask) else @clz(mask)) / 8;
865+
}
866+
};
867+
}
868+
};
869+
// Samll slices (0, @sizeOf(usize) * 2]
870+
if (shortest <= swar_thr) {
871+
const Scan = Ctx.Scan(0);
872+
// (0, @sizeOf(usize))
873+
if (shortest < Scan.size) {
831874
for (0..shortest) |index| if (a[index] != b[index]) return index;
832-
} else {
833-
var index: usize = 0;
834-
while (index + @sizeOf(usize) <= shortest) : (index += @sizeOf(usize)) {
835-
const a_chunk: usize = @bitCast(a[index..][0..@sizeOf(usize)].*);
836-
const b_chunk: usize = @bitCast(b[index..][0..@sizeOf(usize)].*);
837-
const diff = a_chunk ^ b_chunk;
838-
if (diff != 0)
839-
return index + @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
840-
}
841-
if (index < shortest) {
842-
const a_chunk: usize = @bitCast(a[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
843-
const b_chunk: usize = @bitCast(b[shortest - @sizeOf(usize) ..][0..@sizeOf(usize)].*);
844-
const diff = a_chunk ^ b_chunk;
845-
if (diff != 0)
846-
return shortest - @sizeOf(usize) + @divFloor(if (native_endian == .little) @ctz(diff) else @clz(diff), 8);
847-
}
875+
return if (a.len == b.len) null else shortest;
876+
}
877+
// [@sizeOf(usize), @sizeOf(usize) * 2]
878+
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
879+
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
880+
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
848881
}
849882
return if (a.len == b.len) null else shortest;
850883
}
851-
852-
const Scan = if (std.simd.suggestVectorLength(u8)) |vec_len| struct {
853-
const size = vec_len;
854-
855-
pub inline fn isNotZero(cur_size: comptime_int, mask: @Vector(cur_size, bool)) bool {
856-
return @reduce(.Or, mask);
857-
}
858-
859-
pub inline fn firstTrue(cur_size: comptime_int, mask: @Vector(cur_size, bool)) usize {
860-
return std.simd.firstTrue(mask).?;
861-
}
862-
} else struct {
863-
const size = @sizeOf(usize);
864-
865-
pub inline fn isNotZero(_: comptime_int, mask: usize) bool {
866-
return mask != 0;
867-
}
868-
pub inline fn firstTrue(_: comptime_int, mask: usize) usize {
869-
return @divFloor(if (native_endian == .little) @ctz(mask) else @clz(mask), 8);
870-
}
871-
};
872-
873-
// When the slice is smaller than the max vector length, reselect an appropriate vector length.
874-
if (shortest < Scan.size) {
875-
comptime var new_vec_len = 16;
876-
inline while (new_vec_len < Scan.size) : (new_vec_len *= 2) {
877-
if (new_vec_len < shortest and 2 * new_vec_len >= shortest) {
878-
inline for ([_]usize{ 0, shortest - new_vec_len }) |index| {
879-
const a_chunk: @Vector(new_vec_len, u8) = @bitCast(a[index..][0..new_vec_len].*);
880-
const b_chunk: @Vector(new_vec_len, u8) = @bitCast(b[index..][0..new_vec_len].*);
881-
const diff = a_chunk != b_chunk;
882-
if (Scan.isNotZero(new_vec_len, diff))
883-
return index + Scan.firstTrue(new_vec_len, diff);
884+
// Medium slices (@sizeOf(usize) * 2, max_vec_size)
885+
if (shortest < max_vec_size) {
886+
// Finding the appropriate vector length through doubling method
887+
comptime var cur_vec_size = swar_thr;
888+
inline while (cur_vec_size < max_vec_size) : (cur_vec_size *= 2) {
889+
if (cur_vec_size < shortest and shortest <= cur_vec_size * 2) {
890+
const Scan = Ctx.Scan(cur_vec_size);
891+
inline for ([_]usize{ 0, shortest - Scan.size }) |index| {
892+
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
893+
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
884894
}
885-
break;
895+
return if (a.len == b.len) null else shortest;
886896
}
887897
}
888898
}
889-
// Using max vector length to perform SIMD scanning on slice
890-
else {
891-
var index: usize = 0;
892-
const unroll_factor = 4;
893-
while (index + Scan.size * unroll_factor <= shortest) : (index += Scan.size * unroll_factor) {
894-
inline for (0..unroll_factor) |i| {
895-
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[index + Scan.size * i ..][0..Scan.size].*);
896-
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[index + Scan.size * i ..][0..Scan.size].*);
897-
const diff = a_chunk != b_chunk;
898-
if (Scan.isNotZero(Scan.size, diff))
899-
return index + Scan.size * i + Scan.firstTrue(Scan.size, diff);
900-
}
901-
}
902-
while (index + Scan.size <= shortest) : (index += Scan.size) {
903-
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[index..][0..Scan.size].*);
904-
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[index..][0..Scan.size].*);
905-
const diff = a_chunk != b_chunk;
906-
if (Scan.isNotZero(Scan.size, diff))
907-
return index + Scan.firstTrue(Scan.size, diff);
908-
}
899+
// Large slices [max_vec_size, +∞)
900+
const Scan = Ctx.Scan(max_vec_size);
909901

910-
if (index < shortest) {
911-
const a_chunk: @Vector(Scan.size, u8) = @bitCast(a[shortest - Scan.size ..][0..Scan.size].*);
912-
const b_chunk: @Vector(Scan.size, u8) = @bitCast(b[shortest - Scan.size ..][0..Scan.size].*);
913-
const diff = a_chunk != b_chunk;
914-
if (Scan.isNotZero(Scan.size, diff))
915-
return shortest - Scan.size + Scan.firstTrue(Scan.size, diff);
902+
var index: usize = 0;
903+
// Main unrolled loop
904+
while (index + Scan.size * unroll_factor <= shortest) : (index += Scan.size * unroll_factor) {
905+
inline for (0..unroll_factor) |i| {
906+
const mask = Scan.toMask(Scan.load(a[index + Scan.size * i ..]), Scan.load(b[index + Scan.size * i ..]));
907+
if (Scan.hasDiff(mask)) return index + Scan.size * i + Scan.firstDiff(mask);
916908
}
917909
}
910+
// Residual iterations
911+
while (index + Scan.size <= shortest) : (index += Scan.size) {
912+
const mask = Scan.toMask(Scan.load(a[index..]), Scan.load(b[index..]));
913+
if (Scan.hasDiff(mask)) return index + Scan.firstDiff(mask);
914+
}
915+
// Final overlapping check
916+
if (index < shortest) {
917+
const mask = Scan.toMask(Scan.load(a[shortest - Scan.size ..]), Scan.load(b[shortest - Scan.size ..]));
918+
if (Scan.hasDiff(mask)) return shortest - Scan.size + Scan.firstDiff(mask);
919+
}
918920

919921
return if (a.len == b.len) null else shortest;
920922
}

0 commit comments

Comments
 (0)