Skip to content

Commit cbe0f9b

Browse files
adamnschbreakanalysisFlorentinD
committed
Make initialSampler optional for LP predict
Co-Authored-By: Jacob Sznajdman <[email protected]> Co-Authored-By: Florentin Dörre <[email protected]>
1 parent 047fe28 commit cbe0f9b

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Arrays;
2323
import java.util.List;
2424
import java.util.Locale;
25+
import java.util.Optional;
2526
import java.util.function.LongPredicate;
2627
import java.util.stream.Collectors;
2728

@@ -67,6 +68,10 @@ else if (input instanceof SamplerType) {
6768
));
6869
}
6970

71+
public static Optional<SamplerType> parseToOptional(String input) {
72+
return Optional.of(parse(input));
73+
}
74+
7075
public static String toString(SamplerType samplerType) {
7176
return samplerType.toString();
7277
}

alpha/alpha-proc/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineBaseConfig.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
public interface LinkPredictionPredictPipelineBaseConfig extends AlgoBaseConfig, SingleThreadedRandomSeedConfig, ModelConfig {
4545

4646
double DEFAULT_THRESHOLD = 0.0;
47+
String MISSING_INITIAL_SAMPLER = "MISSING_VALUE";
4748

4849
//TODO make this a parameter
4950
String graphName();
@@ -75,11 +76,19 @@ default double sampleRate() {
7576
@Configuration.IntegerRange(min = 0)
7677
Optional<Integer> randomJoins();
7778

78-
@Value.Default
79-
@Configuration.ConvertWith("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#parse")
80-
@Configuration.ToMapValue("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#toString")
81-
default KnnSampler.SamplerType initialSampler() {
82-
return KnnSampler.SamplerType.UNIFORM;
79+
default String initialSampler() {
80+
return MISSING_INITIAL_SAMPLER;
81+
}
82+
83+
@Value.Derived
84+
@Configuration.Ignore
85+
default Optional<KnnSampler.SamplerType> derivedSampler() {
86+
String sampler = initialSampler();
87+
if (sampler.equals(MISSING_INITIAL_SAMPLER)) {
88+
return Optional.empty();
89+
}
90+
91+
return KnnSampler.SamplerType.parseToOptional(sampler);
8392
}
8493

8594
@Value.Check
@@ -96,7 +105,7 @@ default void validateParameterCombinations() {
96105
"deltaThreshold", deltaThreshold().isPresent(),
97106
"maxIterations", maxIterations().isPresent(),
98107
"randomJoins", randomJoins().isPresent(),
99-
"initialSampler", randomJoins().isPresent());
108+
"initialSampler", derivedSampler().isPresent());
100109
validateStrategySpecificParameters(approximateStrategyParameters, "less than 1");
101110

102111
topN().orElseThrow(()-> MissingParameterExceptions.missingValueFor("topN", Collections.emptyList()));
@@ -132,13 +141,13 @@ default KnnBaseConfig approximateConfig() {
132141
.sampleRate(sampleRate())
133142
.nodeProperties(List.of(new KnnNodePropertySpec("NotUsedInLP")))
134143
.minBatchSize(LinkPrediction.MIN_NODE_BATCH_SIZE)
135-
.initialSampler(initialSampler())
136144
.concurrency(concurrency());
137145

138146
topK().ifPresent(knnBuilder::topK);
139147
deltaThreshold().ifPresent(knnBuilder::deltaThreshold);
140148
maxIterations().ifPresent(knnBuilder::maxIterations);
141149
randomJoins().ifPresent(knnBuilder::randomJoins);
150+
derivedSampler().ifPresent(knnBuilder::initialSampler);
142151
randomSeed().ifPresent(knnBuilder::randomSeed);
143152

144153
return knnBuilder.build();

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProcTest.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ void shouldPredictWithTopN(int concurrency, String nodeLabel) {
5050
" modelName: 'model'," +
5151
" threshold: 0," +
5252
" topN: $topN," +
53-
" concurrency:" +
54-
" $concurrency" +
53+
" concurrency: $concurrency" +
5554
"})" +
5655
"YIELD node1, node2, probability" +
5756
" RETURN node1, node2, probability" +
@@ -111,8 +110,7 @@ void shouldPredictWithInitialSamplerSet() {
111110
" randomSeed: 42," +
112111
" topK: $topK," +
113112
" initialSampler: 'randomWalk'," +
114-
" concurrency:" +
115-
" $concurrency" +
113+
" concurrency: $concurrency" +
116114
"})" +
117115
"YIELD node1, node2, probability" +
118116
" RETURN node1, node2, probability" +

0 commit comments

Comments
 (0)