Skip to content

Commit f6f10e0

Browse files
committed
Support sender tracking for reducing computations
1 parent d14088a commit f6f10e0

File tree

9 files changed

+381
-28
lines changed

9 files changed

+381
-28
lines changed

pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java

+28-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.jetbrains.annotations.NotNull;
2323

2424
import java.util.Iterator;
25+
import java.util.OptionalLong;
2526
import java.util.PrimitiveIterator;
2627

2728
public final class Messages implements Iterable<Double> {
@@ -33,7 +34,12 @@ public Iterator<Double> iterator() {
3334
}
3435

3536
public interface MessageIterator extends PrimitiveIterator.OfDouble {
37+
3638
boolean isEmpty();
39+
40+
default OptionalLong sender() {
41+
return OptionalLong.empty();
42+
}
3743
}
3844

3945
private final MessageIterator iterator;
@@ -42,12 +48,32 @@ public interface MessageIterator extends PrimitiveIterator.OfDouble {
4248
this.iterator = iterator;
4349
}
4450

51+
/**
52+
* Returns a iterator that can be used to iterate over the messages.
53+
*/
4554
@NotNull
4655
public PrimitiveIterator.OfDouble doubleIterator() {
47-
return iterator;
56+
return this.iterator;
4857
}
4958

59+
/**
60+
* Indicates if there are messages present.
61+
*/
5062
public boolean isEmpty() {
51-
return iterator.isEmpty();
63+
return this.iterator.isEmpty();
64+
}
65+
66+
/**
67+
* If the computation defined a {@link org.neo4j.gds.beta.pregel.Reducer}, this method will
68+
* return the sender of the aggregated message. Depending on the reducer implementation, the
69+
* sender is deterministically defined by the reducer, e.g., for Max or Min. In any other case,
70+
* the sender will be one of the node ids that sent messages to that node.
71+
* <p>
72+
* Note, that {@link PregelConfig#trackSender()} must return true to enable sender tracking.
73+
*
74+
* @return the sender of an aggregated message or an empty optional if no reducer is defined
75+
*/
76+
public OptionalLong sender() {
77+
return this.iterator.sender();
5278
}
5379
}

pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
*/
2020
package org.neo4j.gds.beta.pregel;
2121

22+
import java.util.OptionalLong;
23+
2224
public interface Messenger<ITERATOR extends Messages.MessageIterator> {
2325

2426
void initIteration(int iteration);
@@ -29,5 +31,9 @@ public interface Messenger<ITERATOR extends Messages.MessageIterator> {
2931

3032
void initMessageIterator(ITERATOR messageIterator, long nodeId, boolean isFirstIteration);
3133

34+
default OptionalLong sender(long nodeId) {
35+
return OptionalLong.empty();
36+
}
37+
3238
void release();
3339
}

pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java

+11-2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,15 @@ public static MemoryEstimation memoryEstimation(
110110
Map<String, ValueType> propertiesMap,
111111
boolean isQueueBased,
112112
boolean isAsync
113+
) {
114+
return memoryEstimation(propertiesMap, isQueueBased, isAsync, false);
115+
}
116+
117+
public static MemoryEstimation memoryEstimation(
118+
Map<String, ValueType> propertiesMap,
119+
boolean isQueueBased,
120+
boolean isAsync,
121+
boolean isTrackingSender
113122
) {
114123
var estimationBuilder = MemoryEstimations.builder(Pregel.class)
115124
.perNode("vote bits", HugeAtomicBitSet::memoryEstimation)
@@ -123,7 +132,7 @@ public static MemoryEstimation memoryEstimation(
123132
estimationBuilder.add("message queues", SyncQueueMessenger.memoryEstimation());
124133
}
125134
} else {
126-
estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation());
135+
estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation(isTrackingSender));
127136
}
128137

