Skip to content

Commit 346d7a5

Browse files
committed
migrate link prediction train estimate
1 parent 92a6c8a commit 346d7a5

File tree

8 files changed

+149
-16
lines changed

8 files changed

+149
-16
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.compat.GdsVersionInfoProvider;
2626
import org.neo4j.gds.core.model.CatalogModelContainer;
2727
import org.neo4j.gds.core.model.Model;
28+
import org.neo4j.gds.core.model.ModelCatalog;
2829
import org.neo4j.gds.mem.MemoryEstimation;
2930
import org.neo4j.gds.mem.MemoryEstimations;
3031
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -40,6 +41,7 @@
4041
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
4142
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
4243
import org.neo4j.gds.ml.training.TrainingStatistics;
44+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
4345

4446
import java.util.ArrayList;
4547
import java.util.List;
@@ -103,21 +105,23 @@ public static Task progressTask(String taskName, LinkPredictionTrainingPipeline
103105
}
104106

105107
public static MemoryEstimation estimate(
106-
ExecutionContext executionContext,
107108
LinkPredictionTrainingPipeline pipeline,
108-
LinkPredictionTrainConfig configuration
109+
LinkPredictionTrainConfig configuration,
110+
ModelCatalog modelCatalog,
111+
AlgorithmsProcedureFacade algorithmsProcedureFacade,
112+
String username
109113
) {
110114
pipeline.validateTrainingParameterSpace();
111115

112116
var splitEstimations = splitEstimation(
113117
pipeline.splitConfig(),
114118
configuration.targetRelationshipType(),
115-
pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())
119+
pipeline.relationshipWeightProperty(modelCatalog, username)
116120
);
117121

118122
MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps(
119-
executionContext.algorithmsProcedureFacade(),
120-
executionContext.modelCatalog(),
123+
algorithmsProcedureFacade,
124+
modelCatalog,
121125
configuration.username(),
122126
pipeline.nodePropertySteps(),
123127
configuration.nodeLabels(),

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutorTest.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,13 @@ void estimateWithDifferentNodePropertySteps(
501501
);
502502

503503
var actualRange = LinkPredictionTrainPipelineExecutor
504-
.estimate(ImmutableExecutionContext.EMPTY, pipeline, config)
504+
.estimate(
505+
pipeline,
506+
config,
507+
ImmutableExecutionContext.EMPTY.modelCatalog(),
508+
ImmutableExecutionContext.EMPTY.algorithmsProcedureFacade(),
509+
ImmutableExecutionContext.EMPTY.username()
510+
)
505511
.estimate(graphDimensions, config.concurrency())
506512
.memoryUsage();
507513

@@ -523,9 +529,11 @@ void failEstimateOnEmptyParameterSpace() {
523529
LinkPredictionTrainingPipeline pipeline = new LinkPredictionTrainingPipeline();
524530

525531
assertThatThrownBy(() -> LinkPredictionTrainPipelineExecutor.estimate(
526-
ExecutionContext.EMPTY,
527532
pipeline,
528-
config
533+
config,
534+
ExecutionContext.EMPTY.modelCatalog(),
535+
ExecutionContext.EMPTY.algorithmsProcedureFacade(),
536+
ExecutionContext.EMPTY.username()
529537
))
530538
.hasMessage("Need at least one model candidate for training.");
531539
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainProc.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import org.neo4j.gds.BaseProc;
2323
import org.neo4j.gds.core.model.ModelCatalog;
2424
import org.neo4j.gds.executor.ExecutionContext;
25-
import org.neo4j.gds.executor.MemoryEstimationExecutor;
2625
import org.neo4j.gds.executor.ProcedureExecutor;
2726
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
2827
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
28+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
2929
import org.neo4j.procedure.Context;
3030
import org.neo4j.procedure.Description;
3131
import org.neo4j.procedure.Mode;
@@ -38,6 +38,8 @@
3838
import static org.neo4j.procedure.Mode.READ;
3939

4040
public class LinkPredictionPipelineTrainProc extends BaseProc {
41+
@Context
42+
public GraphDataScienceProcedures facade;
4143

4244
@Context
4345
public ModelCatalog modelCatalog;
@@ -61,12 +63,7 @@ public Stream<MemoryEstimateResult> estimate(
6163
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
6264
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
6365
) {
64-
PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, algoConfiguration);
65-
return new MemoryEstimationExecutor<>(
66-
new LinkPredictionPipelineTrainSpec(),
67-
executionContext(),
68-
transactionContext()
69-
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
66+
return facade.pipelines().linkPrediction().trainEstimate(graphNameOrConfiguration, algoConfiguration);
7067
}
7168

7269
@Override

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,13 @@ public MemoryEstimation memoryEstimation(LinkPredictionTrainConfig configuration
9393
LinkPredictionTrainingPipeline.class
9494
);
9595

96-
return LinkPredictionTrainPipelineExecutor.estimate(executionContext, pipeline, configuration);
96+
return LinkPredictionTrainPipelineExecutor.estimate(
97+
pipeline,
98+
configuration,
99+
executionContext.modelCatalog(),
100+
executionContext.algorithmsProcedureFacade(),
101+
executionContext.username()
102+
);
97103
}
98104

99105
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.pipelines;
21+
22+
import org.neo4j.gds.api.User;
23+
import org.neo4j.gds.applications.algorithms.machinery.DimensionTransformer;
24+
import org.neo4j.gds.core.GraphDimensions;
25+
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
26+
27+
class DimensionTransformerForLinkPredictionTrain implements DimensionTransformer {
28+
private final PipelineRepository pipelineRepository;
29+
private final LinkPredictionTrainConfig configuration;
30+
31+
DimensionTransformerForLinkPredictionTrain(
32+
PipelineRepository pipelineRepository,
33+
LinkPredictionTrainConfig configuration
34+
) {
35+
this.pipelineRepository = pipelineRepository;
36+
this.configuration = configuration;
37+
}
38+
39+
@Override
40+
public GraphDimensions transform(GraphDimensions graphDimensions) {
41+
// inject expected relationship set sizes which are used in the estimation of the TrainPipelineExecutor
42+
// this allows to compute the MemoryTree over a single graphDimension
43+
var user = new User(configuration.username(), false);
44+
var pipelineName = PipelineName.parse(configuration.pipeline());
45+
46+
var pipeline = pipelineRepository.getLinkPredictionTrainingPipeline(user, pipelineName);
47+
48+
var splitConfig = pipeline.splitConfig();
49+
var targetRelationshipType = configuration.targetRelationshipType();
50+
51+
return splitConfig.expectedGraphDimensions(graphDimensions, targetRelationshipType);
52+
}
53+
}

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,4 +196,20 @@ public Stream<MemoryEstimateResult> streamEstimate(
196196

197197
return Stream.of(result);
198198
}
199+
200+
public Stream<MemoryEstimateResult> trainEstimate(
201+
Object graphNameOrConfiguration,
202+
Map<String, Object> rawConfiguration
203+
) {
204+
PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration);
205+
206+
var configuration = pipelineConfigurationParser.parseLinkPredictionTrainConfig(rawConfiguration);
207+
208+
var result = pipelineApplications.linkPredictionTrainEstimate(
209+
graphNameOrConfiguration,
210+
configuration
211+
);
212+
213+
return Stream.of(result);
214+
}
199215
}

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory;
5353
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
5454
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfiguration;
55+
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
56+
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainPipelineExecutor;
5557
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
5658
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
5759
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationModelResult;
@@ -407,6 +409,26 @@ Stream<StreamResult> linkPredictionStream(GraphName graphName, Map<String, Objec
407409
);
408410
}
409411

