Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.

Commit fc3a398

Browse files
asanakoyfacebook-github-bot
authored andcommitted
Pass gradient clipping and mixed precision params to the lightning Trainer
Summary: Pull Request resolved: #374 AMP trained with mixed precision is implemented for the Native d2go Runner, but not for Lightning Tasks. Now we pass params SOLVER.AMP* and SOLVER.CLIP_GRADIENTS* to the lightning Trainer as well. Reviewed By: wat3rBro Differential Revision: D39798007 fbshipit-source-id: e48560a91d37c21c56d953eed141876d8c759329
1 parent dc176d5 commit fc3a398

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tools/lightning_train_net.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
6363
strategy = _get_strategy(cfg)
6464
accelerator = _get_accelerator(use_cpu)
6565

66-
return {
66+
params = {
6767
"max_epochs": -1,
6868
"max_steps": cfg.SOLVER.MAX_ITER,
6969
"val_check_interval": cfg.TEST.EVAL_PERIOD
@@ -77,7 +77,20 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
7777
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
7878
"num_sanity_val_steps": 0,
7979
"replace_sampler_ddp": False,
80+
"precision": "mixed" if cfg.SOLVER.AMP.ENABLED else 32,
8081
}
82+
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
83+
if (
84+
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE.lower() == "norm"
85+
and cfg.SOLVER.CLIP_GRADIENTS.NORM_TYPE != 2.0
86+
):
87+
raise ValueError(
88+
"D2Go Lightning backend supports only L2-norm for norm-based gradient clipping!"
89+
)
90+
params["gradient_clip_val"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
91+
params["gradient_clip_algorithm"] = cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE
92+
93+
return params
8194

8295

8396
def main(

0 commit comments

Comments
 (0)