Skip to content

Commit 29c42a9

Browse files
committed
Options for Stagger model loading for low memory systems
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host * `--dist_timeout` : (default: either `10` for NCCL or `30` for others set by PyTorch) torch distributed timeout in minutes * Convert the stagger enter/leave into a proper contextlib function Signed-off-by: Joshua Hursey <[email protected]>
1 parent 9779f4d commit 29c42a9

File tree

2 files changed

+69
-16
lines changed

2 files changed

+69
-16
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,44 @@
77
import time
88

99
# Third Party
10-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
10+
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
1111
from fms.utils.tokenizers import BaseTokenizer
1212
from fms.utils.generation import pad_input_ids
1313
import torch
1414
import torch.nn as nn
15+
import math
16+
import contextlib
1517

18+
@contextlib.contextmanager
19+
def stagger_region(limit: int):
20+
"""
21+
Limit the number of concurrent processes into this region of code.
22+
Processes yield from this function when they are allowed to enter the region of code.
23+
Processes return from this function when all of the processes have completed the region of code.
24+
25+
:param limit: Number of concurrent processes allowed in the code region if > 0.
26+
"""
27+
if limit > 0 and limit != world_size:
28+
for _set in range( math.ceil(world_size / float(limit)) ):
29+
if rank < (_set+1)*limit:
30+
break
31+
torch.distributed.barrier()
32+
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})")
33+
yield {}
34+
if limit > 0 and limit != world_size:
35+
for _set in range( math.ceil(world_size / float(limit)) ):
36+
if rank >= (_set+1)*limit:
37+
continue
38+
torch.distributed.barrier()
39+
dprint(f"Stagger: All Complete")
1640

1741
def warmup_model(
1842
model: nn.Module,
1943
input_ids: torch.Tensor,
2044
max_new_tokens: int,
2145
compile_dynamic_sendnn: bool = False,
2246
use_cache: bool = True,
47+
stagger_update_lazyhandle: int = 0,
2348
**extra_kwargs,
2449
):
2550
import torch_sendnn
@@ -53,7 +78,7 @@ def warmup_model(
5378

5479
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
5580

56-
with torch_sendnn.warmup_mode():
81+
with stagger_region(stagger_update_lazyhandle) as _s, torch_sendnn.warmup_mode():
5782
generate(
5883
model,
5984
_warmup_input_ids,

scripts/inference.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Standard
22
import argparse
3+
import datetime
34
from functools import partial
45
import itertools
56
import json
67
import os
78
from pathlib import Path
89
import random
910
import time
11+
import contextlib
1012

1113
# Third Party
12-
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
14+
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region
1315
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
1416
import numpy as np
1517
import torch
@@ -235,6 +237,24 @@
235237
default="sdpa",
236238
help="which backend attention to use in mha",
237239
)
240+
parser.add_argument(
241+
"--stagger_load",
242+
type=int,
243+
default=0,
244+
help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes"
245+
)
246+
parser.add_argument(
247+
"--stagger_update_lazyhandle",
248+
type=int,
249+
default=0,
250+
help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes"
251+
)
252+
parser.add_argument(
253+
"--dist_timeout",
254+
type=int,
255+
default=0,
256+
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group"
257+
)
238258
args = parser.parse_args()
239259

240260
attention_map = {
@@ -296,7 +316,13 @@
296316
is_aiu_backend = "aiu" in args.device_type
297317

298318
if args.distributed:
299-
dist.init_process_group()
319+
if args.dist_timeout > 0:
320+
# Default timeout:
321+
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
322+
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
323+
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
324+
else:
325+
dist.init_process_group()
300326
# Fix until PT 2.3
301327
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
302328
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
@@ -476,18 +502,19 @@ def select_int8_module(
476502
dprint(f"data_type={default_dtype}")
477503
dprint("=" * 60 + "\n")
478504

479-
model = get_model(
480-
args.architecture,
481-
args.variant,
482-
model_path=args.model_path,
483-
device_type="cpu" if is_aiu_backend else args.device_type,
484-
data_type=default_dtype,
485-
source=args.model_source,
486-
distributed_strategy=distr_param,
487-
group=dist.group.WORLD,
488-
linear_config=linear_config,
489-
fused_weights=fused_weights,
490-
)
505+
with stagger_region(args.stagger_load) as _s:
506+
model = get_model(
507+
args.architecture,
508+
args.variant,
509+
model_path=args.model_path,
510+
device_type="cpu" if is_aiu_backend else args.device_type,
511+
data_type=default_dtype,
512+
source=args.model_source,
513+
distributed_strategy=distr_param,
514+
group=dist.group.WORLD,
515+
linear_config=linear_config,
516+
fused_weights=fused_weights,
517+
)
491518

492519
### Quantization
493520

@@ -801,6 +828,7 @@ def infer(use_cache, do_sample, warmup):
801828
ids,
802829
args.max_new_tokens,
803830
args.compile_dynamic_sendnn,
831+
args.stagger_update_lazyhandle,
804832
**extra_generation_kwargs,
805833
)
806834
aiu_warmup_time = time.time()

0 commit comments

Comments
 (0)