Skip to content

Commit

Permalink
rewritten train.py for more user-friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
Kawaeee committed Feb 10, 2022
1 parent bd6be4b commit 266d484
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 28 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
buttbread_resnet152_3.h5
*.h5
.DS_Store
__pycache__
__pycache__
datasets/
64 changes: 38 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
![Visitor Badge](https://visitor-badge.glitch.me/badge?page_id=Kawaeee.butt_or_bread.visitor-badge)
[![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/kawaeee/butt_or_bread/)

* We have seen a popular meme that tries to represent the similarity shared between animal and food such as **"Shiba Inu dog or toasted marshmallow?"** So, We would like to develop the deep learning model that removes the uncertainty of an image that could be like **a loaf of bread or corgi butt**. But for sure, We just do it for fun.
* We have seen a popular meme that tries to represent the similarity shared between animals and food such as **"Shiba Inu dog or toasted marshmallow?"** So, We would like to develop the deep learning model that removes the uncertainty of an image that could be like **a loaf of bread or corgi butt**. But for sure, We just do it for fun.

* We used PyTorch framework with GPU to develop our model using Google Colaboratory.
* We used the PyTorch framework with GPU to develop our model using Google Colaboratory.

<img src="https://img-9gag-fun.9cache.com/photo/aYeP537_700b_v2.jpg" width="500" height="500">

Expand Down Expand Up @@ -39,7 +39,7 @@
|**Valid**|0.0132|0.9969|
|**Test**|-|0.9968|

* We already know that in order to benchmark our model performance, we can't just use `accuracy` and `validation_loss` value as the only acceptable metrics.
* We already know that to benchmark our model performance, we can't just use `accuracy` and `validation_loss` value as the only acceptable metrics.

#### You can download our model weight here: [v1.2](https://github.com/Kawaeee/butt_or_bread/releases/download/v1.3/buttbread_resnet152_3.h5)

Expand All @@ -51,48 +51,60 @@
|Batch Size|32|
|Optimizer|ADAM|

## Model Reproduction
* In order to reproduce the model, it requires our datasets. You can send me an e-mail at `[email protected]` if you are interested.
## Dataset Preparation
* To reproduce the model, requires our datasets. You can send me an e-mail at `[email protected]` if you are interested.

- Prepare dataset in these following directory structure
```Bash
└───datasets/
│ butt/
│ bread/
```

- Clone this repository
- Install [dataset-split](https://github.com/muriloxyz/dataset-split) library
```bash
git clone https://github.com/Kawaeee/butt_or_bread.git
pip install dataset-split
```

- Install dependencies
```Bash
pip install -r requirements.txt
- Execute `dataset-split` command with following arguments on both category
```bash
dataset-split dataset/ -r 0.8 0.1 0.1
```

- Prepare dataset in these following directory structure
- The result will be in this format, and we are ready to proceed to model training
```Bash
butt_or_bread
│ Dockerfile
│ LICENSE
│ README.md
│ requirements.txt
│ train.py
└───datasets/
│ │
│ └───train
│ │ │ corgi/
│ │ │ butt/
│ │ │ bread/
│ └───test
│ │ │ corgi/
│ │ │ butt/
│ │ │ bread/
│ └───valid
│ │ │ corgi/
│ │ │ butt/
│ │ │ bread/
```

## Model Reproduction

- Clone this repository
```bash
git clone https://github.com/Kawaeee/butt_or_bread.git
```

- Install dependencies
```Bash
pip install -r requirements.txt
```

- Run the train.py python script
```Bash
python train.py
python train.py --dataset-path datasets/ --model-path buttbread_resnet152_3.h5
```

- Open and run the notebook for prediction: `predictor.ipynb`
- Check jupyter notebook for interactive prediction: `predictor.ipynb`

## Streamlit Local Reproduction

Expand All @@ -111,18 +123,18 @@
streamlit run streamlit_app.py
```

- Streamlit web application will be host on http://localhost:8501
- Streamlit web application will be hosted on http://localhost:8501

## Streamlit Docker Reproduction

- Build Docker image from Dockerfile
```Bash
docker build -t butt_or_bread -f Dockerfile .
docker build -t butt_or_bread -f Dockerfile .
```

- Run Docker with exposed port 8501
```Bash
docker run -p 8501:8501 butt_or_bread
```

- Streamlit web application will be host on http://localhost:8501
- Streamlit web application will be hosted on http://localhost:8501
182 changes: 182 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import argparse
import os
import time

from tqdm import tqdm

import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader


class ButtBreadModel:
def __init__(self, device):
self.model = None
self.device = device
self.criterion = None
self.optimizer = None

def initialize(self):
self.model = models.resnet152(pretrained=True).to(self.device)

for param in self.model.parameters():
param.requires_grad = False

self.model.fc = torch.nn.Sequential(
torch.nn.Linear(2048, 128),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(128, 2),
).to(self.device)

self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.fc.parameters())

def train(self, image_dataloaders, image_datasets, epochs=1):
for epoch in range(epochs):
time_start = time.monotonic()
print(f"Epoch {epoch + 1}/{epochs}")

# Phase check
for phase in ["train", "valid"]:

if phase == "train":
self.model.train()
else:
self.model.eval()

running_loss = 0.0
running_corrects = 0

# Iterate and try to predict input and check with output generate loss and correct label
for inputs, labels in tqdm(image_dataloaders[phase]):
inputs = inputs.to(self.device)
labels = labels.to(self.device)

outputs = self.model(inputs)
loss = self.criterion(outputs, labels)

if phase == "train":
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

_, preds = torch.max(outputs, 1)
running_loss += loss.detach() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

epoch_loss = running_loss / len(image_datasets[phase])
epoch_acc = running_corrects.float() / len(image_datasets[phase])

print(f"{phase} loss: {epoch_loss.item():.4f}, acc: {epoch_acc.item():.4f}")

print("Runtime: (", "{0:.2f}".format(time.monotonic() - time_start), " seconds)", sep="")

return self.model

def test(self, image_dataloaders):
"""Test with test set"""
test_acc_count = 0

for k, (test_images, test_labels) in tqdm(enumerate(image_dataloaders["test"])):
test_outputs = self.model(test_images.to(self.device))
_, prediction = torch.max(test_outputs.data, 1)
test_acc_count += torch.sum(prediction == test_labels.to(self.device).data).item()

test_accuracy = test_acc_count / len(image_dataloaders["test"])
print(f"Test acc: {test_accuracy}")

return test_accuracy

def save(self, model_path):
"""Saving model weight"""
return torch.save(self.model.state_dict(), model_path)

def load(self, model_path):
"""Loading model weight"""
return self.model.load_state_dict(torch.load(model_path, map_location=self.device)).eval()


def get_dataset(dataset_path: str):
"""
Data transformation steps
Train set :: Resize -> Random affine -> Random horizontal flip -> To Tensor -> Normalize
Valid/test set :: Resize -> To Tensor -> Normalize
"""
data_transformers = {
"train": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
"valid": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
"test": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
}

image_datasets = {
"train": datasets.ImageFolder(os.path.join(dataset_path, "train"), data_transformers["train"]),
"valid": datasets.ImageFolder(os.path.join(dataset_path, "valid"), data_transformers["valid"]),
"test": datasets.ImageFolder(os.path.join(dataset_path, "test"), data_transformers["test"]),
}

image_dataloaders = {
"train": DataLoader(image_datasets["train"], batch_size=32, shuffle=True, num_workers=2),
"valid": DataLoader(image_datasets["valid"], batch_size=32, shuffle=False, num_workers=2),
"test": DataLoader(image_datasets["test"], batch_size=1, shuffle=False, num_workers=2),
}

return image_datasets, image_dataloaders


def main(opt):
dataset_path, model_path, epochs = opt.dataset_path, opt.model_path, opt.epochs

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_datasets, image_dataloaders = get_dataset(dataset_path)

butt_bread_obj = ButtBreadModel(device=device)
butt_bread_obj.initialize()

butt_bread_obj.train(
image_dataloaders=image_dataloaders,
image_datasets=image_datasets,
epochs=epochs,
)

butt_bread_obj.test(image_dataloaders=image_dataloaders)
butt_bread_obj.save(model_path=model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="datasets/", help="Dataset path")
parser.add_argument("--model-path", type=str, default="buttbread_resnet152_3.h5", help="Output model name")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")

args = parser.parse_args()
main(args)

0 comments on commit 266d484

Please sign in to comment.