Skip to content

Commit 3c3b294

Browse files
breakanalysisFlorentinD
authored andcommitted
Make some PageRank and Eigenvector tests use tolerance in map assertions
Co-Authored-By: Florentin Dörre <[email protected]>
1 parent db4bcc2 commit 3c3b294

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

alpha/alpha-proc/src/test/java/org/neo4j/graphalgo/centrality/eigenvector/EigenvectorCentralityProcTest.java

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,13 @@
3838
import org.neo4j.graphdb.Label;
3939

4040
import java.io.File;
41-
import java.util.Collection;
4241
import java.util.HashMap;
43-
import java.util.HashSet;
4442
import java.util.Map;
4543
import java.util.stream.Stream;
4644

4745
import static org.junit.jupiter.api.Assertions.assertEquals;
4846
import static org.junit.jupiter.api.Assertions.assertNotEquals;
4947
import static org.junit.jupiter.api.Assertions.assertTrue;
50-
import static org.junit.jupiter.api.Assertions.fail;
5148
import static org.junit.jupiter.params.provider.Arguments.arguments;
5249
import static org.neo4j.graphalgo.compat.GraphDatabaseApiProxy.findNode;
5350
import static org.neo4j.graphalgo.compat.GraphDatabaseApiProxy.runInTransaction;
@@ -149,25 +146,8 @@ void setup() throws Exception {
149146
createExplicitGraph(EXPLICIT_GRAPH_NAME);
150147
}
151148

152-
@Override
153-
protected void assertMapEquals(Map<Long, Double> expected, Map<Long, Double> actual) {
154-
assertEquals(expected.size(), actual.size(), "number of elements");
155-
Collection<Long> expectedKeys = new HashSet<>(expected.keySet());
156-
for (Map.Entry<Long, Double> entry : actual.entrySet()) {
157-
assertTrue(
158-
expectedKeys.remove(entry.getKey()),
159-
"unknown key " + entry.getKey()
160-
);
161-
assertEquals(
162-
expected.get(entry.getKey()),
163-
entry.getValue(),
164-
0.1,
165-
"value for " + entry.getKey()
166-
);
167-
}
168-
for (Long expectedKey : expectedKeys) {
169-
fail("missing key " + expectedKey);
170-
}
149+
private void assertMapEqualsWithTolerance(Map<Long, Double> expected, Map<Long, Double> actual) {
150+
super.assertMapEqualsWithTolerance(expected, actual, 0.1);
171151
}
172152

173153
@ParameterizedTest(name = "Normalization: {0}")
@@ -192,7 +172,7 @@ void eigenvectorCentralityOnExplicitGraph(String normalizationType, Map<Long, Do
192172
(Double) row.get("score")
193173
)
194174
);
195-
assertMapEquals(expected, actual);
175+
assertMapEqualsWithTolerance(expected, actual);
196176
}
197177

198178
@ParameterizedTest(name = "Normalization: {0}")
@@ -227,7 +207,7 @@ void eigenvectorCentralityOnImplicitGraph(String normalizationType, Map<Long, Do
227207
(Double) row.get("score")
228208
)
229209
);
230-
assertMapEquals(expected, actual);
210+
assertMapEqualsWithTolerance(expected, actual);
231211
}
232212

233213
@ParameterizedTest(name = "Normalization: {0}")
@@ -260,7 +240,7 @@ void eigenvectorCentralityWriteOnExplicitGraph(String normalizationType, Map<Lon
260240
)
261241
);
262242

263-
assertMapEquals(expected, actual);
243+
assertMapEqualsWithTolerance(expected, actual);
264244
}
265245

266246
@ParameterizedTest(name = "Normalization: {0}")
@@ -297,7 +277,7 @@ void eigenvectorCentralityWriteOnImplicitGraph(String normalizationType, Map<Lon
297277
)
298278
);
299279

300-
assertMapEquals(expected, actual);
280+
assertMapEqualsWithTolerance(expected, actual);
301281
}
302282

303283
@ParameterizedTest(name = "Graph Creation: {0}")
@@ -318,7 +298,7 @@ void testStreamAllDefaults(String desc, GdsCypher.ModeBuildStage queryBuilder) {
318298
(Double) row.get("score")
319299
)
320300
);
321-
assertMapEquals(noNormExpected, actual);
301+
assertMapEqualsWithTolerance(noNormExpected, actual);
322302
}
323303

324304
@ParameterizedTest(name = "Graph Creation: {0}")
@@ -369,7 +349,7 @@ void testParallelStream(String desc, GdsCypher.ModeBuildStage queryBuilder) {
369349
actual.put(nodeId, (Double) row.get("score"));
370350
}
371351
);
372-
assertMapEquals(noNormExpected, actual);
352+
assertMapEqualsWithTolerance(noNormExpected, actual);
373353
}
374354

