-
Notifications
You must be signed in to change notification settings - Fork 316
Open
Labels
Description
Thanks for the great work!
I tried to enable AutoQuant on top of the latest gpt-fast repository since gpt-fast version that ao repo is providing as an example is outdated.
Here is the diff of enabling AutoQuant on top of the latest gpt-fast codebase.
But I'm getting the error mentioning "CUDA generator expects graph capture to be underway, but the current stream is not capturing".
(/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin) yejinlee@a100-st-p4de24xlarge-47:/fsx-checkpoints/yejinlee/gpt-fast$ python generate.py --compile --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --prompt "Hello, my name is" --quantization autoquant
uintx feature need torch 2.3+, please upgrade pytorch
Using device=cuda
Loading model ...
Time to load model: 39.35 seconds
activation_shapes: torch.Size([6, 4096]), times_seen: 1
activation_shapes: torch.Size([1, 4096]), times_seen: 199
weight_shape: torch.Size([12288, 4096]), dtype: torch.bfloat16, bias_shape: None
AUTOTUNE mm(6x4096, 4096x12288)
mm 0.0832 ms 100.0%
triton_mm_8 0.0840 ms 99.0%
triton_mm_6 0.0842 ms 98.8%
triton_mm_4 0.0857 ms 97.1%
triton_mm_3 0.0861 ms 96.6%
triton_mm_9 0.0879 ms 94.7%
triton_mm_5 0.0887 ms 93.8%
triton_mm_2 0.0944 ms 88.1%
triton_mm_1 0.0962 ms 86.5%
triton_mm_0 0.1044 ms 79.7%
SingleProcess AUTOTUNE takes 2.7937 seconds
warning: failed to autoquant AQFloatLinearWeight for shape: (torch.Size([6, 4096]), torch.Size([12288, 4096]), None, torch.bfloat16) due to CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Traceback (most recent call last):
File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 480, in <module>
main(
File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 354, in main
model.finalize_autoquant()
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 620, in finalize_autoquant
_change_autoquantizable_to_quantized(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 494, in _change_autoquantizable_to_quantized
_replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
[Previous line repeated 1 more time]
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 183, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 238, in insert_subclass
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 146, in to_quantized
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 94, in tune_autoquant
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
RuntimeError: CUDA generator expects graph capture to be underway, but the current stream is not capturing.
Attaching the env info here
torch 2.3.1+cu121
torchao 0.4.0+gitc2f44608
torchaudio 2.3.1+cu121
torchvision 0.18.1+cu121
Thanks for the help in advance!
msaroufim
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
msaroufim commentedon Aug 27, 2024
I'm actually pretty happy to see this bug report lol, we were hoping to enable torchao in gpt-fast and delete the broken quantization flows. We do need a fork of gpt-fast locally because we are making model changes and unfortunately there isnt' a good solution for us outside of occasionally syncing upstream.
So an action item for @HDCharles is to fix the existing code here in AO but @YJYJLee would you be open to contributing your patch to gpt-fast directly as well? We're doing a big launch on Sep 21 at the CUDA MODE IRL conference and were hoping to feature an integration with gpt-fast by then. Granted would highly recommend you try out PyTorch nightlies first
HDCharles commentedon Aug 27, 2024
I can look at it but in reality these are updates to TorchAO, not gpt-fast i.e. TorchAO's model/generate code is more up to date than gpt-fast, rather than vice versa