Skip to content

Commit 74ad4a1

Browse files
authored
Fix Equals and GetHashCode for types containing Lists and Arrays in C# (#2710)
1 parent ea418d5 commit 74ad4a1

33 files changed

+1935
-403
lines changed

crates/bindings-csharp/BSATN.Codegen/Type.cs

+282-49
Large diffs are not rendered by default.

crates/bindings-csharp/BSATN.Runtime.Tests/Tests.cs

+270-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
namespace SpacetimeDB;
22

3+
using System.Diagnostics.CodeAnalysis;
34
using CsCheck;
5+
using SpacetimeDB.BSATN;
46
using Xunit;
57

68
public static partial class BSATNRuntimeTests
@@ -296,9 +298,77 @@ public BasicDataRecord((int x, string y, int? z, string? w) data)
296298
(int X, string Y, int? Z, string? W) c2
297299
)> GenTwoBasic = Gen.Select(GenBasic, GenBasic, (c1, c2) => (c1, c2));
298300

301+
/// <summary>
302+
/// Count collisions when comparing hashcodes of non-equal structures.
303+
/// </summary>
304+
struct CollisionCounter
305+
{
306+
private uint Comparisons;
307+
private uint Collisions;
308+
309+
public void Add(bool collides)
310+
{
311+
Comparisons += 1;
312+
if (collides)
313+
{
314+
Collisions += 1;
315+
}
316+
}
317+
318+
public double CollisionFraction
319+
{
320+
get => (double)Collisions / (double)Comparisons;
321+
}
322+
323+
public void AssertCollisionsLessThan(double fraction)
324+
{
325+
Assert.True(
326+
CollisionFraction < fraction,
327+
$"Expected {fraction} portion of collisions, but got {CollisionFraction} = {Collisions} / {Comparisons}"
328+
);
329+
}
330+
}
331+
332+
static void TestRoundTrip<T, BSATN>(Gen<T> gen, BSATN serializer)
333+
where BSATN : IReadWrite<T>
334+
{
335+
gen.Sample(
336+
(value) =>
337+
{
338+
var stream = new MemoryStream();
339+
var writer = new BinaryWriter(stream);
340+
serializer.Write(writer, value);
341+
stream.Seek(0, SeekOrigin.Begin);
342+
var reader = new BinaryReader(stream);
343+
var result = serializer.Read(reader);
344+
Assert.Equal(value, result);
345+
},
346+
iter: 10_000
347+
);
348+
}
349+
350+
[Fact]
351+
public static void GeneratedProductRoundTrip()
352+
{
353+
TestRoundTrip(
354+
GenBasic.Select(value => new BasicDataClass(value)),
355+
new BasicDataClass.BSATN()
356+
);
357+
TestRoundTrip(
358+
GenBasic.Select(value => new BasicDataRecord(value)),
359+
new BasicDataRecord.BSATN()
360+
);
361+
TestRoundTrip(
362+
GenBasic.Select(value => new BasicDataStruct(value)),
363+
new BasicDataStruct.BSATN()
364+
);
365+
}
366+
299367
[Fact]
300-
public static void TestGeneratedEquals()
368+
public static void GeneratedProductEqualsWorks()
301369
{
370+
CollisionCounter collisionCounter = new();
371+
302372
GenTwoBasic.Sample(
303373
example =>
304374
{
@@ -355,10 +425,13 @@ public static void TestGeneratedEquals()
355425
// hash code should not depend on the type of object.
356426
Assert.Equal(class1.GetHashCode(), record1.GetHashCode());
357427
Assert.Equal(record1.GetHashCode(), struct1.GetHashCode());
428+
429+
collisionCounter.Add(class1.GetHashCode() == class2.GetHashCode());
358430
}
359431
},
360432
iter: 10_000
361433
);
434+
collisionCounter.AssertCollisionsLessThan(0.05);
362435
}
363436

364437
[Type]
@@ -395,22 +468,17 @@ BasicDataRecord W
395468
(e1, e2) => (e1, e2)
396469
);
397470

398-
[Type]
399-
public partial class ContainsList
471+
[Fact]
472+
public static void GeneratedSumRoundTrip()
400473
{
401-
public List<BasicEnum?> TheList = [];
402-
403-
public ContainsList() { }
404-
405-
public ContainsList(List<BasicEnum?> theList)
406-
{
407-
TheList = theList;
408-
}
474+
TestRoundTrip(GenBasicEnum, new BasicEnum.BSATN());
409475
}
410476

