Skip to content

Commit cd9bc4a

Browse files
Tidy up Node2VecModelTest
Co-authored-by: Ioannis Panagiotas <[email protected]>
1 parent 6e26b7e commit cd9bc4a

File tree

2 files changed

+7
-43
lines changed

2 files changed

+7
-43
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

+7-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.neo4j.gds.core.concurrency.Concurrency;
2626
import org.neo4j.gds.core.utils.Intersections;
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
28-
import org.neo4j.gds.ml.core.helper.FloatVectorTestUtils;
2928

3029
import java.util.Optional;
3130
import java.util.Random;
@@ -38,7 +37,7 @@ class Node2VecModelTest {
3837

3938
@Test
4039
void testModel() {
41-
Random random = new Random(42);
40+
var random = new Random(42);
4241
int numberOfClusters = 10;
4342
int clusterSize = 100;
4443
int numberOfWalks = 10;
@@ -80,12 +79,15 @@ void testModel() {
8079
// as the order of the randomWalks is not deterministic, we also have non-fixed losses
8180
assertThat(trainResult.lossPerIteration())
8281
.hasSize(5)
83-
.allMatch(loss -> loss > 0 && Double.isFinite(loss));
82+
.allSatisfy(loss -> assertThat(loss).isPositive().isFinite());
8483

8584
var embeddings = trainResult.embeddings();
85+
assertThat(embeddings.size()).isEqualTo(nodeCount);
8686

87-
for (long idx = 0; idx < embeddings.size(); idx++) {
88-
assertThat(FloatVectorTestUtils.notContainsNaN(embeddings.get(idx))).isTrue();
87+
for (long idx = 0; idx < nodeCount; idx++) {
88+
assertThat(embeddings.get(idx).data())
89+
.hasSize(trainParameters.embeddingDimension())
90+
.doesNotContain(Float.NaN);
8991
}
9092

9193
double innerClusterSum = LongStream.range(0, numberOfClusters)

ml/ml-test-utils/src/main/java/org/neo4j/gds/ml/core/helper/FloatVectorTestUtils.java

-38
This file was deleted.

0 commit comments

Comments
 (0)