Skip to content

Commit 4ae225a

Browse files
authored
[Blocking][jvm-packages] fix the early stopping feature (dmlc#3808)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * temp * add method for classifier and regressor * update tutorial * address the comments * update
1 parent e26b5d6 commit 4ae225a

File tree

7 files changed

+134
-14
lines changed

7 files changed

+134
-14
lines changed

doc/jvm/xgboost4j_spark_tutorial.rst

+9
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,15 @@ After we set XGBoostClassifier parameters and feature/label column, we can build
183183
184184
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
185185
186+
Early Stopping
187+
----------------
188+
189+
Early stopping is a feature to prevent the unnecessary training iterations. By specifying ``num_early_stopping_rounds`` or directly call ``setNumEarlyStoppingRounds`` over a XGBoostClassifier or XGBoostRegressor, we can define number of rounds for the evaluation metric going to the unexpected direction to tolerate before stopping the training.
190+
191+
In additional to ``num_early_stopping_rounds``, you also need to define ``maximize_evaluation_metrics`` or call ``setMaximizeEvaluationMetrics`` to specify whether you want to maximize or minimize the metrics in training.
192+
193+
After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.
194+
186195
Prediction
187196
==========
188197

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ object XGBoost extends Serializable {
132132
try {
133133
val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
134134
.map(_.toString.toInt).getOrElse(0)
135+
if (numEarlyStoppingRounds > 0) {
136+
if (!params.contains("maximize_evaluation_metrics")) {
137+
throw new IllegalArgumentException("maximize_evaluation_metrics has to be specified")
138+
}
139+
}
135140
val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
136141
val booster = SXGBoost.train(watches.train, params, round,
137142
watches.toMap, metrics, obj, eval,

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala

+3
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ class XGBoostClassifier (
140140

141141
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
142142

143+
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
144+
set(maximizeEvaluationMetrics, value)
145+
143146
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
144147

145148
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala

+3
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ class XGBoostRegressor (
140140

141141
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
142142

143+
def setMaximizeEvaluationMetrics(value: Boolean): this.type =
144+
set(maximizeEvaluationMetrics, value)
145+
143146
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
144147

145148
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala

+7
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ private[spark] trait LearningTaskParams extends Params {
8787

8888
final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
8989

90+
91+
final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
92+
"define the expected optimization to the evaluation metrics, true to maximize otherwise" +
93+
" minimize it")
94+
95+
final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
96+
9097
setDefault(objective -> "reg:linear", baseScore -> 0.5,
9198
trainTestRatio -> 1.0, numEarlyStoppingRounds -> 0)
9299
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java

+46-14
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ public static Booster train(
118118
* performance on the validation set.
119119
* @param metrics array containing the evaluation metrics for each matrix in watches for each
120120
* iteration
121-
* @param earlyStoppingRound if non-zero, training would be stopped
121+
* @param earlyStoppingRounds if non-zero, training would be stopped
122122
* after a specified number of consecutive
123-
* increases in any evaluation metric.
123+
* goes to the unexpected direction in any evaluation metric.
124124
* @param obj customized objective
125125
* @param eval customized evaluation
126126
* @param booster train from scratch if set to null; train from an existing booster if not null.
@@ -134,7 +134,7 @@ public static Booster train(
134134
float[][] metrics,
135135
IObjective obj,
136136
IEvaluation eval,
137-
int earlyStoppingRound,
137+
int earlyStoppingRounds,
138138
Booster booster) throws XGBoostError {
139139

140140
//collect eval matrixs
@@ -196,17 +196,14 @@ public static Booster train(
196196
for (int i = 0; i < metricsOut.length; i++) {
197197
metrics[i][iter] = metricsOut[i];
198198
}
199-
200-
boolean decreasing = true;
201-
float[] criterion = metrics[metrics.length - 1];
202-
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
203-
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
204-
}
205-
206-
if (!decreasing) {
207-
Rabit.trackerPrint(String.format(
208-
"early stopping after %d decreasing rounds", earlyStoppingRound));
209-
break;
199+
if (earlyStoppingRounds > 0) {
200+
boolean onTrack = judgeIfTrainingOnTrack(params, earlyStoppingRounds, metrics, iter);
201+
if (!onTrack) {
202+
String reversedDirection = getReversedDirection(params);
203+
Rabit.trackerPrint(String.format(
204+
"early stopping after %d %s rounds", earlyStoppingRounds, reversedDirection));
205+
break;
206+
}
210207
}
211208
if (Rabit.getRank() == 0) {
212209
Rabit.trackerPrint(evalInfo + '\n');
@@ -217,6 +214,41 @@ public static Booster train(
217214
return booster;
218215
}
219216

217+
static boolean judgeIfTrainingOnTrack(
218+
Map<String, Object> params, int earlyStoppingRounds, float[][] metrics, int iter) {
219+
boolean maximizeEvaluationMetrics = getMetricsExpectedDirection(params);
220+
boolean onTrack = false;
221+
float[] criterion = metrics[metrics.length - 1];
222+
for (int shift = 0; shift < Math.min(iter, earlyStoppingRounds) - 1; shift++) {
223+
onTrack |= maximizeEvaluationMetrics ?
224+
criterion[iter - shift] >= criterion[iter - shift - 1] :
225+
criterion[iter - shift] <= criterion[iter - shift - 1];
226+
}
227+
return onTrack;
228+
}
229+
230+
private static String getReversedDirection(Map<String, Object> params) {
231+
String reversedDirection = null;
232+
if (Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
233+
reversedDirection = "descending";
234+
} else if (!Boolean.valueOf(String.valueOf(params.get("maximize_evaluation_metrics")))) {
235+
reversedDirection = "ascending";
236+
}
237+
return reversedDirection;
238+
}
239+
240+
private static boolean getMetricsExpectedDirection(Map<String, Object> params) {
241+
try {
242+
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
243+
assert(maximize != null);
244+
return Boolean.valueOf(maximize);
245+
} catch (Exception ex) {
246+
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
247+
" allowed value: true/false", ex);
248+
throw ex;
249+
}
250+
}
251+
220252
/**
221253
* Cross-validation with given parameters.
222254
*

jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java

+61
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,66 @@ public float eval(float[][] predicts, DMatrix dmat) {
152152
}
153153
}
154154

155+
@Test
156+
public void testDescendMetrics() {
157+
Map<String, Object> paramMap = new HashMap<String, Object>() {
158+
{
159+
put("max_depth", 3);
160+
put("silent", 1);
161+
put("objective", "binary:logistic");
162+
put("maximize_evaluation_metrics", "false");
163+
}
164+
};
165+
float[][] metrics = new float[1][5];
166+
for (int i = 0; i < 5; i++) {
167+
metrics[0][i] = i;
168+
}
169+
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
170+
TestCase.assertFalse(onTrack);
171+
for (int i = 0; i < 5; i++) {
172+
metrics[0][i] = 5 - i;
173+
}
174+
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
175+
TestCase.assertTrue(onTrack);
176+
for (int i = 0; i < 5; i++) {
177+
metrics[0][i] = 5 - i;
178+
}
179+
metrics[0][0] = 1;
180+
metrics[0][2] = 5;
181+
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
182+
TestCase.assertTrue(onTrack);
183+
}
184+
185+
@Test
186+
public void testAscendMetrics() {
187+
Map<String, Object> paramMap = new HashMap<String, Object>() {
188+
{
189+
put("max_depth", 3);
190+
put("silent", 1);
191+
put("objective", "binary:logistic");
192+
put("maximize_evaluation_metrics", "true");
193+
}
194+
};
195+
float[][] metrics = new float[1][5];
196+
for (int i = 0; i < 5; i++) {
197+
metrics[0][i] = i;
198+
}
199+
boolean onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
200+
TestCase.assertTrue(onTrack);
201+
for (int i = 0; i < 5; i++) {
202+
metrics[0][i] = 5 - i;
203+
}
204+
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
205+
TestCase.assertFalse(onTrack);
206+
for (int i = 0; i < 5; i++) {
207+
metrics[0][i] = i;
208+
}
209+
metrics[0][0] = 6;
210+
metrics[0][2] = 1;
211+
onTrack = XGBoost.judgeIfTrainingOnTrack(paramMap, 5, metrics, 4);
212+
TestCase.assertTrue(onTrack);
213+
}
214+
155215
@Test
156216
public void testBoosterEarlyStop() throws XGBoostError, IOException {
157217
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
@@ -162,6 +222,7 @@ public void testBoosterEarlyStop() throws XGBoostError, IOException {
162222
put("max_depth", 3);
163223
put("silent", 1);
164224
put("objective", "binary:logistic");
225+
put("maximize_evaluation_metrics", "false");
165226
}
166227
};
167228
Map<String, DMatrix> watches = new LinkedHashMap<>();

0 commit comments

Comments
 (0)