Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
avg_pool3d,
histc,
log10,
log1p,
dot,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"avg_pool3d",
"histc",
"log10",
"log1p",
"dot",
]
101 changes: 101 additions & 0 deletions src/ntops/kernels/avg_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed.language import libdevice
from ninetoothed import Tensor
from ninetoothed import Symbol

def arrangement(
*tensors,
kernel_size_d,
kernel_size_h,
kernel_size_w,
stride_d,
stride_h,
stride_w,
block_size,
ceil_mode,
):
input, output, kernel_volume = tensors
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, D_in, H_in, W_in) output: (N, C, D_out, H_out, W_out)
input_arranged = input.tile(
(1, 1, kernel_size_d, kernel_size_h, kernel_size_w),
(1, 1, stride_d, stride_h, stride_w),
floor_mode=not ceil_mode,
)
# => (N, C, D_out, H_out, W_out), dtype=(1, 1, k_d, k_h, k_w)
input_arranged = input_arranged.ravel()
# => (N, C, D_out, H_out, W_out, 1, 1, k_d, k_h, k_w)
input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1)
# => (N*C*D_out*H_out*W_out, k_d*k_h*k_w)

# k_d*k_h*k_w 的找到最近的 2 的倍数
nearest_pow2 = 1 << (kernel_size_d * kernel_size_h * kernel_size_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., k_d*k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2)
input_arranged.dtype = input_arranged.dtype.squeeze(0)
# => (..., 1), dtype=(nearest_pow2, )
input_arranged = input_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, )
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

output_arranged = output.tile((1, 1, 1, 1, 1))
# => (N, C, D_out, H_out, W_out), dtype=(1, 1, 1, 1, 1)
output_arranged = output_arranged.ravel()
# => (N, C, D_out, H_out, W_out, 1, 1, 1, 1)
output_arranged = output_arranged.flatten(end_dim=5).flatten(start_dim=1)
# => (N*C*D_out*H_out*W_out, 1)
output_arranged = output_arranged.tile((block_size, -1))
# => (..., 1), dtype=(block_size, 1)
output_arranged.dtype = output_arranged.dtype.squeeze(1)
# => (..., 1), dtype=(block_size, )

return input_arranged, output_arranged, kernel_volume


def application(input, output, kernel_volume):
# input: (block_size, nearest_pow2)
# output: (block_size,)

# Input 数据: (block_size, nearest_pow2)
# 这是实际的像素值,越界处填充为 0
val_sum = ntl.sum(input, axis=1) # (block_size, )
output = val_sum / kernel_volume # (block_size, )


def premake(
ndim,
kernel_size_d,
kernel_size_h,
kernel_size_w,
stride_d,
stride_h,
stride_w,
block_size=None,
ceil_mode=False,
dtype=None,
):
arrangement_ = functools.partial(
arrangement,
kernel_size_d=kernel_size_d,
kernel_size_h=kernel_size_h,
kernel_size_w=kernel_size_w,
stride_d=stride_d,
stride_h=stride_h,
stride_w=stride_w,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
Tensor(ndim, dtype=dtype, other=0), # input
Tensor(ndim, dtype=dtype), # output
Tensor(0, dtype=dtype), # kernel_volume
)

return arrangement_, application, tensors
73 changes: 73 additions & 0 deletions src/ntops/kernels/dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement

def arrangement_dot_full(input, tensor, out, block_size):
# input/tensor: (N, )
# output: (1, )
input = input.tile((-1, )) # (1, ), dtype=(block_size, )
tensor = tensor.tile((-1, )) # (1, ), dtype=(block_size, )
out = out.tile((1, )) # (1, ), dtype=(1, )
return input, tensor, out

def application_dot_full(input, tensor, out):
out = ntl.sum(input * tensor)

def premake_dot_full(dtype, block_size):
arrangement_ = functools.partial(arrangement_dot_full, block_size=block_size)

tensors = (
Tensor(1, dtype=dtype, shape_options={'constexpr': True}),
Tensor(1, dtype=dtype, shape_options={'constexpr': True}),
Tensor(1, dtype=dtype)
)

return arrangement_, application_dot_full, tensors


# ========= 分块计算 =========

def arrangement_dot_divide(input, tensor, out_temp, block_size):
# input/tensor: (N, )
# output: (N // block_size, )
input = input.tile((block_size, )) # (N // block_size, block_size), dtype=(block_size, )
tensor = tensor.tile((block_size, )) # (N // block_size, block_size), dtype=(block_size, )
out_temp = out_temp.tile((1, )) # (N // block_size, ), dtype=(1, )
return input, tensor, out_temp

def application_dot_divide(input, tensor, out_temp):
out_temp = ntl.sum(input * tensor, 0)

def arrangement_dot_conquer(input_block_wise, out, block_size):
# input/tensor: (N // block_size, )
# output: (1, )
input_block_wise = input_block_wise.tile((-1, )) # (1, ), dtype=(block_size, )
out = out.tile((1, )) # (1, ), dtype=(1, )
return input_block_wise, out

def application_dot_conquer(input_block_wise, out):
out = ntl.sum(input_block_wise)

def premake_dot_divide(dtype, block_size):
arrangement_ = functools.partial(arrangement_dot_divide, block_size=block_size)

tensors = (
Tensor(1, dtype=dtype, shape_options={'constexpr': True}),
Tensor(1, dtype=dtype, shape_options={'constexpr': True}),
Tensor(1, dtype=dtype)
)

return arrangement_, application_dot_divide, tensors

