Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some question about Fp8 Performance profile #39

Open
JimpleM opened this issue Jan 20, 2025 · 0 comments
Open

some question about Fp8 Performance profile #39

JimpleM opened this issue Jan 20, 2025 · 0 comments

Comments

@JimpleM
Copy link

JimpleM commented Jan 20, 2025

I have use nsight compute to test the example code(tma_gemm.py as follow),but got the result is about 695.29GB/s

  • inst number of TMA: 19.97K (samed as the result)
  • L2 to Shared Memory: 2.70TB/s(not same)

my environment is following:

pytorch-triton           3.2.0+git0d4682f0
torch                    2.7.0.dev20250116+cu126
torchaudio               2.6.0.dev20250116+cu126
torchvision              0.22.0.dev20250116+cu126
#tma_gemm.py
import triton
import triton.language as tl
import numpy as np
import torch

@triton.jit
def gemm_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr,  #
                      prob_m, prob_n, prob_k, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):
    
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(prob_m, block_m)
    num_pid_k = tl.cdiv(prob_k, block_k)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m
    offs_am = pid_m * block_m
    offs_bn = pid_n * block_n
    offs_k = 0

    accumulator = tl.zeros((block_m, block_n), dtype=tl.float32)
    for kk in range(0, num_pid_k):

        a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [block_m, block_k], tl.float8e4nv)
        b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [block_n, block_k], tl.float8e4nv)
        
        accumulator = tl.dot(a, b.T, acc=accumulator, out_dtype=tl.float32)
        offs_k += block_k

    accumulator = accumulator.to(tl.float16)
    tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])


def matmul(a, b, config=None):

    m, _ = a.shape
    n, k = b.shape
    
    block_m = 64
    block_n = 64
    block_k = 128
    num_warps = 4
    num_stages = 4
    TMA_SIZE = 128

    if config:
        block_m = config["block_m"]
        block_n = config["block_n"]
        block_k = config["block_k"]
        num_warps = config["num_warps"]
        num_stages = config["num_stages"]
        TMA_SIZE = config["TMA_SIZE"]

    print(block_m,block_n,block_k,num_warps,num_stages,TMA_SIZE)

    desc_a = np.empty(TMA_SIZE, dtype=np.int8)
    desc_b = np.empty(TMA_SIZE, dtype=np.int8)
    desc_c = np.empty(TMA_SIZE, dtype=np.int8)

    c = torch.empty((m, n), dtype=torch.float16, device='cuda')
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(a.data_ptr(), m, k, block_m, block_k, a.element_size(),desc_a)
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(b.data_ptr(), n, k, block_n, block_k, b.element_size(),desc_b)
    triton.runtime.driver.active.utils.fill_2d_tma_descriptor(c.data_ptr(), m, n, block_m, block_n, c.element_size(),desc_c)
    desc_a = torch.tensor(desc_a, device='cuda')
    desc_b = torch.tensor(desc_b, device='cuda')
    desc_c = torch.tensor(desc_c, device='cuda')

    total_blocks_m = triton.cdiv(m, block_m)
    total_blocks_n = triton.cdiv(n, block_n)
    
    grid = (total_blocks_m * total_blocks_n, 1, 1)
    k = gemm_kernel_tma[grid](
        desc_a, desc_b, desc_c,
        m, n, k,
        block_m,
        block_n,
        block_k,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    # with open('tma_fp8.ttgir', 'w') as f:
    #      print(k.asm['ttgir'], file=f)

    # with open('tma_fp8.ptx', 'w') as f:
    #      print(k.asm['ptx'], file=f)

    return c


if __name__ == '__main__':

    M = 128
    N = 4096
    K = 4096

    torch.manual_seed(0)
    a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
    b = b.T.contiguous()

    config = {
        "block_m":64,
        "block_n":64,
        "block_k":256,
        "num_warps":4,
        "num_stages":4,
        "TMA_SIZE":512
    }

    c = matmul(a, b,config=config)
    print(c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant