Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Fold v2.XGBoostTrainer API into the public trainer class as an alternate constructor #50045

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

justinvyu
Copy link
Contributor

Summary

Currently, the new XGBoostTrainer API is only accessible with a separate import ray.train.xgboost.v2.XGBoostTrainer.

To avoid unnecessary import changes, this PR folds the new API, which accepts new arguments (train_loop_per_worker, train_loop_config, xgboost_config), into the public ray.train.xgboost.XGBoostTrainer class.

This also makes some changes in the Ray Train v2 XGBoostTrainer class to improve the migration UX, since it does not support the legacy XGBoostTrainer API at all.

TODO

Signed-off-by: Justin Yu <[email protected]>
Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very elegant!

@@ -67,7 +67,7 @@ def train_fn_per_worker(config: dict):
train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
trainer = XGBoostTrainer(
train_fn_per_worker,
train_loop_per_worker=train_fn_per_worker,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the long term this is the main API change, right? That this needs to be a kwarg.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can we just keep this as a required argument since the V1 will always populate it? Are there any problems if we do?

Comment on lines +130 to +133
# TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
label_column: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
num_boost_round: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these needed for the V2 API? These are not passed in from the V1 API.


num_boost_round = num_boost_round or 10

_log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold off on explicitly logging that it's deprecated until we have the GH issue and documentation for V2 published?

Copy link
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

scaling_config=ray.train.ScalingConfig(num_workers=4),
)
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)
Copy link
Contributor

@hongpeng-guo hongpeng-guo Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a staticmethod of the RayTrainReportCallback? A side question of mine: does this get_model function best to be defined under a callback. Why not make it a utility under ray.train.xgboost if it is not an instance method of RayTrainReportCallback that will use instance attributes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants