Skip to content

Commit 2fd81b0

Browse files
committed
Options for Stagger model loading for low memory systems
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host * `--dist_timeout` : (default: either `10` for NCCL or `30` for others set by PyTorch) torch distributed timeout in minutes Signed-off-by: Joshua Hursey <[email protected]>
1 parent f067cd9 commit 2fd81b0

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,42 @@
33
import time
44
from fms.utils.tokenizers import BaseTokenizer
55
from fms.utils.generation import generate
6-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
6+
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
77
from typing import Optional, List, Tuple
88
import os
99
import requests
1010
import json
1111
import random
12+
import math
1213

13-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
14+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, stagger_update_lazyhandle = 0, **padding_kwargs):
1415
import torch_sendnn
1516
dprint("AIU warmup")
16-
pt_compile_model_time = time.time()
1717
extra_kwargs = {**padding_kwargs, "only_last_token": True}
1818
max_new_tokens_warmup = max_new_tokens
1919
if compile_dynamic_sendnn:
2020
max_new_tokens_warmup = 2
21+
22+
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
23+
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
24+
if rank < (_set+1)*stagger_update_lazyhandle:
25+
break
26+
torch.distributed.barrier()
27+
dprint(f"Stagger update_lazyhandle: Begin (Set: {_set+1} of {math.ceil(world_size / float(stagger_update_lazyhandle))})")
28+
29+
pt_compile_model_time = time.time()
2130
with torch_sendnn.warmup_mode():
2231
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
2332
pt_compile_model_time = time.time() - pt_compile_model_time
2433
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
2534

35+
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size:
36+
for _set in range( math.ceil(world_size / float(stagger_update_lazyhandle)) ):
37+
if rank >= (_set+1)*stagger_update_lazyhandle:
38+
continue
39+
torch.distributed.barrier()
40+
dprint(f"Stagger update_lazyhandle: All Complete")
41+
2642
def ids_for_prompt(prompt, tokenizer):
2743
tokens = tokenizer.tokenize(prompt)
2844
ids = tokenizer.convert_tokens_to_ids(tokens)

scripts/inference.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Standard
22
import argparse
3+
import datetime
34
from functools import partial
45
import itertools
56
import json
@@ -8,6 +9,7 @@
89
import random
910
import time
1011
import contextlib
12+
import math
1113

1214
# Third Party
1315
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
@@ -218,6 +220,24 @@
218220
default=0,
219221
help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
220222
)
223+
parser.add_argument(
224+
"--stagger_load",
225+
type=int,
226+
default=0,
227+
help="Stagger model loading to avoid OOM issues on the host"
228+
)
229+
parser.add_argument(
230+
"--stagger_update_lazyhandle",
231+
type=int,
232+
default=0,
233+
help="Stagger update_lazyhandle to avoid OOM issues on the host"
234+
)
235+
parser.add_argument(
236+
"--dist_timeout",
237+
type=int,
238+
default=0,
239+
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group"
240+
)
221241
args = parser.parse_args()
222242

223243
if args.quantization == "gptq":
@@ -260,7 +280,13 @@
260280
is_aiu_backend = "aiu" in args.device_type
261281

262282
if args.distributed:
263-
dist.init_process_group()
283+
if args.dist_timeout > 0:
284+
# Default timeout:
285+
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
286+
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
287+
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
288+
else:
289+
dist.init_process_group()
264290
# Fix until PT 2.3
265291
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
266292
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
@@ -438,6 +464,13 @@ def select_int8_module(
438464
dprint(f"data_type={default_dtype}")
439465
dprint("="*60 + "\n")
440466

467+
if args.stagger_load > 0 and args.stagger_load != world_size:
468+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
469+
if rank < (_set+1)*args.stagger_load:
470+
break
471+
torch.distributed.barrier()
472+
dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})")
473+
441474
model = get_model(
442475
args.architecture,
443476
args.variant,
@@ -467,6 +500,13 @@ def select_int8_module(
467500
loading_model_time = time.time() - loading_model_time
468501
dprint(f"loading complete, took {loading_model_time:.3f}s")
469502

503+
if args.stagger_load > 0 and args.stagger_load != world_size:
504+
for _set in range( math.ceil(world_size / float(args.stagger_load)) ):
505+
if rank >= (_set+1)*args.stagger_load:
506+
continue
507+
torch.distributed.barrier()
508+
dprint(f"Stagger Model Load: All Complete")
509+
470510
if args.compile:
471511
dprint("compiling model")
472512
if is_aiu_backend:
@@ -695,7 +735,7 @@ def infer(use_cache, do_sample, warmup):
695735
dprint(f"compilation warmup")
696736
pt_compile_model_time = time.time()
697737
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
698-
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs)
738+
warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs)
699739
aiu_warmup_time = time.time()
700740
for sample, cache in itertools.product(do_sample, use_cache):
701741
infer(cache, sample, True)

0 commit comments

Comments
 (0)