Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 5229765

Browse files
authored
Degree cutoff skip values (#880)
* consider skip values when checking degree cut off * d has no links so it gets filtered out by the degree cut off * d has no links so it gets filtered out by the degree cut off * typo
1 parent 2c41049 commit 5229765

File tree

7 files changed

+94
-49
lines changed

7 files changed

+94
-49
lines changed

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityInput.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,26 @@ static int[] indexesFor(long[] inputIds, ProcedureConfiguration configuration, S
6464
}
6565
}
6666

67+
68+
static List<Number> extractValues(Object rawValues) {
69+
if (rawValues == null) {
70+
return Collections.emptyList();
71+
}
72+
73+
List<Number> valueList = new ArrayList<>();
74+
if (rawValues instanceof long[]) {
75+
long[] values = (long[]) rawValues;
76+
for (long value : values) {
77+
valueList.add(value);
78+
}
79+
} else if (rawValues instanceof double[]) {
80+
double[] values = (double[]) rawValues;
81+
for (double value : values) {
82+
valueList.add(value);
83+
}
84+
} else {
85+
valueList = (List<Number>) rawValues;
86+
}
87+
return valueList;
88+
}
6789
}

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
134134
CategoricalInput[] ids = new CategoricalInput[data.size()];
135135
int idx = 0;
136136
for (Map<String, Object> row : data) {
137-
List<Number> targetIds = extractValues(row.get("categories"));
137+
List<Number> targetIds = SimilarityInput.extractValues(row.get("categories"));
138138
int size = targetIds.size();
139139
if (size > degreeCutoff) {
140140
long[] targets = new long[size];
@@ -156,32 +156,14 @@ WeightedInput[] prepareWeights(Object rawData, ProcedureConfiguration configurat
156156
return prepareSparseWeights(api, (String) rawData, skipValue, configuration);
157157
} else {
158158
List<Map<String, Object>> data = (List<Map<String, Object>>) rawData;
159-
return preparseDenseWeights(data, getDegreeCutoff(configuration), skipValue);
159+
return WeightedInput.prepareDenseWeights(data, getDegreeCutoff(configuration), skipValue);
160160
}
161161
}
162162

163163
Double readSkipValue(ProcedureConfiguration configuration) {
164164
return configuration.get("skipValue", Double.NaN);
165165
}
166166

167-
private WeightedInput[] preparseDenseWeights(List<Map<String, Object>> data, long degreeCutoff, Double skipValue) {
168-
WeightedInput[] inputs = new WeightedInput[data.size()];
169-
int idx = 0;
170-
for (Map<String, Object> row : data) {
171-
172-
List<Number> weightList = extractValues(row.get("weights"));
173-
174-
int size = weightList.size();
175-
if (size > degreeCutoff) {
176-
double[] weights = Weights.buildWeights(weightList);
177-
inputs[idx++] = skipValue == null ? WeightedInput.dense((Long) row.get("item"), weights) : WeightedInput.dense((Long) row.get("item"), weights, skipValue);
178-
}
179-
}
180-
if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx);
181-
Arrays.sort(inputs);
182-
return inputs;
183-
}
184-
185167
private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query, Double skipValue, ProcedureConfiguration configuration) throws Exception {
186168
Map<String, Object> params = configuration.getParams();
187169
Long degreeCutoff = getDegreeCutoff(configuration);
@@ -230,28 +212,6 @@ private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query,
230212
return inputs;
231213
}
232214

233-
private List<Number> extractValues(Object rawValues) {
234-
if (rawValues == null) {
235-
return Collections.emptyList();
236-
}
237-
238-
List<Number> valueList = new ArrayList<>();
239-
if (rawValues instanceof long[]) {
240-
long[] values = (long[]) rawValues;
241-
for (long value : values) {
242-
valueList.add(value);
243-
}
244-
} else if (rawValues instanceof double[]) {
245-
double[] values = (double[]) rawValues;
246-
for (double value : values) {
247-
valueList.add(value);
248-
}
249-
} else {
250-
valueList = (List<Number>) rawValues;
251-
}
252-
return valueList;
253-
}
254-
255215
int getTopK(ProcedureConfiguration configuration) {
256216
return configuration.getInt("topK", 0);
257217
}

algo/src/main/java/org/neo4j/graphalgo/similarity/WeightedInput.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020

2121
import org.neo4j.graphalgo.core.utils.Intersections;
2222

