Skip to content

Commit 00486ee

Browse files
authored
'Refactoring' (#35)
* 'Refactoring' * Refactoring * One more * Two more
1 parent c99bf30 commit 00486ee

35 files changed

+947
-343
lines changed

Diff for: README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ Seminar project Unsupervised Image-to-Image translation using GANs
1717
```
1818
## Implementations
1919

20+
### pix2pix
21+
### CycleGAN
2022
### UNIT
21-
### pix2pix
23+
### TUNIT

Diff for: api/app/api.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
api_blueprint = Blueprint("api", __name__)
99

1010

11-
@api_blueprint.route("/coloring", methods=["POST"])
12-
def process_image_route():
11+
@api_blueprint.route("/<style>/coloring", methods=["POST"])
12+
def process_image_route(style):
1313
# if "image" not in request.files:
1414
# return jsonify({"error": "No image provided"}), 400
1515

@@ -18,10 +18,9 @@ def process_image_route():
1818
base64_image_data = request.json["image"]
1919
image_data = base64.b64decode(base64_image_data)
2020

21-
processed_image = process_image(io.BytesIO(image_data))
21+
processed_image = process_image(io.BytesIO(image_data), style)
2222

2323
# Convert the processed image to bytes
24-
2524
processed_image.save(img_byte_array, format="JPEG")
2625
img_byte_array.seek(0)
2726

Diff for: api/app/image_processing.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33

44
from img2img.models.pix2pix.predictor import Pix2PixPredictor
55

6-
predictor = Pix2PixPredictor(
7-
model_path="/home/omarabdelgawad/my_workspace/projects/github_repos/image2image/out/saved_models/anime_training/gen.pth.tar"
8-
)
6+
# from img2img.models.cyclegan.predictor import CycleGANPredictor
97

8+
# Initialize predictors
9+
anime_predictor = Pix2PixPredictor(model_path="./out/saved_models/anime_training/gen.pth.tar")
10+
# monet_predictor = CycleGANPredictor(model_path="./out/saved_models/monet_training/gen.pth.tar")
11+
# yukiyoe_predictor = CycleGANPredictor(model_path="./out/saved_models/yukiyoe_training/gen.pth.tar")
12+
# vangogh_predictor = CycleGANPredictor(model_path="./out/saved_models/vangogh_training/gen.pth.tar")
1013

11-
def process_image(image_file):
14+
15+
predictors = {
16+
"anime": anime_predictor,
17+
18+
# "monet": monet_predictor,
19+
# "yukiyoe": yukiyoe_predictor,
20+
# "vangogh": vangogh_predictor,
21+
}
22+
23+
24+
def process_image(image_file, style):
1225
# Open the image
1326
image = Image.open(image_file)
1427

@@ -18,8 +31,8 @@ def process_image(image_file):
1831
# Ensure the array shape is correct
1932
assert processed_image.shape[2] == 3
2033

21-
# Process the image using the Pix2Pix model
22-
processed_image = predictor(processed_image)
34+
# Process the image using the appropriate model
35+
processed_image = predictors[style](processed_image)
2336

2437
# Convert the processed image array back to PIL Image
2538
processed_image = Image.fromarray(processed_image)

Diff for: pyproject.toml

+10-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "img2img"
77
dynamic = ["version"]
88
dependencies = [
9+
"enums",
910
"requests",
1011
"torch >= 2.3.0 , < 3",
1112
"torchaudio >= 2.3.0, < 3",
@@ -18,14 +19,18 @@ requires-python = ">=3.10"
1819
authors = [
1920
{name = "Omar Abdelgawad", email = "[email protected]"},
2021
{name = "Eyad Hussein", email = "[email protected]"},
22+
{name = "Ali Elsawy", email = "[email protected]"}
2123
]
2224
description = "image to image translation using GANs"
2325
readme = "README.md"
2426
license = {file = "LICENSE"}
2527
keywords = ["image-to-image-translation", "GAN", "vision", "deep-learning"]
2628
classifiers = [
2729
"Programming Language :: Python :: 3",
28-
"Programming Language :: Python :: 3.10"
30+
"Programming Language :: Python :: 3.10",
31+
"Operating System :: OS Independent",
32+
"License :: OSI Approved :: MIT License",
33+
"Topic :: Scientific/Research :: Artificial Intelligence"
2934
]
3035

3136
[project.optional-dependencies]
@@ -35,8 +40,10 @@ dev = [
3540
"pytest",
3641
"pytest-cov",
3742
"tox",
38-
"flask", # api dependency
39-
"requests", # api dependency
43+
]
44+
api = [
45+
"flask",
46+
"requests",
4047
]
4148

4249
[tool.setuptools.package-data]

Diff for: scripts/evaluate.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
22
from torch.utils.data import DataLoader
33

4-
from pix2pix import cfg
5-
from pix2pix.dataset import create_dataset
6-
from pix2pix.generator import Generator
7-
from pix2pix.utils import evaluate_val_set
4+
from img2img.cfg import pix2pix as cfg
5+
from img2img.data.pix2pix import create_dataset
6+
from img2img.models.pix2pix.generator import Generator
7+
from img2img.utils.pix2pix import evaluate_val_set
88

99

1010
def main() -> int:

Diff for: src/img2img/cfg/__init__.py

+6-103
Original file line numberDiff line numberDiff line change
@@ -6,48 +6,28 @@
66

77
from img2img.cli import get_main_parser
88

9+
from .transform import get_transforms
10+
911
from .enums import ActivationType, DatasetType, NormalizationType, PaddingType
1012

1113
# TODO: Add logger instead of all the print statements.
1214

15+
1316
args = get_main_parser()
1417
DEVICE = "cuda" if cuda.is_available() else "cpu"
1518
LEARNING_RATE = args.rate
16-
BETA_OPTIM = (0.5, 0.999)
1719
BATCH_SIZE = args.batch_size
1820
NUM_WORKERS = args.num_workers
1921
IMAGE_SIZE = args.image_size
20-
CHANNELS_IMG = 3
21-
L_1_LAMBDA = 100
2222
NORM_MEAN = 0.5
2323
NORM_STD = 0.5
2424
CHECKPOINT_PERIOD = 5
2525
NUM_EPOCHS = args.num_epochs
2626
LOAD_MODEL = args.load_model
2727
SAVE_MODEL = args.save_model
28-
CHOSEN_DATASET = DatasetType.ANIME_DATASET
29-
TRAIN_DATASET_PATH = CHOSEN_DATASET.value / "train"
30-
VAL_DATASET_PATH = CHOSEN_DATASET.value / "val"
31-
OUT_PATH = Path("./out")
32-
NUM_IMAGES_DATASET = args.num_images_dataset
33-
VAL_BATCH_SIZE = args.val_batch_size
34-
35-
# tunit config
36-
CHANNELS_MULTIPLIER = 64
37-
K = args.cluster_number
38-
39-
# unit config
40-
WEIGHT_DECAY = 0.0001
41-
LR_POLICY = "step"
42-
STEP_SIZE = 100000
43-
GAMMA = 0.5
44-
INIT = "kaiming"
45-
GAN_WEIGHT = 1
46-
RECONSTRUCTION_X_WEIGHT = 10
47-
RECONSTRUCTION_H_WEIGHT = 0
48-
RECONSTRUCTION_KL_WEIGHT = 0.01
49-
RECONSTRUCTION_X_CYC_WEIGHT = 10
50-
RECONSTRUCTION_KL_CYC_WEIGHT = 0.01
28+
both_transform, transform_only_input, transform_only_mask, transforms, prediction_transform = get_transforms(IMAGE_SIZE,
29+
NORM_MEAN,
30+
NORM_STD)
5131

5232

5333
class GEN_HYPERPARAMS:
@@ -71,80 +51,3 @@ class DIS_HYPERPARAMS:
7151
GAN_TYPE = "lsgan"
7252
NUM_SCALES = 3
7353
PAD_TYPE = PaddingType.REFLECT
74-
75-
76-
import albumentations as A
77-
from albumentations.pytorch import ToTensorV2
78-
79-
both_transform = A.Compose(
80-
[
81-
A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
82-
A.HorizontalFlip(p=0.5),
83-
],
84-
additional_targets={"image0": "image"},
85-
)
86-
87-
transform_only_input = A.Compose(
88-
[
89-
A.ColorJitter(p=0.1),
90-
# TODO: calculate mean and std for the dataset instead of using these values.
91-
A.Normalize(
92-
mean=[NORM_MEAN, NORM_MEAN, NORM_MEAN],
93-
std=[NORM_STD, NORM_STD, NORM_STD],
94-
max_pixel_value=255.0,
95-
),
96-
ToTensorV2(),
97-
]
98-
)
99-
100-
transform_only_mask = A.Compose(
101-
[
102-
A.Normalize(
103-
mean=[NORM_MEAN, NORM_MEAN, NORM_MEAN],
104-
std=[NORM_STD, NORM_STD, NORM_STD],
105-
max_pixel_value=255.0,
106-
),
107-
ToTensorV2(),
108-
]
109-
)
110-
111-
112-
"""unit transforms
113-
114-
# TODO: understand the augmentations below and improve them (maybe add more augmentations).
115-
both_transform = A.Compose(
116-
[
117-
A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
118-
A.HorizontalFlip(p=0.5),
119-
],
120-
additional_targets={"image0": "image"},
121-
)
122-
123-
# this is equivalent to first domain transform
124-
transform_only_input = A.Compose(
125-
[
126-
A.ColorJitter(p=0.1),
127-
A.Normalize(
128-
mean=[0.5, 0.5, 0.5],
129-
std=[0.5, 0.5, 0.5],
130-
max_pixel_value=255.0,
131-
),
132-
A.HorizontalFlip(p=0.5),
133-
ToTensorV2(),
134-
]
135-
)
136-
137-
# this is equivalent to second domain transform
138-
transform_only_mask = A.Compose(
139-
[
140-
A.Normalize(
141-
mean=[0.5, 0.5, 0.5],
142-
std=[0.5, 0.5, 0.5],
143-
max_pixel_value=255.0,
144-
),
145-
ToTensorV2(),
146-
]
147-
)
148-
"""
149-
150-
# TODO: make transforms differ from model to another

Diff for: src/img2img/cfg/cyclegan.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from . import *
2+
3+
LAMBDA_IDENTITY = 0.0
4+
LAMBDA_CYCLE = 10
5+
BATCH_SIZE = 2
6+
CHECKPOINT_GEN_H = "genh.pth.tar"
7+
CHECKPOINT_GEN_Z = "genz.pth.tar"
8+
CHECKPOINT_CRITIC_H = "critich.pth.tar"
9+
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"
10+
CHOSEN_DATASET = DatasetType.VANGOGH2PHOTO
11+
TRAIN_DATASET_PATH = CHOSEN_DATASET.value / "train"
12+
VAL_DATASET_PATH = CHOSEN_DATASET.value / "val"

Diff for: src/img2img/cfg/enums.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77
class DatasetType(Enum):
88
"""Enum for the dataset type."""
99

10-
ANIME_DATASET = Path(
11-
"/media/omarabdelgawad/New Volume/Datasets/image_coloring/anime_dataset/"
12-
)
13-
NATURAL_VIEW_DATASET = Path(
14-
"/media/omarabdelgawad/New Volume/Datasets/image_coloring/natural_view/"
15-
)
16-
EDGES2SHOES = Path(
17-
"/media/omarabdelgawad/New Volume/Datasets/image_coloring/edges2shoes/"
18-
)
19-
AFHQ_CATS_DATASET = Path("/home/eyad/Downloads/afhq/")
10+
ANIME_DATASET = Path("/media/omarabdelgawad/New Volume/Datasets/Anime_Dataset")
11+
NATURAL_VIEW_DATASET = Path("/media/omarabdelgawad/New Volume/Datasets/Natural_View")
12+
EDGES2SHOES = Path("/media/omarabdelgawad/New Volume/Datasets/Edges2Shoes")
13+
AFHQ_CATS_DATASET = Path("/media/omarabdelgawad/New Volume/Datasets/AFHQ_Cats")
14+
VANGOGH2PHOTO = Path("/media/omarabdelgawad/New Volume/Datasets/vangogh2photo")
15+
yukiyoe = Path("/media/omarabdelgawad/New Volume/Datasets/yukiyoe")
16+
monet = Path("/media/omarabdelgawad/New Volume/Datasets/monet")
2017

2118

2219
class PaddingType(Enum):

Diff for: src/img2img/cfg/pix2pix.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from . import *
2+
3+
BETA_OPTIM = (0.5, 0.999)
4+
NUM_WORKERS = args.num_workers
5+
CHANNELS_IMG = 3
6+
L_1_LAMBDA = 100
7+
NORM_MEAN = 0.5
8+
NORM_STD = 0.5
9+
CHOSEN_DATASET = DatasetType.ANIME_DATASET
10+
TRAIN_DATASET_PATH = CHOSEN_DATASET.value / "train"
11+
VAL_DATASET_PATH = CHOSEN_DATASET.value / "val"
12+
OUT_PATH = Path("./out")
13+
NUM_IMAGES_DATASET = args.num_images_dataset
14+
VAL_BATCH_SIZE = args.val_batch_size
15+
16+
17+
# Hyperparameters for the generator and discriminator
18+
class GEN_HYPERPARAMS:
19+
"""Hyperparameters for the generator."""
20+
21+
DIM = 64
22+
NORM = NormalizationType.INSTANCE
23+
ACTIV = ActivationType.RELU
24+
N_DOWNSAMPLE = 2
25+
N_RES = 4
26+
PAD_TYPE = PaddingType.REFLECT
27+
28+
29+
class DIS_HYPERPARAMS:
30+
"""Hyperparameters for the discriminator."""
31+
32+
DIM = 64
33+
NORM = NormalizationType.NONE
34+
ACTIV = ActivationType.LEAKY_RELU
35+
N_LAYER = 4
36+
GAN_TYPE = "lsgan"
37+
NUM_SCALES = 3
38+
PAD_TYPE = PaddingType.REFLECT

Diff for: src/img2img/cfg/transform.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import albumentations as A
2+
from albumentations.pytorch import ToTensorV2
3+
4+
5+
def get_transforms(IMAGE_SIZE, NORM_MEAN, NORM_STD):
6+
both_transform = A.Compose(
7+
[
8+
A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
9+
A.HorizontalFlip(p=0.5),
10+
],
11+
additional_targets={"image0": "image"},
12+
)
13+
14+
transform_only_input = A.Compose(
15+
[
16+
A.ColorJitter(p=0.1),
17+
A.Normalize(
18+
mean=[NORM_MEAN, NORM_MEAN, NORM_MEAN],
19+
std=[NORM_STD, NORM_STD, NORM_STD],
20+
max_pixel_value=255.0,
21+
),
22+
ToTensorV2(),
23+
]
24+
)
25+
26+
transform_only_mask = A.Compose(
27+
[
28+
A.Normalize(
29+
mean=[NORM_MEAN, NORM_MEAN, NORM_MEAN],
30+
std=[NORM_STD, NORM_STD, NORM_STD],
31+
max_pixel_value=255.0,
32+
),
33+
ToTensorV2(),
34+
]
35+
)
36+
37+
transforms = A.Compose(
38+
[
39+
A.Resize(width=256, height=256),
40+
A.HorizontalFlip(p=0.5),
41+
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
42+
ToTensorV2(),
43+
],
44+
additional_targets={"image0": "image"},
45+
)
46+
47+
prediction_transform = A.Compose(
48+
[
49+
A.Resize(width=256, height=256),
50+
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
51+
ToTensorV2(),
52+
],
53+
)
54+
55+
return both_transform, transform_only_input, transform_only_mask, transforms, prediction_transform

0 commit comments

Comments
 (0)