Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build failure in issues #178 #187

Merged
merged 15 commits into from
Jan 27, 2025
Merged

Fix build failure in issues #178 #187

merged 15 commits into from
Jan 27, 2025

Conversation

zhaoshiz
Copy link
Contributor

Fixing below compilation errors in include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:

/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:839:68: error: no member named 'getFile' in 'mlir::triton::AssertOp'
llvm::formatv("{0}.py:{1}: {2} Assertion {3} failed", op.getFile(),
~~ ^
/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:26: error: no member named 'getLine' in 'mlir::triton::AssertOp'
op.getLine(), op.getFunc(), op.getMessage());
~~ ^
/home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:40: error: no member named 'getFunc' in 'mlir::triton::AssertOp'
op.getLine(), op.getFunc(), op.getMessage());
~~ ^
3 errors generated.

This fix builds with triton @ab07e5472bcb414a0c8dd7ecab80f84370c4894e, and llvm @cfd3289a1f9a87e220737a634904a886a82d424a.

@nhat-nguyen
Copy link
Collaborator

@zhaoshiz Thank you so much! Would you mind also updating the submodule commit to match in this PR too? Otherwise, our build will fail.

@zhaoshiz
Copy link
Contributor Author

zhaoshiz commented Nov 6, 2024

@nhat-nguyen, I've updated the submodule sha and fixed additional compilation errors. I'm working with legal dept. to get the CLA approved.

@zhaoshiz
Copy link
Contributor Author

@microsoft-github-policy-service agree company="Qualcomm Innovation Center, Inc."

@mdehling
Copy link

Thank you, this saved me some time! :)

@zhaoshiz zhaoshiz closed this Dec 12, 2024
@zhaoshiz zhaoshiz reopened this Dec 12, 2024
nhat-nguyen added a commit that referenced this pull request Jan 10, 2025
Newer triton versions put an additional symlink in the llvm folder, so
the `ls` command ends up listing two file names separated by a newline
which breaks the pipeline run in #187. Use `find` with `top -1` to
ensure we only ever get one llvm path.
zhaoshiz pushed a commit to zhaoshiz/triton-shared that referenced this pull request Jan 10, 2025
Newer triton versions put an additional symlink in the llvm folder, so
the `ls` command ends up listing two file names separated by a newline
which breaks the pipeline run in microsoft#187. Use `find` with `top -1` to
ensure we only ever get one llvm path.
@nhat-nguyen
Copy link
Collaborator

Looks like the CPU backend needs to be updated with some of the new methods from the BaseBackend class

@zhaoshiz
Copy link
Contributor Author

zhaoshiz commented Jan 10, 2025

Looks like the CPU backend needs to be updated with some of the new methods from the BaseBackend class

Oh, sorry I missed that part.

That's get_active_torch_device: https://github.com/triton-lang/triton/blob/f9d9fad1b7b648e73ef03332737f000bed258f13/python/triton/backends/driver.py#L22C1-L24C13.

Should I just add below to CPUDriver class?

def get_active_torch_device(self):
    import torch
    return torch.device("cpu")

@zhaoshiz
Copy link
Contributor Author

test_core.py:1650: in
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9 or is_hip(),
...
ERROR test_core.py - AssertionError: Torch not compiled with CUDA enabled

Shall we add a check for torch.is_cuda_available()?
@pytest.mark.skipif(not torch.is_cuda_available() or torch.cuda.get_device_capability()[0] < 9 or is_hip(),

@nhat-nguyen
Copy link
Collaborator

Yeah looks like we need to. Although the test_core file is symlinked to the file in the triton submodule. @parsifal-47 is there a way we can work around this?

@zhaoshiz
Copy link
Contributor Author

https://github.com/triton-lang/triton/blob/110b66e649711a8fdda66359db5054a8a0ede9d2/python/test/unit/language/test_core.py#L1661C21-L1661C34

seems fixed by Triton already, let me try updating Triton

@nhat-nguyen
Copy link
Collaborator

Could we also extract the splat op change to a separate PR? That would make it easier to keep track of things.

@zhaoshiz
Copy link
Contributor Author

zhaoshiz commented Jan 10, 2025

Could we also extract the splat op change to a separate PR? That would make it easier to keep track of things.

sure. I'll revert the commit and xfail related test in this PR, and create another one.

@parsifal-47
Copy link
Contributor

Yeah looks like we need to. Although the test_core file is symlinked to the file in the triton submodule. @parsifal-47 is there a way we can work around this?

yes, so far I was able to workaround by describing exceptions in this file https://github.com/microsoft/triton-shared/blob/main/python/examples/conftest.py if you need to disable test cases please let me know, I am not sure whether the issue is resolved or not

nhat-nguyen added a commit that referenced this pull request Jan 14, 2025
Newer triton versions put an additional symlink in the llvm folder, so
the `ls` command ends up listing two file names separated by a newline
which breaks the pipeline run in #187. Use `find` with `top -1` to
ensure we only ever get one llvm path.
@nhat-nguyen
Copy link
Collaborator

looks like lots of tests are failing in test_core -- some of them seem to be because triton added a constexpr type which we don't handle in the CPU backend yet

@zhaoshiz
Copy link
Contributor Author

looks like lots of tests are failing in test_core -- some of them seem to be because triton added a constexpr type which we don't handle in the CPU backend yet

I've looked into the failures and am working on constexpr.

Fixing below compilation errors in include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:

  /home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:839:68: error: no member named 'getFile' in 'mlir::triton::AssertOp'
          llvm::formatv("{0}.py:{1}: {2} Assertion `{3}` failed", op.getFile(),
                                                                  ~~ ^
  /home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:26: error: no member named 'getLine' in 'mlir::triton::AssertOp'
                        op.getLine(), op.getFunc(), op.getMessage());
                        ~~ ^
  /home/runner/work/triton-shared/triton-shared/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:840:40: error: no member named 'getFunc' in 'mlir::triton::AssertOp'
                        op.getLine(), op.getFunc(), op.getMessage());
                                      ~~ ^
  3 errors generated.

