Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nhat/debug env #219

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions .github/workflows/test-plugin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ jobs:
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"

- name: Check pre-commit
working-directory: triton_shared/triton
run: |
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files --verbose

- name: Build/Install Triton
working-directory: triton_shared/triton/python
run: |
Expand All @@ -73,30 +67,21 @@ jobs:
export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared"
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'

- name: Run shared middle-layer lit tests
working-directory: triton_shared/triton/python
run: |
python3 -m pip install lit
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test"
if [ ! -d "${LIT_TEST_DIR}" ]; then
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
fi
lit -v "${LIT_TEST_DIR}"

- name: Install CPU backend example dependencies
run: |
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install pytest

- name: Prepare CPU backend environment
working-directory: triton_shared/triton/python
run: |
echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}"
echo "LLVM_BINARY_DIR=${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin" >> "${GITHUB_ENV}"
CMAKE_BUILD_DIR=$(ls $(pwd)/build | grep -i cmake)
LLVM_BIN_DIR=$(find ${HOME}/.triton/llvm/ -name bin | head -1)

ls $(pwd)/build
echo "---"
ls ${HOME}/.triton/llvm/
echo "cmake: +${CMAKE_BUILD_DIR}+"
echo "llvm: +${LLVM_BIN_DIR}+"


- name: Run CPU backend examples
working-directory: triton_shared/python/examples
run: pytest .
echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/${CMAKE_BUILD_DIR}/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}"
echo "LLVM_BINARY_DIR=${HOME}/.triton/llvm/${LLVM_BIN_DIR}" >> "${GITHUB_ENV}"


build_and_test_triton_shared_arm:
Expand Down Expand Up @@ -130,12 +115,6 @@ jobs:
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"

- name: Check pre-commit
working-directory: triton_shared/triton
run: |
python3 -m pip install --upgrade pre-commit
python3 -m pre_commit run --all-files --verbose

- name: Build/Install Triton
working-directory: triton_shared/triton/python
run: |
Expand All @@ -146,27 +125,18 @@ jobs:
export TRITON_PLUGIN_DIRS="${GITHUB_WORKSPACE}/triton_shared"
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'

- name: Run shared middle-layer lit tests
working-directory: triton_shared/triton/python
run: |
python3 -m pip install lit
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test"
if [ ! -d "${LIT_TEST_DIR}" ]; then
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
fi
lit -v "${LIT_TEST_DIR}"

- name: Install CPU backend example dependencies
run: |
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install pytest

- name: Prepare CPU backend environment
working-directory: triton_shared/triton/python
run: |
echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}"
echo "LLVM_BINARY_DIR=${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin" >> "${GITHUB_ENV}"
CMAKE_BUILD_DIR=$(ls $(pwd)/build | grep -i cmake)
LLVM_BINARY_DIR=$(find ${HOME}/.triton/llvm/ -name bin | head -1)

ls $(pwd)/build
echo "---"
ls ${HOME}/.triton/llvm/
echo "cmake: +${CMAKE_BUILD_DIR}+"
echo "llvm: +${LLVM_BINARY_DIR}+"


- name: Run CPU backend examples
working-directory: triton_shared/python/examples
run: pytest .
echo "TRITON_SHARED_OPT_PATH=$(pwd)/build/${CMAKE_BUILD_DIR}/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" >> "${GITHUB_ENV}"
echo "LLVM_BINARY_DIR=${LLVM_BINARY_DIR}" >> "${GITHUB_ENV}"
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ set(TRITON_SHARED_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files
include_directories(${Python3_INCLUDE_DIR})
include_directories(${pybind11_INCLUDE_DIR})

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
Expand Down
1 change: 1 addition & 0 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class CPUOptions:
shared: bool = False
allow_fp8e4nv: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
sanitize_overflow: bool = True

def __post_init__(self):
pass
Expand Down
4 changes: 4 additions & 0 deletions backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def __init__(self):
def is_active():
return False

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench

def get_device_capability(self):
return ("cpu", 0)

Expand Down
3 changes: 3 additions & 0 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class PtrAnalysis {

LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false);

// Only rewrite if a scalar ptr is splated into a tensor of ptr
LogicalResult rewriteSplatOp(triton::SplatOp op);

LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,7 @@ struct AssertConverter : public OpConversionPattern<triton::AssertOp> {
}

auto assertMessage =
llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(),
op.getLine(), op.getFunc(), op.getMessage());
llvm::formatv("Assertion `{0}` failed", op.getMessage());
rewriter.create<mlir::cf::AssertOp>(op.getLoc(), condVal,
assertMessage.str());

Expand Down
35 changes: 33 additions & 2 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op,
auto loc = op.getLoc();

if (!ptr) {
op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so "
op->emitRemark("PtrAnalysis: pointer is not replaced with tts.make_tptr so "
"loadOp cannot be rewritten");
return failure();
}
Expand Down Expand Up @@ -1243,7 +1243,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op,
auto loc = op.getLoc();

if (!ptr) {
op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so "
op->emitRemark("PtrAnalysis: pointer is not replaced with tts.make_tptr so "
"storeOp cannot be rewritten");
return failure();
}
Expand Down Expand Up @@ -1280,6 +1280,30 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op,
return success();
}

LogicalResult PtrAnalysis::rewriteSplatOp(triton::SplatOp op) {
if (isa<triton::PointerType>(op.getSrc().getType())) {
LLVM_DEBUG({
llvm::dbgs() << "SplatOp has ptr-typed src: " << op.getSrc()
<< "\nsplatted into type: " << op.getType() << "\n";
});

OpBuilder builder(op);
PtrState state;
if (visitOperandSplat(op, state, op.getLoc(), builder).failed())
return failure();

knownPtrs[op.getResult()] = state;

if (isa<RankedTensorType>(op.getResult().getType())) {
auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc());
ptrMap.map(op.getResult(), maketptrOp.getResult());
} else {
ptrMap.map(op.getResult(), op.getResult());
}
}
return success();
}

LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) {
LLVM_DEBUG({
llvm::dbgs() << "rewriting rootOp\n";
Expand Down Expand Up @@ -1324,6 +1348,13 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) {
}
return WalkResult::skip();
})
.Case<triton::SplatOp>([&](auto splat) {
if (rewriteSplatOp(splat).failed()) {
splat->emitRemark("PtrAnalysis: Failed rewrite SplatOp");
return WalkResult::advance();
}
return WalkResult::skip();
})
.Case<scf::ForOp>([&](auto forOp) {
// `rewriteForOp` recursively visits its children, so regardless
// whether the rewrite succeeds or not, we need to return "skip" so
Expand Down
17 changes: 9 additions & 8 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ class TritonFunctionSignatureConverter : public TypeConverter {
// handled when we convert addptr op later.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addArgumentMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down Expand Up @@ -118,7 +118,7 @@ class LoopTypeConverter : public TypeConverter {
// reinterpret_cast.
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
auto reinterpretCast =
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
return builder.create<memref::ReinterpretCastOp>(
Expand Down Expand Up @@ -167,9 +167,10 @@ struct ScalarAddptrConverter
}
};

static std::optional<SmallVector<Value>>
buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input,
static SmallVector<Value>
buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) {
Value input = inputs.front();
assert(resultTypes.size() == 2 && isa<MemRefType>(resultTypes[0]) &&
isa<IndexType>(resultTypes[1]) &&
"Unexpected result types when converting addptr");
Expand Down Expand Up @@ -201,8 +202,8 @@ buildCastAndOffsetOps(OpBuilder &builder, TypeRange resultTypes, Value input,
return SmallVector<Value>{cast, zero};
}

static std::optional<Value> buildCastOp(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
static Value buildCastOp(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
assert(isa<triton::PointerType>(resultType));
assert(inputs.size() && isa<MemRefType>(inputs[0].getType()) &&
isa<IndexType>(inputs[1].getType()));
Expand Down Expand Up @@ -311,7 +312,7 @@ class StructuredToMemrefPass
RewritePatternSet patterns(&getContext());

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->2 type conversion here, where a triton pointer type
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TritonToStructuredPass
RewritePatternSet patterns(&getContext());

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->1 type conversion here, where a triton pointer type
Expand Down Expand Up @@ -145,10 +145,10 @@ class TritonToStructuredPass
// Compute the target materialization, given a value with the pointer type,
// convert that value to a tuple type.
converter.addTargetMaterialization(
[](OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) -> std::optional<SmallVector<Value>> {
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) -> SmallVector<Value> {
return builder
.create<UnrealizedConversionCastOp>(loc, resultTypes, input)
.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs.front())
->getResults();
});

Expand All @@ -172,7 +172,7 @@ class TritonToStructuredPass
auto moduleOp = getOperation();

auto context = &getContext();
OneToNTypeConverter converter;
TypeConverter converter;
converter.addConversion([](Type type) { return type; });

// We are doing a 1->N type conversion here, where a pointer tuple type
Expand Down Expand Up @@ -208,10 +208,10 @@ class TritonToStructuredPass
// At the end of pointer analysis, we will use the PtrState to create the
// correct offsets, strides, and remove these ops.
converter.addTargetMaterialization([](OpBuilder &builder,
TypeRange resultTypes, Value input,
TypeRange resultTypes, ValueRange inputs,
Location loc) {
auto placeholder = builder.create<tts::GetStructuredStateOp>(
loc, input.getDefiningOp()->getOperand(0));
loc, inputs.front().getDefiningOp()->getOperand(0));
assert(llvm::equal(placeholder.getResultTypes(), resultTypes));
return placeholder.getResults();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ module {
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>
Expand Down
5 changes: 2 additions & 3 deletions test/Conversion/StructuredToMemref/get_num_programs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ module {

// CHECK-LABEL: func.func @num_programs
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>>
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1xi32>
// CHECK: [[VAR_1_:%.+]] = linalg.fill ins([[PARAM_1_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32>
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<1xi32>, memref<1xi32, strided<[1], offset: ?>>) -> ()
// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_]] : (tensor<1xi32>, memref<1xi32, strided<[1]>>) -> ()
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_1_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>>
// CHECK-DAG: [[VAR_2_:%.+]] = linalg.fill ins([[PARAM_2_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32>
// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<1xi32>, memref<1xi32, strided<[1], offset: ?>>) -> ()
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/StructuredToMemref/triton_assert.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ tt.func public @assert_lol(%arg0: i32) {
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
%1 = tt.splat %0 : i1 -> tensor<1xi1>
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
tt.assert %1, "lol" : tensor<1xi1>
tt.return
}

Expand All @@ -12,6 +12,6 @@ tt.func public @assert_lol(%arg0: i32) {
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
// CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed"
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
// CHECK: return
// CHECK: }
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// XFAIL: *
// Note: LLVM commit 889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes like
// <?x?> and <2x?> are mismatch.
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s

module {
Expand Down
3 changes: 3 additions & 0 deletions test/Conversion/StructuredToMemref/wraparound_stacked.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// XFAIL: *
// Note: LLVM commit 889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes like
// <?x?> and <2x?> are mismatch.
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s

module {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ module {
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>
Expand Down
Loading
Loading