|
1 | 1 | # Standard
|
2 | 2 | import argparse
|
| 3 | +import datetime |
3 | 4 | from functools import partial
|
4 | 5 | import itertools
|
5 | 6 | import json
|
6 | 7 | import os
|
7 | 8 | from pathlib import Path
|
8 | 9 | import random
|
9 | 10 | import time
|
| 11 | +import contextlib |
10 | 12 |
|
11 | 13 | # Third Party
|
12 |
| -from aiu_fms_testing_utils.utils import aiu_setup, warmup_model |
| 14 | +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region |
13 | 15 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
|
14 | 16 | import numpy as np
|
15 | 17 | import torch
|
|
235 | 237 | default="sdpa",
|
236 | 238 | help="which backend attention to use in mha",
|
237 | 239 | )
|
| 240 | +parser.add_argument( |
| 241 | + "--stagger_load", |
| 242 | + type=int, |
| 243 | + default=0, |
| 244 | + help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes" |
| 245 | +) |
| 246 | +parser.add_argument( |
| 247 | + "--stagger_update_lazyhandle", |
| 248 | + type=int, |
| 249 | + default=0, |
| 250 | + help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes" |
| 251 | +) |
| 252 | +parser.add_argument( |
| 253 | + "--dist_timeout", |
| 254 | + type=int, |
| 255 | + default=0, |
| 256 | + help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group" |
| 257 | +) |
238 | 258 | args = parser.parse_args()
|
239 | 259 |
|
240 | 260 | attention_map = {
|
|
296 | 316 | is_aiu_backend = "aiu" in args.device_type
|
297 | 317 |
|
298 | 318 | if args.distributed:
|
299 |
| - dist.init_process_group() |
| 319 | + if args.dist_timeout > 0: |
| 320 | + # Default timeout: |
| 321 | + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group |
| 322 | + dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) |
| 323 | + dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes") |
| 324 | + else: |
| 325 | + dist.init_process_group() |
300 | 326 | # Fix until PT 2.3
|
301 | 327 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
302 | 328 | aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
|
@@ -476,18 +502,19 @@ def select_int8_module(
|
476 | 502 | dprint(f"data_type={default_dtype}")
|
477 | 503 | dprint("=" * 60 + "\n")
|
478 | 504 |
|
479 |
| -model = get_model( |
480 |
| - args.architecture, |
481 |
| - args.variant, |
482 |
| - model_path=args.model_path, |
483 |
| - device_type="cpu" if is_aiu_backend else args.device_type, |
484 |
| - data_type=default_dtype, |
485 |
| - source=args.model_source, |
486 |
| - distributed_strategy=distr_param, |
487 |
| - group=dist.group.WORLD, |
488 |
| - linear_config=linear_config, |
489 |
| - fused_weights=fused_weights, |
490 |
| -) |
| 505 | +with stagger_region(args.stagger_load) as _s: |
| 506 | + model = get_model( |
| 507 | + args.architecture, |
| 508 | + args.variant, |
| 509 | + model_path=args.model_path, |
| 510 | + device_type="cpu" if is_aiu_backend else args.device_type, |
| 511 | + data_type=default_dtype, |
| 512 | + source=args.model_source, |
| 513 | + distributed_strategy=distr_param, |
| 514 | + group=dist.group.WORLD, |
| 515 | + linear_config=linear_config, |
| 516 | + fused_weights=fused_weights, |
| 517 | + ) |
491 | 518 |
|
492 | 519 | ### Quantization
|
493 | 520 |
|
@@ -801,6 +828,7 @@ def infer(use_cache, do_sample, warmup):
|
801 | 828 | ids,
|
802 | 829 | args.max_new_tokens,
|
803 | 830 | args.compile_dynamic_sendnn,
|
| 831 | + args.stagger_update_lazyhandle, |
804 | 832 | **extra_generation_kwargs,
|
805 | 833 | )
|
806 | 834 | aiu_warmup_time = time.time()
|
|
0 commit comments