|
1 | 1 | # Standard
|
2 | 2 | import argparse
|
| 3 | +import datetime |
3 | 4 | from functools import partial
|
4 | 5 | import itertools
|
5 | 6 | import json
|
|
8 | 9 | import random
|
9 | 10 | import time
|
10 | 11 | import contextlib
|
| 12 | +import math |
11 | 13 |
|
12 | 14 | # Third Party
|
13 | 15 | from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
|
|
218 | 220 | default=0,
|
219 | 221 | help="Set verbosity level (pass flag as `-v`, `-vv`, `-vvv`)"
|
220 | 222 | )
|
| 223 | +parser.add_argument( |
| 224 | + "--stagger_load", |
| 225 | + type=int, |
| 226 | + default=0, |
| 227 | + help="Stagger model loading to avoid OOM issues on the host" |
| 228 | +) |
| 229 | +parser.add_argument( |
| 230 | + "--stagger_update_lazyhandle", |
| 231 | + type=int, |
| 232 | + default=0, |
| 233 | + help="Stagger update_lazyhandle to avoid OOM issues on the host" |
| 234 | +) |
| 235 | +parser.add_argument( |
| 236 | + "--dist_timeout", |
| 237 | + type=int, |
| 238 | + default=0, |
| 239 | + help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group" |
| 240 | +) |
221 | 241 | args = parser.parse_args()
|
222 | 242 |
|
223 | 243 | if args.quantization == "gptq":
|
|
260 | 280 | is_aiu_backend = "aiu" in args.device_type
|
261 | 281 |
|
262 | 282 | if args.distributed:
|
263 |
| - dist.init_process_group() |
| 283 | + if args.dist_timeout > 0: |
| 284 | + # Default timeout: |
| 285 | + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group |
| 286 | + dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) |
| 287 | + dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes") |
| 288 | + else: |
| 289 | + dist.init_process_group() |
264 | 290 | # Fix until PT 2.3
|
265 | 291 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
266 | 292 | aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
|
@@ -438,6 +464,13 @@ def select_int8_module(
|
438 | 464 | dprint(f"data_type={default_dtype}")
|
439 | 465 | dprint("="*60 + "\n")
|
440 | 466 |
|
| 467 | +if args.stagger_load > 0 and args.stagger_load != world_size: |
| 468 | + for _set in range( math.ceil(world_size / float(args.stagger_load)) ): |
| 469 | + if rank < (_set+1)*args.stagger_load: |
| 470 | + break |
| 471 | + torch.distributed.barrier() |
| 472 | + dprint(f"Stagger Model Load: Begin (Set: {_set+1} of {math.ceil(world_size / float(args.stagger_load))})") |
| 473 | + |
441 | 474 | model = get_model(
|
442 | 475 | args.architecture,
|
443 | 476 | args.variant,
|
@@ -467,6 +500,13 @@ def select_int8_module(
|
467 | 500 | loading_model_time = time.time() - loading_model_time
|
468 | 501 | dprint(f"loading complete, took {loading_model_time:.3f}s")
|
469 | 502 |
|
| 503 | +if args.stagger_load > 0 and args.stagger_load != world_size: |
| 504 | + for _set in range( math.ceil(world_size / float(args.stagger_load)) ): |
| 505 | + if rank >= (_set+1)*args.stagger_load: |
| 506 | + continue |
| 507 | + torch.distributed.barrier() |
| 508 | + dprint(f"Stagger Model Load: All Complete") |
| 509 | + |
470 | 510 | if args.compile:
|
471 | 511 | dprint("compiling model")
|
472 | 512 | if is_aiu_backend:
|
@@ -695,7 +735,7 @@ def infer(use_cache, do_sample, warmup):
|
695 | 735 | dprint(f"compilation warmup")
|
696 | 736 | pt_compile_model_time = time.time()
|
697 | 737 | if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
|
698 |
| - warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) |
| 738 | + warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs) |
699 | 739 | aiu_warmup_time = time.time()
|
700 | 740 | for sample, cache in itertools.product(do_sample, use_cache):
|
701 | 741 | infer(cache, sample, True)
|
|
0 commit comments