Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions comfy/conds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import math
import comfy.utils
import logging


class CONDRegular:
Expand All @@ -10,12 +11,15 @@ def __init__(self, cond):
def _copy_with(self, cond):
return self.__class__(cond)

def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def process_cond(self, batch_size, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))

def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
return False
return True

def concat(self, others):
Expand All @@ -29,14 +33,14 @@ def size(self):


class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
def process_cond(self, batch_size, area, **kwargs):
data = self.cond
if area is not None:
dims = len(area) // 2
for i in range(dims):
data = data.narrow(i + 2, area[i + dims], area[i])

return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))


class CONDCrossAttn(CONDRegular):
Expand All @@ -51,6 +55,9 @@ def can_concat(self, other):
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
return False
return True

def concat(self, others):
Expand All @@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond

def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
return self._copy_with(self.cond)

def can_concat(self, other):
Expand All @@ -92,10 +99,10 @@ class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond

def process_cond(self, batch_size, device, **kwargs):
def process_cond(self, batch_size, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))

return self._copy_with(out)

Expand Down
5 changes: 3 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
import comfy.model_base

import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
Expand Down Expand Up @@ -264,12 +265,12 @@ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)

def copy(self):
Expand Down
20 changes: 11 additions & 9 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ class ModelSampling(s, c):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype=dtype, device=device)
extra = comfy.model_management.cast_to_device(extra, device, dtype)
else:
extra = extra.to(device=device)
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra


Expand Down Expand Up @@ -162,7 +162,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
xc = self.model_sampling.calculate_input(sigma, x)

if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)

context = c_crossattn
dtype = self.get_dtype()
Expand All @@ -174,7 +174,7 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype=dtype, device=device)
context = comfy.model_management.cast_to_device(context, device, dtype)

extra_conds = {}
for o in kwargs:
Expand Down Expand Up @@ -401,7 +401,7 @@ def encode_adm(self, **kwargs):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is None:
return torch.zeros((1, self.adm_channels))
return torch.zeros((1, self.adm_channels), device=device)
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)

Expand Down Expand Up @@ -615,9 +615,11 @@ def concat_cond(self, **kwargs):

if image is None:
image = torch.zeros_like(noise)
else:
image = image.to(device=device)

if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
Expand Down Expand Up @@ -696,7 +698,7 @@ def extra_conds(self, **kwargs):
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))

out["effnet"] = comfy.conds.CONDRegular(prior)
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out

Expand Down Expand Up @@ -1161,10 +1163,10 @@ def extra_conds(self, **kwargs):

vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
vace_frames_out.append(vf)

vace_frames = torch.stack(vace_frames_out, dim=1)
Expand Down
2 changes: 1 addition & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)

hooks = conds.get('hooks', None)
control = conds.get('control', None)
Expand Down
Loading