-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Fold v2.XGBoostTrainer
API into the public trainer class as an alternate constructor
#50045
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very elegant!
@@ -67,7 +67,7 @@ def train_fn_per_worker(config: dict): | |||
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) | |||
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)]) | |||
trainer = XGBoostTrainer( | |||
train_fn_per_worker, | |||
train_loop_per_worker=train_fn_per_worker, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the long term this is the main API change, right? That this needs to be a kwarg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually can we just keep this as a required argument since the V1 will always populate it? Are there any problems if we do?
# TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API | ||
label_column: Optional[str] = None, | ||
params: Optional[Dict[str, Any]] = None, | ||
num_boost_round: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these needed for the V2 API? These are not passed in from the V1 API.
|
||
num_boost_round = num_boost_round or 10 | ||
|
||
_log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's hold off on explicitly logging that it's deprecated until we have the GH issue and documentation for V2 published?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
scaling_config=ray.train.ScalingConfig(num_workers=4), | ||
) | ||
result = trainer.fit() | ||
booster = RayTrainReportCallback.get_model(result.checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a staticmethod
of the RayTrainReportCallback
? A side question of mine: does this get_model
function best to be defined under a callback. Why not make it a utility under ray.train.xgboost
if it is not an instance method of RayTrainReportCallback
that will use instance attributes.
Summary
Currently, the new
XGBoostTrainer
API is only accessible with a separate importray.train.xgboost.v2.XGBoostTrainer
.To avoid unnecessary import changes, this PR folds the new API, which accepts new arguments
(train_loop_per_worker, train_loop_config, xgboost_config)
, into the publicray.train.xgboost.XGBoostTrainer
class.This also makes some changes in the Ray Train v2
XGBoostTrainer
class to improve the migration UX, since it does not support the legacyXGBoostTrainer
API at all.TODO
XGBoostTrainer
andLightGBMTrainer
API revamps #50042