Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/ntops/kernels/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def arrangement(
beta,
alpha,
output,
input_precision,
block_size_m=None,
block_size_n=None,
block_size_k=None,
Expand All @@ -26,27 +27,46 @@ def arrangement(
if block_size_k is None:
block_size_k = mm.BLOCK_SIZE_K

_, _, input_arranged = mm.arrangement(
_, _, input_arranged, _ = mm.arrangement(
x,
y,
input,
input_precision,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

x_arranged, y_arranged, output_arranged = mm.arrangement(x, y, output)
x_arranged, y_arranged, output_arranged, _ = mm.arrangement(
x, y, output, input_precision
)

input_precision_arranged = input_precision

return input_arranged, x_arranged, y_arranged, beta, alpha, output_arranged
return (
input_arranged,
x_arranged,
y_arranged,
beta,
alpha,
output_arranged,
input_precision_arranged,
)


def application(input, x, y, beta, alpha, output):
def application(input, x, y, beta, alpha, output, input_precision):
mm_output = ntl.zeros(output.shape, dtype=ntl.float32)
mm.application(x, y, mm_output)
mm.application(x, y, mm_output, input_precision)
output = beta * input + alpha * mm_output


def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
def premake(
input_precision=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
block_size_m=block_size_m,
Expand All @@ -61,6 +81,7 @@ def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None)
Tensor(0, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(2, dtype=dtype),
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
)

return arrangement_, application, tensors
27 changes: 23 additions & 4 deletions src/ntops/kernels/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@


def arrangement(
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
input,
other,
output,
input_precision,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
Expand All @@ -32,17 +38,30 @@ def arrangement(
other_arranged.dtype = other_arranged.dtype.squeeze((0, 2))
other_arranged.dtype.dtype = other_arranged.dtype.dtype.squeeze(0)

return input_arranged, other_arranged, output_arranged
input_precision_arranged = input_precision

return input_arranged, other_arranged, output_arranged, input_precision_arranged

def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):

def premake(
input_precision=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

tensors = (Tensor(3, dtype=dtype), Tensor(3, dtype=dtype), Tensor(3, dtype=dtype))
tensors = (
Tensor(3, dtype=dtype),
Tensor(3, dtype=dtype),
Tensor(3, dtype=dtype),
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
)

return arrangement_, application, tensors
43 changes: 37 additions & 6 deletions src/ntops/kernels/mm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import functools

import ninetoothed
Expand All @@ -9,8 +10,20 @@
BLOCK_SIZE_K = ninetoothed.block_size()


class InputPrecisionVariant(enum.IntEnum):
TF32 = enum.auto()

IEEE = enum.auto()


def arrangement(
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
input,
other,
output,
input_precision,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
Expand All @@ -33,26 +46,44 @@ def arrangement(
other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
other_arranged.dtype = other_arranged.dtype.squeeze(1)

return input_arranged, other_arranged, output_arranged
input_precision_arranged = input_precision

return input_arranged, other_arranged, output_arranged, input_precision_arranged


def application(input, other, output):
def application(input, other, output, input_precision):
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)

if input_precision == 2: # InputPrecisionVariant.IEEE:
input_precision_: ntl.constexpr = "ieee"
else:
input_precision_: ntl.constexpr = "tf32"

for k in range(input.shape[0]):
accumulator += ntl.dot(input[k], other[k])
accumulator += ntl.dot(input[k], other[k], input_precision=input_precision_)

output = accumulator


def premake(dtype=None, block_size_m=None, block_size_n=None, block_size_k=None):
def premake(
input_precision=None,
dtype=None,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
arrangement_ = functools.partial(
arrangement,
block_size_m=block_size_m,
block_size_n=block_size_n,
block_size_k=block_size_k,
)

tensors = (Tensor(2, dtype=dtype), Tensor(2, dtype=dtype), Tensor(2, dtype=dtype))
tensors = (
Tensor(2, dtype=dtype),
Tensor(2, dtype=dtype),
Tensor(2, dtype=dtype),
Tensor(0, dtype=dtype, constexpr=True, value=input_precision),
)

return arrangement_, application, tensors
13 changes: 10 additions & 3 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):

kernel = _cached_make(ntops.kernels.addmm.premake)

kernel(input, mat1, mat2, beta, alpha, out)
kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision())

return out

Expand Down Expand Up @@ -125,7 +125,7 @@ def bmm(input, mat2, *, out=None):

kernel = _cached_make(ntops.kernels.bmm.premake)

kernel(input, mat2, out)
kernel(input, mat2, out, _get_matmul_input_precision())

return out

Expand Down Expand Up @@ -294,7 +294,7 @@ def mm(input, mat2, *, out=None):

kernel = _cached_make(ntops.kernels.mm.premake)

kernel(input, mat2, out)
kernel(input, mat2, out, _get_matmul_input_precision())

return out

Expand Down Expand Up @@ -619,3 +619,10 @@ def _cached_make(
num_stages=num_stages,
max_num_configs=max_num_configs,
)


def _get_matmul_input_precision():
if torch.get_float32_matmul_precision() == "highest":
return ntops.kernels.mm.InputPrecisionVariant.IEEE

return ntops.kernels.mm.InputPrecisionVariant.TF32
8 changes: 4 additions & 4 deletions tests/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def generate_arguments():

for dtype in (torch.float32, torch.float16):
if dtype is torch.float32:
atol = 0.05
rtol = 0.05
atol = 0.001
rtol = 0.001
else:
atol = 0.025
rtol = 0.025
atol = 0.01
rtol = 0.01

def generate_random_size():
return random.randint(1, 1024)
Expand Down