Skip to content

Commit

Permalink
Refactor to slow solution
Browse files Browse the repository at this point in the history
  • Loading branch information
dinhanhx committed Oct 29, 2023
1 parent 9b9a7f2 commit c8cefcb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
22 changes: 21 additions & 1 deletion exp/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytorch_lightning as pl
import torch
import torch_xla.core.xla_model as xm
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
LearningRateMonitor,
Expand All @@ -10,6 +11,7 @@
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers.optimization import get_linear_schedule_with_warmup

from src.data import PretrainCollator, PretrainTask
Expand All @@ -31,10 +33,21 @@ def __init__(
self.model.imagetext.econder.load_state_dict(
torch.load("assets/phobert-base-encoder.pt")
)
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()

loss = self.model(**batch).loss
self.log("train_loss", loss)
self.manual_backward(loss)

opt.step()
sch = self.lr_schedulers()
sch.step()

xm.mark_step()
return loss

def configure_optimizers(self):
Expand Down Expand Up @@ -69,18 +82,25 @@ def configure_optimizers(self):
pretrain_collator = PretrainCollator(
bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size
)
sampler = DistributedSampler(
pretrain_task,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False,
)
dataloader = DataLoader(
pretrain_task,
batch_size=8,
num_workers=24,
collate_fn=pretrain_collator,
drop_last=True,
sampler=sampler
)

wrapper = Wrapper(config, warmup_ratio=0.2, learn_rate=5.0e-05, use_phobert=False)

do_every_n_steps = 1000
root_dir = "logs"
root_dir = "pls-logs"

trainer = Trainer(
enable_checkpointing=True,
Expand Down
29 changes: 23 additions & 6 deletions exp/run_uit_viic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytorch_lightning as pl
import torch
import torch_xla.core.xla_model as xm
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
LearningRateMonitor,
Expand All @@ -10,6 +11,7 @@
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers.optimization import get_linear_schedule_with_warmup

from src.data import ImageCaptioningCollator, ImageTextPair
Expand Down Expand Up @@ -37,10 +39,20 @@ def __init__(
self.model.load_state_dict(
torch.load(self.pretrain_model_path_str), strict=False
)
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()

loss = self.model(**batch).loss
self.log("train_loss", loss)

opt.step()
sch = self.lr_schedulers()
sch.step()

xm.mark_step()
return loss

def configure_optimizers(self):
Expand Down Expand Up @@ -81,17 +93,23 @@ def configure_optimizers(self):
ic_collator = ImageCaptioningCollator(
bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size
)

sampler = DistributedSampler(
train_val_ic,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False,
)
dataloader = DataLoader(
train_val_ic,
batch_size=8,
num_workers=24,
collate_fn=ic_collator,
drop_last=True,
sampler=sampler
)

pretrain_model_path_str = (
"logs/lightning_logs/version_0/checkpoints/imagetext-base.pt"
"pls-logs/lightning_logs/version_4/checkpoints/imagetext-base.pt"
)

wrapper = Wrapper(
Expand All @@ -102,8 +120,8 @@ def configure_optimizers(self):
pretrain_model_path_str=pretrain_model_path_str,
)

do_every_n_steps = 32
root_dir = "uit-viic-logs"
do_every_n_steps = 100
root_dir = "pls-uit-viic-logs"

trainer = Trainer(
enable_checkpointing=True,
Expand All @@ -115,8 +133,7 @@ def configure_optimizers(self):
callbacks=[
RichProgressBar(),
ModelCheckpoint(
every_n_epochs=1,
save_on_train_epoch_end=True,
every_n_train_steps=do_every_n_steps
),
LearningRateMonitor(logging_interval="step"),
],
Expand Down
30 changes: 24 additions & 6 deletions exp/run_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytorch_lightning as pl
import torch
import torch_xla.core.xla_model as xm
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
LearningRateMonitor,
Expand All @@ -10,6 +11,7 @@
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers.optimization import get_linear_schedule_with_warmup

from src.data import VisualQuestionAnswer, VisualQuestionAnswerCollator
Expand Down Expand Up @@ -37,10 +39,21 @@ def __init__(
self.model.load_state_dict(
torch.load(self.pretrain_model_path_str), strict=False
)
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()

loss = self.model(**batch).loss
self.log("train_loss", loss)
self.manual_backward(loss)

opt.step()
sch = self.lr_schedulers()
sch.step()

xm.mark_step()
return loss

def configure_optimizers(self):
Expand Down Expand Up @@ -75,17 +88,23 @@ def configure_optimizers(self):
vqa_collator = VisualQuestionAnswerCollator(
bun_tokenizer, image_size=config.image_size, patch_size=config.patch_size
)

sampler = DistributedSampler(
vqa,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False,
)
dataloader = DataLoader(
vqa,
batch_size=8,
num_workers=24,
collate_fn=vqa_collator,
drop_last=True,
sampler=sampler
)

pretrain_model_path_str = (
"logs/lightning_logs/version_0/checkpoints/imagetext-base.pt"
"pls-logs/lightning_logs/version_4/checkpoints/imagetext-base.pt"
)

wrapper = Wrapper(
Expand All @@ -96,8 +115,8 @@ def configure_optimizers(self):
pretrain_model_path_str=pretrain_model_path_str,
)

do_every_n_steps = 32
root_dir = "vivqa-logs"
do_every_n_steps = 100
root_dir = "pls-vivqa-logs"

trainer = Trainer(
enable_checkpointing=True,
Expand All @@ -109,8 +128,7 @@ def configure_optimizers(self):
callbacks=[
RichProgressBar(),
ModelCheckpoint(
every_n_epochs=1,
save_on_train_epoch_end=True,
every_n_train_steps=do_every_n_steps
),
LearningRateMonitor(logging_interval="step"),
],
Expand Down

0 comments on commit c8cefcb

Please sign in to comment.