From e0102d765d3dc9a71788da2f6deede67e1d01ca3 Mon Sep 17 00:00:00 2001 From: omar-abdelgawad Date: Wed, 8 May 2024 03:38:08 +0300 Subject: [PATCH] Created pix2pixpredictor class --- src/img2img/cfg/__init__.py | 3 +- src/img2img/models/pix2pix/predictor.py | 51 +++++++++++++++++++++++++ src/img2img/models/pix2pix/train.py | 1 + src/img2img/models/pix2pix/trainer.py | 3 +- src/img2img/utils/__init__.py | 1 - 5 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 src/img2img/models/pix2pix/predictor.py diff --git a/src/img2img/cfg/__init__.py b/src/img2img/cfg/__init__.py index 8b80c3f..fcbabc8 100644 --- a/src/img2img/cfg/__init__.py +++ b/src/img2img/cfg/__init__.py @@ -1,4 +1,5 @@ -""" Configuration module Interface. """ +"""Configuration module Interface.""" + from pathlib import Path from torch import cuda diff --git a/src/img2img/models/pix2pix/predictor.py b/src/img2img/models/pix2pix/predictor.py new file mode 100644 index 0000000..8aaa310 --- /dev/null +++ b/src/img2img/models/pix2pix/predictor.py @@ -0,0 +1,51 @@ +import torch +from img2img import cfg +from img2img.models.pix2pix.generator import Generator +from img2img.models.pix2pix.utils import remove_normalization +from pathlib import Path +import numpy as np +from PIL import Image + + +class Pix2PixPredictor: + def __init__(self, model_path: str | Path): + self.device = cfg.DEVICE + self.model = Generator(in_channels=3).to(self.device) + self.model.load_state_dict( + torch.load(model_path, map_location=self.device)["state_dict"] + ) + self.model.eval() + + def __call__(self, x: np.ndarray) -> np.ndarray: + augmentations = cfg.both_transform(image=x) + input_image = augmentations["image"] + out_input_image: torch.Tensor = cfg.transform_only_input(image=input_image)[ + "image" + ] + out_input_image = out_input_image.to(self.device) + with torch.inference_mode(): + y = self.model(out_input_image.unsqueeze(0)) # must have a batch dimension + y = remove_normalization(y) + y = y.cpu().detach().numpy() + y = y.squeeze(0) * 255 + y = y.astype(np.uint8) + assert y.shape == (3, 256, 256) + y = np.moveaxis(y, 0, -1) + return y + + +def test(): + model_path = "./out/saved_models/anime_training/gen.pth.tar" + predictor = Pix2PixPredictor(model_path) + image_path = "out/evaluation/pix2pix_predictor_test_image.png" + # take x as an input image in numpy array format where x.shape = (anything, anything, 3) + x = np.array(Image.open(image_path)) # returns (429, 488, 4) + x = x[:, :, :3] # remove alpha channel + print(x.shape) + y = predictor(x) + image_y = Image.fromarray(y) + return image_y + + +if __name__ == "__main__": + image = test() diff --git a/src/img2img/models/pix2pix/train.py b/src/img2img/models/pix2pix/train.py index a3c6efe..00406ae 100644 --- a/src/img2img/models/pix2pix/train.py +++ b/src/img2img/models/pix2pix/train.py @@ -1,4 +1,5 @@ """Main script for training the model. Can train from scratch or resume from a checkpoint.""" + from img2img import cfg from img2img.models.pix2pix.trainer import Pix2PixTrainer diff --git a/src/img2img/models/pix2pix/trainer.py b/src/img2img/models/pix2pix/trainer.py index 95d410b..f15c5a1 100644 --- a/src/img2img/models/pix2pix/trainer.py +++ b/src/img2img/models/pix2pix/trainer.py @@ -1,4 +1,5 @@ -""" Trainer class""" +"""Trainer class""" + from pathlib import Path import torch diff --git a/src/img2img/utils/__init__.py b/src/img2img/utils/__init__.py index 6612015..f6f25fd 100644 --- a/src/img2img/utils/__init__.py +++ b/src/img2img/utils/__init__.py @@ -31,7 +31,6 @@ def prepare_sub_directories(path: str | Path) -> tuple[Path, Path]: path = Path(path) eval_path = path / "evaluation" weights_path = path / "last_trained_weights" - os.makedirs(path, exist_ok=True) os.makedirs(eval_path, exist_ok=True) os.makedirs(weights_path, exist_ok=True) return weights_path, eval_path