Skip to content

Commit dde8528

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add FusedLinearCrossEntropy (#2485)
Summary: As discussed in pytorch/pytorch#136168, I'm going to migrate implementations of operator benchmarking. This PR adds different implementations for FusedLinearCrossEntropy as a starting example. Execution command: ``` python run_benchmark.py triton --op FusedLinearCrossEntropy ``` Example output: ``` x_val LMHeadCE-latency LigerLMHeadCE-latency inductor_fused_linear_cross_entropy-latency ------- ------------------ ----------------------- --------------------------------------------- 0 98.0041 389.87 95.0412 1 196.12 652.619 193.219 2 417.242 1248.75 416.725 3 824.906 2356.25 809.56 ``` Pull Request resolved: #2485 Reviewed By: xuzhao9 Differential Revision: D63859871 Pulled By: FindHao fbshipit-source-id: 4b73a2144702c1f8f3ae5ed15e76112d03f12b87
1 parent a1f4b2e commit dde8528

File tree

4 files changed

+131
-0
lines changed

4 files changed

+131
-0
lines changed

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[build-system]
2+
# Use legacy backend to import local packages in setup.py
3+
build-backend = "setuptools.build_meta:__legacy__"
4+
5+
6+
[tool.black]
7+
line-length = 88
8+
target-version = ["py38"]
9+
exclude = '''/submodules/.*'''
10+
11+
[tool.usort]
12+
excludes = ["**/submodules/**"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import argparse
2+
from typing import Callable, Generator, List, Optional
3+
4+
import torch
5+
6+
from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark
7+
8+
try:
9+
from liger_kernel.transformers.fused_linear_cross_entropy import (
10+
LigerFusedLinearCrossEntropyLoss,
11+
)
12+
except ModuleNotFoundError:
13+
LigerFusedLinearCrossEntropyLoss = None
14+
15+
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
16+
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py
17+
18+
19+
def parse_op_args(args: List[str]):
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size")
22+
parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size")
23+
return parser.parse_args(args)
24+
25+
26+
class TorchLMHeadCE(torch.nn.Module):
27+
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
28+
29+
:param H: hidden size
30+
:param V: vocab size
31+
:param ignore_index: index to ignore
32+
:param reduction: reduction method
33+
"""
34+
35+
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
36+
super().__init__()
37+
self.lin = torch.nn.Linear(
38+
in_features=H, out_features=V, bias=False, dtype=dtype
39+
)
40+
self.ce_loss = torch.nn.CrossEntropyLoss(
41+
ignore_index=ignore_index, reduction="mean"
42+
)
43+
44+
def forward(self, input, target):
45+
logits = self.lin(input)
46+
return self.ce_loss(logits, target)
47+
48+
49+
class LigerLMHeadCE(torch.nn.Module):
50+
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
51+
super().__init__()
52+
self.lin = torch.nn.Linear(
53+
in_features=H, out_features=V, bias=False, dtype=dtype
54+
)
55+
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
56+
ignore_index=ignore_index, reduction="mean"
57+
)
58+
59+
def forward(self, input, target):
60+
return self.ce_loss(self.lin.weight, input, target)
61+
62+
63+
class Operator(BenchmarkOperator):
64+
def __init__(
65+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
66+
):
67+
super().__init__(tb_args, extra_args)
68+
op_args = parse_op_args(self.extra_args)
69+
self.hidden_size = op_args.hidden_size
70+
self.vocab_size = op_args.vocab_size
71+
self.baseline_model = TorchLMHeadCE(
72+
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
73+
).to(self.device)
74+
self.liger_model = LigerLMHeadCE(
75+
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
76+
).to(self.device)
77+
self.use_cuda_graphs = False
78+
79+
def get_input_iter(self) -> Generator:
80+
for BT in [2**i for i in range(12, 16)]:
81+
_input = torch.randn(
82+
BT,
83+
self.hidden_size,
84+
requires_grad=True,
85+
dtype=self.dtype,
86+
device=self.device,
87+
)
88+
target = torch.randint(
89+
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device
90+
).squeeze(1)
91+
yield _input, target
92+
93+
@register_benchmark(baseline=True)
94+
def LMHeadCE(self, input, target) -> Callable:
95+
return lambda: self.baseline_model(input, target)
96+
97+
@register_benchmark()
98+
def LigerLMHeadCE(self, input, target) -> Callable:
99+
return lambda: self.liger_model(input, target)
100+
101+
@register_benchmark()
102+
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
103+
compiled = torch.compile(self.baseline_model, dynamic=False)
104+
return lambda: compiled(input, target)
105+
106+
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
107+
y = fwd_fn()
108+
return lambda: y.backward(retain_graph=True)

userbenchmark/triton/install.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def install_fa3():
6666
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))
6767

6868

69+
def install_liger():
70+
# Liger-kernel has a conflict dependency `triton` with pytorch,
71+
# so we need to install it without dependencies
72+
cmd = ["pip", "install", "liger-kernel", "--no-deps"]
73+
subprocess.check_call(cmd)
74+
75+
6976
def install_tk():
7077
try:
7178
from .tk.install import install_tk
@@ -88,6 +95,7 @@ def install_tk():
8895
)
8996
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
9097
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
98+
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
9199
parser.add_argument("--test", action="store_true", help="Run test")
92100
args = parser.parse_args()
93101

@@ -105,3 +113,5 @@ def install_tk():
105113
install_jax()
106114
if args.tk and not args.test:
107115
install_tk()
116+
if args.liger and not args.test:
117+
install_liger()

0 commit comments

Comments
 (0)