44
44
public interface LinkPredictionPredictPipelineBaseConfig extends AlgoBaseConfig , SingleThreadedRandomSeedConfig , ModelConfig {
45
45
46
46
double DEFAULT_THRESHOLD = 0.0 ;
47
+ String MISSING_INITIAL_SAMPLER = "MISSING_VALUE" ;
47
48
48
49
//TODO make this a parameter
49
50
String graphName ();
@@ -75,11 +76,19 @@ default double sampleRate() {
75
76
@ Configuration .IntegerRange (min = 0 )
76
77
Optional <Integer > randomJoins ();
77
78
78
- @ Value .Default
79
- @ Configuration .ConvertWith ("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#parse" )
80
- @ Configuration .ToMapValue ("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#toString" )
81
- default KnnSampler .SamplerType initialSampler () {
82
- return KnnSampler .SamplerType .UNIFORM ;
79
+ default String initialSampler () {
80
+ return MISSING_INITIAL_SAMPLER ;
81
+ }
82
+
83
+ @ Value .Derived
84
+ @ Configuration .Ignore
85
+ default Optional <KnnSampler .SamplerType > derivedSampler () {
86
+ String sampler = initialSampler ();
87
+ if (sampler .equals (MISSING_INITIAL_SAMPLER )) {
88
+ return Optional .empty ();
89
+ }
90
+
91
+ return KnnSampler .SamplerType .parseToOptional (sampler );
83
92
}
84
93
85
94
@ Value .Check
@@ -96,7 +105,7 @@ default void validateParameterCombinations() {
96
105
"deltaThreshold" , deltaThreshold ().isPresent (),
97
106
"maxIterations" , maxIterations ().isPresent (),
98
107
"randomJoins" , randomJoins ().isPresent (),
99
- "initialSampler" , randomJoins ().isPresent ());
108
+ "initialSampler" , derivedSampler ().isPresent ());
100
109
validateStrategySpecificParameters (approximateStrategyParameters , "less than 1" );
101
110
102
111
topN ().orElseThrow (()-> MissingParameterExceptions .missingValueFor ("topN" , Collections .emptyList ()));
@@ -132,13 +141,13 @@ default KnnBaseConfig approximateConfig() {
132
141
.sampleRate (sampleRate ())
133
142
.nodeProperties (List .of (new KnnNodePropertySpec ("NotUsedInLP" )))
134
143
.minBatchSize (LinkPrediction .MIN_NODE_BATCH_SIZE )
135
- .initialSampler (initialSampler ())
136
144
.concurrency (concurrency ());
137
145
138
146
topK ().ifPresent (knnBuilder ::topK );
139
147
deltaThreshold ().ifPresent (knnBuilder ::deltaThreshold );
140
148
maxIterations ().ifPresent (knnBuilder ::maxIterations );
141
149
randomJoins ().ifPresent (knnBuilder ::randomJoins );
150
+ derivedSampler ().ifPresent (knnBuilder ::initialSampler );
142
151
randomSeed ().ifPresent (knnBuilder ::randomSeed );
143
152
144
153
return knnBuilder .build ();
0 commit comments