def premake_dot_conquer(dtype, block_size):
arrangement_ = functools.partial(arrangement_dot_conquer, block_size=block_size)

tensors = (
Tensor(1, dtype=dtype, shape_options={'constexpr': True}),
Tensor(1, dtype=dtype)
)

return arrangement_, application_dot_conquer, tensors
133 changes: 133 additions & 0 deletions src/ntops/kernels/histc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed.language import libdevice
from ninetoothed import Tensor


def arrangement(*tensors, block_size):
# input, output, min, max = tensors
input, output, min_val, max_val, num_bins_pow2 = tensors

if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, )
# output: (bins, )

input_tiled = input.flatten().tile((block_size, )) # (N // block_size), dtype=(block_size, )

output_expand = output.unsqueeze(0).expand((input_tiled.shape[0], -1)) # (N // block_size, bins)
output_tiled = output_expand.tile((1, -1)).squeeze(1) # (N // block_size, ), dtype=(1, bins)
output_tiled.dtype = output_tiled.dtype.squeeze(0) # dtype=(bins, )

return input_tiled, output_tiled, min_val, max_val, num_bins_pow2

def application_manual_histogram(input, output, min_val, max_val, num_bins_pow2):
"""手动实现直方图计算。

摩尔线程 GPU 内置的 histogram 函数不能正确计算柱状图,
因此使用 ntl.arange 和 ntl.where 手动实现。
"""
# input: (block_size,)
# output: (bins,)
n_out_bins = output.shape[0]

# 只需要 [min_val, max_val]
mask = (input >= min_val) & (input <= max_val)

# 标准化为 [0, n_out_bins)
input_scaled = (input - min_val) / (max_val - min_val) * n_out_bins

# histogram 需要整数 bin 索引
input_indices = ntl.cast(input_scaled, ntl.int32)

# max_val 应该该落在最后一个 bin 中
input_indices = ntl.minimum(input_indices, n_out_bins - 1)

# 将超出范围的索引设为 -1,使其不会被计入直方图
input_indices = ntl.where(mask, input_indices, -1)

# 初始化直方图张量
local_hist = ntl.zeros((num_bins_pow2,), dtype=output.dtype)

# 逐 bin 计数:对每个 bin,用 where 统计匹配的元素个数
# 由于摩尔线程不支持动态索引 histogram,因此只能手动实现
for bin_idx in range(num_bins_pow2):
bin_idx_tensor = ntl.cast(bin_idx, ntl.int32)
match_mask = (input_indices == bin_idx_tensor)
count = ntl.sum(match_mask.to(output.dtype))
idx = ntl.arange(0, num_bins_pow2)
update_mask = (idx == bin_idx_tensor)
local_hist = ntl.where(update_mask, count, local_hist)

# 只需要前 n_out_bins 个 bin
valid_mask = ntl.arange(0, num_bins_pow2) < n_out_bins
local_hist = local_hist.to(output.dtype)
ntl.atomic_add(output.data_ptr() + output.offsets(),
local_hist,
mask=valid_mask)


def application_builtin_histogram(input, output, min_val, max_val, num_bins_pow2):
# input: (block_size,)
# output: (bins,)
n_out_bins = output.shape[0]

# 只需要 [min_val, max_val]
mask = (input >= min_val) & (input <= max_val)

# 标准化为 [0, n_out_bins)
input_scaled = (input - min_val) / (max_val - min_val) * n_out_bins

# histogram 需要整数 bin 索引
input_indices = ntl.cast(input_scaled, ntl.int32)

# max_val 应该该落在最后一个 bin 中
input_indices = ntl.minimum(input_indices, n_out_bins - 1)

# 将超出范围的索引设为 -1,使其不会被计入直方图
# 因为在 triton 3.5.0 版本才引入的 masked histogram
input_indices = ntl.where(mask, input_indices, -1)

local_hist = ntl.histogram(input_indices,
num_bins=num_bins_pow2) # shape: (num_bins_pow2,)

# 只需要前 n_out_bins 个 bin
valid_mask = ntl.arange(0, num_bins_pow2) < n_out_bins
local_hist = local_hist.to(output.dtype)
ntl.atomic_add(output.data_ptr() + output.offsets(),
local_hist,
mask=valid_mask)



def premake_builtin(dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(1, dtype=dtype, other=float("inf"), shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output
Tensor(0, dtype=dtype), # min
Tensor(0, dtype=dtype), # max
Tensor(0, dtype=int, constexpr=True), # num_bins_pow2
)

return arrangement_, application_builtin_histogram, tensors


def premake_manual(dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(1, dtype=dtype, other=float("inf"), shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output
Tensor(0, dtype=dtype), # min
Tensor(0, dtype=dtype), # max
Tensor(0, dtype=int, constexpr=True), # num_bins_pow2
)

return arrangement_, application_manual_histogram, tensors

21 changes: 21 additions & 0 deletions src/ntops/kernels/log10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
if input.dtype == ntl.float16:
output = ntl.log(ntl.cast(input, ntl.float32)) * 0.4342944819032518
else:
output = ntl.log(input) * 0.4342944819032518


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
18 changes: 18 additions & 0 deletions src/ntops/kernels/log1p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = ntl.log(ntl.cast(input, ntl.float32) + 1)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh
from ntops.torch.avg_pool3d import avg_pool3d
from ntops.torch.histc import histc
from ntops.torch.log10 import log10
from ntops.torch.log1p import log1p
from ntops.torch.dot import dot

__all__ = [
"abs",
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"avg_pool3d",
"histc",
"log10",
"log1p",
"dot",
]
Loading