|
| 1 | +import argparse |
| 2 | +from typing import Callable, Generator, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark |
| 7 | + |
| 8 | +try: |
| 9 | + from liger_kernel.transformers.fused_linear_cross_entropy import ( |
| 10 | + LigerFusedLinearCrossEntropyLoss, |
| 11 | + ) |
| 12 | +except ModuleNotFoundError: |
| 13 | + LigerFusedLinearCrossEntropyLoss = None |
| 14 | + |
| 15 | +# Reference: https://github.com/linkedin/Liger-Kernel/blob/\ |
| 16 | +# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py |
| 17 | + |
| 18 | + |
| 19 | +def parse_op_args(args: List[str]): |
| 20 | + parser = argparse.ArgumentParser() |
| 21 | + parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size") |
| 22 | + parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size") |
| 23 | + return parser.parse_args(args) |
| 24 | + |
| 25 | + |
| 26 | +class TorchLMHeadCE(torch.nn.Module): |
| 27 | + """Ground truth implementation of the linear fused with torch based cross entropy loss. |
| 28 | +
|
| 29 | + :param H: hidden size |
| 30 | + :param V: vocab size |
| 31 | + :param ignore_index: index to ignore |
| 32 | + :param reduction: reduction method |
| 33 | + """ |
| 34 | + |
| 35 | + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): |
| 36 | + super().__init__() |
| 37 | + self.lin = torch.nn.Linear( |
| 38 | + in_features=H, out_features=V, bias=False, dtype=dtype |
| 39 | + ) |
| 40 | + self.ce_loss = torch.nn.CrossEntropyLoss( |
| 41 | + ignore_index=ignore_index, reduction="mean" |
| 42 | + ) |
| 43 | + |
| 44 | + def forward(self, input, target): |
| 45 | + logits = self.lin(input) |
| 46 | + return self.ce_loss(logits, target) |
| 47 | + |
| 48 | + |
| 49 | +class LigerLMHeadCE(torch.nn.Module): |
| 50 | + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): |
| 51 | + super().__init__() |
| 52 | + self.lin = torch.nn.Linear( |
| 53 | + in_features=H, out_features=V, bias=False, dtype=dtype |
| 54 | + ) |
| 55 | + self.ce_loss = LigerFusedLinearCrossEntropyLoss( |
| 56 | + ignore_index=ignore_index, reduction="mean" |
| 57 | + ) |
| 58 | + |
| 59 | + def forward(self, input, target): |
| 60 | + return self.ce_loss(self.lin.weight, input, target) |
| 61 | + |
| 62 | + |
| 63 | +class Operator(BenchmarkOperator): |
| 64 | + def __init__( |
| 65 | + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None |
| 66 | + ): |
| 67 | + super().__init__(tb_args, extra_args) |
| 68 | + op_args = parse_op_args(self.extra_args) |
| 69 | + self.hidden_size = op_args.hidden_size |
| 70 | + self.vocab_size = op_args.vocab_size |
| 71 | + self.baseline_model = TorchLMHeadCE( |
| 72 | + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype |
| 73 | + ).to(self.device) |
| 74 | + self.liger_model = LigerLMHeadCE( |
| 75 | + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype |
| 76 | + ).to(self.device) |
| 77 | + self.use_cuda_graphs = False |
| 78 | + |
| 79 | + def get_input_iter(self) -> Generator: |
| 80 | + for BT in [2**i for i in range(12, 16)]: |
| 81 | + _input = torch.randn( |
| 82 | + BT, |
| 83 | + self.hidden_size, |
| 84 | + requires_grad=True, |
| 85 | + dtype=self.dtype, |
| 86 | + device=self.device, |
| 87 | + ) |
| 88 | + target = torch.randint( |
| 89 | + self.vocab_size, (BT, 1), dtype=torch.long, device=self.device |
| 90 | + ).squeeze(1) |
| 91 | + yield _input, target |
| 92 | + |
| 93 | + @register_benchmark(baseline=True) |
| 94 | + def LMHeadCE(self, input, target) -> Callable: |
| 95 | + return lambda: self.baseline_model(input, target) |
| 96 | + |
| 97 | + @register_benchmark() |
| 98 | + def LigerLMHeadCE(self, input, target) -> Callable: |
| 99 | + return lambda: self.liger_model(input, target) |
| 100 | + |
| 101 | + @register_benchmark() |
| 102 | + def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: |
| 103 | + compiled = torch.compile(self.baseline_model, dynamic=False) |
| 104 | + return lambda: compiled(input, target) |
| 105 | + |
| 106 | + def get_bwd_fn(self, fwd_fn: Callable) -> Callable: |
| 107 | + y = fwd_fn() |
| 108 | + return lambda: y.backward(retain_graph=True) |
0 commit comments