412+
MemoryEstimateResult linkPredictionTrainEstimate(
413+
Object graphNameOrConfiguration,
414+
LinkPredictionTrainConfig configuration
415+
) {
416+
var estimate = linkPredictionTrainMemoryEstimation(configuration);
417+
418+
var memoryEstimation = MemoryEstimations.builder("Link Prediction Pipeline Executor")
419+
.add("Pipeline executor", estimate)
420+
.build();
421+
422+
var dimensionTransformer = new DimensionTransformerForLinkPredictionTrain(pipelineRepository, configuration);
423+
424+
return algorithmEstimationTemplate.estimate(
425+
configuration,
426+
graphNameOrConfiguration,
427+
memoryEstimation,
428+
dimensionTransformer
429+
);
430+
}
431+
410432
MemoryEstimateResult nodeClassificationPredictEstimate(
411433
Object graphNameOrConfiguration,
412434
NodeClassificationPredictPipelineBaseConfig configuration
@@ -626,6 +648,28 @@ private MemoryEstimation linkPredictionMemoryEstimation(LinkPredictionPredictPip
626648
return linkPredictionPipelineEstimator.estimate(model, configuration);
627649
}
628650

651+
private MemoryEstimation linkPredictionTrainMemoryEstimation(LinkPredictionTrainConfig configuration) {
652+
var specifiedUser = new User(configuration.username(), false);
653+
var pipelineName = PipelineName.parse(configuration.pipeline());
654+
655+
var pipeline = pipelineRepository.getLinkPredictionTrainingPipeline(
656+
specifiedUser,
657+
pipelineName
658+
);
659+
660+
var estimate = LinkPredictionTrainPipelineExecutor.estimate(
661+
pipeline,
662+
configuration,
663+
modelCatalog,
664+
algorithmsProcedureFacade,
665+
user.getUsername()
666+
);
667+
668+
return MemoryEstimations.builder("LinkPredictionPipelineTrain")
669+
.add(estimate)
670+
.build();
671+
}
672+
629673
private MemoryEstimation nodeClassificationPredictMemoryEstimation(NodeClassificationPredictPipelineBaseConfig configuration) {
630674
var modelName = configuration.modelName();
631675
var username = configuration.username();

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
3232
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfiguration;
3333
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfigurationImpl;
34+
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
3435
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
3536
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
3637

@@ -76,6 +77,10 @@ LinkPredictionPredictPipelineStreamConfig parseLinkPredictionPredictPipelineStre
7677
return parseConfiguration(LinkPredictionPredictPipelineStreamConfig::of, configuration);
7778
}
7879

80+
LinkPredictionTrainConfig parseLinkPredictionTrainConfig(Map<String, Object> configuration) {
81+
return parseConfiguration(LinkPredictionTrainConfig::of, configuration);
82+
}
83+
7984
TunableTrainerConfig parseLogisticRegressionTrainerConfig(Map<String, Object> configuration) {
8085
return parseTrainerConfiguration(
8186
configuration,

0 commit comments

Comments
 (0)