Skip to content

Commit efbc06e

Browse files
Addressing review comments
Co-authored-by: Florentin Dörre <[email protected]>
1 parent 3091f91 commit efbc06e

File tree

4 files changed

+78
-74
lines changed

4 files changed

+78
-74
lines changed

algo/src/main/java/org/neo4j/gds/paths/yens/MutablePathResult.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
*/
3232
final class MutablePathResult {
3333

34+
private final long[] EMPTY_ARRAY = new long[0];
3435
private long index;
3536

3637
private final long sourceNode;
@@ -156,8 +157,7 @@ boolean matchesExactly(MutablePathResult path, int index) {
156157
* The cost value associated with the last value in this path, is added to
157158
* the costs for each node in the second path.
158159
*/
159-
160-
160+
161161
private void append(MutablePathResult path, long[] relationships) {
162162
// spur node is end of first and beginning of second path
163163
assert nodeIds[nodeIds.length - 1] == path.nodeIds[0];
@@ -213,7 +213,7 @@ void append(MutablePathResult path) {
213213
*/
214214
void appendWithoutRelationshipIds(MutablePathResult path) {
215215
// spur node is end of first and beginning of second path
216-
append(path, new long[0]);
216+
append(path, EMPTY_ARRAY);
217217
}
218218

219219

algo/src/main/java/org/neo4j/gds/paths/yens/Yens.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ private Yens(Graph graph, Dijkstra dijkstra, ShortestPathYensBaseConfig config,
127127
}
128128

129129
private boolean shouldAvoidRelationship(long source, long target, long relationshipId) {
130-
long forbidden = target;
131-
if (config.trackRelationships()) {
132-
forbidden = relationshipId;
133-
}
130+
long forbidden = config.trackRelationships()
131+
? relationshipId
132+
: target;
134133
return relationshipAvoidList.getOrDefault(source, EMPTY_SET).contains(forbidden);
135134

136135
}

algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java

-2
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,7 @@ static Stream<List<String>> pathInput() {
166166
@ParameterizedTest
167167
@MethodSource("pathInput")
168168
void compute(Collection<String> expectedPaths) {
169-
170169
assertResult(graph, idFunction, expectedPaths, false);
171-
172170
}
173171

174172
@Test

proc/path-finding/src/test/java/org/neo4j/gds/paths/sourcetarget/YensTestWithDifferentProjections.java

+72-65
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.junit.jupiter.params.provider.ValueSource;
2525
import org.neo4j.gds.BaseProcTest;
2626
import org.neo4j.gds.catalog.GraphProjectProc;
27+
import org.neo4j.gds.extension.IdFunction;
28+
import org.neo4j.gds.extension.Inject;
2729
import org.neo4j.gds.extension.Neo4jGraph;
2830

2931
import java.util.Collection;
@@ -34,71 +36,76 @@
3436

3537
class YensTestWithDifferentProjections extends BaseProcTest {
3638

37-
@Neo4jGraph
38-
private static final String DB_CYPHER =
39-
"CREATE (a:CITY), " +
40-
"(b:CITY), " +
41-
"(c:CITY), " +
42-
"(d:CITY), " +
43-
"(e:CITY), " +
44-
"(f:CITY), " +
45-
"(a)-[:ROAD]->(b), " +
46-
"(a)-[:ROAD]->(b), " +
47-
"(b)-[:ROAD]->(c), " +
48-
"(b)-[:ROAD]->(d), " +
49-
"(c)-[:ROAD]->(f), " +
50-
"(d)-[:ROAD]->(e), " +
51-
"(e)-[:ROAD]->(c), " +
52-
"(e)-[:ROAD]->(f), " +
53-
"(a)-[:PATH]->(b), " +
54-
"(d)-[:PATH]->(e), " +
55-
"(d)-[:PATH]->(e)";
56-
57-
@BeforeEach
58-
void setup() throws Exception {
59-
registerProcedures(
60-
ShortestPathYensStreamProc.class,
61-
GraphProjectProc.class
62-
);
63-
}
64-
65-
66-
@ParameterizedTest
67-
@ValueSource(strings = {
68-
"CALL gds.graph.project('g', '*', {TYPE: {type: '*', aggregation: 'SINGLE'}})",
69-
"CALL gds.graph.project.cypher('g', 'MATCH (n) RETURN id(n) AS id', 'MATCH (n)-[r]->(m) RETURN DISTINCT id(n) AS source, id(m) AS target')"
70-
})
71-
void shouldWorkWithDifferentProjections(String projectionQuery) {
72-
73-
runQuery(projectionQuery);
74-
String yensQuery = "MATCH (source), (target) " +
75-
"WHERE id(source)=0 AND id(target)=5 " +
76-
"CALL gds.shortestPath.yens.stream(" +
77-
" 'g', " +
78-
" {sourceNode:source, targetNode:target, k:3} " +
79-
") " +
80-
"YIELD nodeIds RETURN nodeIds ";
81-
82-
Collection<long[]> encounteredPaths = new HashSet<>();
83-
runQuery(yensQuery, result -> {
84-
assertThat(result.columns()).containsExactlyInAnyOrder("nodeIds");
85-
86-
while (result.hasNext()) {
87-
var next = result.next();
88-
var currentPath = (List<Long>) next.get("nodeIds");
89-
long[] pathToArray = currentPath.stream().mapToLong(l -> l).toArray();
90-
encounteredPaths.add(pathToArray);
91-
}
92-
93-
return true;
94-
});
95-
96-
assertThat(encounteredPaths).containsExactlyInAnyOrder(
97-
new long[]{0l, 1l, 3l, 4l, 2l, 5l},
98-
new long[]{0l, 1l, 3l, 4l, 5l},
99-
new long[]{0l, 1l, 2l, 5l}
100-
);
101-
}
39+
@Neo4jGraph
40+
private static final String DB_CYPHER =
41+
"CREATE (a:CITY {cityid:0}), " +
42+
"(b:CITY {cityid:1}), " +
43+
"(c:CITY {cityid:2}), " +
44+
"(d:CITY {cityid:3}), " +
45+
"(e:CITY {cityid:4}), " +
46+
"(f:CITY {cityid:5}), " +
47+
"(a)-[:ROAD]->(b), " +
48+
"(a)-[:ROAD]->(b), " +
49+
"(b)-[:ROAD]->(c), " +
50+
"(b)-[:ROAD]->(d), " +
51+
"(c)-[:ROAD]->(f), " +
52+
"(d)-[:ROAD]->(e), " +
53+
"(e)-[:ROAD]->(c), " +
54+
"(e)-[:ROAD]->(f), " +
55+
"(a)-[:PATH]->(b), " +
56+
"(d)-[:PATH]->(e), " +
57+
"(d)-[:PATH]->(e)";
58+
59+
@Inject
60+
IdFunction idFunction;
61+
62+
@BeforeEach
63+
void setup() throws Exception {
64+
registerProcedures(
65+
ShortestPathYensStreamProc.class,
66+
GraphProjectProc.class
67+
);
68+
}
69+
70+
71+
@ParameterizedTest
72+
@ValueSource(strings = {
73+
"CALL gds.graph.project('g', '*', {TYPE: {type: '*', aggregation: 'SINGLE'}})",
74+
"CALL gds.graph.project.cypher('g', 'MATCH (n) RETURN id(n) AS id', 'MATCH (n)-[r]->(m) RETURN DISTINCT id(n) AS source, id(m) AS target')"
75+
})
76+
void shouldWorkWithDifferentProjections(String projectionQuery) {
77+
78+
runQuery(projectionQuery);
79+
String yensQuery = "MATCH (source), (target) " +
80+
"WHERE source.cityid=0 AND target.cityid=5 " +
81+
"CALL gds.shortestPath.yens.stream(" +
82+
" 'g', " +
83+
" {sourceNode:source, targetNode:target, k:3} " +
84+
") " +
85+
"YIELD nodeIds RETURN nodeIds ";
86+
87+
Collection<long[]> encounteredPaths = new HashSet<>();
88+
runQuery(yensQuery, result -> {
89+
assertThat(result.columns()).containsExactlyInAnyOrder("nodeIds");
90+
91+
while (result.hasNext()) {
92+
var next = result.next();
93+
var currentPath = (List<Long>) next.get("nodeIds");
94+
long[] pathToArray = currentPath.stream().mapToLong(l -> l).toArray();
95+
encounteredPaths.add(pathToArray);
96+
}
97+
98+
return true;
99+
});
100+
101+
long[] nodes = new long[]{idFunction.of("a"), idFunction.of("b"), idFunction.of("c"), idFunction.of("d"), idFunction.of(
102+
"e"), idFunction.of("f")};
103+
assertThat(encounteredPaths).containsExactlyInAnyOrder(
104+
new long[]{nodes[0], nodes[1], nodes[3], nodes[4], nodes[2], nodes[5]},
105+
new long[]{nodes[0], nodes[1], nodes[3], nodes[4], nodes[5]},
106+
new long[]{nodes[0], nodes[1], nodes[2], nodes[5]}
107+
);
108+
}
102109

103110
}
104111

0 commit comments

Comments
 (0)