Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Hacks for the ROCm port #1314

Closed
wants to merge 17 commits into from

Conversation

pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Oct 10, 2023

This is getting tests working on top of #1313

Currently the following tests are working:

  • kernels/test_activation.py
  • kernels/test_cache.py
  • kernels/test_layernorm.py
  • kernels/test_pos_encoding.py

Currently, the following test is failing:

  • kernels/test_attention.py with
_________________________________________________________________________________ test_single_query_cached_kv_attention[0-dtype0-8-False-64-num_heads0-7] __________________________________________________________________________________

kv_cache_factory = <function create_kv_caches at 0x7f359bcb6700>, num_seqs = 7, num_heads = (40, 40), head_size = 64, use_alibi = False, block_size = 8, dtype = torch.float16, seed = 0

    @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
    @pytest.mark.parametrize("num_heads", NUM_HEADS)
    @pytest.mark.parametrize("head_size", HEAD_SIZES)
    @pytest.mark.parametrize("use_alibi", USE_ALIBI)
    @pytest.mark.parametrize("block_size", BLOCK_SIZES)
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("seed", SEEDS)
    @torch.inference_mode()
    def test_single_query_cached_kv_attention(
        kv_cache_factory,
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        use_alibi: bool,
        block_size: int,
        dtype: torch.dtype,
        seed: int,
    ) -> None:
        random.seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
    
        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        query = torch.empty(num_seqs,
                            num_query_heads,
                            head_size,
                            dtype=dtype,
                            device="cuda")
        query.uniform_(-scale, scale)
    
        assert num_query_heads % num_kv_heads == 0
        num_queries_per_kv = num_query_heads // num_kv_heads
        head_mapping = torch.repeat_interleave(
            torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
            num_queries_per_kv)
        alibi_slopes = None
        if use_alibi:
            alibi_slopes = torch.randn(num_query_heads,
                                       dtype=torch.float,
                                       device="cuda")
    
        context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
        context_lens[-1] = MAX_SEQ_LEN
        max_context_len = max(context_lens)
        context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
    
        # Create the block tables.
        max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
        block_tables = []
        for _ in range(num_seqs):
            block_table = [
                random.randint(0, NUM_BLOCKS - 1)
                for _ in range(max_num_blocks_per_seq)
            ]
            block_tables.append(block_table)
        block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
    
        # Create the KV caches.
        key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                    num_kv_heads, head_size, dtype,
                                                    seed)
        key_cache, value_cache = key_caches[0], value_caches[0]
    
        # Call the paged attention kernel.
        output = torch.empty_like(query)
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    
        # Run the reference implementation.
        ref_output = torch.empty_like(query)
        ref_single_query_cached_kv_attention(
            ref_output,
            query,
            num_queries_per_kv,
            key_cache,
            value_cache,
            block_tables,
            context_lens,
            scale,
            alibi_slopes,
        )
    
        # NOTE(woosuk): Due to the kernel-level differences in the two
        # implementations, there is a small numerical difference in the two
        # outputs. Thus, we use a relaxed tolerance for the test.
