Skip to content

Commit cc3dbf6

Browse files
feat: Update Tile Pre-Processor to support more modes
1 parent 10076fb commit cc3dbf6

File tree

3 files changed

+470
-141
lines changed

3 files changed

+470
-141
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Invocations for ControlNet image preprocessors
22
# initial implementation by Gregg Helt, 2023
33
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
4+
import random
45
from builtins import bool, float
56
from pathlib import Path
6-
from typing import Dict, List, Literal, Union
7+
from typing import Any, Dict, List, Literal, Union
78

89
import cv2
910
import numpy as np
@@ -39,6 +40,7 @@
3940
from invokeai.backend.image_util.canny import get_canny_edges
4041
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
4142
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
43+
from invokeai.backend.image_util.fast_guided_filter.fast_guided_filter import FastGuidedFilter
4244
from invokeai.backend.image_util.hed import HEDProcessor
4345
from invokeai.backend.image_util.lineart import LineartProcessor
4446
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
@@ -483,30 +485,67 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
483485

484486
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
485487
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
488+
mode: Literal["regular", "blur", "var", "super"] = InputField(
489+
default="regular", description="The controlnet tile model being used"
490+
)
491+
492+
def apply_gaussian_blur(self, image_np: np.ndarray[Any, Any], ksize: int = 5, sigmaX: float = 1.0):
493+
if ksize % 2 == 0:
494+
ksize += 1 # ksize must be odd
495+
blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX)
496+
return blurred_image
497+
498+
def apply_guided_filter(self, image_np: np.ndarray[Any, Any], radius: int, eps: float, scale: int):
499+
filter = FastGuidedFilter(image_np, radius, eps, scale)
500+
return filter.filter(image_np)
501+
502+
# based off https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
503+
def tile_resample(self, np_img: np.ndarray[Any, Any]):
504+
height, width, _ = np_img.shape
505+
506+
if self.mode == "regular":
507+
np_img = HWC3(np_img)
508+
if self.down_sampling_rate < 1.1:
509+
return np_img
486510

487-
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
488-
def tile_resample(
489-
self,
490-
np_img: np.ndarray,
491-
res=512, # never used?
492-
down_sampling_rate=1.0,
493-
):
494-
np_img = HWC3(np_img)
495-
if down_sampling_rate < 1.1:
511+
new_height = int(float(height) / float(self.down_sampling_rate))
512+
new_width = int(float(width) / float(self.down_sampling_rate))
513+
np_img = cv2.resize(np_img, (new_width, new_height), interpolation=cv2.INTER_AREA)
496514
return np_img
497-
H, W, C = np_img.shape
498-
H = int(float(H) / float(down_sampling_rate))
499-
W = int(float(W) / float(down_sampling_rate))
500-
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
515+
516+
ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
517+
518+
resize_w, resize_h = int(width * ratio), int(height * ratio)
519+
520+
if self.mode == "super":
521+
resize_w, resize_h = int(width * ratio) // 48 * 48, int(height * ratio) // 48 * 48
522+
523+
np_img = cv2.resize(np_img, (resize_w, resize_h))
524+
525+
if self.mode == "blur":
526+
blur_strength = random.sample([i / 10.0 for i in range(10, 201, 2)], k=1)[0]
527+
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] # noqa: C416
528+
eps = random.sample([i / 1000.0 for i in range(1, 101, 2)], k=1)[0]
529+
scale_factor = random.sample([i / 10.0 for i in range(10, 181, 5)], k=1)[0]
530+
531+
if random.random() > 0.5:
532+
np_img = self.apply_gaussian_blur(np_img, ksize=int(blur_strength), sigmaX=blur_strength / 2)
533+
534+
if random.random() > 0.5:
535+
np_img = self.apply_guided_filter(np_img, radius, eps, int(scale_factor))
536+
537+
np_img = cv2.resize(
538+
np_img, (int(resize_w / scale_factor), int(resize_h / scale_factor)), interpolation=cv2.INTER_AREA
539+
)
540+
np_img = cv2.resize(np_img, (resize_w, resize_h), interpolation=cv2.INTER_CUBIC)
541+
542+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
543+
501544
return np_img
502545

