|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from torchao.quantization.utils import _torchtitan_available, _fbgemm_available |
| 4 | + |
| 5 | +grouped_gemm_fp8_rowwise = None |
| 6 | +if _fbgemm_available: |
| 7 | + try: |
| 8 | + from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import grouped_gemm_fp8_rowwise |
| 9 | + except: |
| 10 | + pass |
| 11 | + |
| 12 | +__all__ = ["fp8_dq_moe_op", |
| 13 | + "manual_pad", |
| 14 | + "torchtitan_pad", |
| 15 | + ] |
| 16 | + |
| 17 | + |
| 18 | +def fp8_dq_moe_op(input, w1, w2, w3, expert_indices, scores, fast_accum=True, use_fbgemm_kernel=False): |
| 19 | + # parameters |
| 20 | + orig_in_shape = input.shape |
| 21 | + input.reshape(-1, orig_in_shape[-1]) |
| 22 | + num_tokens, dim = input.shape |
| 23 | + num_experts, expert_dim, _ = w1.shape |
| 24 | + scores = scores.view(-1, scores.shape[-1]) |
| 25 | + top_k = scores.shape[-1] |
| 26 | + total_activations = num_tokens*top_k |
| 27 | + |
| 28 | + # preprocess indices |
| 29 | + expert_indices = expert_indices.view(-1) |
| 30 | + activation_shuffle = expert_indices.argsort(stable=True) |
| 31 | + token_shuffle = activation_shuffle.div(top_k).floor().to(torch.int64) |
| 32 | + num_tokens_per_expert = torch.histc(expert_indices, bins=num_experts, min=0, max=num_experts) |
| 33 | + |
| 34 | + # padding |
| 35 | + alignment = 16 |
| 36 | + if _torchtitan_available: |
| 37 | + num_ranks = 1 |
| 38 | + padded_indices, m_offsets = torchtitan_pad(num_tokens_per_expert, alignment, num_ranks) |
| 39 | + else: |
| 40 | + padded_indices, m_offsets = manual_pad(num_tokens_per_expert, alignment) |
| 41 | + |
| 42 | + pad_len = padded_indices.shape[0] |
| 43 | + valid_values = padded_indices >= 0 |
| 44 | + |
| 45 | + # get data for weights |
| 46 | + w1_fp8 = w1.original_weight_tensor.tensor_impl.float8_data |
| 47 | + w1_scale = w1.original_weight_tensor.tensor_impl.scale.squeeze() |
| 48 | + w1_qfunc = w1.input_quant_func |
| 49 | + w1_quant_kwargs = w1.quant_kwargs |
| 50 | + |
| 51 | + w3_fp8 = w3.original_weight_tensor.tensor_impl.float8_data |
| 52 | + w3_scale = w3.original_weight_tensor.tensor_impl.scale.squeeze() |
| 53 | + |
| 54 | + w2_fp8 = w2.original_weight_tensor.tensor_impl.float8_data |
| 55 | + w2_scale = w2.original_weight_tensor.tensor_impl.scale.squeeze() |
| 56 | + w2_qfunc = w2.input_quant_func |
| 57 | + w2_quant_kwargs = w2.quant_kwargs |
| 58 | + |
| 59 | + |
| 60 | + # quantize then shuffle input |
| 61 | + q_input = w1_qfunc(input, **w1_quant_kwargs) |
| 62 | + q_input_data = q_input.tensor_impl.float8_data |
| 63 | + q_input_scale = q_input.tensor_impl.scale.squeeze() |
| 64 | + input_fp8 = torch.zeros((pad_len, q_input_data.shape[-1]), dtype=q_input_data.dtype, device=q_input_data.device) |
| 65 | + input_scale = torch.zeros(pad_len, dtype=q_input_scale.dtype, device=q_input_scale.device) |
| 66 | + input_fp8[valid_values] = q_input_data[token_shuffle] |
| 67 | + input_scale[valid_values] = q_input_scale[token_shuffle] if q_input_scale.numel()>1 else q_input_scale |
| 68 | + |
| 69 | + if use_fbgemm_kernel: |
| 70 | + assert grouped_gemm_fp8_rowwise is not None, "fbgemm kernel requires fbgemm-gpu-genai to be installed: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/README.md" |
| 71 | + y1 = grouped_gemm_fp8_rowwise(input_fp8, w1_fp8.reshape(-1, w1_fp8.shape[-1]), m_offsets, input_scale, w1_scale.reshape(-1), use_fast_accum=True) |
| 72 | + y3 = grouped_gemm_fp8_rowwise(input_fp8, w3_fp8.reshape(-1, w3_fp8.shape[-1]), m_offsets, input_scale, w3_scale.reshape(-1), use_fast_accum=True) |
| 73 | + |
| 74 | + y = F.silu(y1)*y3 |
| 75 | + |
| 76 | + y_q = w2_qfunc(y, **w2_quant_kwargs) |
| 77 | + |
| 78 | + y_fp8 = y_q.tensor_impl.float8_data |
| 79 | + y_scale = y_q.tensor_impl.scale.squeeze() |
| 80 | + out = grouped_gemm_fp8_rowwise(y_fp8, w2_fp8.view(-1, w1_fp8.shape[-1]), m_offsets, y_scale, w2_scale.view(-1), use_fast_accum=fast_accum) |
| 81 | + # unpad and combine output with weights |
| 82 | + out = out[valid_values] |
| 83 | + sorted_scores = scores.reshape(-1,1)[activation_shuffle] |
| 84 | + out = out*sorted_scores |
| 85 | + |
| 86 | + # sum weighted outputs |
| 87 | + final_out = torch.zeros_like(input) |
| 88 | + final_out = final_out.scatter_add( |
| 89 | + dim=0, |
| 90 | + index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64), |
| 91 | + src=out |
| 92 | + ) |
| 93 | + final_out = final_out.reshape(orig_in_shape) |
| 94 | + return final_out |
| 95 | + |
| 96 | + else: |
| 97 | + y1 = torch._scaled_grouped_mm(input_fp8, w1_fp8.transpose(-2, -1), input_scale, w1_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) |
| 98 | + y3 = torch._scaled_grouped_mm(input_fp8, w3_fp8.transpose(-2, -1), input_scale, w3_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) |
| 99 | + y = F.silu(y1)*y3 |
| 100 | + y_q = w2_qfunc(y, **w2_quant_kwargs) |
| 101 | + |
| 102 | + y_fp8 = y_q.tensor_impl.float8_data |
| 103 | + y_scale = y_q.tensor_impl.scale.squeeze() |
| 104 | + out = torch._scaled_grouped_mm(y_fp8, w2_fp8.transpose(-2, -1), y_scale, w2_scale, offs=m_offsets, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) |
| 105 | + |
| 106 | + # unpad and combine output with weights |
| 107 | + out = out[valid_values] |
| 108 | + sorted_scores = scores.reshape(-1,1)[activation_shuffle] |
| 109 | + out = out*sorted_scores |
| 110 | + |
| 111 | + # sum weighted outputs |
| 112 | + final_out = torch.zeros_like(input) |
| 113 | + final_out = final_out.scatter_add( |
| 114 | + dim=0, |
| 115 | + index=token_shuffle.unsqueeze(-1).expand(total_activations, dim).to(torch.int64), |
| 116 | + src=out |
| 117 | + ) |
| 118 | + final_out = final_out.reshape(orig_in_shape) |
| 119 | + return final_out |
| 120 | + |
| 121 | +def torchtitan_pad(num_tokens_per_expert, alignment, num_ranks): |
| 122 | + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices |
| 123 | + num_experts = num_tokens_per_expert.shape[0] |
| 124 | + |
| 125 | + # pad to nearest multiple of alignment that's greater than 0 |
| 126 | + padded_sizes = (((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment) |
| 127 | + pad_len = int(padded_sizes.sum().item()) |
| 128 | + |
| 129 | + padded_indices, _, m_offsets = generate_permute_indices( |
| 130 | + num_tokens_per_expert, |
| 131 | + num_experts, |
| 132 | + num_ranks, |
| 133 | + pad_len, |
| 134 | + alignment, |
| 135 | + use_cpu=False |
| 136 | + ) |
| 137 | + return padded_indices, m_offsets |
| 138 | + |
| 139 | +def manual_pad(num_tokens_per_expert, alignment): |
| 140 | + num_experts = num_tokens_per_expert.shape[0] |
| 141 | + |
| 142 | + padded_sizes = (((num_tokens_per_expert + (num_tokens_per_expert==0))/alignment).ceil() * alignment) |
| 143 | + pad_len = int(padded_sizes.sum().item()) |
| 144 | + |
| 145 | + padded_indices = torch.zeros(pad_len, dtype=torch.int32, device=num_tokens_per_expert.device)-1 |
| 146 | + start_tok_index = 0 |
| 147 | + start_pad_index = 0 |
| 148 | + for i in range(num_experts): |
| 149 | + end_tok_index = int(start_tok_index+num_tokens_per_expert[i].item()) |
| 150 | + end_pad_index = int(start_pad_index+num_tokens_per_expert[i].item()) |
| 151 | + padded_indices[start_pad_index:end_pad_index] = torch.arange(start_tok_index, end_tok_index, dtype=torch.int32, device=num_tokens_per_expert.device) |
| 152 | + start_tok_index = end_tok_index |
| 153 | + start_pad_index = start_pad_index + int(padded_sizes[i].item()) |
| 154 | + m_offsets = padded_sizes.cumsum(0).to(torch.int32) |
| 155 | + return padded_indices, m_offsets |
0 commit comments