Skip to content

[1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 #72

@functionstackx

Description

@functionstackx

Problem Description

Hi AMD team,

When trying to do FP8 Training on MI300X, it is extremely slower due to extremely high cpu overhead taking up more than 81% of the time. As you can see from the profile, most of the time is spent in CPU & doing hipFree. On GPT-2 XL 1.5B, TFLOP/s is at 22 TFLOP/s. This is 10x slower than mi300x bf16.

For Comparsion, On H100 GPT-2 XL 1.5B, FP8 makes it to be 1.3x faster than BF16 H100. Not slower.

The Reprod Script is attached Below & can be ran using NVTE_FUSED_ATTN_CK=0 python3 ./train.py

image

image

cc: @hliuca

Steps to Reproduce

Versions

root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241012+rocm6.2
torchvision             0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine      1.8.0.dev0+691dc23

Install Instructions

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2


WORKDIR /workspace/

RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942

RUN cd TransformerEngine && pip install .

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod GPT2 XL 1.5B Training

import contextlib

import torch
import torch.nn.functional as F
import torch.nn as nn

from pydantic.dataclasses import dataclass

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    @staticmethod
    def estimate_flops_per_token(model, config):
        # get param count
        N = sum(p.numel() for p in model.parameters())
        
        # print param count in B
        print(f"Param count: {N/1e9}B")
                 
        head_dim = config['d_embd'] // config['n_heads'] 
         
        flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']
        
        return flops_per_token

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        
        
        # self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        import transformer_engine.pytorch as te
        self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
                    d_embd,
                    d_embd * 4,
                    kwargs['n_heads'],
                    layer_number=i+1,
                    # Optional, for speedups
                    fuse_qkv_params=True,
                    attn_input_format='bshd'
                ) 
                for i in range(n_layers)                       
                )
        
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def train(
    gpu_id: int = 0,
    bsz: int = 8,
    grad_acc_steps: int = 8,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(gpu_id)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    cfg_m = GPTConfig(**cfg_json)
    model = GPT(**cfg_json).to(gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
    flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2600e12

    model.train()
    
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    fp8_format = Format.HYBRID
    # Reasonable default setting
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    # Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
    @contextlib.contextmanager
    def ctx():
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                yield

    with ctx():
         for step_idx in range(100):
            input_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            label_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
            loss /= grad_acc_steps
            loss.backward()

            if (step_idx + 1) % grad_acc_steps == 0:  # Assume n_steps % grad_acc_steps == 0
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            print(f'{(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')

if __name__ == '__main__':
    import fire
    fire.Fire(train)

Operating System

Ubuntu

CPU

AMD CPU

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions