Skip to content

Commit 548d076

Browse files
authored
Merge pull request #35 from enesozeren/model_weight_fix
model weight usage fix
2 parents 5a00191 + 00a6e59 commit 548d076

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ To build the docker image for inference api, use
128128
docker build -f dockerfiles/inference_api.dockerfile . -t inference_api:latest
129129
```
130130

131+
To build the docker image for prediction, use
132+
```bash
133+
docker build -f dockerfiles/predict_model.dockerfile . -t predict_model:latest
134+
```
135+
136+
To build the docker image for training, use
137+
```bash
138+
docker build -f dockerfiles/train_model.dockerfile . -t train_model:latest
139+
```
140+
131141
### Running Docker Containers
132142

133143
To run the docker image for inference api, use

api/main.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
@asynccontextmanager
1313
async def lifespan(app: FastAPI):
1414
"""Load and clean up model on startup and shutdown."""
15+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1516
# Load the tokenizer
1617
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
1718

18-
# Get the model from the saved checkpoint
19-
19+
# Load the model
2020
model = BertForSequenceClassification.from_pretrained(
21-
"bert-base-uncased", num_labels=2, output_attentions=False, output_hidden_states=False
21+
"bert-base-uncased",
22+
num_labels=2,
23+
output_attentions=False,
24+
output_hidden_states=False,
25+
state_dict=torch.load(MODEL_PATH, map_location=device),
2226
)
23-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24-
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
2527
model.eval()
26-
model.to(device)
2728

2829
# Set the model and tokenizer in the app state
2930
app.state.tokenizer = tokenizer

dockerfiles/predict_model.dockerfile

-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ COPY outputs/ outputs/
1414
# Set environment variable
1515
ENV PYTHONPATH=/lmu-mlops-project
1616

17-
# Do not set the directory to root
18-
RUN pip install . --no-deps --no-cache-dir
19-
2017
# Set the entrypoint to the python script
2118
ENTRYPOINT ["python3", "-u", "mlops_project/predict_model.py"]
2219

mlops_project/models/.gitkeep

Whitespace-only changes.

mlops_project/predict_model.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@ def predict(model_path: str, dataset_path: str) -> None:
1717
Tensor of shape [N, d] where N is the number of samples and d is the output dimension of the model
1818
"""
1919

20+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2021
# Load the tokenizer
2122
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
2223

2324
# Load the model
2425
model = BertForSequenceClassification.from_pretrained(
25-
"bert-base-uncased", num_labels=2, output_attentions=False, output_hidden_states=False
26+
"bert-base-uncased",
27+
num_labels=2,
28+
output_attentions=False,
29+
output_hidden_states=False,
30+
state_dict=torch.load(model_path, map_location=device),
2631
)
27-
28-
# Load the model weights
29-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30-
model.load_state_dict(torch.load(model_path, map_location=device))
3132
model.eval()
3233

3334
# Read the dataset

0 commit comments

Comments
 (0)