|
1 | 1 | # Standard
|
2 | 2 | import argparse
|
| 3 | +import datetime |
3 | 4 | from functools import partial
|
4 | 5 | import itertools
|
5 | 6 | import json
|
6 | 7 | import os
|
7 | 8 | from pathlib import Path
|
8 | 9 | import random
|
9 | 10 | import time
|
| 11 | +import contextlib |
10 | 12 |
|
11 | 13 | # Third Party
|
12 |
| -from aiu_fms_testing_utils.utils import aiu_setup, warmup_model |
| 14 | +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region |
13 | 15 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
|
14 | 16 | import numpy as np
|
15 | 17 | import torch
|
|
234 | 236 | default="sdpa",
|
235 | 237 | help="which backend attention to use in mha",
|
236 | 238 | )
|
| 239 | +parser.add_argument( |
| 240 | + "--stagger_load", |
| 241 | + type=int, |
| 242 | + default=0, |
| 243 | + help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes" |
| 244 | +) |
| 245 | +parser.add_argument( |
| 246 | + "--stagger_update_lazyhandle", |
| 247 | + type=int, |
| 248 | + default=0, |
| 249 | + help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes" |
| 250 | +) |
| 251 | +parser.add_argument( |
| 252 | + "--dist_timeout", |
| 253 | + type=int, |
| 254 | + default=0, |
| 255 | + help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group" |
| 256 | +) |
237 | 257 | args = parser.parse_args()
|
238 | 258 |
|
239 | 259 | attention_map = {
|
|
293 | 313 | is_aiu_backend = "aiu" in args.device_type
|
294 | 314 |
|
295 | 315 | if args.distributed:
|
296 |
| - dist.init_process_group() |
| 316 | + if args.dist_timeout > 0: |
| 317 | + # Default timeout: |
| 318 | + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group |
| 319 | + dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) |
| 320 | + dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes") |
| 321 | + else: |
| 322 | + dist.init_process_group() |
297 | 323 | # Fix until PT 2.3
|
298 | 324 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
299 | 325 | aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
|
@@ -471,18 +497,19 @@ def select_int8_module(
|
471 | 497 | dprint(f"data_type={default_dtype}")
|
472 | 498 | dprint("="*60 + "\n")
|
473 | 499 |
|
474 |
| -model = get_model( |
475 |
| - args.architecture, |
476 |
| - args.variant, |
477 |
| - model_path=args.model_path, |
478 |
| - device_type="cpu" if is_aiu_backend else args.device_type, |
479 |
| - data_type=default_dtype, |
480 |
| - source=args.model_source, |
481 |
| - distributed_strategy=distr_param, |
482 |
| - group=dist.group.WORLD, |
483 |
| - linear_config=linear_config, |
484 |
| - fused_weights=fused_weights, |
485 |
| -) |
| 500 | +with stagger_region(args.stagger_load) as _s: |
| 501 | + model = get_model( |
| 502 | + args.architecture, |
| 503 | + args.variant, |
| 504 | + model_path=args.model_path, |
| 505 | + device_type="cpu" if is_aiu_backend else args.device_type, |
| 506 | + data_type=default_dtype, |
| 507 | + source=args.model_source, |
| 508 | + distributed_strategy=distr_param, |
| 509 | + group=dist.group.WORLD, |
| 510 | + linear_config=linear_config, |
| 511 | + fused_weights=fused_weights, |
| 512 | + ) |
486 | 513 |
|
487 | 514 | ### Quantization
|
488 | 515 |
|
@@ -759,7 +786,7 @@ def infer(use_cache, do_sample, warmup):
|
759 | 786 | pt_compile_model_time = time.time()
|
760 | 787 | if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
|
761 | 788 | for cache in use_cache:
|
762 |
| - warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) |
| 789 | + warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, args.stagger_update_lazyhandle, **extra_generation_kwargs) |
763 | 790 | aiu_warmup_time = time.time()
|
764 | 791 | for sample, cache in itertools.product(do_sample, use_cache):
|
765 | 792 | infer(cache, sample, True)
|
|
0 commit comments