This fix builds with triton @ab07e5472bcb414a0c8dd7ecab80f84370c4894e,
and llvm @cfd3289a1f9a87e220737a634904a886a82d424a.
Fix compilation errors caused by LLVM commits:
f18c3e4e7335df282c468b6dff3d29be1822a96d [mlir][Transforms] Dialect Conversion: Simplify materialization fn result type (#113031)
8c4bc1e75de27adfbaead34b895b0efbaf17bd02 [mlir][Transforms] Merge 1:1 and 1:N type converters (#113032)

Update Triton to 94684d326723b67b146f23f342623ea058a32098
Remove parts of llvm::formatv string correspond to File, Line and Func
arguments of triton::AssertOp.
Add `sanitize_overflow: bool = True` to class CPUOptions in compiler.py
and `get_benchmarker(self)` to class CPUDriver in driver.py to run the
tests.

XFAILing TritonToLinalg tests since this pass will be retire soon:
test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir
test/Conversion/TritonToLinalg/wraparound_stacked.mlir

XFAILing StructuredToMemref tests due to LLVM commit
889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes
like <2x?> and <?x?> are mismatch:
test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir
test/Conversion/StructuredToMemref/wraparound_stacked.mlir
zhaoshiz and others added 9 commits January 17, 2025 14:30
Update CMakeList.txt for python and pybind11 headers.
Fixed test/Conversion/TritonArithToLinalg/split.mlir.
Working on test/Conversion/StructuredToMemref/get_num_programs.mlir.

Builds with Triton@acc25d91fba850c18c099e7e577962ba56bdd06c and
LLVM@86b69c31642e98f8357df62c09d118ad1da4e16a.
Add rewriteSplatOp() in PtrAnalysis pass. This function creates a
tts.makeptr for the case below:
    %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>

Previously we rely on rewriteAddPtrOp to create the tts.makeptr:
    %3 = arith.constant 0 : index
    %6 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>>
    %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr<i32>>, tensor<1xi32>

Creating a constant 0 and adding it to a pointer is optimized away by
Triton.
This reverts commit 2153c53.

The commit being reverted will be sent in a separate PR.
In commit 9743ec0dca5bbd9dbce20adc3ee273af6b095f94, Triton moved to use
"constexpr"s instead of "constant"s in its function signature.
Also update Triton to 2efb067bfc0f9acabcd8b4ffe7c55ad248dfb282.
Change various lambda return type from `std::optional<Value>` to `Value`
per LLVM API change.
@zhaoshiz
Copy link
Contributor Author

rebased and fixed build failures

@zhaoshiz
Copy link
Contributor Author

zhaoshiz commented Jan 17, 2025

yes, so far I was able to workaround by describing exceptions in this file https://github.com/microsoft/triton-shared/blob/main/python/examples/conftest.py if you need to disable test cases please let me know, I am not sure whether the issue is resolved or not

I have disabled several tests from Triton's test_core.py in conftest.py. I think some are not supported by Triton-Shared but I'm unsure about FP8 data types on CPUs. Please take a look: 12d816e

@parsifal-47
Copy link
Contributor

I have disabled several tests from Triton's test_core.py in conftest.py. I think some are not supported by Triton-Shared but I'm unsure about FP8 data types on CPUs. Please take a look: 12d816e

you also provided comments for each disabled test, looks good to me, thanks a lot for doing that!

@nhat-nguyen
Copy link
Collaborator

@zhaoshiz Looks like we're very close! For the modulo tests, you can use my patch here to fix both the lit tests and the CPU backend tests:

diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
index fa195c6..8fac5d8 100644
--- a/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
+++ b/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
@@ -176,9 +176,9 @@ private:
         SmallVector<int64_t>(resultShape.size(), ShapedType::kDynamic),
         /* result shape */
         SmallVector<int64_t>{
-
-            // Row stays the same
-            resultShape[0],
+            // Row stays the same, but mlir doesn't allow this anymore. Put
+            // dynamic.
+            ShapedType::kDynamic,
 
             // Column is dynamic, in most cases, this
             // should be the same as the original column.
@@ -286,9 +286,9 @@ private:
             // around.
             ShapedType::kDynamic,
 
-            // Col stays the same.
-            resultShape[1],
-        });
+            // Col stays the same, which is resultShape[1], but mlir doesn't
+            // allow this anymore. So we put dynamic instead.
+            ShapedType::kDynamic});
 
     Value rowSize = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getIndexAttr(op.getSizes()[0]));

We can disable some of the tests in core.py to unblock this update.

Thanks Nhat Nguyen for the fix.
UnXFAILed and updated wraparound_side_by_side.mlir and wraparound_stacked.mlir
in test/Conversion/StructuredToMemref.
@zhaoshiz
Copy link
Contributor Author

@zhaoshiz Looks like we're very close! For the modulo tests, you can use my patch here to fix both the lit tests and the CPU backend tests:
...
We can disable some of the tests in core.py to unblock this update.

Thanks! I was looking to fix it in MLIR but this is a better solution.

Copy link
Collaborator

@nhat-nguyen nhat-nguyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you @zhaoshiz for your help here!

@nhat-nguyen nhat-nguyen merged commit 560c064 into microsoft:main Jan 27, 2025
3 checks passed
@zhaoshiz
Copy link
Contributor Author

thank you @zhaoshiz for your help here!

my pleasure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants