Skip to content

Commit f7cd64c

Browse files
committed
[not for land] testing out float8 128_1_128_128 blockwise scaling
Summary: Test drive of pytorch/ao#2386, not for land Test Plan: ```bash with-proxy CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.converters float8 --model.print_after_conversion ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 1ab4353 commit f7cd64c

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,19 @@ def convert(self, model: nn.Module):
101101
if not self.enabled:
102102
return
103103

104+
from torchao.quantization import quantize_
105+
from torchao.prototype.deep_gemm_float8_training.linear import (
106+
DeepGemmFloat8LinearConfig,
107+
)
108+
109+
quantize_(
110+
model,
111+
config=DeepGemmFloat8LinearConfig(),
112+
filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "output",
113+
)
114+
logger.info("enabled DeepGemm dense training")
115+
return
116+
104117
from torchao.float8 import convert_to_float8_training
105118

106119
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear

0 commit comments

Comments
 (0)