Skip to content

Commit 5ee6999

Browse files
committed
Make ProgressTrackerCreator independent of AlgoBaseConfig
1 parent 7676a86 commit 5ee6999

File tree

16 files changed

+220
-92
lines changed

16 files changed

+220
-92
lines changed

applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.neo4j.gds.core.concurrency.Concurrency;
5151
import org.neo4j.gds.core.concurrency.DefaultPool;
5252
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
53+
import org.neo4j.gds.core.utils.progress.tasks.Task;
5354
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
5455
import org.neo4j.gds.degree.DegreeCentrality;
5556
import org.neo4j.gds.degree.DegreeCentralityConfig;
@@ -99,7 +100,7 @@ public CentralityAlgorithms(ProgressTrackerCreator progressTrackerCreator, Termi
99100

100101
PageRankResult articleRank(Graph graph, ArticleRankConfig configuration) {
101102
var task = Pregel.progressTask(graph, configuration, ArticleRank.asString());
102-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
103+
var progressTracker = createProgressTracker(task, configuration);
103104

104105
return articleRank(graph, configuration, progressTracker);
105106
}
@@ -123,7 +124,7 @@ public PageRankResult articleRank(Graph graph, ArticleRankConfig configuration,
123124
ArticulationPointsResult articulationPoints(Graph graph, AlgoBaseConfig configuration,boolean shouldComputeComponents) {
124125

125126
var task = ArticulationPointsProgressTaskCreator.progressTask(graph.nodeCount());
126-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
127+
var progressTracker = createProgressTracker(task, configuration);
127128

128129
var algorithm = ArticulationPoints.create(graph, progressTracker,shouldComputeComponents);
129130

@@ -142,7 +143,12 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral
142143
AlgorithmLabel.BetweennessCentrality.asString(),
143144
samplingSize.orElse(graph.nodeCount())
144145
);
145-
return progressTrackerCreator.createProgressTracker(configuration, task);
146+
return progressTrackerCreator.createProgressTracker(
147+
task,
148+
configuration.jobId(),
149+
configuration.concurrency(),
150+
configuration.logProgress()
151+
);
146152
}
147153
);
148154
}
@@ -186,7 +192,7 @@ public BetwennessCentralityResult betweennessCentrality(
186192
BridgeResult bridges(Graph graph, AlgoBaseConfig configuration, boolean shouldComputeComponents) {
187193

188194
var task = BridgeProgressTaskCreator.progressTask(graph.nodeCount());
189-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
195+
var progressTracker = createProgressTracker(task, configuration);
190196

191197
var algorithm = Bridges.create(graph, progressTracker, shouldComputeComponents);
192198

@@ -204,7 +210,7 @@ public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuratio
204210
Tasks.leaf("Greedy", graph.nodeCount()),
205211
Tasks.leaf("LazyForwarding", configuration.seedSetSize() - 1)
206212
);
207-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
213+
var progressTracker = createProgressTracker(task, configuration);
208214

209215
var algorithm = new CELF(graph, configuration.toParameters(), DefaultPool.INSTANCE, progressTracker);
210216

@@ -218,7 +224,7 @@ public CELFResult celf(Graph graph, InfluenceMaximizationBaseConfig configuratio
218224

219225
ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) {
220226
var task = ClosenessCentralityTask.create(graph.nodeCount());
221-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
227+
var progressTracker = createProgressTracker(task, configuration);
222228

223229
return closenessCentrality(graph, configuration, progressTracker);
224230
}
@@ -255,7 +261,7 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf
255261
var parameters = configuration.toParameters();
256262

257263
var task = Tasks.leaf(AlgorithmLabel.DegreeCentrality.asString(), graph.nodeCount());
258-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
264+
var progressTracker = createProgressTracker(task, configuration);
259265

260266
var algorithm = new DegreeCentrality(
261267
graph,
@@ -277,7 +283,7 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf
277283

278284
PageRankResult eigenVector(Graph graph, EigenvectorConfig configuration) {
279285
var task = Pregel.progressTask(graph, configuration, EigenVector.asString());
280-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
286+
var progressTracker = createProgressTracker(task, configuration);
281287

282288
return eigenVector(graph, configuration, progressTracker);
283289
}
@@ -304,7 +310,7 @@ public PageRankResult eigenVector(
304310

305311
HarmonicResult harmonicCentrality(Graph graph, AlgoBaseConfig configuration) {
306312
var task = Tasks.leaf(AlgorithmLabel.HarmonicCentrality.asString());
307-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
313+
var progressTracker = createProgressTracker(task, configuration);
308314

309315
return harmonicCentrality(graph, configuration, progressTracker);
310316
}
@@ -336,7 +342,7 @@ PregelResult hits(Graph graph, HitsConfig configuration) {
336342
configuration.maxIterations(),
337343
AlgorithmLabel.HITS.asString()
338344
);
339-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
345+
var progressTracker = createProgressTracker(task, configuration);
340346

341347
var algorithm = new Hits(
342348
graph,
@@ -359,7 +365,7 @@ IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig conf
359365
Tasks.leaf("TotalTransfers", graph.nodeCount()),
360366
Pregel.progressTask(graph, configuration, "ExposurePropagation")
361367
);
362-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
368+
var progressTracker = createProgressTracker(task, configuration);
363369

364370
var algorithm = new IndirectExposure(
365371
graph,
@@ -378,7 +384,7 @@ IndirectExposureResult indirectExposure(Graph graph, IndirectExposureConfig conf
378384

379385
public PageRankResult pageRank(Graph graph, PageRankConfig configuration) {
380386
var task = Pregel.progressTask(graph, configuration, PageRank.asString());
381-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
387+
var progressTracker = createProgressTracker(task, configuration);
382388

383389
return pageRank(graph, configuration, progressTracker);
384390
}
@@ -457,4 +463,13 @@ private PageRankComputation<PageRankConfig> pageRankComputation(Graph graph, Pag
457463

458464
return new PageRankComputation<>(configuration, mappedSourceNodes, degreeFunction);
459465
}
466+
467+
private ProgressTracker createProgressTracker(Task task, AlgoBaseConfig configuration) {
468+
return progressTrackerCreator.createProgressTracker(
469+
task,
470+
configuration.jobId(),
471+
configuration.concurrency(),
472+
configuration.logProgress()
473+
);
474+
}
460475
}

applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/HitsETLHook.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ public void onGraphStoreLoaded(GraphStore graphStore) {
7272
var parameters = InverseRelationshipsConfigTransformer.toParameters(inverseConfig);
7373

7474
var task = InverseRelationshipsProgressTaskCreator.progressTask(graphStore.nodeCount(),relationshipTypes);
75-
var progressTracker = progressTrackerCreator.createProgressTracker(inverseConfig,task);
75+
var progressTracker = progressTrackerCreator.createProgressTracker(
76+
task,
77+
inverseConfig.jobId(),
78+
inverseConfig.concurrency(),
79+
inverseConfig.logProgress()
80+
);
7681

7782
var inverseRelationships=new InverseRelationships(graphStore,parameters,progressTracker, DefaultPool.INSTANCE,terminationFlag);
7883

applications/algorithms/community/src/main/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithms.java

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ApproxMaxKCutResult approximateMaximumKCut(Graph graph, ApproxMaxKCutBaseConfig
114114
),
115115
configuration.iterations()
116116
);
117-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
117+
var progressTracker = createProgressTracker(task, configuration);
118118

119119
return approximateMaximumKCut(graph, configuration, progressTracker);
120120
}
@@ -147,7 +147,7 @@ ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration)
147147
Tasks.leaf("accumulate counts"),
148148
Tasks.leaf("perform conductance computations")
149149
);
150-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
150+
var progressTracker = createProgressTracker(task, configuration);
151151

152152
var parameters = ConductanceConfigTransformer.toParameters(configuration);
153153

@@ -172,7 +172,7 @@ ConductanceResult conductance(Graph graph, ConductanceBaseConfig configuration)
172172
public Labels hdbscan(Graph graph, HDBScanBaseConfig configuration) {
173173

174174
var task = HDBScanProgressTrackerCreator.hdbscanTask(AlgorithmLabel.HDBScan.asString(), graph.nodeCount());
175-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
175+
var progressTracker = createProgressTracker(task, configuration);
176176

177177
var hdbScan = new HDBScan(
178178
graph,
@@ -188,7 +188,7 @@ public Labels hdbscan(Graph graph, HDBScanBaseConfig configuration) {
188188
public K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configuration) {
189189
var parameters = configuration.toParameters();
190190
var task = K1ColoringProgressTrackerTaskCreator.progressTask(graph.nodeCount(), parameters.maxIterations());
191-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
191+
var progressTracker = createProgressTracker(task, configuration);
192192

193193
var k1ColoringStub = new K1ColoringStub(algorithmMachinery);
194194

@@ -204,7 +204,7 @@ public K1ColoringResult k1Coloring(Graph graph, K1ColoringBaseConfig configurati
204204

205205
KCoreDecompositionResult kCore(Graph graph, AlgoBaseConfig configuration) {
206206
var task = Tasks.leaf(AlgorithmLabel.KCore.asString(), graph.nodeCount());
207-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
207+
var progressTracker = createProgressTracker(task, configuration);
208208

209209
var algorithm = new KCoreDecomposition(graph, configuration.concurrency(), progressTracker, terminationFlag);
210210

@@ -226,7 +226,7 @@ public KmeansResult kMeans(Graph graph, KmeansBaseConfig configuration) {
226226
}
227227

228228
var task = constructKMeansProgressTask(graph, configuration);
229-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
229+
var progressTracker = createProgressTracker(task, configuration);
230230

231231
var kmeansContext = ImmutableKmeansContext.builder()
232232
.executor(DefaultPool.INSTANCE)
@@ -252,7 +252,7 @@ LabelPropagationResult labelPropagation(Graph graph, LabelPropagationBaseConfig
252252
configuration.maxIterations()
253253
)
254254
);
255-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
255+
var progressTracker = createProgressTracker(task, configuration);
256256

257257
var algorithm = new LabelPropagation(
258258
graph,
@@ -277,7 +277,7 @@ LocalClusteringCoefficientResult lcc(Graph graph, LocalClusteringCoefficientBase
277277
}
278278
tasks.add(Tasks.leaf("Calculate Local Clustering Coefficient", graph.nodeCount()));
279279
var task = Tasks.task(AlgorithmLabel.LCC.asString(), tasks);
280-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
280+
var progressTracker = createProgressTracker(task, configuration);
281281

282282
var parameters = configuration.toParameters();
283283

@@ -305,7 +305,7 @@ public LeidenResult leiden(Graph graph, LeidenBaseConfig configuration) {
305305
}
306306

307307
var task = LeidenTask.create(graph, configuration);
308-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
308+
var progressTracker = createProgressTracker(task, configuration);
309309

310310
var parameters = configuration.toParameters();
311311
var seedValues = Optional.ofNullable(parameters.seedProperty())
@@ -344,7 +344,7 @@ LouvainResult louvain(Graph graph, LouvainBaseConfig configuration) {
344344
parameters.maxLevels(),
345345
parameters.maxIterations()
346346
);
347-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
347+
var progressTracker = createProgressTracker(task, configuration);
348348

349349
var algorithm = new Louvain(
350350
graph,
@@ -387,7 +387,7 @@ ModularityOptimizationResult modularityOptimization(Graph graph, ModularityOptim
387387
parameters.maxIterations()
388388
);
389389

390-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
390+
var progressTracker = createProgressTracker(task, configuration);
391391

392392
var seedPropertyValues = configuration.seedProperty() != null ?
393393
CommunityCompanion.extractSeedingNodePropertyValues(
@@ -417,8 +417,10 @@ ModularityOptimizationResult modularityOptimization(Graph graph, ModularityOptim
417417

418418
HugeLongArray scc(Graph graph, AlgoBaseConfig configuration) {
419419
var progressTracker = progressTrackerCreator.createProgressTracker(
420-
configuration,
421-
Tasks.leaf(AlgorithmLabel.SCC.asString(), graph.nodeCount())
420+
Tasks.leaf(AlgorithmLabel.SCC.asString(), graph.nodeCount()),
421+
configuration.jobId(),
422+
configuration.concurrency(),
423+
configuration.logProgress()
422424
);
423425

424426
return scc(graph, configuration, progressTracker);
@@ -437,7 +439,7 @@ public HugeLongArray scc(Graph graph, ConcurrencyConfig configuration, ProgressT
437439

438440
TriangleCountResult triangleCount(Graph graph, TriangleCountBaseConfig configuration) {
439441
var task = Tasks.leaf(AlgorithmLabel.TriangleCount.asString(), graph.nodeCount());
440-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
442+
var progressTracker = createProgressTracker(task, configuration);
441443

442444
return triangleCount(graph, configuration, progressTracker);
443445
}
@@ -479,7 +481,7 @@ Stream<TriangleResult> triangles(Graph graph, ConcurrencyConfig configuration) {
479481

480482
public DisjointSetStruct wcc(Graph graph, WccBaseConfig configuration) {
481483
var task = Tasks.leaf(AlgorithmLabel.WCC.asString(), graph.relationshipCount());
482-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
484+
var progressTracker = createProgressTracker(task, configuration);
483485

484486
if (configuration.hasRelationshipWeightProperty() && configuration.threshold() == 0) {
485487
progressTracker.logWarning(
@@ -565,7 +567,7 @@ PregelResult speakerListenerLPA(Graph graph, SpeakerListenerLPAConfig configurat
565567
configuration.maxIterations(),
566568
AlgorithmLabel.SLLPA.asString()
567569
);
568-
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
570+
var progressTracker = createProgressTracker(task, configuration);
569571

570572
var algorithm = new SpeakerListenerLPA(
571573
graph,
@@ -582,4 +584,13 @@ PregelResult speakerListenerLPA(Graph graph, SpeakerListenerLPAConfig configurat
582584
configuration.concurrency()
583585
);
584586
}
587+
588+
private ProgressTracker createProgressTracker(Task task, AlgoBaseConfig configuration) {
589+
return progressTrackerCreator.createProgressTracker(
590+
task,
591+
configuration.jobId(),
592+
configuration.concurrency(),
593+
configuration.logProgress()
594+
);
595+
}
585596
}

applications/algorithms/community/src/test/java/org/neo4j/gds/applications/algorithms/community/CommunityAlgorithmsTest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import static org.junit.jupiter.api.Assertions.assertEquals;
8383
import static org.junit.jupiter.api.Assertions.assertTrue;
8484
import static org.mockito.ArgumentMatchers.any;
85+
import static org.mockito.ArgumentMatchers.anyBoolean;
8586
import static org.mockito.Mockito.mock;
8687
import static org.mockito.Mockito.when;
8788
import static org.neo4j.gds.TestSupport.fromGdl;
@@ -1176,11 +1177,16 @@ TestProgressTrackerCreator progressTrackerCreator(int concurrency, Log log) {
11761177

11771178
AtomicReference<TestProgressTracker> progressTrackerAtomicReference = new AtomicReference<>();
11781179
var progressTrackerCreator = mock(TestProgressTrackerCreator.class);
1179-
when(progressTrackerCreator.createProgressTracker(any(), any(Task.class))).then(
1180+
when(progressTrackerCreator.createProgressTracker(
1181+
any(Task.class),
1182+
any(),
1183+
any(),
1184+
anyBoolean()
1185+
)).then(
11801186
i ->
11811187
{
11821188
var taskProgressTracker = new TestProgressTracker(
1183-
i.getArgument(1),
1189+
i.getArgument(0, Task.class),
11841190
new LoggerForProgressTrackingAdapter(log),
11851191
new Concurrency(concurrency),
11861192
EmptyTaskRegistryFactory.INSTANCE

applications/algorithms/machine-learning/src/main/java/org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ public MachineLearningAlgorithms(ProgressTrackerCreator progressTrackerCreator,
4949

5050
KGEPredictResult kge(Graph graph, KGEPredictBaseConfig configuration) {
5151
var progressTracker = progressTrackerCreator.createProgressTracker(
52-
configuration,
53-
Tasks.leaf(AlgorithmLabel.KGE.asString())
52+
Tasks.leaf(AlgorithmLabel.KGE.asString()),
53+
configuration.jobId(),
54+
configuration.concurrency(),
55+
configuration.logProgress()
5456
);
5557

5658
return kge(graph, configuration, progressTracker);

applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/ProgressTrackerCreator.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
*/
2020
package org.neo4j.gds.applications.algorithms.machinery;
2121

22-
import org.neo4j.gds.config.AlgoBaseConfig;
22+
import org.neo4j.gds.core.concurrency.Concurrency;
23+
import org.neo4j.gds.core.utils.progress.JobId;
2324
import org.neo4j.gds.core.utils.progress.tasks.LoggerForProgressTracking;
2425
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2526
import org.neo4j.gds.core.utils.progress.tasks.Task;
@@ -38,13 +39,18 @@ public ProgressTrackerCreator(LoggerForProgressTracking log, RequestScopedDepend
3839
this.requestScopedDependencies = requestScopedDependencies;
3940
}
4041

41-
public ProgressTracker createProgressTracker(AlgoBaseConfig configuration, Task task) {
42-
if (configuration.logProgress()) {
42+
public ProgressTracker createProgressTracker(
43+
Task task,
44+
JobId jobId,
45+
Concurrency concurrency,
46+
boolean logProgress
47+
) {
48+
if (logProgress) {
4349
return new TaskProgressTracker(
4450
task,
4551
log,
46-
configuration.concurrency(),
47-
configuration.jobId(),
52+
concurrency,
53+
jobId,
4854
requestScopedDependencies.taskRegistryFactory(),
4955
requestScopedDependencies.userLogRegistryFactory()
5056
);
@@ -53,8 +59,8 @@ public ProgressTracker createProgressTracker(AlgoBaseConfig configuration, Task
5359
return new TaskTreeProgressTracker(
5460
task,
5561
log,
56-
configuration.concurrency(),
57-
configuration.jobId(),
62+
concurrency,
63+
jobId,
5864
requestScopedDependencies.taskRegistryFactory(),
5965
requestScopedDependencies.userLogRegistryFactory()
6066
);

0 commit comments

Comments
 (0)