Skip to content

Commit 9645ec0

Browse files
committed
Add Ascend NPU support for generate and chat
1 parent 2640f6a commit 9645ec0

File tree

7 files changed

+48
-25
lines changed

7 files changed

+48
-25
lines changed

install/install_requirements.sh

+10
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ then
7171
elif [[ -x "$(command -v xpu-smi)" ]];
7272
then
7373
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
74+
elif [[ -x "$(command -v npu-smi)" ]]
75+
then
76+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/test/cpu"
7477
else
7578
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7679
fi
@@ -83,6 +86,13 @@ then
8386
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
8487
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
8588
)
89+
elif [[ -x "$(command -v npu-smi)" ]];
90+
then
91+
REQUIREMENTS_TO_INSTALL=(
92+
torch=="2.7.0"
93+
torchvision=="0.22.0"
94+
torchtune=="0.6.0"
95+
)
8696
else
8797
REQUIREMENTS_TO_INSTALL=(
8898
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"

torchchat/cli/builder.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchchat.utils.build_utils import (
3030
device_sync,
3131
is_cpu_device,
32-
is_cuda_or_cpu_or_xpu_device,
32+
is_supported_device,
3333
name_to_dtype,
3434
)
3535
from torchchat.utils.measure_time import measure_time
@@ -78,6 +78,8 @@ def __post_init__(self):
7878
self.device = "cuda"
7979
elif torch.xpu.is_available():
8080
self.device = "xpu"
81+
elif hasattr(torch, "npu") and torch.npu.is_available():
82+
self.device = "npu"
8183
else:
8284
self.device = "cpu"
8385

@@ -539,7 +541,7 @@ def _initialize_model(
539541
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
540542

541543
if builder_args.dso_path:
542-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
544+
if not is_supported_device(builder_args.device):
543545
print(
544546
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
545547
)
@@ -573,7 +575,7 @@ def do_nothing(max_batch_size, max_seq_length):
573575
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
574576

575577
elif builder_args.aoti_package_path:
576-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
578+
if not is_supported_device(builder_args.device):
577579
print(
578580
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
579581
)

torchchat/cli/cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
181181
)
182182
model_config_parser.add_argument(
183183
"--attention-backend",

torchchat/generate.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,8 @@ def callback(x, *, done_generating=False):
12131213
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
12141214
elif self.builder_args.device == "cuda":
12151215
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1216+
elif self.builder_args.device == "npu":
1217+
print(prof.key_averages().table(sort_by="self_npu_time_total"))
12161218
else:
12171219
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
12181220
prof.export_chrome_trace(f"{self.profile}.json")
@@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
12991301
)
13001302
if torch.cuda.is_available():
13011303
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1302-
if torch.xpu.is_available():
1304+
elif torch.xpu.is_available():
13031305
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1306+
elif hasattr(torch, "npu") and torch.npu.is_available():
1307+
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")
13041308

13051309

13061310

@@ -1595,7 +1599,6 @@ def sample(
15951599

15961600
return idx_next, probs
15971601

1598-
15991602
def run_generator(
16001603
args,
16011604
rank: Optional[int] =None
@@ -1628,8 +1631,10 @@ def run_generator(
16281631
)
16291632
if torch.cuda.is_available():
16301633
torch.cuda.reset_peak_memory_stats()
1631-
if torch.xpu.is_available():
1634+
elif torch.xpu.is_available():
16321635
torch.xpu.reset_peak_memory_stats()
1636+
elif hasattr(torch, "npu") and torch.npu.is_available():
1637+
torch.npu.reset_peak_memory_stats()
16331638

16341639
for _ in gen.chat(generator_args):
16351640
pass

torchchat/utils/build_utils.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def device_sync(device="cpu"):
233233
torch.cuda.synchronize(device)
234234
elif "xpu" in device:
235235
torch.xpu.synchronize(device)
236+
elif "npu" in device:
237+
torch.npu.synchronize(device)
236238
elif ("cpu" in device) or ("mps" in device):
237239
pass
238240
else:
@@ -275,33 +277,36 @@ def is_mps_available() -> bool:
275277
# MPS, is that you?
276278
return True
277279

280+
def select_device(device) -> str:
281+
if torch.cuda.is_available():
282+
return "cuda"
283+
elif is_mps_available():
284+
return "mps"
285+
elif hasattr(torch, "npu") and torch.npu.is_available():
286+
return "npu"
287+
elif torch.xpu.is_available():
288+
return "xpu"
289+
else:
290+
return "cpu"
291+
278292

279293
def get_device_str(device) -> str:
280294
if isinstance(device, str) and device == "fast":
281-
device = (
282-
"cuda"
283-
if torch.cuda.is_available()
284-
else "mps" if is_mps_available()
285-
else "xpu" if torch.xpu.is_available() else "cpu"
286-
)
295+
device = select_device(device)
287296
return device
288297
else:
289298
return str(device)
290299

291300

292301
def get_device(device) -> str:
293302
if isinstance(device, str) and device == "fast":
294-
device = (
295-
"cuda"
296-
if torch.cuda.is_available()
297-
else "mps" if is_mps_available()
298-
else "xpu" if torch.xpu.is_available() else "cpu"
299-
)
303+
device = select_device(device)
300304
return torch.device(device)
301305

302306

303307
def is_cpu_device(device) -> bool:
304308
return device == "" or str(device) == "cpu"
305309

306-
def is_cuda_or_cpu_or_xpu_device(device) -> bool:
307-
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
310+
def is_supported_device(device) -> bool:
311+
device_str = str(device)
312+
return is_cpu_device(device) or any(dev in device_str for dev in ('cuda', 'xpu', 'npu'))

torchchat/utils/device_info.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
import torch
1111

12-
1312
def get_device_info(device: str) -> str:
1413
"""Returns a human-readable description of the hardware based on a torch.device.type
1514
1615
Args:
17-
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
16+
device: A torch.device.type string: one of {"cpu", "cuda", "xpu", "npu"}.
1817
Returns:
1918
str: A human-readable description of the hardware or an empty string if the device type is unhandled.
2019
@@ -46,4 +45,6 @@ def get_device_info(device: str) -> str:
4645
.split("\n")[0]
4746
.split("Device Name:")[1]
4847
)
48+
if device == "npu":
49+
return torch.npu.get_device_name(0)
4950
return ""

torchchat/utils/quantize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def quantize_model(
123123
raise RuntimeError(f"unknown quantizer {quantizer} specified")
124124
else:
125125
# Use tensor subclass API for int4 weight only.
126-
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
126+
if (device in ["cuda", "xpu", "npu"]) and quantizer == "linear:int4":
127127
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
128128
if not support_tensor_subclass:
129129
unwrap_tensor_subclass(model)

0 commit comments

Comments
 (0)