Skip to content

Commit 05c406d

Browse files
adamnschFlorentinDbreakanalysis
committed
Simplify initialSampler in LP predict config
Co-Authored-By: Florentin Dörre <[email protected]> Co-Authored-By: Jacob Sznajdman <[email protected]>
1 parent cbe0f9b commit 05c406d

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.Arrays;
2323
import java.util.List;
2424
import java.util.Locale;
25-
import java.util.Optional;
2625
import java.util.function.LongPredicate;
2726
import java.util.stream.Collectors;
2827

@@ -68,10 +67,6 @@ else if (input instanceof SamplerType) {
6867
));
6968
}
7069

71-
public static Optional<SamplerType> parseToOptional(String input) {
72-
return Optional.of(parse(input));
73-
}
74-
7570
public static String toString(SamplerType samplerType) {
7671
return samplerType.toString();
7772
}

alpha/alpha-proc/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineBaseConfig.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
public interface LinkPredictionPredictPipelineBaseConfig extends AlgoBaseConfig, SingleThreadedRandomSeedConfig, ModelConfig {
4545

4646
double DEFAULT_THRESHOLD = 0.0;
47-
String MISSING_INITIAL_SAMPLER = "MISSING_VALUE";
4847

4948
//TODO make this a parameter
5049
String graphName();
@@ -76,19 +75,12 @@ default double sampleRate() {
7675
@Configuration.IntegerRange(min = 0)
7776
Optional<Integer> randomJoins();
7877

79-
default String initialSampler() {
80-
return MISSING_INITIAL_SAMPLER;
81-
}
78+
Optional<String> initialSampler();
8279

8380
@Value.Derived
8481
@Configuration.Ignore
85-
default Optional<KnnSampler.SamplerType> derivedSampler() {
86-
String sampler = initialSampler();
87-
if (sampler.equals(MISSING_INITIAL_SAMPLER)) {
88-
return Optional.empty();
89-
}
90-
91-
return KnnSampler.SamplerType.parseToOptional(sampler);
82+
default Optional<KnnSampler.SamplerType> derivedInitialSampler() {
83+
return initialSampler().map(KnnSampler.SamplerType::parse);
9284
}
9385

9486
@Value.Check
@@ -105,7 +97,7 @@ default void validateParameterCombinations() {
10597
"deltaThreshold", deltaThreshold().isPresent(),
10698
"maxIterations", maxIterations().isPresent(),
10799
"randomJoins", randomJoins().isPresent(),
108-
"initialSampler", derivedSampler().isPresent());
100+
"initialSampler", derivedInitialSampler().isPresent());
109101
validateStrategySpecificParameters(approximateStrategyParameters, "less than 1");
110102

111103
topN().orElseThrow(()-> MissingParameterExceptions.missingValueFor("topN", Collections.emptyList()));
@@ -147,7 +139,7 @@ default KnnBaseConfig approximateConfig() {
147139
deltaThreshold().ifPresent(knnBuilder::deltaThreshold);
148140
maxIterations().ifPresent(knnBuilder::maxIterations);
149141
randomJoins().ifPresent(knnBuilder::randomJoins);
150-
derivedSampler().ifPresent(knnBuilder::initialSampler);
142+
derivedInitialSampler().ifPresent(knnBuilder::initialSampler);
151143
randomSeed().ifPresent(knnBuilder::randomSeed);
152144

153145
return knnBuilder.build();

0 commit comments

Comments
 (0)