Skip to content

Commit 5dc7b08

Browse files
Enhance Node2Vec random seed tests
Co-authored-by: Ioannis Panagiotas <[email protected]>
1 parent 0a1f294 commit 5dc7b08

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java

+146
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22+
import org.assertj.core.api.SoftAssertions;
23+
import org.assertj.core.api.junit.jupiter.InjectSoftAssertions;
2224
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
2325
import org.assertj.core.data.Offset;
26+
import org.junit.jupiter.api.DisplayName;
2427
import org.junit.jupiter.api.Test;
2528
import org.junit.jupiter.api.extension.ExtendWith;
2629
import org.junit.jupiter.params.ParameterizedTest;
@@ -52,6 +55,7 @@
5255
import org.neo4j.gds.ml.core.tensor.FloatVector;
5356
import org.neo4j.gds.termination.TerminationFlag;
5457

58+
import java.util.Arrays;
5559
import java.util.List;
5660
import java.util.Optional;
5761
import java.util.SplittableRandom;
@@ -65,6 +69,9 @@
6569
@GdlExtension
6670
class Node2VecTest {
6771

72+
@InjectSoftAssertions
73+
private SoftAssertions assertions;
74+
6875
private static final List<Long> NO_SOURCE_NODES = List.of();
6976
private static final Optional<Long> NO_RANDOM_SEED = Optional.empty();
7077

@@ -190,6 +197,145 @@ void randomSeed(int concurrency) {
190197
}
191198
}
192199

200+
@Test
201+
@DisplayName("Should produce the same embeddings for the same randomSeed and single-threaded.")
202+
void twoRunsSingleThreadedWithTheSameRandomSeed() {
203+
204+
var concurrency = new Concurrency(1);
205+
int embeddingDimension = 8;
206+
var walkParameters = new SamplingWalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
207+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);
208+
209+
var firstRunEmbeddings = new Node2Vec(
210+
graph,
211+
concurrency,
212+
NO_SOURCE_NODES,
213+
Optional.of(1337L),
214+
1000,
215+
new Node2VecParameters(walkParameters, trainParameters),
216+
ProgressTracker.NULL_TRACKER,
217+
TerminationFlag.RUNNING_TRUE
218+
).compute().embeddings();
219+
220+
var secondRunEmbedding = new Node2Vec(
221+
graph,
222+
concurrency,
223+
NO_SOURCE_NODES,
224+
Optional.of(1337L),
225+
1000,
226+
new Node2VecParameters(walkParameters, trainParameters),
227+
ProgressTracker.NULL_TRACKER,
228+
TerminationFlag.RUNNING_TRUE
229+
).compute().embeddings();
230+
231+
for (long node = 0; node < graph.nodeCount(); node++) {
232+
var e1 = firstRunEmbeddings.get(node).data();
233+
var e2 = secondRunEmbedding.get(node).data();
234+
assertThat(e1)
235+
.isEqualTo(e2);
236+
}
237+
}
238+
239+
@ParameterizedTest(name = "Should produce similar embeddings for the same randomSeed and concurrency={0}")
240+
@ValueSource(ints = {4, 8})
241+
void twoRunsWithTheSameConcurrencyAndRandomSeed(int concurrency) {
242+
243+
int embeddingDimension = 8;
244+
var walkParameters = new SamplingWalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
245+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);
246+
247+
var firstRunEmbeddings = new Node2Vec(
248+
graph,
249+
new Concurrency(concurrency),
250+
NO_SOURCE_NODES,
251+
Optional.of(1337L),
252+
1000,
253+
new Node2VecParameters(walkParameters, trainParameters),
254+
ProgressTracker.NULL_TRACKER,
255+
TerminationFlag.RUNNING_TRUE
256+
).compute().embeddings();
257+
258+
var secondRunEmbedding = new Node2Vec(
259+
graph,
260+
new Concurrency(concurrency),
261+
NO_SOURCE_NODES,
262+
Optional.of(1337L),
263+
1000,
264+
new Node2VecParameters(walkParameters, trainParameters),
265+
ProgressTracker.NULL_TRACKER,
266+
TerminationFlag.RUNNING_TRUE
267+
).compute().embeddings();
268+
269+
for (long node = 0; node < graph.nodeCount(); node++) {
270+
var e1 = firstRunEmbeddings.get(node).data();
271+
var e2 = secondRunEmbedding.get(node).data();
272+
var cosine = Intersections.cosine(e1, e2, embeddingDimension);
273+
assertions.assertThat(cosine)
274+
.as(
275+
"""
276+
Cosine similarity of the embedding for node %s should be close to 1, it was %s.
277+
Actual embeddings are:
278+
e1 = %s,
279+
e2 = %s
280+
""",
281+
node,
282+
cosine,
283+
Arrays.toString(e1),
284+
Arrays.toString(e2)
285+
)
286+
.isCloseTo(1, Offset.offset(1e-3f));
287+
}
288+
}
289+
290+
@Test
291+
@DisplayName("Should produce similar embeddings for the same randomSeed and different concurrency values.")
292+
void twoRunsSameRandomSeedDifferentConcurrency() {
293+
int embeddingDimension = 8;
294+
var walkParameters = new SamplingWalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
295+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);
296+
297+
var firstRunEmbeddings = new Node2Vec(
298+
graph,
299+
new Concurrency(4),
300+
NO_SOURCE_NODES,
301+
Optional.of(1337L),
302+
1000,
303+
new Node2VecParameters(walkParameters, trainParameters),
304+
ProgressTracker.NULL_TRACKER,
305+
TerminationFlag.RUNNING_TRUE
306+
).compute().embeddings();
307+
308+
var secondRunEmbedding = new Node2Vec(
309+
graph,
310+
new Concurrency(8),
311+
NO_SOURCE_NODES,
312+
Optional.of(1337L),
313+
1000,
314+
new Node2VecParameters(walkParameters, trainParameters),
315+
ProgressTracker.NULL_TRACKER,
316+
TerminationFlag.RUNNING_TRUE
317+
).compute().embeddings();
318+
319+
for (long node = 0; node < graph.nodeCount(); node++) {
320+
var e1 = firstRunEmbeddings.get(node).data();
321+
var e2 = secondRunEmbedding.get(node).data();
322+
var cosine = Intersections.cosine(e1, e2, embeddingDimension);
323+
assertions.assertThat(cosine)
324+
.as(
325+
"""
326+
Cosine similarity of the embedding for node %s should be close to 1, it was %s.
327+
Actual embeddings are:
328+
e1 = %s,
329+
e2 = %s
330+
""",
331+
node,
332+
cosine,
333+
Arrays.toString(e1),
334+
Arrays.toString(e2)
335+
)
336+
.isCloseTo(1, Offset.offset(1e-3f));
337+
}
338+
}
193339
static Stream<Arguments> graphs() {
194340
return Stream.of(
195341
Arguments.of("All Labels", List.of(NodeLabel.of("Node1"), NodeLabel.of("Node2"), NodeLabel.of("Isolated"))),

0 commit comments

Comments
 (0)