Skip to content

Commit 71e1329

Browse files
authored
Merge pull request #32 from enesozeren/bugfix_vertex
Bugfix vertex
2 parents e8f269d + 776ba98 commit 71e1329

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

cloudbuild/config_gpu.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ workerPoolSpecs:
99
env:
1010
- name: WANDB_API_KEY
1111
value: $WANDB_API_KEY
12-
args: ["--config", "config/config-defaults-sweep.yaml"]
12+
args: ["--config", "config/config-defaults.yaml"]
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
program: train_model.py
2-
name: sweep_XXX # Change sweep name
2+
name: sweep_best # Change sweep name
33
method: grid
44
metric:
55
goal: minimize
@@ -8,6 +8,6 @@ parameters:
88
BATCH_SIZE:
99
values: [16]
1010
EPOCHS:
11-
values: [5]
11+
values: [7]
1212
LEARNING_RATE:
1313
values: [0.000005]

mlops_project/models/.gitkeep

Whitespace-only changes.

mlops_project/train_model.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
train_set = TensorDataset(train_token_ids, train_attention_masks, train_labels)
2828
val_set = TensorDataset(val_token_ids, val_attention_masks, val_labels)
2929

30+
CLOUD_BUCKET = "data_bucket_lmu"
31+
checkpoint_path = (
32+
os.path.join("/gcs", CLOUD_BUCKET, "checkpoints")
33+
if os.path.exists("/gcs/data_bucket_lmu/")
34+
else "mlops_project/checkpoints"
35+
)
3036

3137
# Reproducibility
3238
seed_everything(47, workers=True)
@@ -62,13 +68,15 @@ def main():
6268
train_dataloader = DataLoader(
6369
train_set,
6470
worker_init_fn=seed_worker,
71+
num_workers=7,
6572
generator=g,
6673
sampler=RandomSampler(train_set),
6774
batch_size=wandb.config.BATCH_SIZE,
6875
)
6976
validation_dataloader = DataLoader(
7077
val_set,
7178
worker_init_fn=seed_worker,
79+
num_workers=7,
7280
generator=g,
7381
sampler=SequentialSampler(val_set),
7482
batch_size=wandb.config.BATCH_SIZE,
@@ -78,7 +86,7 @@ def main():
7886
model = HatespeechModel(wandb.config.LEARNING_RATE)
7987

8088
checkpoint_callback = ModelCheckpoint(
81-
monitor="val_loss", dirpath="mlops_project/checkpoints", filename="best-checkpoint", save_top_k=1, mode="min"
89+
monitor="val_loss", dirpath=checkpoint_path, filename="best-checkpoint", save_top_k=1, mode="min"
8290
)
8391
# early_stopping_callback = EarlyStopping(monitor="val_loss", patience=3, verbose=True, mode="min")
8492

@@ -97,12 +105,7 @@ def main():
97105
# Train the model
98106
trainer.fit(model, train_dataloader, validation_dataloader)
99107
# save best model as model weights
100-
CLOUD_BUCKET = "data_bucket_lmu"
101-
checkpoint_path = (
102-
os.path.join("/gcs", CLOUD_BUCKET, "checkpoints")
103-
if os.path.exists("/gcs/data_bucket_lmu/")
104-
else "mlops_project/checkpoints"
105-
)
108+
106109
checkpoint = torch.load(os.path.join(checkpoint_path, "best-checkpoint.ckpt"))
107110
state = {key[6:]: value for key, value in checkpoint["state_dict"].items()}
108111
weight_path = os.path.join(checkpoint_path, "best-checkpoint.pth")

0 commit comments

Comments
 (0)