Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions dlblas/kernels/kernelagent/level1/100_HingeLoss
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch.nn as nn
import triton
import triton.language as tl

@triton.jit
def hinge_loss_kernel(
predictions_ptr,
targets_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

p = tl.load(predictions_ptr + offsets, mask=mask)
t = tl.load(targets_ptr + offsets, mask=mask)

element = 1.0 - p * t
clamped = tl.where(element > 0, element, 0.0)

total = tl.sum(clamped, axis=0)
if pid == 0:
mean_val = total / n_elements
tl.store(output_ptr, mean_val)

class ModelNew(nn.Module):
def __init__(self):
super(ModelNew, self).__init__()

def forward(self, predictions, targets):
total_elements = predictions.numel()
if total_elements == 0:
return torch.tensor(0.0, device=predictions.device)

predictions_flat = predictions.view(-1)
targets_flat = targets.view(-1)
output = torch.empty(1, device=predictions.device)

grid = (1,)
BLOCK_SIZE = triton.next_power_of_2(total_elements)
hinge_loss_kernel[grid](
predictions_flat,
targets_flat,
output,
total_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.squeeze(0)

batch_size = 128
input_shape = (1,)
dim = 1

def get_inputs():
return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]

def get_init_inputs():
return []
107 changes: 107 additions & 0 deletions dlblas/kernels/kernelagent/level1/10_3D_tensor_matrix_multiplication
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
import torch.nn as nn
import triton
import triton.language as tl

@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 128, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_L': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_L': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 64, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
],
key=['M', 'K', 'L'],
)
@triton.jit
def _matmul_kernel(
A_ptr, B_ptr, C_ptr,
N, M, K, L,
stride_An, stride_Am, stride_Ak,
stride_Bk, stride_Bl,
stride_Cn, stride_Cm, stride_Cl,
BLOCK_M: tl.constexpr, BLOCK_L: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
pid_l = tl.program_id(2)

if pid_n >= N:
return

# Create block offsets with proper masking
m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
k_offsets = tl.arange(0, BLOCK_K)

# Initialize accumulator
acc = tl.zeros((BLOCK_M, BLOCK_L), dtype=tl.float32)

# Compute pointer bases
a_base = A_ptr + pid_n * stride_An
b_base = B_ptr
c_base = C_ptr + pid_n * stride_Cn

# Blocked matrix multiplication
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Compute current K block
k_start = k * BLOCK_K

# Load A block with coalesced access
a_ptrs = a_base + m_offsets[:, None] * stride_Am + (k_start + k_offsets[None, :]) * stride_Ak
a_mask = (m_offsets[:, None] < M) & ((k_start + k_offsets[None, :]) < K)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)

# Load B block with coalesced access
b_ptrs = b_base + (k_start + k_offsets[:, None]) * stride_Bk + l_offsets[None, :] * stride_Bl
b_mask = ((k_start + k_offsets[:, None]) < K) & (l_offsets[None, :] < L)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)

# Accumulate matrix product with full FP32 precision
acc += tl.dot(a, b, allow_tf32=False)

# Store result with masking
c_ptrs = c_base + m_offsets[:, None] * stride_Cm + l_offsets[None, :] * stride_Cl
c_mask = (m_offsets[:, None] < M) & (l_offsets[None, :] < L)
tl.store(c_ptrs, acc, mask=c_mask)

class ModelNew(nn.Module):
def __init__(self):
super(ModelNew, self).__init__()

def forward(self, A, B):
N, M, K = A.shape
L = B.shape[1]
A = A.contiguous()
B = B.contiguous()
C = torch.empty((N, M, L), device=A.device, dtype=A.dtype)

# Dynamic grid using autotuner meta-parameters
grid = lambda meta: (
N,
triton.cdiv(M, meta['BLOCK_M']),
triton.cdiv(L, meta['BLOCK_L'])
)

# Launch kernel without overriding autotuned parameters
_matmul_kernel[grid](
A, B, C,
N, M, K, L,
A.stride(0), A.stride(1), A.stride(2),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1), C.stride(2),
)
return C

N = 16
M = 1024
K = 2048
L = 768

def get_inputs():
A = torch.randn(N, M, K, device='cuda')
B = torch.randn(K, L, device='cuda')
return [A, B]

def get_init_inputs():
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import torch.nn as nn
import triton
import triton.language as tl

@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _triton_matmul(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_mask = (rm[:, None] < M) & (rk[None, :] < K)
b_mask = (rk[:, None] < K) & (rn[None, :] < N)

a = tl.load(a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak,
mask=a_mask, other=0.0)
b = tl.load(b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn,
mask=b_mask, other=0.0)
acc += tl.dot(a, b, allow_tf32=False, out_dtype=tl.float32)

c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
c_mask = (rm[:, None] < M) & (rn[None, :] < N)
tl.store(c_ptrs, acc, mask=c_mask)

class ModelNew(nn.Module):
def __init__(self):
super(ModelNew, self).__init__()

def forward(self, A, B):
b, i, j, l = A.shape
k = B.shape[1]

# Ensure contiguous memory layout
A_flat = A.reshape(-1, l).contiguous()
B = B.contiguous()

M, K = A_flat.shape
N = k

C_flat = torch.empty((M, N), device=A.device, dtype=A.dtype)

grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']),
triton.cdiv(N, META['BLOCK_SIZE_N']),
)

_triton_matmul[grid](
A_flat, B, C_flat,
M, N, K,
A_flat.stride(0), A_flat.stride(1),
B.stride(0), B.stride(1),
C_flat.stride(0), C_flat.stride(1),
)

return C_flat.reshape(b, i, j, k)

# Test code
b = 16
i = 256
j = 512
l = 256
k = 768

def get_inputs():
A = torch.randn(b, i, j, l)
B = torch.randn(l, k)
return [A, B]

def get_init_inputs():
return [] # No special initialization inputs needed
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import triton
import triton.language as tl

@triton.jit
def diag_matmul_kernel(
A_ptr,
B_ptr,
C_ptr,
N, M,
stride_B0, stride_B1,
stride_C0, stride_C1,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= N:
return

a_val = tl.load(A_ptr + pid)
row_start_B = pid * stride_B0
row_start_C = pid * stride_C0

for col_block in range(0, tl.cdiv(M, BLOCK_SIZE)):
col_offset = col_block * BLOCK_SIZE
col_indices = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_indices < M

b_vals = tl.load(
B_ptr + row_start_B + col_indices * stride_B1,
mask=mask,
other=0.0
)
c_vals = a_val * b_vals
tl.store(
C_ptr + row_start_C + col_indices * stride_C1,
c_vals,
mask=mask
)

class ModelNew(nn.Module):
def __init__(self):
super(ModelNew, self).__init__()

def forward(self, A, B):
N, M = B.shape
C = torch.empty_like(B)

if B.numel() == 0:
return C

BLOCK_SIZE = 1024
grid = (N,)
diag_matmul_kernel[grid](
A, B, C,
N, M,
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_SIZE=BLOCK_SIZE
)
return C

M = 4096
N = 4096

def get_inputs():
A = torch.randn(N)
B = torch.randn(N, M)
return [A, B]

def get_init_inputs():
return [] # No special initialization inputs needed
Loading