Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] TypeError: 'constexpr' object is not iterable #138

Open
York-Cheung opened this issue Jan 23, 2025 · 5 comments
Open

[Bug] TypeError: 'constexpr' object is not iterable #138

York-Cheung opened this issue Jan 23, 2025 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@York-Cheung
Copy link

York-Cheung commented Jan 23, 2025

Describe the Bug

When I run Benchmark for retnet kernel with Triton 2.x (I have to use 2.x for some reason) :

# -*- coding: utf-8 -*-

import os

import torch
import triton

from fla.ops.retention import (fused_chunk_retention, chunk_retention)

try:
    from flash_attn import flash_attn_func
    HAS_FLASH = True
except BaseException:
    HAS_FLASH = False


@triton.testing.perf_report(
    triton.testing.Benchmark(
        # argument names to use as an x-axis for the plot
        x_names=['T'],
        # different possible values for `x_name`
        x_vals=[128 * 2 ** i for i in range(0, 8)],
        # argument name whose value corresponds to a different line in the plot
        line_arg='provider',
        # possible values for `line_arg``
        line_vals=['fused_chunk', 'fused_chunk_bwd', 'chunk', 'chunk_bwd'],
        # label name for the lines
        line_names=['fused_chunk', 'fused_chunk_bwd', 'chunk', 'chunk_bwd'],
        # line styles
        styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'),
                ('red', 'dotted')] + ([('cyan', '-'), ('cyan', 'dotted')] if HAS_FLASH else []),
        ylabel="Execution Time (ms)",  # label name for the y-axis
        # name for the plot. Used also as a file name for saving the plot.
        plot_name="Performance",
        args={},
    )
)
def benchmark(T, provider):
    device = 'cuda'
    dtype = torch.bfloat16
    requires_grad = True
    B, H, D = 4, 32, 2048
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

    if provider == 'flash' or provider == 'flash_bwd':
        q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
        k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
        v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype)
    else:
        q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
        k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
        v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype)
    do = torch.ones_like(q, dtype=dtype)

    quantiles = [0.5, 0.2, 0.8]
    results = 0, 0, 0
    if provider == 'fused_chunk':
        print("do fused_chunk")
        results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v), quantiles=quantiles)
    elif provider == 'fused_chunk_bwd':
        print("do fused_chunk_bwd")
        results = triton.testing.do_bench(lambda: fused_chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
    elif provider == 'chunk':
        print("do chunk")
        results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles)
    elif provider == 'chunk_bwd':
        print("do chunk_bwd")
        results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
    return results


if __name__ == '__main__':
    benchmark.run(print_data=True)

Error Message comes when do chunk_bwd:

