Skip to content

Commit 114d9f4

Browse files
committed
Rename holdOutFraction to testFraction
1 parent 299bf13 commit 114d9f4

File tree

9 files changed

+41
-47
lines changed

9 files changed

+41
-47
lines changed

alpha/alpha-algo/src/main/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationSplitConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public interface NodeClassificationSplitConfig extends ToMapConvertible {
3737

3838
@Value.Default
3939
@Configuration.DoubleRange(min = 0, max = 1)
40-
default double holdoutFraction() {
40+
default double testFraction() {
4141
return 0.3;
4242
}
4343

alpha/alpha-proc/src/main/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationPipelineCompanion.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
public final class NodeClassificationPipelineCompanion {
3333
public static final String PREDICT_DESCRIPTION = "Predicts classes for all nodes based on a previously trained pipeline model";
3434
public static final String PIPELINE_MODEL_TYPE = "Node classification training pipeline";
35-
static final Map<String, Object> DEFAULT_SPLIT_CONFIG = Map.of("holdoutFraction", 0.3, "validationFolds", 3);
35+
static final Map<String, Object> DEFAULT_SPLIT_CONFIG = Map.of("testFraction", 0.3, "validationFolds", 3);
3636
static final List<Map<String, Object>> DEFAULT_PARAM_CONFIG = List.of(
3737
NodeLogisticRegressionTrainCoreConfig.defaultConfig().toMap()
3838
);

alpha/alpha-proc/src/main/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationTrainPipelineExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ NodeClassificationTrainConfig innerConfig() {
115115
.featureProperties(pipeline.featureProperties())
116116
.params(params)
117117
.randomSeed(config.randomSeed())
118-
.holdoutFraction(pipeline.splitConfig().holdoutFraction())
118+
.holdoutFraction(pipeline.splitConfig().testFraction())
119119
.validationFolds(pipeline.splitConfig().validationFolds())
120120
.nodeLabels(config.nodeLabels())
121121
.relationshipTypes(config.relationshipTypes())

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationPipelineConfigureSplitProcTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ void shouldOverrideSingleSplitField() {
6565
@Test
6666
void shouldOnlyKeepLastOverride() {
6767
var expectedSplitConfig = new HashMap<>(NodeClassificationPipelineCompanion.DEFAULT_SPLIT_CONFIG) {{
68-
put("holdoutFraction", 0.5);
68+
put("testFraction", 0.5);
6969
}};
7070
runQuery("CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('myPipeline', {validationFolds: 42})");
7171

7272
assertCypherResult(
73-
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('myPipeline', {holdoutFraction: 0.5})",
73+
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('myPipeline', {testFraction: 0.5})",
7474
List.of(Map.of(
7575
"name", "myPipeline",
7676
"splitConfig", expectedSplitConfig,
@@ -84,8 +84,8 @@ void shouldOnlyKeepLastOverride() {
8484
@Test
8585
void failOnInvalidKeys() {
8686
assertError(
87-
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('myPipeline', {invalidKey: 42, holdMyFraction: -0.51})",
88-
"Unexpected configuration keys: holdMyFraction (Did you mean [holdoutFraction]?), invalidKey"
87+
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('myPipeline', {invalidKey: 42, testMyFraction: -0.51})",
88+
"Unexpected configuration keys: invalidKey, testMyFraction (Did you mean [testFraction]?)"
8989
);
9090
}
9191
}

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationPipelineIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void trainWithNodePropertyStepsAndFeatures() {
124124
runQuery("CALL gds.alpha.ml.pipeline.nodeClassification.addFeatures('p', ['b', 'deg'])");
125125

126126
runQuery("CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('p', {" +
127-
" holdoutFraction: 0.2, " +
127+
" testFraction: 0.2, " +
128128
" validationFolds: 5" +
129129
"})");
130130
runQuery("CALL gds.alpha.ml.pipeline.nodeClassification.configureParams('p', [" +

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationPipelineTest.java

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void overridesTheParameterSpace() {
113113
@Test
114114
void canSetSplitConfig() {
115115
var pipeline = new NodeClassificationPipeline();
116-
var splitConfig = NodeClassificationSplitConfig.builder().holdoutFraction(0.555).build();
116+
var splitConfig = NodeClassificationSplitConfig.builder().testFraction(0.555).build();
117117
pipeline.setSplitConfig(splitConfig);
118118

119119
assertThat(pipeline)
@@ -123,10 +123,10 @@ void canSetSplitConfig() {
123123
@Test
124124
void overridesTheSplitConfig() {
125125
var pipeline = new NodeClassificationPipeline();
126-
var splitConfig = NodeClassificationSplitConfig.builder().holdoutFraction(0.5).build();
126+
var splitConfig = NodeClassificationSplitConfig.builder().testFraction(0.5).build();
127127
pipeline.setSplitConfig(splitConfig);
128128

129-
var splitConfigOverride = NodeClassificationSplitConfig.builder().holdoutFraction(0.7).build();
129+
var splitConfigOverride = NodeClassificationSplitConfig.builder().testFraction(0.7).build();
130130
pipeline.setSplitConfig(splitConfigOverride);
131131

132132
assertThat(pipeline)
@@ -141,14 +141,12 @@ void returnsCorrectDefaultsMap() {
141141
var pipeline = new NodeClassificationPipeline();
142142
assertThat(pipeline.toMap())
143143
.containsOnlyKeys("featurePipeline", "splitConfig", "trainingParameterSpace")
144-
.satisfies(pipelineMap -> {
145-
assertThat(pipelineMap.get("featurePipeline"))
146-
.isInstanceOf(Map.class)
147-
.asInstanceOf(InstanceOfAssertFactories.MAP)
148-
.containsOnlyKeys("nodePropertySteps", "featureSteps")
149-
.returns(List.of(), featurePipelineMap -> featurePipelineMap.get("nodePropertySteps"))
150-
.returns(List.of(), featurePipelineMap -> featurePipelineMap.get("featureSteps"));
151-
})
144+
.satisfies(pipelineMap -> assertThat(pipelineMap.get("featurePipeline"))
145+
.isInstanceOf(Map.class)
146+
.asInstanceOf(InstanceOfAssertFactories.MAP)
147+
.containsOnlyKeys("nodePropertySteps", "featureSteps")
148+
.returns(List.of(), featurePipelineMap -> featurePipelineMap.get("nodePropertySteps"))
149+
.returns(List.of(), featurePipelineMap -> featurePipelineMap.get("featureSteps")))
152150
.returns(
153151
NodeClassificationSplitConfig.DEFAULT_CONFIG.toMap(),
154152
pipelineMap -> pipelineMap.get("splitConfig")
@@ -173,25 +171,23 @@ void returnsCorrectMapWithFullConfiguration() {
173171
NodeLogisticRegressionTrainCoreConfig.of(Map.of("penalty", 1))
174172
));
175173

176-
var splitConfig = NodeClassificationSplitConfig.builder().holdoutFraction(0.5).build();
174+
var splitConfig = NodeClassificationSplitConfig.builder().testFraction(0.5).build();
177175
pipeline.setSplitConfig(splitConfig);
178176

179177
assertThat(pipeline.toMap())
180178
.containsOnlyKeys("featurePipeline", "splitConfig", "trainingParameterSpace")
181-
.satisfies(pipelineMap -> {
182-
assertThat(pipelineMap.get("featurePipeline"))
183-
.isInstanceOf(Map.class)
184-
.asInstanceOf(InstanceOfAssertFactories.MAP)
185-
.containsOnlyKeys("nodePropertySteps", "featureSteps")
186-
.returns(
187-
List.of(pageRankPropertyStep.toMap()),
188-
featurePipelineMap -> featurePipelineMap.get("nodePropertySteps")
189-
)
190-
.returns(
191-
List.of(fooStep.toMap()),
192-
featurePipelineMap -> featurePipelineMap.get("featureSteps")
193-
);
194-
})
179+
.satisfies(pipelineMap -> assertThat(pipelineMap.get("featurePipeline"))
180+
.isInstanceOf(Map.class)
181+
.asInstanceOf(InstanceOfAssertFactories.MAP)
182+
.containsOnlyKeys("nodePropertySteps", "featureSteps")
183+
.returns(
184+
List.of(pageRankPropertyStep.toMap()),
185+
featurePipelineMap -> featurePipelineMap.get("nodePropertySteps")
186+
)
187+
.returns(
188+
List.of(fooStep.toMap()),
189+
featurePipelineMap -> featurePipelineMap.get("featureSteps")
190+
))
195191
.returns(
196192
pipeline.splitConfig().toMap(),
197193
pipelineMap -> pipelineMap.get("splitConfig")
@@ -272,16 +268,14 @@ void deepCopiesParameterSpace() {
272268
@Test
273269
void doesntDeepCopySplitConfig() {
274270
var pipeline = new NodeClassificationPipeline();
275-
var splitConfig = NodeClassificationSplitConfig.builder().holdoutFraction(0.5).build();
271+
var splitConfig = NodeClassificationSplitConfig.builder().testFraction(0.5).build();
276272
pipeline.setSplitConfig(splitConfig);
277273

278274
var copy = pipeline.copy();
279275

280276
assertThat(copy)
281277
.isNotSameAs(pipeline)
282-
.satisfies(copiedPipeline -> {
283-
assertThat(copiedPipeline.splitConfig()).isSameAs(splitConfig);
284-
});
278+
.satisfies(copiedPipeline -> assertThat(copiedPipeline.splitConfig()).isSameAs(splitConfig));
285279
}
286280
}
287281
}

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationTrainPipelineExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void trainsAModel() {
116116
)));
117117

118118
pipeline.setSplitConfig(ImmutableNodeClassificationSplitConfig.builder()
119-
.holdoutFraction(0.01)
119+
.testFraction(0.01)
120120
.validationFolds(2)
121121
.build()
122122
);

alpha/alpha-proc/src/test/java/org/neo4j/gds/ml/nodemodels/pipeline/predict/NodeClassificationPipelineTrainProcTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void train() {
124124
pipe
125125
);
126126
runQuery(
127-
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit($pipeline, {holdoutFraction: 0.01, validationFolds: 2})",
127+
"CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit($pipeline, {testFraction: 0.01, validationFolds: 2})",
128128
pipe
129129
);
130130

@@ -137,7 +137,7 @@ void train() {
137137
Object.class
138138
);
139139

140-
var modelInfoCheck = new Condition<Object>(m -> {
140+
var modelInfoCheck = new Condition<>(m -> {
141141

142142
var modelInfo = assertThat(m).asInstanceOf(soMap)
143143
.containsEntry("modelName", MODEL_NAME)

doc/asciidoc/algorithms/alpha/nodeclassification-pipeline/nodeclassification.adoc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ CALL gds.alpha.ml.pipeline.nodeClassification.create('pipe')
102102
|===
103103
| name | nodePropertySteps | featureSteps | splitConfig | parameterSpace
104104
| "pipe" | [] | []
105-
| {validationFolds=3, holdoutFraction=0.3}
105+
| {validationFolds=3, testFraction=0.3}
106106
| [{maxEpochs=100, minEpochs=1, penalty=0.0, patience=1, batchSize=100, tolerance=0.001}]
107107
|===
108108
--
@@ -281,9 +281,9 @@ YIELD
281281
.Configuration
282282
[opts="header",cols="1,1,1,4"]
283283
|===
284-
| Name | Type | Default | Description
285-
| validationFolds | Integer | 3 | Number of divisions of the training graph used during <<algorithms-ml-nodeclassification-pipelines-train,model selection>>.
286-
| holdoutFraction | Double | 0.3 | Fraction of the graph reserved for testing. Must be in the range (0, 1).
284+
| Name | Type | Default | Description
285+
| validationFolds | Integer | 3 | Number of divisions of the training graph used during <<algorithms-ml-nodeclassification-pipelines-train,model selection>>.
286+
| testFraction | Double | 0.3 | Fraction of the graph reserved for testing. Must be in the range (0, 1). The fraction used for the training is `1 - testFraction`.
287287
|===
288288

289289
include::pipelineInfoResult.adoc[]
@@ -297,7 +297,7 @@ include::pipelineInfoResult.adoc[]
297297
[source, cypher, role=noplay]
298298
----
299299
CALL gds.alpha.ml.pipeline.nodeClassification.configureSplit('pipe', {
300-
holdoutFraction: 0.2,
300+
testFraction: 0.2,
301301
validationFolds: 5
302302
})
303303
YIELD splitConfig
@@ -307,7 +307,7 @@ YIELD splitConfig
307307
[opts="header",cols="1"]
308308
|===
309309
| splitConfig
310-
| {validationFolds=5, holdoutFraction=0.2}
310+
| {validationFolds=5, testFraction=0.2}
311311
|===
312312

313313
We now reconfigured the splitting of the pipeline, which will be applied during <<algorithms-ml-nodeclassification-pipelines-train, training>>.

0 commit comments

Comments
 (0)