Skip to content

Commit b30c709

Browse files
committed
Options for Stagger model loading for low memory systems
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host * `--dist_timeout` : (default: either `10` for NCCL or `30` for others set by PyTorch) torch distributed timeout in minutes * Convert the stagger enter/leave into a proper contextlib function Signed-off-by: Joshua Hursey <[email protected]>
1 parent f43cc04 commit b30c709

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,39 @@
1212
from fms.utils.generation import pad_input_ids
1313
import torch
1414
import torch.nn as nn
15+
import math
16+
import contextlib
1517

18+
@contextlib.contextmanager
19+
def stagger_region(limit: int):
20+
"""
21+
Limit the number of concurrent processes into this region of code.
22+
Processes yield from this function when they are allowed to enter the region of code.
23+
Processes return from this function when all of the processes have completed the region of code.
24+
25+
:param limit: Number of concurrent processes allowed in the code region if > 0.
26+
"""
27+
if limit > 0 and limit != world_size:
28+
for _set in range( math.ceil(world_size / float(limit)) ):
29+
if rank < (_set+1)*limit:
30+
break
31+
torch.distributed.barrier()
32+
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})")
33+
yield {}
34+
if limit > 0 and limit != world_size:
35+
for _set in range( math.ceil(world_size / float(limit)) ):
36+
if rank >= (_set+1)*limit:
37+
continue
38+
torch.distributed.barrier()
39+
dprint(f"Stagger: All Complete")
1640

1741
def warmup_model(
1842
model: nn.Module,
1943
input_ids: torch.Tensor,
2044
max_new_tokens: int,
2145
compile_dynamic_sendnn: bool = False,
2246
use_cache: bool = True,
47+
stagger_update_lazyhandle: int = 0,
2348
**extra_kwargs
2449
):
2550
import torch_sendnn
@@ -51,7 +76,7 @@ def warmup_model(
5176

5277
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
5378

54-
with torch_sendnn.warmup_mode():
79+
with stagger_region(stagger_update_lazyhandle) as _s, torch_sendnn.warmup_mode():
5580
generate(
5681
model,
5782
_warmup_input_ids,

scripts/inference.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Standard
22
import argparse
3+
import datetime
34
from functools import partial
45
import itertools
56
import json
67
import os
78
from pathlib import Path
89
import random
910
import time
11+
import contextlib
1012

1113
# Third Party
12-
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
14+
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region
1315
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1416
import numpy as np
1517
import torch
@@ -234,6 +236,24 @@
234236
default="sdpa",
235237
help="which backend attention to use in mha",
236238
)
239+
parser.add_argument(
240+
"--stagger_load",
241+
type=int,
242+
default=0,
243+
help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes"
244+
)
245+
parser.add_argument(
246+
"--stagger_update_lazyhandle",
247+
type=int,
248+
default=0,
249+
help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes"
250+
)
251+
parser.add_argument(
252+
"--dist_timeout",
253+
type=int,
254+
default=0,
255+
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group"
256+
)
237257
args = parser.parse_args()
238258

239259
attention_map = {
@@ -293,7 +313,13 @@
293313
is_aiu_backend = "aiu" in args.device_type
294314

295315
if args.distributed:
296-
dist.init_process_group()
316+
if args.dist_timeout > 0:
317+
# Default timeout:
318+
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
319+
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
320+
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
321+
else:
322+
dist.init_process_group()
297323
# Fix until PT 2.3
298324
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
299325
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
@@ -471,18 +497,19 @@ def select_int8_module(
471497
dprint(f"data_type={default_dtype}")
472498
dprint("="*60 + "\n")
473499

474-
model = get_model(
475-
args.architecture,
476-
args.variant,
477-
model_path=args.model_path,
478-
device_type="cpu" if is_aiu_backend else args.device_type,
479-
data_type=default_dtype,
480-
source=args.model_source,
481-
distributed_strategy=distr_param,
482-
group=dist.group.WORLD,
483-
linear_config=linear_config,
484-
fused_weights=fused_weights,
485-
)
500+
with stagger_region(args.stagger_load) as _s:
501+
model = get_model(
502+
args.architecture,
503+
args.variant,
504+
model_path=args.model_path,
505+
device_type="cpu" if is_aiu_backend else args.device_type,
506+
data_type=default_dtype,
507+
source=args.model_source,
508+
distributed_strategy=distr_param,
509+
group=dist.group.WORLD,
510+
linear_config=linear_config,
511+
fused_weights=fused_weights,
512+
)
486513

487514
### Quantization
488515

@@ -759,7 +786,7 @@ def infer(use_cache, do_sample, warmup):
759786
pt_compile_model_time = time.time()
760787
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
761788
for cache in use_cache:
762-
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
789+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs)
763790
aiu_warmup_time = time.time()
764791
for sample, cache in itertools.product(do_sample, use_cache):
765792
infer(cache, sample, True)

0 commit comments

Comments
 (0)