-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewritten train.py for more user-friendly
- Loading branch information
Showing
3 changed files
with
223 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
buttbread_resnet152_3.h5 | ||
*.h5 | ||
.DS_Store | ||
__pycache__ | ||
__pycache__ | ||
datasets/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
 | ||
[](https://share.streamlit.io/kawaeee/butt_or_bread/) | ||
|
||
* We have seen a popular meme that tries to represent the similarity shared between animal and food such as **"Shiba Inu dog or toasted marshmallow?"** So, We would like to develop the deep learning model that removes the uncertainty of an image that could be like **a loaf of bread or corgi butt**. But for sure, We just do it for fun. | ||
* We have seen a popular meme that tries to represent the similarity shared between animals and food such as **"Shiba Inu dog or toasted marshmallow?"** So, We would like to develop the deep learning model that removes the uncertainty of an image that could be like **a loaf of bread or corgi butt**. But for sure, We just do it for fun. | ||
|
||
* We used PyTorch framework with GPU to develop our model using Google Colaboratory. | ||
* We used the PyTorch framework with GPU to develop our model using Google Colaboratory. | ||
|
||
<img src="https://img-9gag-fun.9cache.com/photo/aYeP537_700b_v2.jpg" width="500" height="500"> | ||
|
||
|
@@ -39,7 +39,7 @@ | |
|**Valid**|0.0132|0.9969| | ||
|**Test**|-|0.9968| | ||
|
||
* We already know that in order to benchmark our model performance, we can't just use `accuracy` and `validation_loss` value as the only acceptable metrics. | ||
* We already know that to benchmark our model performance, we can't just use `accuracy` and `validation_loss` value as the only acceptable metrics. | ||
|
||
#### You can download our model weight here: [v1.2](https://github.com/Kawaeee/butt_or_bread/releases/download/v1.3/buttbread_resnet152_3.h5) | ||
|
||
|
@@ -51,48 +51,60 @@ | |
|Batch Size|32| | ||
|Optimizer|ADAM| | ||
|
||
## Model Reproduction | ||
* In order to reproduce the model, it requires our datasets. You can send me an e-mail at `[email protected]` if you are interested. | ||
## Dataset Preparation | ||
* To reproduce the model, requires our datasets. You can send me an e-mail at `[email protected]` if you are interested. | ||
|
||
- Prepare dataset in these following directory structure | ||
```Bash | ||
└───datasets/ | ||
│ butt/ | ||
│ bread/ | ||
``` | ||
|
||
- Clone this repository | ||
- Install [dataset-split](https://github.com/muriloxyz/dataset-split) library | ||
```bash | ||
git clone https://github.com/Kawaeee/butt_or_bread.git | ||
pip install dataset-split | ||
``` | ||
|
||
- Install dependencies | ||
```Bash | ||
pip install -r requirements.txt | ||
- Execute `dataset-split` command with following arguments on both category | ||
```bash | ||
dataset-split dataset/ -r 0.8 0.1 0.1 | ||
``` | ||
|
||
- Prepare dataset in these following directory structure | ||
- The result will be in this format, and we are ready to proceed to model training | ||
```Bash | ||
butt_or_bread | ||
│ Dockerfile | ||
│ LICENSE | ||
│ README.md | ||
│ requirements.txt | ||
│ train.py | ||
│ | ||
└───datasets/ | ||
│ │ | ||
│ └───train | ||
│ │ │ corgi/ | ||
│ │ │ butt/ | ||
│ │ │ bread/ | ||
│ └───test | ||
│ │ │ corgi/ | ||
│ │ │ butt/ | ||
│ │ │ bread/ | ||
│ └───valid | ||
│ │ │ corgi/ | ||
│ │ │ butt/ | ||
│ │ │ bread/ | ||
│ | ||
``` | ||
|
||
## Model Reproduction | ||
|
||
- Clone this repository | ||
```bash | ||
git clone https://github.com/Kawaeee/butt_or_bread.git | ||
``` | ||
|
||
- Install dependencies | ||
```Bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
- Run the train.py python script | ||
```Bash | ||
python train.py | ||
python train.py --dataset-path datasets/ --model-path buttbread_resnet152_3.h5 | ||
``` | ||
|
||
- Open and run the notebook for prediction: `predictor.ipynb` | ||
- Check jupyter notebook for interactive prediction: `predictor.ipynb` | ||
|
||
## Streamlit Local Reproduction | ||
|
||
|
@@ -111,18 +123,18 @@ | |
streamlit run streamlit_app.py | ||
``` | ||
|
||
- Streamlit web application will be host on http://localhost:8501 | ||
- Streamlit web application will be hosted on http://localhost:8501 | ||
|
||
## Streamlit Docker Reproduction | ||
|
||
- Build Docker image from Dockerfile | ||
```Bash | ||
docker build -t butt_or_bread -f Dockerfile . | ||
docker build -t butt_or_bread -f Dockerfile . | ||
``` | ||
|
||
- Run Docker with exposed port 8501 | ||
```Bash | ||
docker run -p 8501:8501 butt_or_bread | ||
``` | ||
|
||
- Streamlit web application will be host on http://localhost:8501 | ||
- Streamlit web application will be hosted on http://localhost:8501 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import argparse | ||
import os | ||
import time | ||
|
||
from tqdm import tqdm | ||
|
||
import torch | ||
from torchvision import datasets, models, transforms | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class ButtBreadModel: | ||
def __init__(self, device): | ||
self.model = None | ||
self.device = device | ||
self.criterion = None | ||
self.optimizer = None | ||
|
||
def initialize(self): | ||
self.model = models.resnet152(pretrained=True).to(self.device) | ||
|
||
for param in self.model.parameters(): | ||
param.requires_grad = False | ||
|
||
self.model.fc = torch.nn.Sequential( | ||
torch.nn.Linear(2048, 128), | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.Linear(128, 2), | ||
).to(self.device) | ||
|
||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.optimizer = torch.optim.Adam(self.model.fc.parameters()) | ||
|
||
def train(self, image_dataloaders, image_datasets, epochs=1): | ||
for epoch in range(epochs): | ||
time_start = time.monotonic() | ||
print(f"Epoch {epoch + 1}/{epochs}") | ||
|
||
# Phase check | ||
for phase in ["train", "valid"]: | ||
|
||
if phase == "train": | ||
self.model.train() | ||
else: | ||
self.model.eval() | ||
|
||
running_loss = 0.0 | ||
running_corrects = 0 | ||
|
||
# Iterate and try to predict input and check with output generate loss and correct label | ||
for inputs, labels in tqdm(image_dataloaders[phase]): | ||
inputs = inputs.to(self.device) | ||
labels = labels.to(self.device) | ||
|
||
outputs = self.model(inputs) | ||
loss = self.criterion(outputs, labels) | ||
|
||
if phase == "train": | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
_, preds = torch.max(outputs, 1) | ||
running_loss += loss.detach() * inputs.size(0) | ||
running_corrects += torch.sum(preds == labels.data) | ||
|
||
epoch_loss = running_loss / len(image_datasets[phase]) | ||
epoch_acc = running_corrects.float() / len(image_datasets[phase]) | ||
|
||
print(f"{phase} loss: {epoch_loss.item():.4f}, acc: {epoch_acc.item():.4f}") | ||
|
||
print("Runtime: (", "{0:.2f}".format(time.monotonic() - time_start), " seconds)", sep="") | ||
|
||
return self.model | ||
|
||
def test(self, image_dataloaders): | ||
"""Test with test set""" | ||
test_acc_count = 0 | ||
|
||
for k, (test_images, test_labels) in tqdm(enumerate(image_dataloaders["test"])): | ||
test_outputs = self.model(test_images.to(self.device)) | ||
_, prediction = torch.max(test_outputs.data, 1) | ||
test_acc_count += torch.sum(prediction == test_labels.to(self.device).data).item() | ||
|
||
test_accuracy = test_acc_count / len(image_dataloaders["test"]) | ||
print(f"Test acc: {test_accuracy}") | ||
|
||
return test_accuracy | ||
|
||
def save(self, model_path): | ||
"""Saving model weight""" | ||
return torch.save(self.model.state_dict(), model_path) | ||
|
||
def load(self, model_path): | ||
"""Loading model weight""" | ||
return self.model.load_state_dict(torch.load(model_path, map_location=self.device)).eval() | ||
|
||
|
||
def get_dataset(dataset_path: str): | ||
""" | ||
Data transformation steps | ||
Train set :: Resize -> Random affine -> Random horizontal flip -> To Tensor -> Normalize | ||
Valid/test set :: Resize -> To Tensor -> Normalize | ||
""" | ||
data_transformers = { | ||
"train": transforms.Compose( | ||
[ | ||
transforms.Resize((224, 224)), | ||
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225], | ||
), | ||
] | ||
), | ||
"valid": transforms.Compose( | ||
[ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225], | ||
), | ||
] | ||
), | ||
"test": transforms.Compose( | ||
[ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225], | ||
), | ||
] | ||
), | ||
} | ||
|
||
image_datasets = { | ||
"train": datasets.ImageFolder(os.path.join(dataset_path, "train"), data_transformers["train"]), | ||
"valid": datasets.ImageFolder(os.path.join(dataset_path, "valid"), data_transformers["valid"]), | ||
"test": datasets.ImageFolder(os.path.join(dataset_path, "test"), data_transformers["test"]), | ||
} | ||
|
||
image_dataloaders = { | ||
"train": DataLoader(image_datasets["train"], batch_size=32, shuffle=True, num_workers=2), | ||
"valid": DataLoader(image_datasets["valid"], batch_size=32, shuffle=False, num_workers=2), | ||
"test": DataLoader(image_datasets["test"], batch_size=1, shuffle=False, num_workers=2), | ||
} | ||
|
||
return image_datasets, image_dataloaders | ||
|
||
|
||
def main(opt): | ||
dataset_path, model_path, epochs = opt.dataset_path, opt.model_path, opt.epochs | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
image_datasets, image_dataloaders = get_dataset(dataset_path) | ||
|
||
butt_bread_obj = ButtBreadModel(device=device) | ||
butt_bread_obj.initialize() | ||
|
||
butt_bread_obj.train( | ||
image_dataloaders=image_dataloaders, | ||
image_datasets=image_datasets, | ||
epochs=epochs, | ||
) | ||
|
||
butt_bread_obj.test(image_dataloaders=image_dataloaders) | ||
butt_bread_obj.save(model_path=model_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset-path", type=str, default="datasets/", help="Dataset path") | ||
parser.add_argument("--model-path", type=str, default="buttbread_resnet152_3.h5", help="Output model name") | ||
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") | ||
|
||
args = parser.parse_args() | ||
main(args) |