375355
static Stream<Arguments> normalizations() {

proc/centrality/src/test/java/org/neo4j/graphalgo/pagerank/PageRankStreamProcTest.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ public PageRankStreamConfig createConfig(CypherMapWrapper mapWrapper) {
5252
return PageRankStreamConfig.of("", Optional.empty(), Optional.empty(), mapWrapper);
5353
}
5454

55+
private void assertMapEqualsWithTolerance(Map<Long, Double> expected, Map<Long, Double> actual) {
56+
super.assertMapEqualsWithTolerance(expected, actual, 0.1);
57+
}
58+
5559
@ParameterizedTest(name = "{1}")
5660
@MethodSource("org.neo4j.graphalgo.pagerank.PageRankProcTest#graphVariations")
5761
void testPageRankParallelExecution(ModeBuildStage queryBuilder, String testName) {
@@ -64,7 +68,7 @@ void testPageRankParallelExecution(ModeBuildStage queryBuilder, String testName)
6468
actual.put(nodeId, (Double) row.get("score"));
6569
}
6670
);
67-
assertMapEquals(expected, actual);
71+
assertMapEqualsWithTolerance(expected, actual);
6872
}
6973

7074
@ParameterizedTest(name = "{1}")
@@ -79,7 +83,7 @@ void testWeightedPageRankWithAllRelationshipsEqual(ModeBuildStage queryBuilder,
7983
runQueryWithRowConsumer(query,
8084
row -> actual.put((Long) row.get("nodeId"), (Double) row.get("score"))
8185
);
82-
assertMapEquals(expected, actual);
86+
assertMapEqualsWithTolerance(expected, actual);
8387
}
8488

8589
@ParameterizedTest(name = "{1}")
@@ -129,7 +133,7 @@ void testWeightedPageRankWithCachedWeights(ModeBuildStage queryBuilder, String t
129133
runQueryWithRowConsumer(query,
130134
row -> actual.put((Long) row.get("nodeId"), (Double) row.get("score"))
131135
);
132-
assertMapEquals(weightedExpected, actual);
136+
assertMapEqualsWithTolerance(weightedExpected, actual);
133137
}
134138

135139
@ParameterizedTest(name = "{1}")
@@ -144,7 +148,7 @@ void testPageRank(ModeBuildStage queryBuilder, String testCaseName) {
144148
runQueryWithRowConsumer(query,
145149
row -> actual.put((Long) row.get("nodeId"), (Double) row.get("score"))
146150
);
147-
assertMapEquals(expected, actual);
151+
assertMapEqualsWithTolerance(expected, actual);
148152
}
149153

150154
@ParameterizedTest(name = "{1}")
@@ -159,7 +163,7 @@ void testWeightedPageRank(ModeBuildStage queryBuilder, String testCaseName) {
159163
runQueryWithRowConsumer(query,
160164
row -> actual.put((Long) row.get("nodeId"), (Double) row.get("score"))
161165
);
162-
assertMapEquals(weightedExpected, actual);
166+
assertMapEqualsWithTolerance(weightedExpected, actual);
163167
}
164168

165169
@ParameterizedTest(name = "{1}")

proc/centrality/src/test/java/org/neo4j/graphalgo/pagerank/PersonalizedPageRankProcTest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ Map<Long, Double> createExpectedResults(final GraphDatabaseService db) {
107107
return expected;
108108
}
109109

110+
private void assertMapEqualsWithTolerance(Map<Long, Double> expected, Map<Long, Double> actual) {
111+
super.assertMapEqualsWithTolerance(expected, actual, 0.1);
112+
}
113+
110114
@Test
111115
void personalizedPageRankOnImplicitGraph() {
112116
List<Node> startNodes = new ArrayList<>();
@@ -144,7 +148,7 @@ void personalizedPageRankOnImplicitGraph() {
144148
)
145149
);
146150

147-
assertMapEquals(expected, actual);
151+
assertMapEqualsWithTolerance(expected, actual);
148152
}
149153

150154
@Test
@@ -236,7 +240,7 @@ void personalizedPageRankOnExplicitGraph() {
236240
)
237241
);
238242

239-
assertMapEquals(expected, actual);
243+
assertMapEqualsWithTolerance(expected, actual);
240244
}
241245

242246
@Test
@@ -285,7 +289,7 @@ void testStreamRunsOnLoadedGraphWithNodeLabelFilter() throws Exception {
285289
)
286290
);
287291

288-
assertMapEquals(expected, actual);
292+
assertMapEqualsWithTolerance(expected, actual);
289293
} finally {
290294
db.shutdown();
291295
GraphStoreCatalog.removeAllLoadedGraphs();

test-utils/src/main/java/org/neo4j/graphalgo/BaseProcTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,26 @@ protected void assertMapEquals(Map<Long, Double> expected, Map<Long, Double> act
237237
assertThat(actual, mapEquals(expected));
238238
}
239239

240+
protected void assertMapEqualsWithTolerance(Map<Long, Double> expected, Map<Long, Double> actual, Double tolerance) {
241+
assertEquals(expected.size(), actual.size(), "number of elements");
242+
Collection<Long> expectedKeys = new HashSet<>(expected.keySet());
243+
for (Map.Entry<Long, Double> entry : actual.entrySet()) {
244+
assertTrue(
245+
expectedKeys.remove(entry.getKey()),
246+
"unknown key " + entry.getKey()
247+
);
248+
assertEquals(
249+
expected.get(entry.getKey()),
250+
entry.getValue(),
251+
tolerance,
252+
"value for " + entry.getKey()
253+
);
254+
}
255+
for (Long expectedKey : expectedKeys) {
256+
fail("missing key " + expectedKey);
257+
}
258+
}
259+
240260
protected void assertResult(String scoreProperty, Map<Long, Double> expected) {
241261
runInTransaction(db, tx -> {
242262
for (Map.Entry<Long, Double> entry : expected.entrySet()) {

0 commit comments

Comments
 (0)