Skip to content

Commit

Permalink
It works now on all 4 tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
omar-abdelgawad committed Jun 5, 2024
1 parent 603b2ed commit 228c741
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
6 changes: 3 additions & 3 deletions api/app/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import base64
import io

import base64
from flask import Blueprint, jsonify, request

from .image_processing import process_image

api_blueprint = Blueprint("api", __name__)


@api_blueprint.route("/<style>/coloring", methods=["POST"])
@api_blueprint.route("/<style>", methods=["POST"])
def process_image_route(style):
# if "image" not in request.files:
# return jsonify({"error": "No image provided"}), 400
Expand All @@ -21,6 +20,7 @@ def process_image_route(style):
processed_image = process_image(io.BytesIO(image_data), style)

# Convert the processed image to bytes

processed_image.save(img_byte_array, format="JPEG")
img_byte_array.seek(0)

Expand Down
28 changes: 18 additions & 10 deletions api/app/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,35 @@

from img2img.models.pix2pix.predictor import Pix2PixPredictor

# from img2img.models.cyclegan.predictor import CycleGANPredictor
from img2img.models.cyclegan.predictor import CycleGanPredictor

# Initialize predictors
anime_predictor = Pix2PixPredictor(model_path="./out/saved_models/anime_training/gen.pth.tar")
# monet_predictor = CycleGANPredictor(model_path="./out/saved_models/monet_training/gen.pth.tar")
# yukiyoe_predictor = CycleGANPredictor(model_path="./out/saved_models/yukiyoe_training/gen.pth.tar")
# vangogh_predictor = CycleGANPredictor(model_path="./out/saved_models/vangogh_training/gen.pth.tar")
print("Initializing predictors...")
anime_predictor = Pix2PixPredictor(
model_path="./out/saved_models/anime_training/gen.pth.tar"
)
monet_predictor = CycleGanPredictor(
model_path="./out/saved_models/monet_training/genh.pth.tar"
)
yukiyoe_predictor = CycleGanPredictor(
model_path="./out/saved_models/yukiyoe_training/genh.pth.tar"
)
vangogh_predictor = CycleGanPredictor(
model_path="./out/saved_models/vangogh_training/genh.pth.tar"
)


predictors = {
"anime": anime_predictor,

# "monet": monet_predictor,
# "yukiyoe": yukiyoe_predictor,
# "vangogh": vangogh_predictor,
"monet": monet_predictor,
"yukiyoe": yukiyoe_predictor,
"vangogh": vangogh_predictor,
}


def process_image(image_file, style):
# Open the image
image = Image.open(image_file)

# Example processing: convert to RGB array
processed_image = np.array(image.convert("RGB"))

Expand All @@ -37,5 +44,6 @@ def process_image(image_file, style):
# Convert the processed image array back to PIL Image
processed_image = Image.fromarray(processed_image)

processed_image.show()
# Return the processed image as a PIL Image object
return processed_image
1 change: 1 addition & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def main() -> int:
print("Starting server...")
app.run(debug=True)
return 0

Expand Down

0 comments on commit 228c741

Please sign in to comment.