27
27
train_set = TensorDataset (train_token_ids , train_attention_masks , train_labels )
28
28
val_set = TensorDataset (val_token_ids , val_attention_masks , val_labels )
29
29
30
+ CLOUD_BUCKET = "data_bucket_lmu"
31
+ checkpoint_path = (
32
+ os .path .join ("/gcs" , CLOUD_BUCKET , "checkpoints" )
33
+ if os .path .exists ("/gcs/data_bucket_lmu/" )
34
+ else "mlops_project/checkpoints"
35
+ )
30
36
31
37
# Reproducibility
32
38
seed_everything (47 , workers = True )
@@ -62,13 +68,15 @@ def main():
62
68
train_dataloader = DataLoader (
63
69
train_set ,
64
70
worker_init_fn = seed_worker ,
71
+ num_workers = 7 ,
65
72
generator = g ,
66
73
sampler = RandomSampler (train_set ),
67
74
batch_size = wandb .config .BATCH_SIZE ,
68
75
)
69
76
validation_dataloader = DataLoader (
70
77
val_set ,
71
78
worker_init_fn = seed_worker ,
79
+ num_workers = 7 ,
72
80
generator = g ,
73
81
sampler = SequentialSampler (val_set ),
74
82
batch_size = wandb .config .BATCH_SIZE ,
@@ -78,7 +86,7 @@ def main():
78
86
model = HatespeechModel (wandb .config .LEARNING_RATE )
79
87
80
88
checkpoint_callback = ModelCheckpoint (
81
- monitor = "val_loss" , dirpath = "mlops_project/checkpoints" , filename = "best-checkpoint" , save_top_k = 1 , mode = "min"
89
+ monitor = "val_loss" , dirpath = checkpoint_path , filename = "best-checkpoint" , save_top_k = 1 , mode = "min"
82
90
)
83
91
# early_stopping_callback = EarlyStopping(monitor="val_loss", patience=3, verbose=True, mode="min")
84
92
@@ -97,12 +105,7 @@ def main():
97
105
# Train the model
98
106
trainer .fit (model , train_dataloader , validation_dataloader )
99
107
# save best model as model weights
100
- CLOUD_BUCKET = "data_bucket_lmu"
101
- checkpoint_path = (
102
- os .path .join ("/gcs" , CLOUD_BUCKET , "checkpoints" )
103
- if os .path .exists ("/gcs/data_bucket_lmu/" )
104
- else "mlops_project/checkpoints"
105
- )
108
+
106
109
checkpoint = torch .load (os .path .join (checkpoint_path , "best-checkpoint.ckpt" ))
107
110
state = {key [6 :]: value for key , value in checkpoint ["state_dict" ].items ()}
108
111
weight_path = os .path .join (checkpoint_path , "best-checkpoint.pth" )
0 commit comments