411477
[Fact]
412-
public static void GeneratedEnumsWork()
478+
public static void GeneratedSumEqualsWorks()
413479
{
480+
CollisionCounter collisionCounter = new();
481+
414482
GenTwoBasicEnum.Sample(
415483
example =>
416484
{
@@ -442,10 +510,186 @@ public static void GeneratedEnumsWork()
442510
Assert.False(example.e1 == example.e2);
443511
Assert.True(example.e1 != example.e2);
444512
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
513+
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
445514
}
446515
},
447516
iter: 10_000
448517
);
518+
collisionCounter.AssertCollisionsLessThan(0.05);
519+
}
520+
521+
[Type]
522+
public partial class ContainsList
523+
{
524+
public List<BasicEnum?>? TheList = [];
525+
526+
public ContainsList() { }
527+
528+
public ContainsList(List<BasicEnum?>? theList)
529+
{
530+
TheList = theList;
531+
}
532+
}
533+
534+
static readonly Gen<ContainsList> GenContainsList = GenBasicEnum
535+
.Null()
536+
.List[0, 2]
537+
.Null()
538+
.Select(list => new ContainsList(list));
539+
static readonly Gen<(ContainsList e1, ContainsList e2)> GenTwoContainsList = Gen.Select(
540+
GenContainsList,
541+
GenContainsList,
542+
(e1, e2) => (e1, e2)
543+
);
544+
545+
[Fact]
546+
public static void GeneratedListRoundTrip()
547+
{
548+
TestRoundTrip(GenContainsList, new ContainsList.BSATN());
549+
}
550+
551+
[Fact]
552+
public static void GeneratedListEqualsWorks()
553+
{
554+
CollisionCounter collisionCounter = new();
555+
GenTwoContainsList.Sample(
556+
example =>
557+
{
558+
var equal =
559+
example.e1.TheList == null
560+
? example.e2.TheList == null
561+
: (
562+
example.e2.TheList == null
563+
? false
564+
: example.e1.TheList.SequenceEqual(example.e2.TheList)
565+
);
566+
567+
if (equal)
568+
{
569+
Assert.Equal(example.e1, example.e2);
570+
Assert.True(example.e1 == example.e2);
571+
Assert.False(example.e1 != example.e2);
572+
Assert.Equal(example.e1.ToString(), example.e2.ToString());
573+
Assert.Equal(example.e1.GetHashCode(), example.e2.GetHashCode());
574+
}
575+
else
576+
{
577+
Assert.NotEqual(example.e1, example.e2);
578+
Assert.False(example.e1 == example.e2);
579+
Assert.True(example.e1 != example.e2);
580+
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
581+
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
582+
}
583+
},
584+
iter: 10_000
585+
);
586+
collisionCounter.AssertCollisionsLessThan(0.05);
587+
}
588+
589+
[Type]
590+
public partial class ContainsNestedList
591+
{
592+
public List<BasicEnum[][]> TheList = [];
593+
594+
public ContainsNestedList() { }
595+
596+
public ContainsNestedList(List<BasicEnum[][]> theList)
597+
{
598+
TheList = theList;
599+
}
600+
}
601+
602+
// For the serialization test, forbid nulls.
603+
static readonly Gen<ContainsNestedList> GenContainsNestedListNoNulls = GenBasicEnum
604+
.Array[0, 2]
605+
.Array[0, 2]
606+
.List[0, 2]
607+
.Select(list => new ContainsNestedList(list));
608+
609+
[Fact]
610+
public static void GeneratedNestedListRoundTrip()
611+
{
612+
TestRoundTrip(GenContainsNestedListNoNulls, new ContainsNestedList.BSATN());
613+
}
614+
615+
// However, for the equals + hashcode test, throw in some nulls, just to be paranoid.
616+
// The user might have constructed a bad one of these in-memory.
617+
618+
#pragma warning disable CS8620 // Argument cannot be used for parameter due to differences in the nullability of reference types.
619+
static readonly Gen<ContainsNestedList> GenContainsNestedList = GenBasicEnum
620+
.Null()
621+
.Array[0, 2]
622+
.Null()
623+
.Array[0, 2]
624+
.Null()
625+
.List[0, 2]
626+
.Select(list => new ContainsNestedList(list));
627+
#pragma warning restore CS8620 // Argument cannot be used for parameter due to differences in the nullability of reference types.
628+
629+
630+
static readonly Gen<(ContainsNestedList e1, ContainsNestedList e2)> GenTwoContainsNestedList =
631+
Gen.Select(GenContainsNestedList, GenContainsNestedList, (e1, e2) => (e1, e2));
632+
633+
class EnumerableEqualityComparer<T> : EqualityComparer<IEnumerable<T>>
634+
{
635+
private readonly EqualityComparer<T> EqualityComparer;
636+
637+
public EnumerableEqualityComparer(EqualityComparer<T> equalityComparer)
638+
{
639+
EqualityComparer = equalityComparer;
640+
}
641+
642+
public override bool Equals(IEnumerable<T>? x, IEnumerable<T>? y) =>
643+
x == null ? y == null : (y == null ? false : x.SequenceEqual(y, EqualityComparer));
644+
645+
public override int GetHashCode([DisallowNull] IEnumerable<T> obj)
646+
{
647+
var hashCode = 0;
648+
foreach (var item in obj)
649+
{
650+
if (item != null)
651+
{
652+
hashCode ^= EqualityComparer.GetHashCode(item);
653+
}
654+
}
655+
return hashCode;
656+
}
657+
}
658+
659+
[Fact]
660+
public static void GeneratedNestedListEqualsWorks()
661+
{
662+
var equalityComparer = new EnumerableEqualityComparer<IEnumerable<IEnumerable<BasicEnum>>>(
663+
new EnumerableEqualityComparer<IEnumerable<BasicEnum>>(
664+
new EnumerableEqualityComparer<BasicEnum>(EqualityComparer<BasicEnum>.Default)
665+
)
666+
);
667+
CollisionCounter collisionCounter = new();
668+
GenTwoContainsNestedList.Sample(
669+
example =>
670+
{
671+
var equal = equalityComparer.Equals(example.e1.TheList, example.e2.TheList);
672+
673+
if (equal)
674+
{
675+
Assert.Equal(example.e1, example.e2);
676+
Assert.True(example.e1 == example.e2);
677+
Assert.False(example.e1 != example.e2);
678+
Assert.Equal(example.e1.ToString(), example.e2.ToString());
679+
Assert.Equal(example.e1.GetHashCode(), example.e2.GetHashCode());
680+
}
681+
else
682+
{
683+
Assert.NotEqual(example.e1, example.e2);
684+
Assert.False(example.e1 == example.e2);
685+
Assert.True(example.e1 != example.e2);
686+
Assert.NotEqual(example.e1.ToString(), example.e2.ToString());
687+
collisionCounter.Add(example.e1.GetHashCode() == example.e2.GetHashCode());
688+
}
689+
},
690+
iter: 10_000
691+
);
692+
collisionCounter.AssertCollisionsLessThan(0.05);
449693
}
450694