23+
import java.util.Arrays;
24+
import java.util.List;
25+
import java.util.Map;
26+
2327
class WeightedInput implements Comparable<WeightedInput>, SimilarityInput {
2428
private final long id;
2529
private int itemCount;
@@ -62,6 +66,32 @@ public static WeightedInput dense(long id, double[] weights) {
6266
return new WeightedInput(id, weights);
6367
}
6468

69+
static WeightedInput[] prepareDenseWeights(List<Map<String, Object>> data, long degreeCutoff, Double skipValue) {
70+
WeightedInput[] inputs = new WeightedInput[data.size()];
71+
int idx = 0;
72+
73+
boolean skipAnything = skipValue != null;
74+
boolean skipNan = skipAnything && Double.isNaN(skipValue);
75+
76+
for (Map<String, Object> row : data) {
77+
List<Number> weightList = SimilarityInput.extractValues(row.get("weights"));
78+
79+
long weightsSize = skipAnything ? skipSize(skipValue, skipNan, weightList) : weightList.size();
80+
81+
if (weightsSize > degreeCutoff) {
82+
double[] weights = Weights.buildWeights(weightList);
83+
inputs[idx++] = skipValue == null ? dense((Long) row.get("item"), weights) : dense((Long) row.get("item"), weights, skipValue);
84+
}
85+
}
86+
if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx);
87+
Arrays.sort(inputs);
88+
return inputs;
89+
}
90+
91+
private static long skipSize(Double skipValue, boolean skipNan, List<Number> weightList) {
92+
return weightList.stream().filter(value -> !Intersections.shouldSkip(value.doubleValue(), skipValue, skipNan)).count();
93+
}
94+
6595
public int compareTo(WeightedInput o) {
6696
return Long.compare(id, o.id);
6797
}

algo/src/test/java/org/neo4j/graphalgo/similarity/WeightedInputTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,50 @@
1919
package org.neo4j.graphalgo.similarity;
2020

2121
import org.junit.Test;
22+
import org.neo4j.helpers.collection.MapUtil;
23+
24+
import java.util.ArrayList;
25+
import java.util.Arrays;
26+
import java.util.List;
27+
import java.util.Map;
2228

2329
import static junit.framework.TestCase.assertEquals;
2430
import static junit.framework.TestCase.assertNull;
2531

2632
public class WeightedInputTest {
33+
@Test
34+
public void degreeCutoffBasedOnSkipValue() {
35+
List<Map<String, Object>> data = new ArrayList<>();
36+
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
37+
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN)));
38+
39+
WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, Double.NaN);
40+
41+
assertEquals(1, weightedInputs.length);
42+
}
43+
44+
@Test
45+
public void degreeCutoffWithoutSkipValue() {
46+
List<Map<String, Object>> data = new ArrayList<>();
47+
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
48+
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN)));
49+
50+
WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, null);
51+
52+
assertEquals(2, weightedInputs.length);
53+
}
54+
55+
@Test
56+
public void degreeCutoffWithNumericSkipValue() {
57+
List<Map<String, Object>> data = new ArrayList<>();
58+
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
59+
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, 5.0)));
60+
61+
WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, 5.0);
62+
63+
assertEquals(1, weightedInputs.length);
64+
}
65+
2766
@Test
2867
public void pearsonNoCompression() {
2968
double[] weights1 = new double[]{1, 2, 3, 4, 4, 4, 4, 5, 6};

core/src/main/java/org/neo4j/graphalgo/core/utils/Intersections.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
269269
return Double.isNaN(result) ? 0 : result;
270270
}
271271

272-
private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {
272+
public static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {
273273
return weight == skipValue || (skipNan && Double.isNaN(weight));
274274
}
275275

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/CosineTest.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,7 @@ public void cosineSkipStreamTest() {
238238
assertTrue(results.hasNext());
239239
assert01Skip(results.next());
240240
assert02Skip(results.next());
241-
assert03Skip(results.next());
242241
assert12Skip(results.next());
243-
assert13Skip(results.next());
244-
assert23Skip(results.next());
245242
assertFalse(results.hasNext());
246243
}
247244

tests/src/test/java/org/neo4j/graphalgo/algo/similarity/EuclideanTest.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,7 @@ public void eucideanSkipStreamTest() {
259259
assertTrue(results.hasNext());
260260
assert01Skip(results.next());
261261
assert02Skip(results.next());
262-
assert03Skip(results.next());
263262
assert12Skip(results.next());
264-
assert13Skip(results.next());
265-
assert23Skip(results.next());
266263
assertFalse(results.hasNext());
267264
}
268265

0 commit comments

Comments
 (0)