Skip to content

Commit 7f578ce

Browse files
Refactor Yens
Co-authored-by: Veselin Nikolov <[email protected]>
1 parent c90296f commit 7f578ce

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java

+41-31
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import java.util.Comparator;
4040
import java.util.Optional;
4141
import java.util.PriorityQueue;
42+
import java.util.function.ToLongBiFunction;
4243
import java.util.stream.Stream;
4344

4445
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
@@ -50,9 +51,10 @@ public final class Yens extends Algorithm<DijkstraResult> {
5051
private final Graph graph;
5152
private final ShortestPathYensBaseConfig config;
5253
private final Dijkstra dijkstra;
53-
54-
private final LongScatterSet nodeBlackList;
55-
private final LongObjectScatterMap<LongHashSet> relationshipBlackList;
54+
private final LongScatterSet nodeAvoidList;
55+
private final LongObjectScatterMap<LongHashSet> relationshipAvoidList;
56+
private final ToLongBiFunction
57+
<MutablePathResult, Integer> relationshipAvoidMapper;
5658

5759
/**
5860
* Configure Yens to compute at most one source-target shortest path.
@@ -63,22 +65,15 @@ public static Yens sourceTarget(
6365
ProgressTracker progressTracker
6466
) {
6567
// If the input graph is a multi-graph, we need to track
66-
// parallel relationships. This is necessary since shortest
68+
// parallel relationships ids. This is necessary since shortest
6769
// paths can visit the same nodes via different relationships.
70+
//If not, we need to track which is the next neighbor.
6871

69-
System.out.println(graph.schema().relationshipSchema().toMap());
70-
graph.forEachNode(nodeId -> {
71-
graph.forEachRelationship(nodeId, 1.0, (s, t, w) -> {
72-
System.out.println(s + "-[" + w + "]->" + t);
73-
return true;
74-
});
75-
return true;
76-
});
77-
72+
boolean shouldTrackRelationships = graph.isMultiGraph();
7873
var newConfig = ImmutableShortestPathYensBaseConfig
7974
.builder()
8075
.from(config)
81-
.trackRelationships(graph.isMultiGraph())
76+
.trackRelationships(shouldTrackRelationships)
8277
.build();
8378
// Init dijkstra algorithm for computing shortest paths
8479
var dijkstra = Dijkstra.sourceTarget(graph, newConfig, Optional.empty(), progressTracker);
@@ -105,16 +100,35 @@ private Yens(Graph graph, Dijkstra dijkstra, ShortestPathYensBaseConfig config,
105100
this.config = config;
106101
// Track nodes and relationships that are skipped in a single iteration.
107102
// The content of these data structures is reset after each of k iterations.
108-
this.nodeBlackList = new LongScatterSet();
109-
this.relationshipBlackList = new LongObjectScatterMap<>();
110-
// set filter in Dijkstra to respect our blacklists
103+
this.nodeAvoidList = new LongScatterSet();
104+
this.relationshipAvoidList = new LongObjectScatterMap<>();
105+
// set filter in Dijkstra to respect our list of relationships to avoid
111106
this.dijkstra = dijkstra;
107+
108+
if (config.trackRelationships()) {
109+
// if we are in a multi-graph, we must store the relationships ids as they are
110+
//since two nodes may be connected by multiple relationships and we must know which to avoid
111+
relationshipAvoidMapper = (path, position) -> path.relationship(position);
112+
} else {
113+
//otherwise the graph has surely no parallel edges, we do not need to explicitly store relationship ids
114+
//we can just store endpoints, so that we know which nodes a node should avoid
115+
relationshipAvoidMapper = (path, position) -> path.node(position + 1);
116+
}
112117
dijkstra.withRelationshipFilter((source, target, relationshipId) ->
113-
!nodeBlackList.contains(target) &&
114-
!(relationshipBlackList.getOrDefault(source, EMPTY_SET).contains(relationshipId)) &&
115-
!(relationshipBlackList.getOrDefault(source, EMPTY_SET).contains(-target - 1)));
118+
!nodeAvoidList.contains(target)
119+
&& !shouldAvoidRelationship(source, target, relationshipId)
120+
121+
);
116122
}
117123

124+
private boolean shouldAvoidRelationship(long source, long target, long relationshipId) {
125+
long forbidden = target;
126+
if (config.trackRelationships()) {
127+
forbidden = relationshipId;
128+
}
129+
return relationshipAvoidList.getOrDefault(source, EMPTY_SET).contains(forbidden);
130+
131+
}
118132

119133
@Override
120134
public DijkstraResult compute() {
@@ -150,22 +164,21 @@ public DijkstraResult compute() {
150164
// Filter relationships that are part of the previous
151165
// shortest paths which share the same root path.
152166
if (rootPath.matchesExactly(path, n + 1)) {
153-
System.out.println(i + ": " + rootPath + " |" + prevPath);
154-
var relationshipId = graph.isMultiGraph() ? path.relationship(n) : -(1 + path.node(n + 1));
167+
var relationshipId = relationshipAvoidMapper.applyAsLong(path, n);
155168

156-
var neighbors = relationshipBlackList.get(spurNode);
169+
var neighbors = relationshipAvoidList.get(spurNode);
157170

158171
if (neighbors == null) {
159172
neighbors = new LongHashSet();
160-
relationshipBlackList.put(spurNode, neighbors);
173+
relationshipAvoidList.put(spurNode, neighbors);
161174
}
162175
neighbors.add(relationshipId);
163176
}
164177
}
165178

166179
// Filter nodes from root path to avoid cyclic path searches.
167180
for (int j = 0; j < n; j++) {
168-
nodeBlackList.add(rootPath.node(j));
181+
nodeAvoidList.add(rootPath.node(j));
169182
}
170183

171184
// Calculate the spur path from the spur node to the sink.
@@ -174,8 +187,8 @@ public DijkstraResult compute() {
174187
var spurPath = computeDijkstra(graph.toOriginalNodeId(spurNode));
175188

176189
// Clear filters for next spur node
177-
nodeBlackList.clear();
178-
relationshipBlackList.clear();
190+
nodeAvoidList.clear();
191+
relationshipAvoidList.clear();
179192

180193
// No new candidate from this spur node, continue with next node.
181194
if (spurPath.isEmpty()) {
@@ -201,10 +214,7 @@ public DijkstraResult compute() {
201214
progressTracker.endSubTask();
202215

203216
progressTracker.endSubTask();
204-
System.out.println("----");
205-
for (var path : kShortestPaths) {
206-
System.out.println(path);
207-
}
217+
208218
return new DijkstraResult(kShortestPaths.stream().map(MutablePathResult::toPathResult));
209219
}
210220

algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java

+1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ static Stream<List<String>> pathInput() {
168168
void compute(Collection<String> expectedPaths) {
169169

170170
assertResult(graph, idFunction, expectedPaths, false);
171+
171172
}
172173

173174
@Test

0 commit comments

Comments
 (0)