diff --git a/mini_lightning/_mini_lightning.py b/mini_lightning/_mini_lightning.py index 4af65c2..5133036 100644 --- a/mini_lightning/_mini_lightning.py +++ b/mini_lightning/_mini_lightning.py @@ -850,10 +850,10 @@ def _train_epoch(self, dataloader: DataLoader, val_dataloader: Optional[DataLoad prog_bar_mes = self._get_res_mes(_mean_metrics, _rec_mes, "prog_bar") if self.rank >= 0: prog_bar_mes = self._reduce_mes(prog_bar_mes, device) - # - if self.version is not None: + # + if self.rank in {-1, 0}: prog_bar_mes["v"] = self.version - prog_bar_mes["global_step"] = str(prog_bar_mes["global_step"]) + prog_bar_mes["global_step"] = str(int(prog_bar_mes["global_step"])) prog_bar.set_postfix(prog_bar_mes, refresh=False) # rank > 0 disable. prog_bar.update(self.prog_bar_n_steps) # tensorboard @@ -865,7 +865,8 @@ def _train_epoch(self, dataloader: DataLoader, val_dataloader: Optional[DataLoad # val if mc.val_mode == "step" and self.global_step % mc.val_every_n == 0: res_mes = self._get_res_mes(_mean_metrics, _rec_mes, "result") - prog_bar.fp.write("\n") + if not prog_bar.disable: + prog_bar.fp.write("\n") prog_bar.refresh() self._val_and_save_after_train(val_dataloader, res_mes) #