>       assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7f37adc8c980>(tensor([[[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        ...,\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]],\n\n        [[inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         ...,\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf],\n         [inf, inf, inf,  ..., inf, inf, inf]]], device='cuda:0',\n       dtype=torch.float16), tensor([[[-2.9540e-04, -9.1505e-04, -5.2643e-03,  ..., -4.1389e-03,\n           3.7746e-03, -1.1806e-03],\n         [-2.3403e-03, -6.4545e-03,  1.9197e-03,  ..., -1.6556e-03,\n          -5.9319e-03, -4.0741e-03],\n         [-5.5981e-04,  2.6398e-03, -1.4648e-03,  ..., -7.0047e-04,\n           4.5547e-03, -1.5097e-03],\n         ...,\n         [-3.6411e-03, -2.8057e-03,  1.7796e-03,  ...,  1.0653e-03,\n          -1.0586e-03, -1.0653e-03],\n         [ 3.5915e-03,  2.6798e-03,  1.4706e-03,  ..., -2.3212e-03,\n           5.5008e-03,  3.4657e-03],\n         [-9.9659e-04,  8.4543e-04,  5.0163e-03,  ...,  1.5497e-03,\n          -4.0169e-03,  1.8406e-03]],\n\n        [[-1.2970e-04, -3.5572e-04, -2.5806e-03,  ..., -5.5542e-03,\n           3.1090e-03, -1.2016e-03],\n         [-2.1477e-03, -5.9204e-03,  1.9283e-03,  ..., -2.2907e-03,\n          -4.3106e-03, -1.1225e-03],\n         [-1.8883e-03,  1.8559e-03, -8.9836e-04,  ..., -1.3151e-03,\n           4.8218e-03,  3.4571e-04],\n         ...,\n         [-1.7996e-03, -5.6410e-04,  6.3360e-05,  ..., -1.0090e-03,\n          -4.7636e-04, -1.4601e-03],\n         [-1.9312e-04,  1.5726e-03, -6.9189e-04,  ..., -1.0176e-03,\n           4.1847e-03,  3.1605e-03],\n         [-2.1610e-03, -1.2455e-03,  5.3596e-03,  ...,  5.1069e-04,\n          -2.6073e-03, -1.6487e-04]],\n\n        [[-2.5043e-03, -4.9896e-03, -2.0828e-03,  ..., -5.8708e-03,\n           3.5019e-03, -3.8357e-03],\n         [ 2.5215e-03, -7.4043e-03,  1.6594e-04,  ..., -1.8339e-03,\n          -1.7347e-03, -2.3880e-03],\n         [ 4.2419e-03, -1.0729e-03, -4.2648e-03,  ..., -5.4512e-03,\n           1.0338e-02, -4.6959e-03],\n         ...,\n         [-6.6471e-04, -2.9659e-03, -5.2452e-04,  ...,  7.2908e-04,\n           5.1613e-03,  1.3485e-03],\n         [-1.4277e-03, -3.2883e-03,  3.8509e-03,  ..., -1.8845e-03,\n           5.3673e-03,  2.2583e-03],\n         [-2.8706e-03, -5.9938e-04,  8.8654e-03,  ...,  3.7251e-03,\n          -7.4425e-03, -5.9700e-03]],\n\n        ...,\n\n        [[ 7.9060e-04, -6.0892e-04, -4.1847e-03,  ..., -4.3831e-03,\n           3.0994e-03, -3.0003e-03],\n         [-2.2011e-03, -4.2000e-03,  6.2370e-04,  ..., -1.1024e-03,\n          -4.9896e-03, -2.7027e-03],\n         [-1.0033e-03,  3.2692e-03, -2.0065e-03,  ...,  1.0080e-03,\n           3.9978e-03, -1.6823e-03],\n         ...,\n         [-8.6689e-04, -1.4143e-03, -7.7438e-04,  ..., -2.0850e-04,\n           3.3438e-05, -4.0321e-03],\n         [ 1.0900e-03,  2.3079e-03,  1.1129e-03,  ...,  3.2353e-04,\n           3.8033e-03,  1.0681e-03],\n         [-1.4133e-03,  3.9697e-04,  4.8561e-03,  ...,  3.7289e-04,\n          -4.6120e-03,  1.6661e-03]],\n\n        [[-4.1604e-04, -1.8415e-03, -6.8817e-03,  ..., -4.1161e-03,\n           4.0970e-03, -4.7302e-03],\n         [-1.5364e-03, -6.6071e-03,  9.3699e-04,  ..., -2.3117e-03,\n          -5.2147e-03, -3.1834e-03],\n         [-1.6203e-03,  2.2907e-03, -2.0943e-03,  ...,  4.9496e-04,\n           3.9597e-03, -7.8249e-04],\n         ...,\n         [-1.4238e-03, -9.6035e-04,  2.8515e-04,  ..., -1.2674e-03,\n          -1.0672e-03, -1.9522e-03],\n         [ 1.9321e-03,  1.5774e-03,  8.1062e-06,  ..., -1.0405e-03,\n           5.6992e-03,  2.0828e-03],\n         [-2.5749e-03,  2.3186e-04,  5.7220e-03,  ...,  2.5730e-03,\n          -3.7003e-03,  2.2049e-03]],\n\n        [[-7.6294e-05,  4.0293e-05, -3.2806e-03,  ..., -4.9210e-03,\n           2.8820e-03, -2.0123e-03],\n         [-1.8377e-03, -6.0577e-03,  3.4370e-03,  ..., -1.4210e-03,\n          -5.4855e-03, -2.2049e-03],\n         [-1.9875e-03,  3.8471e-03, -2.9125e-03,  ..., -5.3453e-04,\n           6.4545e-03,  2.1338e-04],\n         ...,\n         [-1.2665e-03, -6.2466e-04,  1.8396e-03,  ..., -3.7932e-04,\n           6.3181e-04, -2.3403e-03],\n         [ 1.8797e-03,  1.2455e-03,  2.7514e-04,  ..., -9.2411e-04,\n           2.8858e-03,  2.9793e-03],\n         [-4.8923e-04,  5.1618e-05,  4.1428e-03,  ...,  1.1559e-03,\n          -2.8362e-03,  3.9363e-04]]], device='cuda:0', dtype=torch.float16), atol=0.001, rtol=1e-05)
E        +    where <built-in method allclose of type object at 0x7f37adc8c980> = torch.allclose

@iAmir97
Copy link
Contributor

iAmir97 commented Oct 18, 2023

@pcmoritz I'd like to submit a pull request to replace the usage of HIP API for type conversions for Float16 with ASM volatile instructions.

Additionally, I've modified the setup.py file to check for the availability of ROCMHOME, and if it is set, the flags are updated to include -DUSE_ROCM.

Would appreciate it if you could review the pull request and let me know if there are any issues or concerns. Thanks!

@WoosukKwon WoosukKwon added the rocm Related to AMD ROCm label Dec 1, 2023
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Dec 10, 2023

Closed as we merged #1836 which is a superset of this PR. @pcmoritz Thanks for the amazing work!

@WoosukKwon WoosukKwon closed this Dec 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants