25
25
import org .neo4j .gds .core .concurrency .Concurrency ;
26
26
import org .neo4j .gds .core .utils .Intersections ;
27
27
import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
28
- import org .neo4j .gds .ml .core .helper .FloatVectorTestUtils ;
29
28
30
29
import java .util .Optional ;
31
30
import java .util .Random ;
@@ -38,7 +37,7 @@ class Node2VecModelTest {
38
37
39
38
@ Test
40
39
void testModel () {
41
- Random random = new Random (42 );
40
+ var random = new Random (42 );
42
41
int numberOfClusters = 10 ;
43
42
int clusterSize = 100 ;
44
43
int numberOfWalks = 10 ;
@@ -80,12 +79,15 @@ void testModel() {
80
79
// as the order of the randomWalks is not deterministic, we also have non-fixed losses
81
80
assertThat (trainResult .lossPerIteration ())
82
81
.hasSize (5 )
83
- .allMatch (loss -> loss > 0 && Double . isFinite (loss ));
82
+ .allSatisfy (loss -> assertThat ( loss ). isPositive (). isFinite ());
84
83
85
84
var embeddings = trainResult .embeddings ();
85
+ assertThat (embeddings .size ()).isEqualTo (nodeCount );
86
86
87
- for (long idx = 0 ; idx < embeddings .size (); idx ++) {
88
- assertThat (FloatVectorTestUtils .notContainsNaN (embeddings .get (idx ))).isTrue ();
87
+ for (long idx = 0 ; idx < nodeCount ; idx ++) {
88
+ assertThat (embeddings .get (idx ).data ())
89
+ .hasSize (trainParameters .embeddingDimension ())
90
+ .doesNotContain (Float .NaN );
89
91
}
90
92
91
93
double innerClusterSum = LongStream .range (0 , numberOfClusters )
0 commit comments