129138
return estimationBuilder.build();
@@ -169,7 +178,7 @@ private Pregel(
169178
var reducer = computation.reducer();
170179

171180
this.messenger = reducer.isPresent()
172-
? new ReducingMessenger(graph, config, reducer.get())
181+
? ReducingMessenger.create(graph, config, reducer.get())
173182
: config.isAsynchronous()
174183
? new AsyncQueueMessenger(graph.nodeCount())
175184
: new SyncQueueMessenger(graph.nodeCount());

pregel/src/main/java/org/neo4j/gds/beta/pregel/PregelConfig.java

+5
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,9 @@ default Partitioning partitioning() {
4747
default boolean useForkJoin() {
4848
return partitioning() == Partitioning.AUTO;
4949
}
50+
51+
@Configuration.Ignore
52+
default boolean trackSender() {
53+
return false;
54+
}
5055
}

pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java

+101-16
Original file line numberDiff line numberDiff line change
@@ -20,49 +20,70 @@
2020
package org.neo4j.gds.beta.pregel;
2121

2222
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.collections.ha.HugeLongArray;
2324
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
2425
import org.neo4j.gds.core.concurrency.ParallelUtil;
25-
import org.neo4j.gds.termination.TerminationFlag;
26+
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
2627
import org.neo4j.gds.mem.MemoryEstimation;
2728
import org.neo4j.gds.mem.MemoryEstimations;
28-
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
29+
import org.neo4j.gds.termination.TerminationFlag;
30+
31+
import java.util.OptionalLong;
2932

3033
/**
3134
* A messenger implementation that is backed by two double arrays used
3235
* to send and receive messages. The messenger can only be applied in
3336
* combination with a {@link Reducer}
3437
* which atomically reduces all incoming messages into a single one.
3538
*/
36-
public class ReducingMessenger implements Messenger<ReducingMessenger.SingleMessageIterator> {
39+
class ReducingMessenger implements Messenger<ReducingMessenger.SingleMessageIterator> {
3740

3841
private final Graph graph;
3942
private final PregelConfig config;
40-
private final Reducer reducer;
43+
final Reducer reducer;
4144

42-
private HugeAtomicDoubleArray sendArray;
43-
private HugeAtomicDoubleArray receiveArray;
45+
HugeAtomicDoubleArray sendArray;
46+
HugeAtomicDoubleArray receiveArray;
4447

45-
ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) {
46-
assert !Double.isNaN(reducer.identity()): "identity element must not be NaN";
48+
static ReducingMessenger create(Graph graph, PregelConfig config, Reducer reducer) {
49+
return config.trackSender()
50+
? new WithSender(graph, config, reducer)
51+
: new ReducingMessenger(graph, config, reducer);
52+
}
53+
54+
private ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) {
55+
assert !Double.isNaN(reducer.identity()) : "identity element must not be NaN";
4756

4857
this.graph = graph;
4958
this.config = config;
5059
this.reducer = reducer;
5160

52-
this.receiveArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency()));
53-
this.sendArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency()));
61+
this.receiveArray = HugeAtomicDoubleArray.of(
62+
graph.nodeCount(),
63+
ParallelDoublePageCreator.passThrough(config.concurrency())
64+
);
65+
this.sendArray = HugeAtomicDoubleArray.of(
66+
graph.nodeCount(),
67+
ParallelDoublePageCreator.passThrough(config.concurrency())
68+
);
5469
}
5570

