Skip to content

Commit 69186b6

Browse files
committed
[mlir][python] bind block successors and predecessors
1 parent 2dfcc43 commit 69186b6

File tree

4 files changed

+144
-4
lines changed

4 files changed

+144
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
986986
MLIR_CAPI_EXPORTED void
987987
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
988988

989+
/// Returns the number of successor blocks of the block.
990+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
991+
992+
/// Returns `pos`-th successor of the block.
993+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
994+
intptr_t pos);
995+
996+
/// Returns the number of predecessor blocks of the block.
997+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
998+
999+
/// Returns `pos`-th predecessor of the block.
1000+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
1001+
intptr_t pos);
1002+
9891003
//===----------------------------------------------------------------------===//
9901004
// Value API.
9911005
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,85 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
26262626
PyOperationRef operation;
26272627
};
26282628

2629+
/// A list of block successors. Internally, these are stored as consecutive
2630+
/// elements, random access is cheap. The (returned) successor list is
2631+
/// associated with the operation and block whose successors these are, and thus
2632+
/// extends the lifetime of this operation and block.
2633+
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2634+
public:
2635+
static constexpr const char *pyClassName = "BlockSuccessors";
2636+
2637+
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2638+
intptr_t startIndex = 0, intptr_t length = -1,
2639+
intptr_t step = 1)
2640+
: Sliceable(startIndex,
2641+
length == -1 ? mlirBlockGetNumSuccessors(block.get())
2642+
: length,
2643+
step),
2644+
operation(operation), block(block) {}
2645+
2646+
private:
2647+
/// Give the parent CRTP class access to hook implementations below.
2648+
friend class Sliceable<PyBlockSuccessors, PyBlock>;
2649+
2650+
intptr_t getRawNumElements() {
2651+
block.checkValid();
2652+
return mlirBlockGetNumSuccessors(block.get());
2653+
}
2654+
2655+
PyBlock getRawElement(intptr_t pos) {
2656+
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2657+
return PyBlock(operation, block);
2658+
}
2659+
2660+
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2661+
return PyBlockSuccessors(block, operation, startIndex, length, step);
2662+
}
2663+
2664+
PyOperationRef operation;
2665+
PyBlock block;
2666+
};
2667+
2668+
/// A list of block predecessors. Internally, these are stored as consecutive
2669+
/// elements, random access is cheap. The (returned) predecessor list is
2670+
/// associated with the operation and block whose predecessors these are, and
2671+
/// thus extends the lifetime of this operation and block.
2672+
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2673+
public:
2674+
static constexpr const char *pyClassName = "BlockPredecessors";
2675+
2676+
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2677+
intptr_t startIndex = 0, intptr_t length = -1,
2678+
intptr_t step = 1)
2679+
: Sliceable(startIndex,
2680+
length == -1 ? mlirBlockGetNumPredecessors(block.get())
2681+
: length,
2682+
step),
2683+
operation(operation), block(block) {}
2684+
2685+
private:
2686+
/// Give the parent CRTP class access to hook implementations below.
2687+
friend class Sliceable<PyBlockPredecessors, PyBlock>;
2688+
2689+
intptr_t getRawNumElements() {
2690+
block.checkValid();
2691+
return mlirBlockGetNumPredecessors(block.get());
2692+
}
2693+
2694+
PyBlock getRawElement(intptr_t pos) {
2695+
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2696+
return PyBlock(operation, block);
2697+
}
2698+
2699+
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2700+
intptr_t step) {
2701+
return PyBlockPredecessors(block, operation, startIndex, length, step);
2702+
}
2703+
2704+
PyOperationRef operation;
2705+
PyBlock block;
2706+
};
2707+
26292708
/// A list of operation attributes. Can be indexed by name, producing
26302709
/// attributes, or by index, producing named attributes.
26312710
class PyOpAttributeMap {
@@ -3655,7 +3734,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553734
},
36563735
nb::arg("operation"),
36573736
"Appends an operation to this block. If the operation is currently "
3658-
"in another block, it will be moved.");
3737+
"in another block, it will be moved.")
3738+
.def_prop_ro(
3739+
"successors",
3740+
[](PyBlock &self) {
3741+
return PyBlockSuccessors(self, self.getParentOperation());
3742+
},
3743+
"Returns the list of Block successors.")
3744+
.def_prop_ro(
3745+
"predecessors",
3746+
[](PyBlock &self) {
3747+
return PyBlockPredecessors(self, self.getParentOperation());
3748+
},
3749+
"Returns the list of Block predecessors.");
36593750

36603751
//----------------------------------------------------------------------------
36613752
// Mapping of PyInsertionPoint.
@@ -4099,6 +4190,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40994190
PyBlockArgumentList::bind(m);
41004191
PyBlockIterator::bind(m);
41014192
PyBlockList::bind(m);
4193+
PyBlockSuccessors::bind(m);
4194+
PyBlockPredecessors::bind(m);
41024195
PyOperationIterator::bind(m);
41034196
PyOperationList::bind(m);
41044197
PyOpAttributeMap::bind(m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
10591059
unwrap(block)->print(stream);
10601060
}
10611061

1062+
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
1063+
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
1064+
}
1065+
1066+
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
1067+
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
1068+
}
1069+
1070+
intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
1071+
Block *b = unwrap(block);
1072+
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
1073+
}
1074+
1075+
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
1076+
Block *b = unwrap(block);
1077+
Block::pred_iterator it = b->pred_begin();
1078+
std::advance(it, pos);
1079+
return wrap(*it);
1080+
}
1081+
10621082
//===----------------------------------------------------------------------===//
10631083
// Value API.
10641084
//===----------------------------------------------------------------------===//

mlir/test/python/ir/blocks.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4-
import io
5-
import itertools
6-
from mlir.ir import *
4+
75
from mlir.dialects import builtin
86
from mlir.dialects import cf
97
from mlir.dialects import func
8+
from mlir.ir import *
109

1110

1211
def run(f):
@@ -54,10 +53,24 @@ def testBlockCreation():
5453
with InsertionPoint(middle_block) as middle_ip:
5554
assert middle_ip.block == middle_block
5655
cf.BranchOp([i32_arg], dest=successor_block)
56+
5757
module.print(enable_debug_info=True)
5858
# Ensure region back references are coherent.
5959
assert entry_block.region == middle_block.region == successor_block.region
6060

61+
assert len(entry_block.successors) == 1
62+
assert len(entry_block.predecessors) == 0
63+
assert middle_block == entry_block.successors[0]
64+
assert len(middle_block.predecessors) == 1
65+
assert entry_block == middle_block.predecessors[0]
66+
67+
assert len(middle_block.successors) == 1
68+
assert successor_block == middle_block.successors[0]
69+
assert len(successor_block.predecessors) == 1
70+
assert middle_block == successor_block.predecessors[0]
71+
72+
assert len(successor_block.successors) == 0
73+
6174

6275
# CHECK-LABEL: TEST: testBlockCreationArgLocs
6376
@run

0 commit comments

Comments
 (0)