Traceback (most recent call last):
  File "benchmark_retnet.py", line 73, in <module>
    benchmark.run(print_data=True)
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 321, in run
    self._run(bench, save_path, show_plots, print_data)
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 276, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
  File "benchmark_retnet.py", line 68, in benchmark
    results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 104, in do_bench
    fn()
  File "benchmark_retnet.py", line 68, in <lambda>
    results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.8/site-packages/fla/utils.py", line 18, in wrapper
    return fn(ctx,
  File "/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/simple_gla/chunk.py", line 204, in backward
    dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, head_first=head_first).to(g.dtype)
  File "/opt/conda/lib/python3.8/site-packages/fla/utils.py", line 18, in wrapper
    return fn(ctx,
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/utils/cumsum.py", line 407, in chunk_local_cumsum
    return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/utils/cumsum.py", line 252, in chunk_local_cumsum_scalar
    chunk_local_cumsum_scalar_kernel[grid](
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 232, in run
    return self.fn.run(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 114, in run
    ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "<string>", line 63, in chunk_local_cumsum_scalar_kernel
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/compiler.py", line 522, in compile
    next_module = compile_kernel(module)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/compiler.py", line 427, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 33:25:    if HEAD_FIRST:
        p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
    else:
        p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
    # [BT]
    b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
    b_o = tl.cumsum(b_s, axis=0)
    if REVERSE:
        b_z = tl.sum(b_s, axis=0)
        b_o = -b_o + b_z[None] + b_s
                         ^
TypeError("'constexpr' object is not iterable")

Steps to Reproduce the Bug

See the code.

Expected Behavior

Run the benchmark.

Environment Information

  1. Torch: 2.2
  2. Triton: 2.1
@York-Cheung York-Cheung added the bug Something isn't working label Jan 23, 2025
yzhangcs added a commit that referenced this issue Jan 23, 2025
@yzhangcs
Copy link
Member

@York-Cheung Hello, check out 7e0a972

@yzhangcs
Copy link
Member

We address this issue by converting the pointer offsets to int64 if necessary, as the product of H*T*D exceeds the capacity that can be accommodated by int32.

@yzhangcs yzhangcs self-assigned this Jan 23, 2025
@York-Cheung
Copy link
Author

I tried a small D = 64, and get the same err.
It seems the error come from the chunk_retention and cumsum kernel:

do chunk_bwd
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
    generator.visit(fn.parse())
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 293, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/opt/conda/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 614, in visit_If
    self.visit_compound_statement(node.body)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 462, in visit_BinOp
    lhs = self.visit(node.left)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 463, in visit_BinOp
    rhs = self.visit(node.right)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/opt/conda/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 737, in visit_Subscript
    return lhs.__getitem__(slices, _builder=self.builder)
  File "/opt/conda/lib/python3.8/site-packages/triton/language/core.py", line 30, in wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/language/core.py", line 733, in __getitem__
    for dim, sl in enumerate(slices):
TypeError: 'constexpr' object is not iterable

@yzhangcs
Copy link
Member

yzhangcs commented Jan 23, 2025

@York-Cheung Could you upgrade triton to >=3.0. We are removing support for triton 2.x.
3.0 works for me.

or you may try 2.2.

@York-Cheung
Copy link
Author

York-Cheung commented Jan 27, 2025

@yzhangcs triton==3.0 + py3.8 will lead to other errors. Triton 2.2 works for me when only using the chunk_retention kernel (✅) and will cause error when using the fused_recurrent_retention kernel (❎).

btw, the benchmark_gla.py runs successfully with triton 2.2 too.

The error message about fused_recurrent_retention kernel:

Traceback (most recent call last):
  File "./flash-linear-attention/benchmarks/ops/benchmark_retention.py", line 94, in <module>
    benchmark.run(print_data=True)
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 341, in run
    result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 287, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
  File "./flash-linear-attention/benchmarks/ops/benchmark_retention.py", line 87, in benchmark
    results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v, g), quantiles=quantiles)
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "./flash-linear-attention/benchmarks/ops/benchmark_retention.py", line 87, in <lambda>
    results = triton.testing.do_bench(lambda: fused_recurrent_retention(q, k, v, g), quantiles=quantiles)
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/retention/fused_recurrent.py", line 31, in fused_recurrent_retention
    return fused_recurrent_simple_gla(
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/simple_gla/fused_recurrent.py", line 97, in fused_recurrent_simple_gla
    o, final_state = fused_recurrent(
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/common/fused_recurrent.py", line 561, in fused_recurrent
    return FusedRecurrentFunction.apply(
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.8/site-packages/fla/utils.py", line 18, in wrapper
    return fn(ctx,
  File "/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/common/fused_recurrent.py", line 491, in forward
    o, ht = fused_recurrent_fwd(
  File "/opt/conda/lib/python3.8/site-packages/fla/ops/common/fused_recurrent.py", line 352, in fused_recurrent_fwd
    fused_recurrent_fwd_kernel[grid](
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 305, in run
    return self.fn.run(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
    self.fn.run(
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/jit.py", line 532, in run
    self.cache[device][key] = compile(
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/compiler.py", line 588, in compile
    next_module = compile_kernel(module)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/compiler.py", line 480, in <lambda>
    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler/code_generator.py", line 1242, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 73:66:
    mask_k = (i_k * BK + tl.arange(0, BK)) < K
    mask_v = (i_v * BV + tl.arange(0, BV)) < V
    mask_h = mask_k[None, :] & mask_v[:, None]
    b_h = tl.zeros([BV, BK], dtype=tl.float32)

    if USE_INITIAL_STATE:
        p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
        b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)

    for _ in range(0, T):
        b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
                                                                  ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<bf16> and triton.language.fp32')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants