Skip to content

Commit 56b629f

Browse files
committed
clean up ruff check
Signed-off-by: kcirred <[email protected]>
1 parent 89056cb commit 56b629f

File tree

8 files changed

+18
-181
lines changed

8 files changed

+18
-181
lines changed

aiu_fms_testing_utils/utils/model_setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def get_device(args: argparse.Namespace) -> torch.device:
4747
device = torch.device(args.device_type, local_rank)
4848
torch.cuda.set_device(device)
4949
elif args.is_aiu_backend:
50+
from torch_sendnn import torch_sendnn # noqa: F401
51+
5052
if args.distributed:
5153
aiu_setup.aiu_dist_setup(
5254
distributed.get_rank(),

aiu_fms_testing_utils/utils/paged.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55
import torch
6+
import fms.utils.spyre.paged # noqa
67

78

89
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):

scripts/generate_metrics.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,11 @@ def find_eos_index(reference_tokens, eos_token_id):
175175
return result
176176

177177

178-
def filter_before_eos(level_metric, filter_indexes):
178+
def filter_before_eos(metrics, filter_indexes):
179179
from itertools import groupby
180180

181181
filtered_results = [
182-
list(g)[: filter_indexes[k]]
183-
for k, g in groupby(level_metric, key=lambda x: x[0])
182+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
184183
]
185184
return [item for sublist in filtered_results for item in sublist]
186185

@@ -197,10 +196,10 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
197196
return input_ids, padding_kwargs
198197

199198

200-
def write_csv(metric, path, metric_name):
199+
def write_csv(metrics, path, metric_name):
201200
with open(path, "w") as f:
202201
f.write(f"{metric_name}\n")
203-
for t in metric:
202+
for t in metrics:
204203
f.write(f"{t[2].item()}\n")
205204
f.close()
206205

scripts/inference.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,20 @@
252252
from fms.utils.generation import generate
253253

254254
if "fp8" in attn_name:
255-
pass
255+
import fms_mo.aiu_addons.fp8.fp8_attn # noqa: F401
256256

257257
if args.quantization == "gptq":
258258
if "aiu" in args.device_type:
259259
try:
260+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa
261+
260262
print("Loaded `aiu_addons` functionalities")
261263
except ImportError:
262264
raise ImportError("Failed to import GPTQ addons from fms-mo.")
263265
elif args.quantization == "int8":
264266
try:
267+
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear # noqa
268+
265269
print("Loaded `aiu_addons` functionalities")
266270
except ImportError:
267271
raise ImportError("Failed to import INT8 addons from fms-mo.")
@@ -301,6 +305,8 @@
301305
device = torch.device(args.device_type, local_rank)
302306
torch.cuda.set_device(device)
303307
elif is_aiu_backend:
308+
from torch_sendnn import torch_sendnn # noqa
309+
304310
if not args.distributed:
305311
aiu_setup.aiu_setup(rank, world_size)
306312

scripts/roberta.py

Lines changed: 0 additions & 172 deletions
This file was deleted.

scripts/small-toy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from fms.utils.tp_wrapping import apply_tp
1717

1818
# Import AIU Libraries
19+
from torch_sendnn import torch_sendnn # noqa
1920

2021

2122
# ==============================================================

scripts/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@
297297
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
298298

299299
# Always initialize AIU in this script
300+
from torch_sendnn import torch_sendnn # noqa
300301

301302
if not args.distributed:
302303
aiu_setup.aiu_setup(rank, world_size)

tests/models/test_decoders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,11 @@ def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens)
315315
return result
316316

317317

318-
def __filter_before_eos(level_metric, filter_indexes):
318+
def __filter_before_eos(metrics, filter_indexes):
319319
from itertools import groupby
320320

321321
filtered_results = [
322-
list(g)[: filter_indexes[k]]
323-
for k, g in groupby(level_metric, key=lambda x: x[0])
322+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
324323
]
325324
return [item for sublist in filtered_results for item in sublist]
326325

0 commit comments

Comments
 (0)