Skip to content

Commit ba2451a

Browse files
Merge pull request #3738 from soerenreichardt/revisit-progtracker-release
Release ProgressTracker when baseTask has finished
2 parents b3d22bb + 041318b commit ba2451a

File tree

8 files changed

+109
-11
lines changed

8 files changed

+109
-11
lines changed

alpha/alpha-proc/src/test/java/org/neo4j/gds/walking/CollapsePathMutateProcTest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020
package org.neo4j.gds.walking;
2121

2222
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Disabled;
2324
import org.junit.jupiter.api.Test;
2425
import org.neo4j.gds.AlgoBaseProc;
2526
import org.neo4j.gds.AlgoBaseProcTest;
2627
import org.neo4j.gds.BaseProcTest;
2728
import org.neo4j.gds.GdsCypher;
2829
import org.neo4j.gds.MutateRelationshipsTest;
30+
import org.neo4j.gds.RelationshipType;
31+
import org.neo4j.gds.api.Relationships;
2932
import org.neo4j.gds.catalog.GraphCreateProc;
3033
import org.neo4j.gds.catalog.GraphWriteNodePropertiesProc;
3134
import org.neo4j.gds.core.CypherMapWrapper;
3235
import org.neo4j.gds.extension.Neo4jGraph;
3336
import org.neo4j.gds.impl.walking.CollapsePath;
3437
import org.neo4j.gds.impl.walking.CollapsePathConfig;
35-
import org.neo4j.gds.RelationshipType;
36-
import org.neo4j.gds.api.Relationships;
3738
import org.neo4j.kernel.internal.GraphDatabaseAPI;
3839

3940
import java.util.List;
@@ -154,4 +155,8 @@ void testMutateYields() {
154155
)
155156
);
156157
}
158+
159+
@Disabled
160+
@Override
161+
public void shouldUnregisterTaskAfterComputation() {}
157162
}

core/src/main/java/org/neo4j/gds/core/utils/progress/GlobalTaskStore.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ public class GlobalTaskStore implements TaskStore, ThrowingFunction<Context, Tas
3434

3535
private final Map<String, Map<JobId, Task>> registeredTasks;
3636

