-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathimage_mask_cutter.py
46 lines (35 loc) · 1.44 KB
/
image_mask_cutter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
class ImageMaskCutter:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "cut_image"
CATEGORY = "Bjornulf"
def cut_image(self, image, mask):
print(f"Image shape: {image.shape}")
print(f"Mask shape: {mask.shape}")
# Check if image channels are in the last dimension
if image.shape[-1] == 3 or image.shape[-1] == 4:
# Move channels to second dimension
image = image.permute(0, 3, 1, 2)
# Ensure image and mask have compatible dimensions
if image.shape[2:] != mask.shape[1:]:
raise ValueError(f"Image and mask must have compatible dimensions. Got image shape {image.shape} and mask shape {mask.shape}")
# Convert mask to float and ensure it's in the range [0, 1]
mask = mask.float()
mask = torch.clamp(mask, 0, 1)
# If image is RGB, convert to RGBA
if image.shape[1] == 3:
alpha = torch.ones((image.shape[0], 1, image.shape[2], image.shape[3]), device=image.device)
image = torch.cat([image, alpha], dim=1)
# Use the mask as the alpha channel
image[:, 3:4, :, :] = mask.unsqueeze(1)
# Move channels back to the last dimension
cut_image = image.permute(0, 2, 3, 1)
return (cut_image,)