Skip to content

Commit 4e6ca37

Browse files
Compute sign inside SignedProbabilities and replace 0 probabilities
Since random forest may produce a probability of 0, we replace such values with 1e-100 to retain the sign (positivity vs negativity of LP eddges). Adding the sign as an argument to SignedProbabilities::add makes it impossible to add a 0 value with ambiguos sign. If the logic was at the call-site, then future call-sites would be unprotected from such faults and would have to know about the 1e-100 constant that replaces 0.0. Co-Authored-By: Mats Rydberg <[email protected]>
1 parent 73c234b commit 4e6ca37

File tree

6 files changed

+69
-50
lines changed

6 files changed

+69
-50
lines changed

alpha/alpha-algo/src/main/java/org/neo4j/gds/ml/linkmodels/SignedProbabilities.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
* Represents a sorted set of doubles, sorted according to their absolute value in increasing order.
3636
*/
3737
public final class SignedProbabilities {
38+
public static double ALMOST_ZERO = 1e-100;
3839
private static final Comparator<Double> ABSOLUTE_VALUE_COMPARATOR = Comparator.comparingDouble(Math::abs);
3940

4041
private final Optional<TreeSet<Double>> tree;
@@ -71,13 +72,15 @@ public static SignedProbabilities create(long capacity) {
7172
return new SignedProbabilities(tree, list, isTree);
7273
}
7374

74-
public synchronized void add(double value) {
75-
if (value > 0) positiveCount++;
75+
public synchronized void add(double probability, boolean isPositive) {
76+
var nonZeroProbability = probability == 0 ? ALMOST_ZERO : probability;
77+
var signedProbability = isPositive ? nonZeroProbability : -1 * nonZeroProbability;
78+
if (signedProbability > 0) positiveCount++;
7679
else negativeCount++;
7780
if (isTree) {
78-
tree.get().add(value);
81+
tree.get().add(signedProbability);
7982
} else {
80-
list.get().add(value);
83+
list.get().add(signedProbability);
8184
}
8285
}
8386

alpha/alpha-algo/src/test/java/org/neo4j/gds/ml/linkmodels/metrics/LinkMetricTest.java

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import java.util.stream.Stream;
3636

3737
import static org.assertj.core.api.Assertions.assertThat;
38-
import static org.neo4j.gds.ml.metrics.LinkMetric.AUCPR;
38+
import static org.neo4j.gds.ml.linkmodels.metrics.LinkMetric.AUCPR;
3939

4040
class LinkMetricTest {
4141

@@ -46,16 +46,16 @@ class LinkMetricTest {
4646
@Test
4747
void shouldComputeAUCPR() {
4848
var signedProbabilities = SignedProbabilities.create(10);
49-
signedProbabilities.add(10);
50-
signedProbabilities.add(9);
51-
signedProbabilities.add(-8);
52-
signedProbabilities.add(7);
53-
signedProbabilities.add(-6);
54-
signedProbabilities.add(5);
55-
signedProbabilities.add(-4);
56-
signedProbabilities.add(-3);
57-
signedProbabilities.add(-2);
58-
signedProbabilities.add(1);
49+
signedProbabilities.add(10, true);
50+
signedProbabilities.add(9, true);
51+
signedProbabilities.add(8, false);
52+
signedProbabilities.add(7, true);
53+
signedProbabilities.add(6, false);
54+
signedProbabilities.add(5, true);
55+
signedProbabilities.add(4, false);
56+
signedProbabilities.add(3, false);
57+
signedProbabilities.add(2, false);
58+
signedProbabilities.add(1, true);
5959
var r10 = 1.0;
6060
var p10 = 0.5;
6161
var r9 = 0.8;
@@ -93,17 +93,17 @@ void shouldComputeAUCPR() {
9393
double expectedAUCScore = area10to9 + area9to8 + area8to7 + area7to6 + area6to5
9494
+ area5to4 + area4to3 + area3to2 + area2to1 + area1to0;
9595

96-
var aucScore = LinkMetric.AUCPR.compute(signedProbabilities, 1.0);
96+
var aucScore = AUCPR.compute(signedProbabilities, 1.0);
9797
assertThat(aucScore).isCloseTo(expectedAUCScore, Offset.offset(1e-24));
9898
}
9999

100100
@Test
101101
void shouldComputeAUCPRWithNegativeClassWeight() {
102102
var signedProbabilities = SignedProbabilities.create(4);
103-
signedProbabilities.add(4);
104-
signedProbabilities.add(-3);
105-
signedProbabilities.add(-2);
106-
signedProbabilities.add(1);
103+
signedProbabilities.add(4, true);
104+
signedProbabilities.add(3, false);
105+
signedProbabilities.add(2, false);
106+
signedProbabilities.add(1, true);
107107
// r4 means recall when extracting 4 examples , r3 means recall when extracting 3 examples etc
108108
var r4 = 1.0;
109109
var p4 = 2.0/22.0;
@@ -122,21 +122,21 @@ void shouldComputeAUCPRWithNegativeClassWeight() {
122122
var area1to0 = (r1 - r0) * (p1 + p0) / 2.0;
123123

124124
double expectedAUCScore = area4to3 + area3to2 + area2to1 + area1to0;
125-
var aucScore = LinkMetric.AUCPR.compute(signedProbabilities, 10.0);
125+
var aucScore = AUCPR.compute(signedProbabilities, 10.0);
126126
assertThat(aucScore).isCloseTo(expectedAUCScore, Offset.offset(1e-24));
127127
}
128128

129129
@Test
130130
void shouldComputeAUCPRRepeatedScores() {
131131
var signedProbabilities = SignedProbabilities.create(7);
132-
signedProbabilities.add(-4);
133-
signedProbabilities.add(4);
134-
signedProbabilities.add(-4);
135-
signedProbabilities.add(3);
136-
signedProbabilities.add(2);
137-
signedProbabilities.add(-2);
138-
signedProbabilities.add(1);
139-
var aucScore = LinkMetric.AUCPR.compute(signedProbabilities, 1.0);
132+
signedProbabilities.add(4, false);
133+
signedProbabilities.add(4, true);
134+
signedProbabilities.add(4, false);
135+
signedProbabilities.add(3, true);
136+
signedProbabilities.add(2, true);
137+
signedProbabilities.add(2, false);
138+
signedProbabilities.add(1, true);
139+
var aucScore = AUCPR.compute(signedProbabilities, 1.0);
140140
// r4 means recall when extracting 4 groups , r3 means recall when extracting 3 groups etc
141141
var r4 = 1.0;
142142
var p4 = 4.0/7.0;
@@ -161,11 +161,11 @@ void shouldComputeAUCPRRepeatedScores() {
161161
@Test
162162
void shouldComputeAUCPRHighestPriorityElementIsNegative() {
163163
var signedProbabilities = SignedProbabilities.create(4);
164-
signedProbabilities.add(-4);
165-
signedProbabilities.add(3);
166-
signedProbabilities.add(-2);
167-
signedProbabilities.add(1);
168-
var aucScore = LinkMetric.AUCPR.compute(signedProbabilities, 1.0);
164+
signedProbabilities.add(4, false);
165+
signedProbabilities.add(3, true);
166+
signedProbabilities.add(2, false);
167+
signedProbabilities.add(1, true);
168+
var aucScore = AUCPR.compute(signedProbabilities, 1.0);
169169
// r4 means recall when extracting 4 examples , r3 means recall when extracting 3 examples etc
170170
var r4 = 1.0;
171171
var p4 = 0.5;
@@ -192,9 +192,9 @@ void shouldComputeSklearnAUC() throws IOException {
192192
IOUtils.readLines(resourceAsStream, StandardCharsets.UTF_8).forEach(line -> {
193193
var split = line.split(",");
194194
var prob = Float.parseFloat(split[1]);
195-
signedProbabilites.add(split[0].equals("+") ? prob : -prob);
195+
signedProbabilites.add(prob, split[0].equals("+"));
196196
});
197-
assertThat(LinkMetric.AUCPR.compute(signedProbabilites, 1.0)).isEqualTo(expectedAUC, Offset.offset(1e-24));
197+
assertThat(AUCPR.compute(signedProbabilites, 1.0)).isEqualTo(expectedAUC, Offset.offset(1e-24));
198198
}
199199

200200
@ParameterizedTest
@@ -207,8 +207,7 @@ void shouldProduceAUCPRBetween0And1(long randomSeed) {
207207
SignedProbabilities signedProbabilities = SignedProbabilities.create(examples);
208208
for (int i = 0; i < examples; i++) {
209209
double prob = (double) rng.nextInt(numberOfTrees + 1) / numberOfTrees;
210-
double signedP = rng.nextBoolean() ? prob : -prob;
211-
signedProbabilities.add(signedP);
210+
signedProbabilities.add(prob, rng.nextBoolean());
212211
}
213212

214213
double compute = AUCPR.compute(signedProbabilities, 200);

alpha/alpha-algo/src/test/java/org/neo4j/gds/ml/linkmodels/metrics/SignedProbabilitiesTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.ml.linkmodels.metrics;
2121

22+
import org.junit.jupiter.api.Test;
2223
import org.junit.jupiter.params.ParameterizedTest;
2324
import org.junit.jupiter.params.provider.Arguments;
2425
import org.junit.jupiter.params.provider.MethodSource;
@@ -30,11 +31,13 @@
3031
import java.util.ArrayList;
3132
import java.util.Map;
3233
import java.util.Optional;
34+
import java.util.stream.Collectors;
3335
import java.util.stream.DoubleStream;
3436
import java.util.stream.LongStream;
3537
import java.util.stream.Stream;
3638

3739
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.neo4j.gds.ml.linkmodels.SignedProbabilities.ALMOST_ZERO;
3841

3942
public class SignedProbabilitiesTest {
4043

@@ -62,6 +65,21 @@ void shouldEstimateCorrectly(long nodeCount, long relationshipCount, double rela
6265

6366
}
6467

68+
@Test
69+
void shouldAddWithCorrectSignsAndReplaceZeroValues() {
70+
var signedProbabilities = SignedProbabilities.create(4);
71+
signedProbabilities.add(0.8, true);
72+
signedProbabilities.add(0.0, false);
73+
signedProbabilities.add(0.4, false);
74+
signedProbabilities.add(0.0, true);
75+
assertThat(signedProbabilities.stream().boxed().collect(Collectors.toList())).containsExactly(
76+
-ALMOST_ZERO,
77+
ALMOST_ZERO,
78+
-0.4,
79+
0.8
80+
);
81+
}
82+
6583
static Stream<Arguments> parameters() {
6684
return LongStream.of(42, 1339).boxed().flatMap(
6785
nodeCount -> LongStream.of(100, 1000).boxed().flatMap(

doc/asciidoc/machine-learning/linkprediction-pipeline/predict.adoc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,14 @@ We specified `threshold` to filter out predictions with probability less than 45
186186
[opts="header"]
187187
|===
188188
| person1 | person2 | probability
189-
| "Alice" | "Chris" | 0.754705631406466
190-
| "Chris" | "Mark" | 0.720865853276495
191-
| "Alice" | "Mark" | 0.569785164796211
192-
| "Alice" | "Karin" | 0.565318409460237
193-
| "Alice" | "Greg" | 0.563396306698756
189+
| "Michael" | "Veselin" | 0.8
190+
| "Alice" | "Mark" | 0.6
191+
| "Alice" | "Will" | 0.6
192+
| "Greg" | "Veselin" | 0.6
193+
| "Karin" | "Greg" | 0.6
194194
|===
195195

196-
We can see, that our model predicts the most likely link is between Alice and Chris.
196+
We can see, that our model predicts the most likely link is between Michael and Veselin.
197197
--
198198

199199

@@ -284,7 +284,7 @@ Because we are using the `UNDIRECTED` orientation, we will write twice as many r
284284
[opts="header",cols="3,7"]
285285
|===
286286
| relationshipsWritten | samplingStats
287-
| 16 | {linksConsidered=44, didConverge=true, strategy=approximate, ranIterations=2}
287+
| 16 | {linksConsidered=49, didConverge=true, strategy=approximate, ranIterations=3}
288288
|===
289289
--
290290

doc/asciidoc/machine-learning/linkprediction-pipeline/training.adoc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ RETURN
233233
.Results
234234
[opts="header", cols="6, 2, 2, 2"]
235235
|===
236-
| winningModel | avgTrainScore | outerTrainScore | testScore
237-
| {maxEpochs=100, minEpochs=1, penalty=0.0625, patience=1, methodName=LogisticRegression, batchSize=100, tolerance=0.001} | 0.3721560846560847 | 0.3801587301587301 | 0.7638888888888888
236+
| winningModel | avgTrainScore | outerTrainScore | testScore
237+
| {maxDepth=2147483647, minSplitSize=2, numberOfDecisionTrees=5, methodName=RandomForest, numberOfSamplesRatio=1.0} | 0.904365079365079 | 0.971428571428572 | 0.583333333333333
238238
|===
239239

240-
We can see the model configuration with `tolerance = 0.001` (and defaults filled for remaining parameters) was selected, and has a score of `0.76` on the test set.
240+
We can see the RandomForest model configuration with `numberOfDecisionTrees = 5` (and defaults filled for remaining parameters) was selected, and has a score of `0.58` on the test set.
241241
The score computed as the <<linkprediction-pipelines-metrics, AUCPR>> metric, which is in the range [0, 1].
242242
A model which gives higher score to all links than non-links will have a score of 1.0, and a model that assigns random scores will on average have a score of 0.5.
243243

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionEvaluationMetricComputer.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ static Map<LinkMetric, Double> computeMetric(
6363
offset += 1;
6464
boolean isEdge = targets.get(relationshipIdx) == EdgeSplitter.POSITIVE;
6565

66-
var signedProbability = isEdge ? probabilityOfPositiveEdge : -1 * probabilityOfPositiveEdge;
67-
signedProbabilities.add(signedProbability);
66+
signedProbabilities.add(probabilityOfPositiveEdge, isEdge);
6867
}
6968

7069
progressTracker.logProgress(batch.size());

0 commit comments

Comments
 (0)