diff --git a/exp/run_pretrain.py b/exp/run_pretrain.py index ad049e2..7f20dc8 100755 --- a/exp/run_pretrain.py +++ b/exp/run_pretrain.py @@ -2,6 +2,7 @@ import pytorch_lightning as pl import torch +import torch_xla.core.xla_model as xm from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ( LearningRateMonitor, @@ -10,6 +11,7 @@ ) from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from transformers.optimization import get_linear_schedule_with_warmup from src.data import PretrainCollator, PretrainTask @@ -31,10 +33,21 @@ def __init__( self.model.imagetext.econder.load_state_dict( torch.load("assets/phobert-base-encoder.pt") ) + self.automatic_optimization = False def training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + loss = self.model(**batch).loss self.log("train_loss", loss) + self.manual_backward(loss) + + opt.step() + sch = self.lr_schedulers() + sch.step() + + xm.mark_step() return loss def configure_optimizers(self): @@ -69,18 +82,25 @@ def configure_optimizers(self): pretrain_collator = PretrainCollator( bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size ) + sampler = DistributedSampler( + pretrain_task, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=False, + ) dataloader = DataLoader( pretrain_task, batch_size=8, num_workers=24, collate_fn=pretrain_collator, drop_last=True, + sampler=sampler ) wrapper = Wrapper(config, warmup_ratio=0.2, learn_rate=5.0e-05, use_phobert=False) do_every_n_steps = 1000 - root_dir = "logs" + root_dir = "pls-logs" trainer = Trainer( enable_checkpointing=True, diff --git a/exp/run_uit_viic.py b/exp/run_uit_viic.py index f10e152..4b64e89 100644 --- a/exp/run_uit_viic.py +++ b/exp/run_uit_viic.py @@ -2,6 +2,7 @@ import pytorch_lightning as pl import torch +import torch_xla.core.xla_model as xm from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ( LearningRateMonitor, @@ -10,6 +11,7 @@ ) from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data.distributed import DistributedSampler from transformers.optimization import get_linear_schedule_with_warmup from src.data import ImageCaptioningCollator, ImageTextPair @@ -37,10 +39,20 @@ def __init__( self.model.load_state_dict( torch.load(self.pretrain_model_path_str), strict=False ) + self.automatic_optimization = False def training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + loss = self.model(**batch).loss self.log("train_loss", loss) + + opt.step() + sch = self.lr_schedulers() + sch.step() + + xm.mark_step() return loss def configure_optimizers(self): @@ -81,17 +93,23 @@ def configure_optimizers(self): ic_collator = ImageCaptioningCollator( bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size ) - + sampler = DistributedSampler( + train_val_ic, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=False, + ) dataloader = DataLoader( train_val_ic, batch_size=8, num_workers=24, collate_fn=ic_collator, drop_last=True, + sampler=sampler ) pretrain_model_path_str = ( - "logs/lightning_logs/version_0/checkpoints/imagetext-base.pt" + "pls-logs/lightning_logs/version_4/checkpoints/imagetext-base.pt" ) wrapper = Wrapper( @@ -102,8 +120,8 @@ def configure_optimizers(self): pretrain_model_path_str=pretrain_model_path_str, ) - do_every_n_steps = 32 - root_dir = "uit-viic-logs" + do_every_n_steps = 100 + root_dir = "pls-uit-viic-logs" trainer = Trainer( enable_checkpointing=True, @@ -115,8 +133,7 @@ def configure_optimizers(self): callbacks=[ RichProgressBar(), ModelCheckpoint( - every_n_epochs=1, - save_on_train_epoch_end=True, + every_n_train_steps=do_every_n_steps ), LearningRateMonitor(logging_interval="step"), ], diff --git a/exp/run_vqa.py b/exp/run_vqa.py index 24519a2..35f5f98 100755 --- a/exp/run_vqa.py +++ b/exp/run_vqa.py @@ -2,6 +2,7 @@ import pytorch_lightning as pl import torch +import torch_xla.core.xla_model as xm from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ( LearningRateMonitor, @@ -10,6 +11,7 @@ ) from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from transformers.optimization import get_linear_schedule_with_warmup from src.data import VisualQuestionAnswer, VisualQuestionAnswerCollator @@ -37,10 +39,21 @@ def __init__( self.model.load_state_dict( torch.load(self.pretrain_model_path_str), strict=False ) + self.automatic_optimization = False def training_step(self, batch, batch_idx): + opt = self.optimizers() + opt.zero_grad() + loss = self.model(**batch).loss self.log("train_loss", loss) + self.manual_backward(loss) + + opt.step() + sch = self.lr_schedulers() + sch.step() + + xm.mark_step() return loss def configure_optimizers(self): @@ -75,17 +88,23 @@ def configure_optimizers(self): vqa_collator = VisualQuestionAnswerCollator( bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size ) - + sampler = DistributedSampler( + vqa, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=False, + ) dataloader = DataLoader( vqa, batch_size=8, num_workers=24, collate_fn=vqa_collator, drop_last=True, + sampler=sampler ) pretrain_model_path_str = ( - "logs/lightning_logs/version_0/checkpoints/imagetext-base.pt" + "pls-logs/lightning_logs/version_4/checkpoints/imagetext-base.pt" ) wrapper = Wrapper( @@ -96,8 +115,8 @@ def configure_optimizers(self): pretrain_model_path_str=pretrain_model_path_str, ) - do_every_n_steps = 32 - root_dir = "vivqa-logs" + do_every_n_steps = 100 + root_dir = "pls-vivqa-logs" trainer = Trainer( enable_checkpointing=True, @@ -109,8 +128,7 @@ def configure_optimizers(self): callbacks=[ RichProgressBar(), ModelCheckpoint( - every_n_epochs=1, - save_on_train_epoch_end=True, + every_n_train_steps=do_every_n_steps ), LearningRateMonitor(logging_interval="step"), ],