56-
static MemoryEstimation memoryEstimation() {
57-
return MemoryEstimations.builder(ReducingMessenger.class)
71+
static MemoryEstimation memoryEstimation(boolean withSender) {
72+
var builder = MemoryEstimations.builder(ReducingMessenger.class)
5873
.perNode("send array", HugeAtomicDoubleArray::memoryEstimation)
59-
.perNode("receive array", HugeAtomicDoubleArray::memoryEstimation)
74+
.perNode("receive array", HugeAtomicDoubleArray::memoryEstimation);
75+
76+
if (withSender) {
77+
builder
78+
.perNode("send sender array", HugeLongArray::memoryEstimation)
79+
.perNode("receive sender array", HugeLongArray::memoryEstimation);
80+
}
81+
return builder
6082
.build();
6183
}
6284

6385
@Override
6486
public void initIteration(int iteration) {
65-
// Swap arrays
6687
var tmp = receiveArray;
6788
this.receiveArray = sendArray;
6889
this.sendArray = tmp;
@@ -96,7 +117,7 @@ public void initMessageIterator(
96117
boolean isInitialIteration
97118
) {
98119
var message = receiveArray.getAndReplace(nodeId, reducer.identity());
99-
messageIterator.init(message, message != reducer.identity());
120+
messageIterator.init(message, message != reducer.identity(), OptionalLong.empty());
100121
}
101122

102123
@Override
@@ -105,14 +126,73 @@ public void release() {
105126
receiveArray.release();
106127
}
107128

129+
static class WithSender extends ReducingMessenger {
130+
private HugeLongArray sendSenderArray;
131+
private HugeLongArray receiveSenderArray;
132+
133+
WithSender(Graph graph, PregelConfig config, Reducer reducer) {
134+
super(graph, config, reducer);
135+
this.sendSenderArray = HugeLongArray.newArray(graph.nodeCount());
136+
this.receiveSenderArray = HugeLongArray.newArray(graph.nodeCount());
137+
}
138+
139+
@Override
140+
public void initIteration(int iteration) {
141+
super.initIteration(iteration);
142+
// Swap sender arrays
143+
var tmp = receiveSenderArray;
144+
this.receiveSenderArray = sendSenderArray;
145+
this.sendSenderArray = tmp;
146+
}
147+
148+
@Override
149+
public void initMessageIterator(
150+
ReducingMessenger.SingleMessageIterator messageIterator,
151+
long nodeId,
152+
boolean isInitialIteration
153+
) {
154+
var message = receiveArray.getAndReplace(nodeId, reducer.identity());
155+
var sender = receiveSenderArray.get(nodeId);
156+
messageIterator.init(message, message != reducer.identity(), OptionalLong.of(sender));
157+
}
158+
159+
@Override
160+
public void sendTo(long sourceNodeId, long targetNodeId, double message) {
161+
sendArray.update(
162+
targetNodeId,
163+
currentMessage -> {
164+
var reducedMessage = reducer.reduce(currentMessage, message);
165+
if (Double.compare(reducedMessage, currentMessage) != 0) {
166+
sendSenderArray.set(targetNodeId, sourceNodeId);
167+
}
168+
return reducedMessage;
169+
}
170+
);
171+
}
172+
173+
@Override
174+
public OptionalLong sender(long nodeId) {
175+
return OptionalLong.of(receiveSenderArray.get(nodeId));
176+
}
177+
178+
@Override
179+
public void release() {
180+
sendSenderArray.release();
181+
receiveSenderArray.release();
182+
super.release();
183+
}
184+
}
185+
108186
static class SingleMessageIterator implements Messages.MessageIterator {
109187

110188
boolean hasNext;
111189
double message;
190+
OptionalLong sender;
112191

113-
void init(double value, boolean hasNext) {
192+
void init(double value, boolean hasNext, OptionalLong sender) {
114193
this.message = value;
115194
this.hasNext = hasNext;
195+
this.sender = sender;
116196
}
117197

118198
@Override
@@ -130,5 +210,10 @@ public double nextDouble() {
130210
hasNext = false;
131211
return message;
132212
}
213+
214+
@Override
215+
public OptionalLong sender() {
216+
return this.sender;
217+
}
133218
}
134219
}

pregel/src/main/java/org/neo4j/gds/beta/pregel/context/ComputeContext.java

+9
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ public long longNodeValue(String key) {
8686
return nodeValue.longValue(key, nodeId);
8787
}
8888

89+
/**
90+
* Returns the node value for the given node-id and node schema key.
91+
*
92+
* @throws IllegalArgumentException if the key does not exist or the value is not a long
93+
*/
94+
public long longNodeValue(String key, long nodeId) {
95+
return nodeValue.longValue(key, nodeId);
96+
}
97+
8998
/**
9099
* Returns the node value for the given node schema key.
91100
*

pregel/src/main/java/org/neo4j/gds/beta/pregel/context/NodeCentricContext.java

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.beta.pregel.PregelConfig;
2626
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2727

28+
import java.util.OptionalLong;
2829
import java.util.function.LongConsumer;
2930

3031
public abstract class NodeCentricContext<CONFIG extends PregelConfig> extends PregelContext<CONFIG> {
@@ -33,8 +34,11 @@ public abstract class NodeCentricContext<CONFIG extends PregelConfig> extends Pr
3334

3435
protected final Graph graph;
3536

37+
private OptionalLong sender = OptionalLong.empty();
38+
3639
long nodeId;
3740

41+
3842
NodeCentricContext(Graph graph, CONFIG config, NodeValue nodeValue, ProgressTracker progressTracker) {
3943
super(config, progressTracker);
4044
this.graph = graph;

0 commit comments

Comments
 (0)