Skip to content

Commit 37fca30

Browse files
authored
Merge pull request #33 from enesozeren/train_script_fix
Train script fix
2 parents 71e1329 + 70589a8 commit 37fca30

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

README.md

+11-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Note: You need GCP bucket permissions to be able to run this command
8686
Predictions from this script are saved to outputs directory. To make a prediction, use
8787
```bash
8888
python mlops_project/predict_model.py \
89-
--model_path=/your/model/path.txt \
89+
--model_path=/your/model/path.pth \
9090
--dataset_path=/your/data/path.txt
9191
```
9292

@@ -137,13 +137,20 @@ docker run -p 8080:8080 -e PORT=8080 inference_api:latest
137137

138138
You can also use the predict_model docker image by mounting with your machine for your model weights and dataset
139139
```bash
140-
docker run -v /home/user/models:/container/models \
141-
-v /home/user/data:/container/data \
140+
docker run -v /to/your/model/weight/path/best-checkpoint.pth:/container/models/best-checkpoint.pth \
141+
-v /to/your/test_path/test_text.txt:/container/data/test_text.txt \
142+
-v /to/your/outputs/predictions:/lmu-mlops-project/outputs/predictions \
142143
predict_model:latest \
143-
--model_path /container/models/model.pth \
144+
--model_path /container/models/best-checkpoint.pth \
144145
--dataset_path /container/data/test_text.txt
145146
```
146147

148+
To run training docker container use:
149+
```bash
150+
docker run -e WANDB_API_KEY=your_wandb_api_key \
151+
train_model:latest --config=mlops_project/config/config-defaults.yaml
152+
```
153+
147154
## Tests
148155

149156
Unit tests for this repo can be found in the ``tests/`` directory.

dockerfiles/train_model.dockerfile

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# Base image
22
FROM hatespeech-base
33

4+
WORKDIR /lmu-mlops-project
5+
46
COPY pyproject.toml pyproject.toml
57
COPY mlops_project/ mlops_project/
68
COPY utils/ mlops_project/utils/
79
COPY data/ data/
810

9-
WORKDIR /mlops_project
10-
ENTRYPOINT ["python3", "-u", "train_model.py"]
11+
ENTRYPOINT ["python3", "-u", "mlops_project/train_model.py"]

mlops_project/predict_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def predict(model_path: str, dataset_path: str) -> None:
2626
)
2727

2828
# Load the model weights
29-
model.load_state_dict(torch.load(model_path))
29+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30+
model.load_state_dict(torch.load(model_path, map_location=device))
3031
model.eval()
3132

3233
# Set the device to GPU if available

mlops_project/train_model.py

-2
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@ def main():
6868
train_dataloader = DataLoader(
6969
train_set,
7070
worker_init_fn=seed_worker,
71-
num_workers=7,
7271
generator=g,
7372
sampler=RandomSampler(train_set),
7473
batch_size=wandb.config.BATCH_SIZE,
7574
)
7675
validation_dataloader = DataLoader(
7776
val_set,
7877
worker_init_fn=seed_worker,
79-
num_workers=7,
8078
generator=g,
8179
sampler=SequentialSampler(val_set),
8280
batch_size=wandb.config.BATCH_SIZE,

0 commit comments

Comments
 (0)