Skip to content

Commit 0dcc75c

Browse files
Add experimental --async-offload lowvram weight offloading. (Comfy-Org#7820)
This should speed up the lowvram mode a bit. It currently is only enabled when --async-offload is used but it will be enabled by default in the future if there are no problems.
1 parent b685b8a commit 0dcc75c

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

comfy/cli_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class LatentPreviewMethod(enum.Enum):
128128

129129
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
130130

131+
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
131132

132133
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
133134

comfy/model_management.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -939,15 +939,56 @@ def force_channels_last():
939939
#TODO
940940
return False
941941

942-
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
942+
943+
STREAMS = {}
944+
NUM_STREAMS = 1
945+
if args.async_offload:
946+
NUM_STREAMS = 2
947+
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
948+
949+
stream_counter = 0
950+
def get_offload_stream(device):
951+
global stream_counter
952+
if NUM_STREAMS <= 1:
953+
return None
954+
955+
if device in STREAMS:
956+
ss = STREAMS[device]
957+
s = ss[stream_counter]
958+
stream_counter = (stream_counter + 1) % len(ss)
959+
if is_device_cuda(device):
960+
ss[stream_counter].wait_stream(torch.cuda.current_stream())
961+
return s
962+
elif is_device_cuda(device):
963+
ss = []
964+
for k in range(NUM_STREAMS):
965+
ss.append(torch.cuda.Stream(device=device, priority=10))
966+
STREAMS[device] = ss
967+
s = ss[stream_counter]
968+
stream_counter = (stream_counter + 1) % len(ss)
969+
return s
970+
return None
971+
972+
def sync_stream(device, stream):
973+
if stream is None:
974+
return
975+
if is_device_cuda(device):
976+
torch.cuda.current_stream().wait_stream(stream)
977+
978+
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
943979
if device is None or weight.device == device:
944980
if not copy:
945981
if dtype is None or weight.dtype == dtype:
946982
return weight
947983
return weight.to(dtype=dtype, copy=copy)
948984

949-
r = torch.empty_like(weight, dtype=dtype, device=device)
950-
r.copy_(weight, non_blocking=non_blocking)
985+
if stream is not None:
986+
with stream:
987+
r = torch.empty_like(weight, dtype=dtype, device=device)
988+
r.copy_(weight, non_blocking=non_blocking)
989+
else:
990+
r = torch.empty_like(weight, dtype=dtype, device=device)
991+
r.copy_(weight, non_blocking=non_blocking)
951992
return r
952993

953994
def cast_to_device(tensor, device, dtype, copy=False):

comfy/ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
3737
if device is None:
3838
device = input.device
3939

40+
offload_stream = comfy.model_management.get_offload_stream(device)
4041
bias = None
4142
non_blocking = comfy.model_management.device_supports_non_blocking(device)
4243
if s.bias is not None:
4344
has_function = len(s.bias_function) > 0
44-
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
45+
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
4546
if has_function:
4647
for f in s.bias_function:
4748
bias = f(bias)
4849

4950
has_function = len(s.weight_function) > 0
50-
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
51+
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
5152
if has_function:
5253
for f in s.weight_function:
5354
weight = f(weight)
55+
56+
comfy.model_management.sync_stream(device, offload_stream)
5457
return weight, bias
5558

5659
class CastWeightBiasOp:

0 commit comments

Comments
 (0)