451695
[Fact]
@@ -516,5 +760,19 @@ public static void GeneratedToString()
516760
]
517761
).ToString()
518762
);
763+
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
764+
Assert.Equal(
765+
"ContainsNestedList { TheList = [ [ [ X(1), null ], null ], null ] }",
766+
new ContainsNestedList(
767+
[
768+
[
769+
[new BasicEnum.X(1), null],
770+
null,
771+
],
772+
null,
773+
]
774+
).ToString()
775+
);
776+
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
519777
}
520778
}

crates/bindings-csharp/Codegen.Tests/Tests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ static IEnumerable<Diagnostic> GetCompilationErrors(Compilation compilation)
124124
.Emit(Stream.Null)
125125
.Diagnostics.Where(diag => diag.Severity != DiagnosticSeverity.Hidden)
126126
// The order of diagnostics is not predictable, sort them by location to make the test deterministic.
127-
.OrderBy(diag => diag.Location.ToString());
127+
.OrderBy(diag => diag.GetMessage() + diag.Location.ToString());
128128
}
129129

130130
[Fact]

crates/bindings-csharp/Codegen.Tests/fixtures/client/snapshots/Type#CustomClass.verified.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ SpacetimeDB.BSATN.ITypeRegistrar registrar
5050

5151
public override int GetHashCode()
5252
{
53-
return IntField.GetHashCode() ^ StringField.GetHashCode();
53+
var ___hashIntField = IntField.GetHashCode();
54+
var ___hashStringField = StringField == null ? 0 : StringField.GetHashCode();
55+
return ___hashIntField ^ ___hashStringField;
5456
}
5557

5658
#nullable enable
5759
public bool Equals(CustomClass that)
5860
{
59-
return IntField.Equals(that.IntField) && StringField.Equals(that.StringField);
61+
var ___eqIntField = this.IntField.Equals(that.IntField);
62+
var ___eqStringField =
63+
this.StringField == null
64+
? that.StringField == null
65+
: this.StringField.Equals(that.StringField);
66+
return ___eqIntField && ___eqStringField;
6067
}
6168

6269
public override bool Equals(object? that)

crates/bindings-csharp/Codegen.Tests/fixtures/client/snapshots/Type#CustomStruct.verified.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,20 @@ SpacetimeDB.BSATN.ITypeRegistrar registrar
5252

5353
public override int GetHashCode()
5454
{
55-
return IntField.GetHashCode() ^ StringField.GetHashCode();
55+
var ___hashIntField = IntField.GetHashCode();
56+
var ___hashStringField = StringField == null ? 0 : StringField.GetHashCode();
57+
return ___hashIntField ^ ___hashStringField;
5658
}
5759

5860
#nullable enable
5961
public bool Equals(CustomStruct that)
6062
{
61-
return IntField.Equals(that.IntField) && StringField.Equals(that.StringField);
63+
var ___eqIntField = this.IntField.Equals(that.IntField);
64+
var ___eqStringField =
65+
this.StringField == null
66+
? that.StringField == null
67+
: this.StringField.Equals(that.StringField);
68+
return ___eqIntField && ___eqStringField;
6269
}
6370

6471
public override bool Equals(object? that)

0 commit comments

Comments
 (0)