503546
def run_processor(self, image: Image.Image) -> Image.Image:
504547
np_img = np.array(image, dtype=np.uint8)
505-
processed_np_image = self.tile_resample(
506-
np_img,
507-
# res=self.tile_size,
508-
down_sampling_rate=self.down_sampling_rate,
509-
)
548+
processed_np_image = self.tile_resample(np_img)
510549
processed_image = Image.fromarray(processed_np_image)
511550
return processed_image
512551

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# ruff: noqa: E741
2+
# -*- coding: utf-8 -*-
3+
## @package guided_filter.core.filters
4+
#
5+
# Implementation of guided filter.
6+
# * GuidedFilter: Original guided filter.
7+
# * FastGuidedFilter: Fast version of the guided filter.
8+
# @author tody
9+
# @date 2015/08/26
10+
11+
12+
import cv2
13+
import numpy as np
14+
15+
16+
## Convert image into float32 type.
17+
def to32F(img):
18+
if img.dtype == np.float32:
19+
return img
20+
return (1.0 / 255.0) * np.float32(img)
21+
22+
23+
## Convert image into uint8 type.
24+
def to8U(img):
25+
if img.dtype == np.uint8:
26+
return img
27+
return np.clip(np.uint8(255.0 * img), 0, 255)
28+
29+
30+
## Return if the input image is gray or not.
31+
def _isGray(I):
32+
return len(I.shape) == 2
33+
34+
35+
## Return down sampled image.
36+
# @param scale (w/s, h/s) image will be created.
37+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
38+
def _downSample(I, scale=4, shape=None):
39+
if shape is not None:
40+
h, w = shape
41+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
42+
43+
h, w = I.shape[:2]
44+
return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
45+
46+
47+
## Return up sampled image.
48+
# @param scale (w*s, h*s) image will be created.
49+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
50+
def _upSample(I, scale=2, shape=None):
51+
if shape is not None:
52+
h, w = shape
53+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
54+
55+
h, w = I.shape[:2]
56+
return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
57+
58+
59+
## Fast guide filter.
60+
class FastGuidedFilter:
61+
## Constructor.
62+
# @param I Input guidance image. Color or gray.
63+
# @param radius Radius of Guided Filter.
64+
# @param epsilon Regularization term of Guided Filter.
65+
# @param scale Down sampled scale.
66+
def __init__(self, I, radius=5, epsilon=0.4, scale=4):
67+
I_32F = to32F(I)
68+
self._I = I_32F
69+
h, w = I.shape[:2]
70+
71+
I_sub = _downSample(I_32F, scale)
72+
73+
self._I_sub = I_sub
74+
radius = int(radius / scale)
75+
76+
if _isGray(I):
77+
self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
78+
else:
79+
self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
80+
81+
## Apply filter for the input image.
82+
# @param p Input image for the filtering.
83+
def filter(self, p):
84+
p_32F = to32F(p)
85+
shape_original = p.shape[:2]
86+
87+
p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
88+
89+
if _isGray(p_sub):
90+
return self._filterGray(p_sub, shape_original)
91+
92+
cs = p.shape[2]
93+
q = np.array(p_32F)
94+
95+
for ci in range(cs):
96+
q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
97+
return to8U(q)
98+
99+
def _filterGray(self, p_sub, shape_original):
100+
ab_sub = self._guided_filter._computeCoefficients(p_sub)
101+
ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
102+
return self._guided_filter._computeOutput(ab, self._I)
103+
104+
105+
## Guide filter.
106+
class GuidedFilter:
107+
## Constructor.
108+
# @param I Input guidance image. Color or gray.
109+
# @param radius Radius of Guided Filter.
110+
# @param epsilon Regularization term of Guided Filter.
111+
def __init__(self, I, radius=5, epsilon=0.4):
112+
I_32F = to32F(I)
113+
114+
if _isGray(I):
115+
self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
116+
else:
117+
self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
118+
119+
## Apply filter for the input image.
120+
# @param p Input image for the filtering.
121+
def filter(self, p):
122+
return to8U(self._guided_filter.filter(p))
123+
124+
125+
## Common parts of guided filter.
126+
#
127+
# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
128+
# Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
129+
# GuidedFilterCommon.filter computes filtered image for color and gray.
130+
class GuidedFilterCommon:
131+
def __init__(self, guided_filter):
132+
self._guided_filter = guided_filter
133+
134+
## Apply filter for the input image.
135+
# @param p Input image for the filtering.
136+
def filter(self, p):
137+
p_32F = to32F(p)
138+
if _isGray(p_32F):
139+
return self._filterGray(p_32F)
140+
141+
cs = p.shape[2]
142+
q = np.array(p_32F)
143+
144+
for ci in range(cs):
145+
q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
146+
return q
147+
148+
def _filterGray(self, p):
149+
ab = self._guided_filter._computeCoefficients(p)
150+
return self._guided_filter._computeOutput(ab, self._guided_filter._I)
151+
152+
153+
## Guided filter for gray guidance image.
154+
class GuidedFilterGray:
155+
# @param I Input gray guidance image.
156+
# @param radius Radius of Guided Filter.
157+
# @param epsilon Regularization term of Guided Filter.
158+
def __init__(self, I, radius=5, epsilon=0.4):
159+
self._radius = 2 * radius + 1
160+
self._epsilon = epsilon
161+
self._I = to32F(I)
162+
self._initFilter()
163+
self._filter_common = GuidedFilterCommon(self)
164+
165+
## Apply filter for the input image.
166+
# @param p Input image for the filtering.
167+
def filter(self, p):
168+
return self._filter_common.filter(p)
169+
170+
def _initFilter(self):
171+
I = self._I
172+
r = self._radius
173+
self._I_mean = cv2.blur(I, (r, r))
174+
I_mean_sq = cv2.blur(I**2, (r, r))
175+
self._I_var = I_mean_sq - self._I_mean**2
176+
177+
def _computeCoefficients(self, p):
178+
r = self._radius
179+
p_mean = cv2.blur(p, (r, r))
180+
p_cov = p_mean - self._I_mean * p_mean
181+
a = p_cov / (self._I_var + self._epsilon)
182+
b = p_mean - a * self._I_mean
183+
a_mean = cv2.blur(a, (r, r))
184+
b_mean = cv2.blur(b, (r, r))
185+
return a_mean, b_mean
186+
187+
def _computeOutput(self, ab, I):
188+
a_mean, b_mean = ab
189+
return a_mean * I + b_mean
190+
191+
192+
## Guided filter for color guidance image.
193+
class GuidedFilterColor:
194+
# @param I Input color guidance image.
195+
# @param radius Radius of Guided Filter.
196+
# @param epsilon Regularization term of Guided Filter.
197+
def __init__(self, I, radius=5, epsilon=0.2):
198+
self._radius = 2 * radius + 1
199+
self._epsilon = epsilon
200+
self._I = to32F(I)
201+
self._initFilter()
202+
self._filter_common = GuidedFilterCommon(self)
203+
204+
## Apply filter for the input image.
205+
# @param p Input image for the filtering.
206+
def filter(self, p):
207+
return self._filter_common.filter(p)
208+
209+
def _initFilter(self):
210+
I = self._I
211+
r = self._radius
212+
eps = self._epsilon
213+
214+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
215+
216+
self._Ir_mean = cv2.blur(Ir, (r, r))
217+
self._Ig_mean = cv2.blur(Ig, (r, r))
218+
self._Ib_mean = cv2.blur(Ib, (r, r))
219+
220+
Irr_var = cv2.blur(Ir**2, (r, r)) - self._Ir_mean**2 + eps
221+
Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
222+
Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
223+
Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
224+
Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
225+
Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
226+
227+
Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
228+
Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
229+
Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
230+
Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
231+
Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
232+
Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
233+
234+
I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
235+
Irr_inv /= I_cov
236+
Irg_inv /= I_cov
237+
Irb_inv /= I_cov
238+
Igg_inv /= I_cov
239+
Igb_inv /= I_cov
240+
Ibb_inv /= I_cov
241+
242+
self._Irr_inv = Irr_inv
243+
self._Irg_inv = Irg_inv
244+
self._Irb_inv = Irb_inv
245+
self._Igg_inv = Igg_inv
246+
self._Igb_inv = Igb_inv
247+
self._Ibb_inv = Ibb_inv
248+
249+
def _computeCoefficients(self, p):
250+
r = self._radius
251+
I = self._I
252+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
253+
254+
p_mean = cv2.blur(p, (r, r))
255+
256+
Ipr_mean = cv2.blur(Ir * p, (r, r))
257+
Ipg_mean = cv2.blur(Ig * p, (r, r))
258+
Ipb_mean = cv2.blur(Ib * p, (r, r))
259+
260+
Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
261+
Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
262+
Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
263+
264+
ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
265+
ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
266+
ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
267+
b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
268+
269+
ar_mean = cv2.blur(ar, (r, r))
270+
ag_mean = cv2.blur(ag, (r, r))
271+
ab_mean = cv2.blur(ab, (r, r))
272+
b_mean = cv2.blur(b, (r, r))
273+
274+
return ar_mean, ag_mean, ab_mean, b_mean
275+
276+
def _computeOutput(self, ab, I):
277+
ar_mean, ag_mean, ab_mean, b_mean = ab
278+
279+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
280+
281+
q = ar_mean * Ir + ag_mean * Ig + ab_mean * Ib + b_mean
282+
283+
return q

0 commit comments

Comments
 (0)