-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathgreen_to_transparency.py
76 lines (61 loc) · 2.79 KB
/
green_to_transparency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
class GreenScreenToTransparency:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE", {}),
"threshold": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "remove_green_screen"
RETURN_TYPES = ("IMAGE",)
OUTPUT_NODE = True
CATEGORY = "Bjornulf"
def remove_green_screen(self, image, threshold=0.1, prompt=None, extra_pnginfo=None):
# Ensure the input image is on CPU and convert to numpy array
image_np = image.cpu().numpy()
# Check if the image is in the format [batch, height, width, channel]
if image_np.ndim == 4:
# If so, we'll process each image in the batch
processed_images = []
for img in image_np:
processed_img = self._process_single_image(img, threshold)
processed_images.append(processed_img)
# Stack the processed images back into a batch
processed_batch = np.stack(processed_images)
# Convert to torch tensor
processed_tensor = torch.from_numpy(processed_batch)
else:
# If it's a single image, process it directly
processed_np = self._process_single_image(image_np, threshold)
# Add batch dimension if it was originally present
if image.dim() == 4:
processed_np = np.expand_dims(processed_np, axis=0)
# Convert to torch tensor
processed_tensor = torch.from_numpy(processed_np)
# Update metadata if needed
if extra_pnginfo is not None:
extra_pnginfo["green_screen_removed"] = True
return (processed_tensor, prompt, extra_pnginfo)
def _process_single_image(self, img, threshold):
# Convert to PIL Image
pil_img = Image.fromarray((img * 255).astype(np.uint8))
# Convert the image to RGBA mode
pil_img = pil_img.convert("RGBA")
# Get image data as numpy array
data = np.array(pil_img)
# Create a mask for green pixels
r, g, b, a = data[:,:,0], data[:,:,1], data[:,:,2], data[:,:,3]
mask = (g > r + threshold * 255) & (g > b + threshold * 255)
# Set alpha channel to 0 for green pixels
data[:,:,3] = np.where(mask, 0, a)
# Create a new image with the updated data
result = Image.fromarray(data)
# Convert back to numpy and normalize
processed_np = np.array(result).astype(np.float32) / 255.0
return processed_np