Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,51 @@
import time

# Third Party
from aiu_fms_testing_utils.utils.aiu_setup import dprint

from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from fms.utils.generation import pad_input_ids
import torch
import torch.nn as nn

import math
import contextlib
import warnings


@contextlib.contextmanager
def stagger_region(limit: int):
"""
Limit the number of concurrent processes into this region of code.
Processes yield from this function when they are allowed to enter the region of code.
Processes return from this function when all of the processes have completed the region of code.

:param limit: Number of concurrent processes allowed in the code region if > 0.
"""
if limit > 0 and limit != world_size:
for _set in range(math.ceil(world_size / float(limit))):
if rank < (_set + 1) * limit:
break
torch.distributed.barrier()
dprint(
f"Stagger: Enter (Set: {_set + 1} of {math.ceil(world_size / float(limit))})"
)
yield
if limit > 0 and limit != world_size:
for _set in range(math.ceil(world_size / float(limit))):
if rank >= (_set + 1) * limit:
continue
torch.distributed.barrier()
dprint("Stagger: All Complete")


def warmup_model(
model: nn.Module,
input_ids: torch.Tensor,
max_new_tokens: int,
compile_dynamic_sendnn: bool = False,
use_cache: bool = True,
stagger_update_lazyhandle: int = 0,
**extra_kwargs,
):
import torch_sendnn
Expand Down Expand Up @@ -55,16 +85,17 @@ def warmup_model(

extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}

with torch_sendnn.warmup_mode():
generate(
model,
_warmup_input_ids,
max_new_tokens=_max_new_tokens,
do_sample=False,
use_cache=use_cache,
extra_kwargs=extra_kwargs,
**attention_specific_kwargs,
)
with stagger_region(stagger_update_lazyhandle):
with torch_sendnn.warmup_mode():
generate(
model,
_warmup_input_ids,
max_new_tokens=_max_new_tokens,
do_sample=False,
use_cache=use_cache,
extra_kwargs=extra_kwargs,
**attention_specific_kwargs,
)
pt_compile_model_time = time.time() - pt_compile_model_time
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")

Expand Down
56 changes: 42 additions & 14 deletions scripts/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard
import argparse
import datetime
from functools import partial
import itertools
import json
Expand All @@ -9,7 +10,7 @@
import time

# Third Party
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
import numpy as np
import torch
Expand Down Expand Up @@ -237,6 +238,24 @@
default="sdpa",
help="which backend attention to use in mha",
)
parser.add_argument(
"--stagger_load",
type=int,
default=0,
help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes",
)
parser.add_argument(
"--stagger_update_lazyhandle",
type=int,
default=0,
help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes",
)
parser.add_argument(
"--dist_timeout",
type=int,
default=0,
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group",
)
args = parser.parse_args()

attention_map = {
Expand Down Expand Up @@ -298,7 +317,13 @@
is_aiu_backend = "aiu" in args.device_type

if args.distributed:
dist.init_process_group()
if args.dist_timeout > 0:
# Default timeout:
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout))
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
else:
dist.init_process_group()
# Fix until PT 2.3
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
Expand Down Expand Up @@ -478,18 +503,19 @@ def select_int8_module(
dprint(f"data_type={default_dtype}")
dprint("=" * 60 + "\n")

model = get_model(
args.architecture,
args.variant,
model_path=args.model_path,
device_type="cpu" if is_aiu_backend else args.device_type,
data_type=default_dtype,
source=args.model_source,
distributed_strategy=distr_param,
group=dist.group.WORLD,
linear_config=linear_config,
fused_weights=fused_weights,
)
with stagger_region(args.stagger_load):
model = get_model(
args.architecture,
args.variant,
model_path=args.model_path,
device_type="cpu" if is_aiu_backend else args.device_type,
data_type=default_dtype,
source=args.model_source,
distributed_strategy=distr_param,
group=dist.group.WORLD,
linear_config=linear_config,
fused_weights=fused_weights,
)

### Quantization

Expand Down Expand Up @@ -814,6 +840,8 @@ def infer(use_cache, do_sample, warmup):
ids,
args.max_new_tokens,
args.compile_dynamic_sendnn,
use_cache=cache,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
**extra_generation_kwargs,
)
if (
Expand Down