Skip to content

Commit

Permalink
Created pix2pixpredictor class
Browse files Browse the repository at this point in the history
  • Loading branch information
omar-abdelgawad committed May 8, 2024
1 parent bb1b20c commit e0102d7
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/img2img/cfg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Configuration module Interface. """
"""Configuration module Interface."""

from pathlib import Path

from torch import cuda
Expand Down
51 changes: 51 additions & 0 deletions src/img2img/models/pix2pix/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from img2img import cfg
from img2img.models.pix2pix.generator import Generator
from img2img.models.pix2pix.utils import remove_normalization
from pathlib import Path
import numpy as np
from PIL import Image


class Pix2PixPredictor:
def __init__(self, model_path: str | Path):
self.device = cfg.DEVICE
self.model = Generator(in_channels=3).to(self.device)
self.model.load_state_dict(
torch.load(model_path, map_location=self.device)["state_dict"]
)
self.model.eval()

def __call__(self, x: np.ndarray) -> np.ndarray:
augmentations = cfg.both_transform(image=x)
input_image = augmentations["image"]
out_input_image: torch.Tensor = cfg.transform_only_input(image=input_image)[
"image"
]
out_input_image = out_input_image.to(self.device)
with torch.inference_mode():
y = self.model(out_input_image.unsqueeze(0)) # must have a batch dimension
y = remove_normalization(y)
y = y.cpu().detach().numpy()
y = y.squeeze(0) * 255
y = y.astype(np.uint8)
assert y.shape == (3, 256, 256)
y = np.moveaxis(y, 0, -1)
return y


def test():
model_path = "./out/saved_models/anime_training/gen.pth.tar"
predictor = Pix2PixPredictor(model_path)
image_path = "out/evaluation/pix2pix_predictor_test_image.png"
# take x as an input image in numpy array format where x.shape = (anything, anything, 3)
x = np.array(Image.open(image_path)) # returns (429, 488, 4)
x = x[:, :, :3] # remove alpha channel
print(x.shape)
y = predictor(x)
image_y = Image.fromarray(y)
return image_y


if __name__ == "__main__":
image = test()
1 change: 1 addition & 0 deletions src/img2img/models/pix2pix/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main script for training the model. Can train from scratch or resume from a checkpoint."""

from img2img import cfg
from img2img.models.pix2pix.trainer import Pix2PixTrainer

Expand Down
3 changes: 2 additions & 1 deletion src/img2img/models/pix2pix/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Trainer class"""
"""Trainer class"""

from pathlib import Path

import torch
Expand Down
1 change: 0 additions & 1 deletion src/img2img/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def prepare_sub_directories(path: str | Path) -> tuple[Path, Path]:
path = Path(path)
eval_path = path / "evaluation"
weights_path = path / "last_trained_weights"
os.makedirs(path, exist_ok=True)
os.makedirs(eval_path, exist_ok=True)
os.makedirs(weights_path, exist_ok=True)
return weights_path, eval_path
Expand Down

0 comments on commit e0102d7

Please sign in to comment.