Skip to content

Commit 1bfb8fb

Browse files
FlorentinDadamnsch
andcommitted
Correct the trainSize for LLR in LinkPrediction
Cherry-picked PR #3706 Co-authored-by: Adam Schill Collberg <[email protected]>
1 parent 1df5e37 commit 1bfb8fb

File tree

5 files changed

+33
-63
lines changed

5 files changed

+33
-63
lines changed

alpha/alpha-algo/src/main/java/org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionBase.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,8 @@ protected Variable<Matrix> predictions(Constant<Matrix> features) {
5151
return new Sigmoid<>(MatrixMultiplyWithTransposedSecondOperand.of(features, modelData.weights()));
5252
}
5353

54-
protected Constant<Matrix> features(Graph graph, Batch batch) {
54+
protected Constant<Matrix> features(Graph graph, Batch batch, int rows) {
5555
var graphCopy = graph.concurrentCopy();
56-
// TODO: replace by MutableLong and throw an error saying reduce batchSize if larger than maxint
57-
var relationshipCount = new MutableInt();
58-
// assume batching has been done so that relationship count does not overflow int
59-
batch.nodeIds().forEach(nodeId -> relationshipCount.add(graph.degree(nodeId)));
60-
int rows = relationshipCount.intValue();
6156
int cols = modelData.linkFeatureDimension();
6257
double[] features = new double[rows * cols];
6358
var relationshipOffset = new MutableInt();

alpha/alpha-algo/src/main/java/org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionObjective.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,14 @@ public List<Weights<? extends Tensor<?>>> weights() {
102102

103103
@Override
104104
public Variable<Scalar> loss(Batch batch, long trainSize) {
105-
var features = features(graph, batch);
105+
// assume batching has been done so that relationship count does not overflow int
106+
int rows = 0;
107+
for (var nodeId : batch.nodeIds()) {
108+
rows += graph.degree(nodeId);
109+
}
110+
111+
var features = features(graph, batch, rows);
106112
Variable<Matrix> predictions = predictions(features);
107-
var relationshipCount = new MutableInt();
108-
batch.nodeIds().forEach(nodeId -> relationshipCount.add(graph.degree(nodeId)));
109-
var rows = relationshipCount.getValue();
110113
var targets = makeTargetsArray(batch, rows);
111114
var penaltyVariable = new ConstantScale<>(new L2NormSquared(modelData.weights()), rows * penalty / trainSize);
112115
var unpenalizedLoss = new LogisticLoss(modelData.weights(), predictions, features, targets);

alpha/alpha-algo/src/main/java/org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionTrain.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,15 @@ public LinkLogisticRegressionData compute() {
6565
config.penalty(),
6666
graph
6767
);
68-
var training = new Training(config, progressLogger, graph.nodeCount());
68+
69+
70+
long trainSize = 0;
71+
for (long i = 0; i < trainSet.size(); i++) {
72+
trainSize += graph.degree(trainSet.get(i));
73+
}
74+
75+
var training = new Training(config, progressLogger, trainSize);
76+
6977
Supplier<BatchQueue> queueSupplier = () -> new HugeBatchQueue(trainSet, config.batchSize());
7078
training.train(objective, queueSupplier, config.concurrency());
7179
return objective.modelData;

alpha/alpha-algo/src/test/java/org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionBaseTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void shouldComputeCorrectFeatures() {
6969
);
7070

7171
var allNodesBatch = new LazyBatch(0, (int) graph.nodeCount(), graph.nodeCount());
72-
var features = base.features(graph, allNodesBatch);
72+
var features = base.features(graph, allNodesBatch, (int) graph.relationshipCount());
7373
var expectedFeatures = new Matrix(new double[]{
7474
0.49, 0.49, 1.0,
7575
4.00, 2.56, 1.0,

alpha/alpha-algo/src/test/java/org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionTrainTest.java

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
package org.neo4j.gds.ml.linkmodels.logisticregression;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.params.ParameterizedTest;
24+
import org.junit.jupiter.params.provider.ValueSource;
2325
import org.neo4j.gds.ml.core.features.FeatureExtraction;
2426
import org.neo4j.gds.ml.core.tensor.Matrix;
2527
import org.neo4j.graphalgo.api.Graph;
@@ -30,14 +32,11 @@
3032
import org.neo4j.graphalgo.extension.GdlExtension;
3133
import org.neo4j.graphalgo.extension.GdlGraph;
3234
import org.neo4j.graphalgo.extension.Inject;
33-
import org.neo4j.graphalgo.math.L2Norm;
3435

3536
import java.util.List;
3637
import java.util.Map;
3738

3839
import static org.assertj.core.api.Assertions.assertThat;
39-
import static org.neo4j.gds.ml.core.Dimensions.COLUMNS_INDEX;
40-
import static org.neo4j.gds.ml.core.Dimensions.ROWS_INDEX;
4140

4241
@GdlExtension
4342
class LinkLogisticRegressionTrainTest {
@@ -58,69 +57,34 @@ class LinkLogisticRegressionTrainTest {
5857
@Inject
5958
private Graph graph;
6059

61-
@Test
62-
void shouldComputeWithDefaultAdamOptimizerAndStreakStopper() {
60+
@ParameterizedTest
61+
@ValueSource(ints = {1, 4})
62+
void shouldComputeWithDefaultAdamOptimizerAndStreakStopper(int concurrency) {
6363
var featureProperties = List.of("a", "b");
6464
var config = new LinkLogisticRegressionTrainConfigImpl(
6565
featureProperties,
6666
CypherMapWrapper.create(Map.of(
67-
"maxEpochs", 100000,
67+
"maxEpochs", 10_000,
6868
"tolerance", 1e-4,
69-
"concurrency", 1
69+
"concurrency", concurrency
7070
))
7171
);
7272

7373
var extractors = FeatureExtraction.propertyExtractors(graph, featureProperties);
7474
var trainSet = HugeLongArray.newArray(graph.nodeCount(), AllocationTracker.empty());
7575
trainSet.setAll(i -> i);
76-
var linearRegression = new LinkLogisticRegressionTrain(graph, trainSet, extractors, config, ProgressLogger.NULL_LOGGER);
77-
78-
var result = linearRegression.compute();
79-
80-
assertThat(result).isNotNull();
81-
82-
var trainedWeights = result.weights();
83-
84-
var expected = new Matrix(new double[]{-1.0681821169962793, 1.0115009499444914, -0.1381213947059403}, 1, 3);
85-
assertThat(trainedWeights.data()).satisfies(matrix -> matrix.equals(expected, 1e-8));
86-
}
87-
88-
@Test
89-
void shouldComputeWithDefaultAdamOptimizerAndStreakStopperConcurrently() {
90-
var featureProperties = List.of("a", "b");
91-
var config = new LinkLogisticRegressionTrainConfigImpl(
92-
featureProperties,
93-
CypherMapWrapper.create(Map.of(
94-
"penalty", 1.0,
95-
"maxEpochs", 1000000,
96-
"tolerance", 1e-10,
97-
"concurrency", 4
98-
))
76+
var linearRegression = new LinkLogisticRegressionTrain(
77+
graph,
78+
trainSet,
79+
extractors,
80+
config,
81+
ProgressLogger.NULL_LOGGER
9982
);
10083

101-
var extractors = FeatureExtraction.propertyExtractors(graph, featureProperties);
102-
103-
var trainSet = HugeLongArray.newArray(graph.nodeCount(), AllocationTracker.empty());
104-
trainSet.setAll(i -> i);
105-
var linearRegression = new LinkLogisticRegressionTrain(graph, trainSet, extractors, config, ProgressLogger.NULL_LOGGER);
84+
var expected = new Matrix(new double[]{-1.0681821169962793, 1.0115009499444914, -0.1381213947059403}, 1, 3);
10685

10786
var result = linearRegression.compute();
108-
109-
assertThat(result).isNotNull();
110-
111-
var trainedWeights = result.weights();
112-
assertThat(trainedWeights.dimension(ROWS_INDEX)).isEqualTo(1);
113-
assertThat(trainedWeights.dimension(COLUMNS_INDEX)).isEqualTo(3);
114-
115-
var trainedData = trainedWeights.data().data();
116-
var expectedData = new double[]{-0.16207697085323056, 0.10360002113065836, 0.04906215177508012};
117-
118-
var deviation = new double[3];
119-
for (int i = 0; i < 3; i++) {
120-
deviation[i] = (trainedData[i] - expectedData[i]);
121-
}
122-
// could be flaky but passed 1327 times in a row
123-
assertThat(L2Norm.l2Norm(deviation) / L2Norm.l2Norm(expectedData)).isLessThan(0.05);
87+
assertThat(result.weights().data()).matches(matrix -> matrix.equals(expected, 1e-8));
12488
}
12589

12690
@Test

0 commit comments

Comments
 (0)