Skip to content

Commit 6e26b7e

Browse files
IoannisPanagiotasvnickolov
authored andcommitted
Some more testing
1 parent 5dc7b08 commit 6e26b7e

File tree

2 files changed

+107
-15
lines changed

2 files changed

+107
-15
lines changed

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducerTest.java

+36-15
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22+
import org.assertj.core.data.Offset;
2223
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.collections.ha.HugeLongArray;
2325
import org.neo4j.gds.core.concurrency.Concurrency;
2426

25-
import java.util.Map;
26-
import java.util.function.Function;
27-
import java.util.stream.Collectors;
28-
import java.util.stream.IntStream;
27+
import java.util.HashMap;
2928

30-
import static org.junit.jupiter.api.Assertions.assertEquals;
29+
import static org.assertj.core.api.Assertions.assertThat;
3130

3231
class NegativeSampleProducerTest {
3332

3433
@Test
3534
void shouldProduceSamplesAccordingToNodeDistribution() {
35+
3636
var builder = new RandomWalkProbabilities.Builder(
3737
2,
3838
new Concurrency(4),
@@ -46,19 +46,40 @@ void shouldProduceSamplesAccordingToNodeDistribution() {
4646
builder.registerWalk(new long[]{1});
4747

4848
RandomWalkProbabilities probabilityComputer = builder.build();
49-
5049
var sampler = new NegativeSampleProducer(probabilityComputer.negativeSamplingDistribution(),0);
5150

52-
Map<Long, Integer> distribution = IntStream
53-
.range(0, 1300)
54-
.mapToObj(ignore -> sampler.next())
55-
.collect(Collectors.toMap(
56-
Function.identity(),
57-
ignore -> 1,
58-
Integer::sum
59-
));
51+
var distribution = new HashMap<Long,Integer>();
52+
int SAMPLES = 1300;
53+
for (int i=0;i<SAMPLES;++i){
54+
var next = sampler.next();
55+
distribution.put(next, distribution.getOrDefault(next,0) + 1);
56+
}
6057

6158
// We samples nodes with a probability of their number of occurrences^0.75 (16^0.75=12, 1^0.75=1)
62-
assertEquals(1.0 / 12, distribution.get(1L).doubleValue() / distribution.get(0L), 0.1);
59+
assertThat(distribution.get(1L).doubleValue() / distribution.get(0L)).isCloseTo(1.0/12, Offset.offset(0.1));
6360
}
61+
62+
@Test
63+
void shouldProduceDifferentlyIfSeeded() {
64+
65+
var sampler1 = new NegativeSampleProducer(HugeLongArray.of(16,18),0);
66+
var sampler2 = new NegativeSampleProducer(HugeLongArray.of(16,18),1);
67+
68+
var distribution1 = new HashMap<Long,Integer>();
69+
var distribution2 = new HashMap<Long,Integer>();
70+
71+
int SAMPLES = 1300;
72+
for (int i=0;i<SAMPLES;++i){
73+
var next1 = sampler1.next();
74+
distribution1.put(next1, distribution1.getOrDefault(next1,0) + 1);
75+
//
76+
var next2 = sampler2.next();
77+
distribution2.put(next2, distribution1.getOrDefault(next2,0) + 1);
78+
}
79+
80+
assertThat(distribution1.get(0L)).isNotEqualTo(distribution2.get(0L));
81+
assertThat(distribution1.get(1L)).isNotEqualTo(distribution2.get(1L));
82+
83+
}
84+
6485
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
import org.assertj.core.data.Offset;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.core.concurrency.Concurrency;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
class RandomWalkProbabilitiesTest {
29+
30+
@Test
31+
void shouldProduceSamplesAccordingToNodeDistribution() {
32+
double positiveSamplingFactor = 0.001;
33+
double negativeSamplingExponent = 0.75;
34+
var builder = new RandomWalkProbabilities.Builder(
35+
2,
36+
new Concurrency(4),
37+
positiveSamplingFactor,
38+
negativeSamplingExponent
39+
);
40+
41+
builder
42+
.registerWalk(new long[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
43+
44+
builder.registerWalk(new long[]{1});
45+
46+
RandomWalkProbabilities probabilityComputer = builder.build();
47+
48+
var negSampling = probabilityComputer.negativeSamplingDistribution();
49+
var posSampling = probabilityComputer.positiveSamplingProbabilities();
50+
51+
double app0 = 16;
52+
double app1 = 1;
53+
double sum = 17;
54+
double freq0 = app0/sum;
55+
double freq1 = app1/sum;
56+
57+
var expectedPos0 = (Math.sqrt(freq0/positiveSamplingFactor) + 1) * (positiveSamplingFactor/freq0);
58+
var expectedPos1 = (Math.sqrt(freq1/positiveSamplingFactor) + 1) * (positiveSamplingFactor/freq1);
59+
60+
assertThat(posSampling.get(0)).isCloseTo(expectedPos0, Offset.offset(1e-6));
61+
assertThat(posSampling.get(1)).isCloseTo(expectedPos1, Offset.offset(1e-6));
62+
63+
//neg[i] = 2*pow(16,negativeSamplingExponent) + neg[i-1]
64+
long expectedNeg0 = 2 * (long) Math.pow(app0, negativeSamplingExponent);
65+
long expectedNeg1 = 2 * (long) Math.pow(app1, negativeSamplingExponent) + expectedNeg0;
66+
67+
assertThat(negSampling.get(0)).isEqualTo(expectedNeg0);
68+
assertThat(negSampling.get(1)).isEqualTo(expectedNeg1);
69+
70+
}
71+
}

0 commit comments

Comments
 (0)