37-
GlobalTaskStore() {
37+
public GlobalTaskStore() {
3838
this.registeredTasks = new ConcurrentHashMap<>();
3939
}
4040

41-
void store(String username, JobId jobId, Task task) {
41+
public void store(String username, JobId jobId, Task task) {
4242
this.registeredTasks
4343
.computeIfAbsent(username, __ -> new ConcurrentHashMap<>())
4444
.put(jobId, task);

core/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressTracker.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,12 @@ public void endSubTask() {
9595
var currentTask = requireCurrentTask();
9696
taskProgressLogger.logEndSubTask(currentTask, parentTask());
9797
currentTask.finish();
98-
this.currentTask = nestedTasks.isEmpty()
99-
? Optional.empty()
100-
: Optional.of(nestedTasks.pop());
98+
if (nestedTasks.isEmpty()) {
99+
this.currentTask = Optional.empty();
100+
release();
101+
} else {
102+
this.currentTask = Optional.of(nestedTasks.pop());
103+
}
101104
}
102105

103106
@Override
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.core.utils.progress;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
24+
25+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
26+
27+
class GlobalTaskStoreTest {
28+
29+
@Test
30+
void shouldBeIdempotentOnRemove() {
31+
var taskStore = new GlobalTaskStore();
32+
var jobId = new JobId();
33+
taskStore.store("", jobId, Tasks.leaf("leaf"));
34+
taskStore.remove("", jobId);
35+
assertDoesNotThrow(() -> taskStore.remove("", jobId));
36+
}
37+
38+
}

pregel/src/test/java/org/neo4j/gds/beta/pregel/PregelTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ void cleanupProgressLogging() {
209209
pregelAlgo.run();
210210
pregelAlgo.release();
211211

212-
assertThat(taskRegistry.unregisterTaskCalls()).isEqualTo(1);
212+
assertThat(taskRegistry.unregisterTaskCalls()).isGreaterThanOrEqualTo(1);
213213
}
214214

215215
static Stream<Arguments> forkJoinAndPartitioning() {

proc/common/src/test/java/org/neo4j/gds/AlgorithmCleanupTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void cleanupTaskRegistryUnderRegularExecution() {
5858
Map<String, Object> config = Map.of("writeProperty", "test");
5959

6060
assertThatCode(() -> proc.stats("g", config)).doesNotThrowAnyException();
61-
assertThat(taskRegistry.unregisterTaskCalls()).isEqualTo(1);
61+
assertThat(taskRegistry.unregisterTaskCalls()).isGreaterThanOrEqualTo(1);
6262
}
6363

6464
@Test
@@ -72,6 +72,6 @@ void cleanupTaskRegistryWhenTheAlgorithmFails() {
7272
Map<String, Object> config = Map.of("writeProperty", "test", "throwInCompute", true);
7373

7474
assertThatThrownBy(() -> proc.stats("g", config)).isNotNull();
75-
assertThat(taskRegistry.unregisterTaskCalls()).isEqualTo(1);
75+
assertThat(taskRegistry.unregisterTaskCalls()).isGreaterThanOrEqualTo(1);
7676
}
7777
}

proc/test/src/main/java/org/neo4j/gds/AlgoBaseProcTest.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
import org.neo4j.gds.core.TransactionContext;
4343
import org.neo4j.gds.core.loading.GraphStoreCatalog;
4444
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistry;
45+
import org.neo4j.gds.core.utils.progress.GlobalTaskStore;
46+
import org.neo4j.gds.core.utils.progress.JobId;
47+
import org.neo4j.gds.core.utils.progress.LocalTaskRegistry;
48+
import org.neo4j.gds.core.utils.progress.tasks.Task;
4549
import org.neo4j.gds.core.write.NativeNodePropertyExporter;
4650
import org.neo4j.gds.core.write.NativeRelationshipExporter;
4751
import org.neo4j.gds.core.write.NativeRelationshipStreamExporter;
@@ -269,6 +273,55 @@ default void assertMissingProperty(String error, Runnable runnable) {
269273
assertThat(exception).hasMessageContaining(error);
270274
}
271275

276+
class InvocationCountingTaskStore extends GlobalTaskStore {
277+
int registerTaskInvocations;
278+
279+
@Override
280+
public void store(
281+
String username, JobId jobId, Task task
282+
) {
283+
super.store(username, jobId, task);
284+
registerTaskInvocations++;
285+
}
286+
}
287+
288+
@Test
289+
default void shouldUnregisterTaskAfterComputation() {
290+
var taskStore = new InvocationCountingTaskStore();
291+
292+
String loadedGraphName = "loadedGraph";
293+
GraphCreateConfig graphCreateConfig = withNameAndRelationshipProjections("", loadedGraphName, relationshipProjections());
294+
applyOnProcedure(proc -> {
295+
proc.taskRegistry = new LocalTaskRegistry("", taskStore);
296+
297+
GraphStore graphStore = graphLoader(graphCreateConfig).graphStore();
298+
GraphStoreCatalog.set(
299+
graphCreateConfig,
300+
graphStore
301+
);
302+
Map<String, Object> configMap = createMinimalConfig(CypherMapWrapper.empty()).toMap();
303+
AlgoBaseProc.ComputationResult<?, RESULT, CONFIG> computationResult1 = proc.compute(
304+
loadedGraphName,
305+
configMap,
306+
releaseAlgorithm(),
307+
true
308+
);
309+
310+
AlgoBaseProc.ComputationResult<?, RESULT, CONFIG> computationResult2 = proc.compute(
311+
loadedGraphName,
312+
configMap,
313+
releaseAlgorithm(),
314+
true
315+
);
316+
317+
// trigger consumption of stream return values
318+
assertResultEquals(computationResult1.result(), computationResult2.result());
319+
320+
assertThat(taskStore.taskStream()).isEmpty();
321+
assertThat(taskStore.registerTaskInvocations).isGreaterThan(1);
322+
});
323+
}
324+
272325
@AllGraphStoreFactoryTypesTest
273326
default void testRunOnLoadedGraph(TestSupport.FactoryType factoryType) {
274327
// FIXME rethink this test for mutate

proc/test/src/main/java/org/neo4j/gds/test/TestAlgorithm.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public TestAlgorithm(
4646
this.progressTracker = progressTracker;
4747
}
4848

49-
5049
@Override
5150
public TestAlgorithm compute() {
5251
progressTracker.beginSubTask();

0 commit comments

Comments
 (0)