Skip to content

Commit

Permalink
[MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update fai…
Browse files Browse the repository at this point in the history
…ling obsolete OpDSL tests. (llvm#115319)

The earlier PR(llvm#104783) which
introduces
transpose and broadcast semantic to linalg.matmul was reverted due to
two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of
Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files.
All other files
were part of earlier PR and just cherry picked.
    "mlir/test/python/integration/dialects/linalg/opsrun.py"
    "mlir/test/python/integration/dialects/transform.py"

---------

Co-authored-by: Renato Golin <[email protected]>
  • Loading branch information
shahidact and rengolin authored Nov 7, 2024
1 parent 21835ee commit 3ad0148
Show file tree
Hide file tree
Showing 16 changed files with 959 additions and 309 deletions.
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,16 @@ def LinalgStructuredInterface
return;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if the user has supplied an explicit indexing maps for this op.
}],
/*retTy=*/"bool",
/*methodName=*/"hasUserDefinedMaps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{ return false; }]
>,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
Expand Down
72 changes: 0 additions & 72 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
doc: |-
Performs a matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
kind: input_tensor
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
kind: input_tensor
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: C
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- !LinalgOperandDefConfig
name: cast
kind: type_fn_attr
default_fn: cast_signed
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
iterator_types:
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
scalar_fn:
kind: type
attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
cpp_class_name: QuantizedMatmulOp
Expand Down
134 changes: 134 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//

def MatmulOp : LinalgStructuredBase_Op<"matmul", [
AttrSizedOperandSegments,
LinalgContractionOpInterface]> {

let summary = [{
Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
}];
let description = [{
Numeric casting is performed on the operands to the inner multiply,
promoting them to the same data type as the accumulator/output.

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list must include all
the maps if specified.

Example Transpose:
```
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```

Example Broadcast:
```
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```

Example Broadcast and transpose:
```
linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
affine_map<(d0, d1, d2) -> (d2)>, // broadcast
affine_map<(d0, d1, d2) -> (d0, d1)>
]
ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(operands);
$_state.addAttributes(attributes);
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion();
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
"Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addAttribute("cast", cast);
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
attributes, MatmulOp::getRegionBuilder());
}]>

];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
SmallVector<utils::IteratorType> getIteratorTypesArray();

/// Implements the block region builder.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);

/// Returns a list of AffineMap with the typical matmul indexing charactristic.
SmallVector<AffineMap> getDefaultIndexingMaps();

/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
bool hasDynamicIndexingMaps();
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
/// user defined indexing maps are not equal to default map.
bool hasUserDefinedMaps();
}];
}

//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <numeric>
#include <optional>

using namespace mlir;
using namespace mlir::linalg;
Expand Down Expand Up @@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {

LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);

// Mixed tensor/buffer operands are not allowed.
if (!linalgOp.hasPureTensorSemantics() &&
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
Expand All @@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< ") to be equal to the number of input/output operands ("
<< linalgOp->getNumOperands() << ")";

// Set this flag if this op has user defined maps. This is required to guard
// the below error condition which assume default indexing maps.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);

Expand All @@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< " dim(s) to match the number of loops";

int64_t rank = linalgOp.getRank(&opOperand);

if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand.getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
}

SmallVector<unsigned> redDims;
linalgOp.getReductionDims(redDims);

Expand All @@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
// Check if given shapes match to inferred shapes.
SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);

// Verify only static cases since we can't get exact dimension sizes and loop
// ranges for dynamic cases in this stage.
// Verify only static cases since we can't get exact dimension sizes and
// loop ranges for dynamic cases in this stage.
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
Expand Down
Loading

0 comments on commit 3ad0148

Please sign in to comment.