Skip to content

Commit 4f88c23

Browse files
authored
[mlir][py] Add NVGPU's TensorMapDescriptorType in py bindings (llvm#88855)
This PR adds NVGPU dialects' TensorMapDescriptorType in the py bindings. This is a follow-up issue from [this PR](llvm#87153 (comment))
1 parent 856d1c4 commit 4f88c23

File tree

6 files changed

+101
-0
lines changed

6 files changed

+101
-0
lines changed

mlir/include/mlir-c/Dialect/NVGPU.h

+11
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,24 @@
1111
#define MLIR_C_DIALECT_NVGPU_H
1212

1313
#include "mlir-c/IR.h"
14+
#include "mlir-c/Support.h"
1415

1516
#ifdef __cplusplus
1617
extern "C" {
1718
#endif
1819

1920
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu);
2021

22+
//===---------------------------------------------------------------------===//
23+
// TensorMapDescriptorType
24+
//===---------------------------------------------------------------------===//
25+
26+
MLIR_CAPI_EXPORTED bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type);
27+
28+
MLIR_CAPI_EXPORTED MlirType mlirNVGPUTensorMapDescriptorTypeGet(
29+
MlirContext ctx, MlirType tensorMemrefType, int swizzle, int l2promo,
30+
int oobFill, int interleave);
31+
2132
#ifdef __cplusplus
2233
}
2334
#endif
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===--- DialectNvgpu.cpp - Pybind module for Nvgpu dialect API support ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/NVGPU.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
#include <pybind11/pybind11.h>
13+
14+
namespace py = pybind11;
15+
using namespace llvm;
16+
using namespace mlir;
17+
using namespace mlir::python;
18+
using namespace mlir::python::adaptors;
19+
20+
static void populateDialectNvgpuSubmodule(const pybind11::module &m) {
21+
auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
22+
m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);
23+
24+
nvgpuTensorMapDescriptorType.def_classmethod(
25+
"get",
26+
[](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
27+
int oobFill, int interleave, MlirContext ctx) {
28+
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
29+
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
30+
},
31+
"Gets an instance of TensorMapDescriptorType in the same context",
32+
py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"),
33+
py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"),
34+
py::arg("ctx") = py::none());
35+
}
36+
37+
PYBIND11_MODULE(_mlirDialectsNvgpu, m) {
38+
m.doc() = "MLIR NVGPU dialect.";
39+
40+
populateDialectNvgpuSubmodule(m);
41+
}

mlir/lib/CAPI/Dialect/NVGPU.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,23 @@
99
#include "mlir-c/Dialect/NVGPU.h"
1010
#include "mlir/CAPI/Registration.h"
1111
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::nvgpu;
1216

1317
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu, mlir::nvgpu::NVGPUDialect)
18+
19+
bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type) {
20+
return isa<nvgpu::TensorMapDescriptorType>(unwrap(type));
21+
}
22+
23+
MlirType mlirNVGPUTensorMapDescriptorTypeGet(MlirContext ctx,
24+
MlirType tensorMemrefType,
25+
int swizzle, int l2promo,
26+
int oobFill, int interleave) {
27+
return wrap(nvgpu::TensorMapDescriptorType::get(
28+
unwrap(ctx), cast<MemRefType>(unwrap(tensorMemrefType)),
29+
TensorMapSwizzleKind(swizzle), TensorMapL2PromoKind(l2promo),
30+
TensorMapOOBKind(oobFill), TensorMapInterleaveKind(interleave)));
31+
}

mlir/python/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
524524
MLIRCAPIQuant
525525
)
526526

527+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
528+
MODULE_NAME _mlirDialectsNvgpu
529+
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
530+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
531+
SOURCES
532+
DialectNVGPU.cpp
533+
PRIVATE_LINK_LIBS
534+
LLVMSupport
535+
EMBED_CAPI_LINK_LIBS
536+
MLIRCAPIIR
537+
MLIRCAPINVGPU
538+
)
539+
527540
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
528541
MODULE_NAME _mlirDialectsPDL
529542
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl

mlir/python/mlir/dialects/nvgpu.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from ._nvgpu_ops_gen import *
66
from ._nvgpu_enum_gen import *
7+
from .._mlir_libs._mlirDialectsNvgpu import *

mlir/test/python/dialects/nvgpu.py

+17
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@ def constructAndPrintInModule(f):
1515
return f
1616

1717

18+
# CHECK-LABEL: testTypes
19+
@constructAndPrintInModule
20+
def testTypes():
21+
tensorMemrefType = MemRefType.get(
22+
(128, 64), F16Type.get(), memory_space=Attribute.parse("3")
23+
)
24+
# CHECK: !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = l2promo_256b, oob = nan, interleave = none>
25+
tma_desc = nvgpu.TensorMapDescriptorType.get(
26+
tensorMemrefType,
27+
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
28+
nvgpu.TensorMapL2PromoKind.L2PROMO_256B,
29+
nvgpu.TensorMapOOBKind.OOB_NAN,
30+
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
31+
)
32+
print(tma_desc)
33+
34+
1835
# CHECK-LABEL: testSmoke
1936
@constructAndPrintInModule
2037
def testSmoke():

0 commit comments

Comments
 (0)