Skip to content

Commit afbac6a

Browse files
committed
Updated re PR review
1 parent 4056b58 commit afbac6a

File tree

4 files changed

+22
-26
lines changed

4 files changed

+22
-26
lines changed

helpers/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,11 +1428,11 @@ def parse_args(input_args=None):
14281428
),
14291429
)
14301430
parser.add_argument(
1431-
"--cuda_clear_cache_steps",
1431+
"--accelerator_cache_clear_interval",
14321432
default=None,
14331433
type=int,
14341434
help=(
1435-
"Clear the CUDA cache every X steps. This can help prevent memory leaks, but may slow down training."
1435+
"Clear the cache from VRAM every X steps. This can help prevent memory leaks, but may slow down training."
14361436
),
14371437
)
14381438

helpers/caching/memory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
def reclaim_memory():
2+
import gc
3+
import torch
4+
5+
if torch.cuda.is_available():
6+
torch.cuda.empty_cache()
7+
torch.cuda.ipc_collect()
8+
9+
if torch.backends.mps.is_available():
10+
torch.mps.empty_cache()
11+
12+
gc.collect()

train_sd21.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from pathlib import Path
2323
from helpers.arguments import parse_args
24+
from helpers.caching.memory import reclaim_memory
2425
from helpers.legacy.validation import prepare_validation_prompt_list
2526
from helpers.training.validation import Validation
2627
from helpers.training.state_tracker import StateTracker
@@ -109,15 +110,6 @@
109110
check_min_version("0.27.0.dev0")
110111

111112

112-
def garbage_collection():
113-
import gc
114-
115-
if torch.cuda.is_available():
116-
torch.cuda.empty_cache()
117-
torch.cuda.ipc_collect()
118-
gc.collect()
119-
120-
121113
SCHEDULER_NAME_MAP = {
122114
"euler": EulerDiscreteScheduler,
123115
"euler-a": EulerAncestralDiscreteScheduler,
@@ -885,7 +877,7 @@ def main():
885877
for _, backend in StateTracker.get_data_backends().items():
886878
if "vaecache" in backend:
887879
backend["vaecache"].vae = None
888-
garbage_collection()
880+
reclaim_memory()
889881
memory_after_unload = torch.cuda.memory_allocated() / 1024**3
890882
memory_saved = memory_after_unload - memory_before_unload
891883
logger.info(
@@ -1570,7 +1562,7 @@ def main():
15701562
)
15711563

15721564
del text_encoder_lora_layers
1573-
garbage_collection()
1565+
reclaim_memory()
15741566

15751567
if args.use_ema:
15761568
ema_unet.copy_to(unet.parameters())

train_sdxl.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pathlib import Path
2424
from helpers.arguments import parse_args
25+
from helpers.caching.memory import reclaim_memory
2526
from helpers.legacy.validation import prepare_validation_prompt_list
2627
from helpers.training.validation import Validation
2728
from helpers.training.state_tracker import StateTracker
@@ -219,15 +220,6 @@ def get_tokenizers(args):
219220
return tokenizer_1, tokenizer_2, tokenizer_3
220221

221222

222-
import gc
223-
224-
225-
def garbage_collection():
226-
if torch.cuda.is_available():
227-
torch.cuda.empty_cache()
228-
torch.cuda.ipc_collect()
229-
gc.collect()
230-
231223
def main():
232224
StateTracker.set_model_type("sdxl")
233225
args = parse_args()
@@ -742,7 +734,7 @@ def main():
742734
text_encoder_2 = None
743735
text_encoder_3 = None
744736
text_encoders = []
745-
garbage_collection()
737+
reclaim_memory()
746738
memory_after_unload = torch.cuda.memory_allocated() / 1024**3
747739
memory_saved = memory_after_unload - memory_before_unload
748740
logger.info(
@@ -1182,7 +1174,7 @@ def main():
11821174
for _, backend in StateTracker.get_data_backends().items():
11831175
if "vaecache" in backend:
11841176
backend["vaecache"].vae = None
1185-
garbage_collection()
1177+
reclaim_memory()
11861178
memory_after_unload = torch.cuda.memory_allocated() / 1024**3
11871179
memory_saved = memory_after_unload - memory_before_unload
11881180
logger.info(
@@ -1945,7 +1937,7 @@ def main():
19451937
)
19461938

19471939
if global_step % args.cuda_clear_cache == 0:
1948-
garbage_collection()
1940+
reclaim_memory()
19491941

19501942
logs = {
19511943
"step_loss": loss.detach().item(),
@@ -2042,7 +2034,7 @@ def main():
20422034
del transformer
20432035
del text_encoder_lora_layers
20442036
del text_encoder_2_lora_layers
2045-
garbage_collection()
2037+
reclaim_memory()
20462038
elif args.use_ema:
20472039
if unet is not None:
20482040
ema_unet.copy_to(unet.parameters())

0 commit comments

Comments
 (0)