Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions midend/lib/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ add_mlir_public_c_api_library(BuddyMLIRCAPI
BuddyGemminiTransforms
BuddyRVV
BuddyRVVTransforms
# TODO: naming consistancy
VIR
VectorExp
BuddyMLIRInitAll
BuddyToLLVMIRTranslationRegistration
Expand Down
3 changes: 3 additions & 0 deletions midend/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "buddy-mlir-c/Dialects.h"

// TODO: Does those inclusion for xxxOps.h necessary?
#include "Dialect/Bud/BudDialect.h"
#include "Dialect/Bud/BudOps.h"
#include "Dialect/DAP/DAPDialect.h"
Expand All @@ -25,6 +26,7 @@
#include "Dialect/Gemmini/GemminiDialect.h"
#include "Dialect/Gemmini/GemminiOps.h"
#include "Dialect/RVV/RVVDialect.h"
#include "Dialect/VIR/VIRDialect.h"
#include "Dialect/VectorExp/VectorExpDialect.h"
#include "Dialect/VectorExp/VectorExpOps.h"

Expand All @@ -38,3 +40,4 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Gemmini, gemmini,
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(RVV, rvv, buddy::rvv::RVVDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(VectorExp, vector_exp,
buddy::vector_exp::VectorExpDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(VIR, vir, buddy::vir::VIRDialect)
1 change: 1 addition & 0 deletions midend/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(LinkedLibs
LowerLinalgToGemminiPass
LowerRVVPass
LowerVectorExpPass
VIRToVectorPass
MatMulOptimization
BatchMatMulOptimization
MatMulParallelVectorization
Expand Down
4 changes: 4 additions & 0 deletions midend/lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "Dialect/DIP/DIPDialect.h"
#include "Dialect/Gemmini/GemminiDialect.h"
#include "Dialect/RVV/RVVDialect.h"
#include "Dialect/VIR/VIRDialect.h"
#include "Dialect/VectorExp/VectorExpDialect.h"

namespace mlir {
Expand All @@ -45,6 +46,7 @@ void registerMatMulOptimizePass();
void registerMatMulParallelVectorizationPass();
void registerMatMulVectorizationPass();
void registerTransposeOptimizationPass();
void registerVIRToVectorPass();
} // namespace buddy
} // namespace mlir

Expand All @@ -55,6 +57,7 @@ void mlir::buddy::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<::buddy::gemmini::GemminiDialect>();
registry.insert<::buddy::rvv::RVVDialect>();
registry.insert<::buddy::vector_exp::VectorExpDialect>();
registry.insert<::buddy::vir::VIRDialect>();
}

void mlir::buddy::registerAllPasses() {
Expand All @@ -74,4 +77,5 @@ void mlir::buddy::registerAllPasses() {
mlir::buddy::registerMatMulParallelVectorizationPass();
mlir::buddy::registerMatMulVectorizationPass();
mlir::buddy::registerTransposeOptimizationPass();
mlir::buddy::registerVIRToVectorPass();
}
9 changes: 8 additions & 1 deletion midend/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ declare_mlir_dialect_python_bindings(
dialects/rvv.py
DIALECT_NAME rvv)


declare_mlir_dialect_python_bindings(
ADD_TO_PARENT BuddyMLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir"
Expand All @@ -66,6 +65,14 @@ declare_mlir_dialect_python_bindings(
dialects/vector_exp.py
DIALECT_NAME vector_exp)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT BuddyMLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir"
TD_FILE dialects/VIRBinding.td
SOURCES
dialects/vir.py
DIALECT_NAME vir)

################################################################################
# Python extensions.
# The sources for these are all in lib/python/Bindings, but since they have to
Expand Down
22 changes: 22 additions & 0 deletions midend/python/buddy_mlir/dialects/VIRBinding.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===-------- VIROps.td - Python bindings for VIR --*- tablegen -*--------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_VIR_OPS
#define PYTHON_BINDINGS_VIR_OPS

include "VIR/VIROps.td"

#endif
17 changes: 17 additions & 0 deletions midend/python/buddy_mlir/dialects/vir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ===------------------------ vir.py -------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------

from ._vir_ops_gen import *
71 changes: 71 additions & 0 deletions tests/Python/dialects/test_vir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# RUN: %PYTHON %s | FileCheck %s

from buddy_mlir.dialects import arith, func, vir, memref
from buddy_mlir import ir
from buddy_mlir.passmanager import PassManager


def run(f):
print("\nTEST:", f.__name__)
f()
return f


# CHECK-LABEL: TEST: testVIROperations
@run
def testVIROperations():
with ir.Context(), ir.Location.unknown():
module = ir.Module.parse(
"""
memref.global "private" @gv : memref<10xf32> = dense<[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.]>
func.func private @printMemrefF32(memref<*xf32>)

func.func @main() {
%vl = arith.constant 5 : index
%f1 = arith.constant 1.0 : f32
%mem = memref.get_global @gv : memref<10xf32>
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index

vir.set_vl %vl : index {
%v1 = vir.constant { value = 2.0 : f32 } : !vir.vec<?xf32>
%v2 = vir.broadcast %f1 : f32 -> !vir.vec<?xf32>
vir.store %v1, %mem[%c0] : !vir.vec<?xf32> -> memref<10xf32>
vir.store %v2, %mem[%c5] : !vir.vec<?xf32> -> memref<10xf32>
vector.yield
}

%print_mem = memref.cast %mem : memref<10xf32> to memref<*xf32>
call @printMemrefF32(%print_mem) : (memref<*xf32>) -> ()

return
}
"""
)

module.operation.verify()

pm = PassManager("builtin.module")
pm.add("lower-vir-to-vector")
pm.run(module.operation)

# CHECK: #map = affine_map<(d0) -> (d0)>
# CHECK: func.func @main() {
# CHECK: %[[VL:.*]] = arith.constant 5 : index
# CHECK: %[[F1:.*]] = arith.constant 1.000000e+00 : f32
# CHECK: %[[MEM:.*]] = memref.get_global @gv : memref<10xf32>
# CHECK: %[[C0:.*]] = arith.constant 0 : index
# CHECK: %[[C5:.*]] = arith.constant 5 : index
# CHECK: %[[C256:.*]] = arith.constant 256 : index
# CHECK: affine.for %{{.*}} = #map(%{{.*}}) to #map(%{{.*}}) step 256 {
# CHECK: %[[CONST_VEC:.*]] = arith.constant dense<2.000000e+00> : vector<256xf32>
# CHECK: %[[BROADCAST_VEC:.*]] = vector.broadcast %[[F1]] : f32 to vector<256xf32>
# CHECK: vector.store %[[CONST_VEC]], %[[MEM]][%{{.*}}] : memref<10xf32>, vector<256xf32>
# CHECK: vector.store %[[BROADCAST_VEC]], %[[MEM]][%{{.*}}] : memref<10xf32>, vector<256xf32>
# CHECK: }
# CHECK: affine.for %{{.*}} = #map(%{{.*}}) to #map(%[[VL]]) {
# CHECK: %[[SCALAR_CONST:.*]] = arith.constant 2.000000e+00 : f32
# CHECK: memref.store %[[SCALAR_CONST]], %[[MEM]][%{{.*}}] : memref<10xf32>
# CHECK: memref.store %[[F1]], %[[MEM]][%{{.*}}] : memref<10xf32>
# CHECK: }
print(module)