Skip to content

Commit c760955

Browse files
committed
[mlir][python] bind block successors and predecessors
1 parent 2dfcc43 commit c760955

File tree

4 files changed

+144
-4
lines changed

4 files changed

+144
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
986986
MLIR_CAPI_EXPORTED void
987987
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
988988

989+
/// Returns the number of successor blocks of the block.
990+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
991+
992+
/// Returns `pos`-th successor of the block.
993+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
994+
intptr_t pos);
995+
996+
/// Returns the number of predecessor blocks of the block.
997+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
998+
999+
/// Returns `pos`-th predecessor of the block.
1000+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
1001+
intptr_t pos);
1002+
9891003
//===----------------------------------------------------------------------===//
9901004
// Value API.
9911005
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,84 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
26262626
PyOperationRef operation;
26272627
};
26282628

2629+
/// A list of block successors. Internally, these are stored as consecutive
2630+
/// elements, random access is cheap. The (returned) successor list is
2631+
/// associated with the operation and block whose successors these are, and thus
2632+
/// extends the lifetime of this operation and block.
2633+
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2634+
public:
2635+
static constexpr const char *pyClassName = "BlockSuccessors";
2636+
2637+
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2638+
intptr_t startIndex = 0, intptr_t length = -1,
2639+
intptr_t step = 1)
2640+
: Sliceable(startIndex,
2641+
length == -1 ? mlirBlockGetNumSuccessors(block.get())
2642+
: length,
2643+
step),
2644+
operation(operation), block(block) {}
2645+
2646+
private:
2647+
/// Give the parent CRTP class access to hook implementations below.
2648+
friend class Sliceable<PyBlockSuccessors, PyBlock>;
2649+
2650+
intptr_t getRawNumElements() {
2651+
block.checkValid();
2652+
return mlirBlockGetNumSuccessors(block.get());
2653+
}
2654+
2655+
PyBlock getRawElement(intptr_t pos) {
2656+
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2657+
return PyBlock(operation, block);
2658+
}
2659+
2660+
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2661+
return PyBlockSuccessors(block, operation, startIndex, length, step);
2662+
}
2663+
2664+
PyOperationRef operation;
2665+
PyBlock block;
2666+
};
2667+
2668+
/// A list of block predecessors. The (returned) predecessor list is
2669+
/// associated with the operation and block whose predecessors these are, and
2670+
/// thus extends the lifetime of this operation and block.
2671+
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
2672+
public:
2673+
static constexpr const char *pyClassName = "BlockPredecessors";
2674+
2675+
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
2676+
intptr_t startIndex = 0, intptr_t length = -1,
2677+
intptr_t step = 1)
2678+
: Sliceable(startIndex,
2679+
length == -1 ? mlirBlockGetNumPredecessors(block.get())
2680+
: length,
2681+
step),
2682+
operation(operation), block(block) {}
2683+
2684+
private:
2685+
/// Give the parent CRTP class access to hook implementations below.
2686+
friend class Sliceable<PyBlockPredecessors, PyBlock>;
2687+
2688+
intptr_t getRawNumElements() {
2689+
block.checkValid();
2690+
return mlirBlockGetNumPredecessors(block.get());
2691+
}
2692+
2693+
PyBlock getRawElement(intptr_t pos) {
2694+
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
2695+
return PyBlock(operation, block);
2696+
}
2697+
2698+
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
2699+
intptr_t step) {
2700+
return PyBlockPredecessors(block, operation, startIndex, length, step);
2701+
}
2702+
2703+
PyOperationRef operation;
2704+
PyBlock block;
2705+
};
2706+
26292707
/// A list of operation attributes. Can be indexed by name, producing
26302708
/// attributes, or by index, producing named attributes.
26312709
class PyOpAttributeMap {
@@ -3655,7 +3733,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553733
},
36563734
nb::arg("operation"),
36573735
"Appends an operation to this block. If the operation is currently "
3658-
"in another block, it will be moved.");
3736+
"in another block, it will be moved.")
3737+
.def_prop_ro(
3738+
"successors",
3739+
[](PyBlock &self) {
3740+
return PyBlockSuccessors(self, self.getParentOperation());
3741+
},
3742+
"Returns the list of Block successors.")
3743+
.def_prop_ro(
3744+
"predecessors",
3745+
[](PyBlock &self) {
3746+
return PyBlockPredecessors(self, self.getParentOperation());
3747+
},
3748+
"Returns the list of Block predecessors.");
36593749

36603750
//----------------------------------------------------------------------------
36613751
// Mapping of PyInsertionPoint.
@@ -4099,6 +4189,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40994189
PyBlockArgumentList::bind(m);
41004190
PyBlockIterator::bind(m);
41014191
PyBlockList::bind(m);
4192+
PyBlockSuccessors::bind(m);
4193+
PyBlockPredecessors::bind(m);
41024194
PyOperationIterator::bind(m);
41034195
PyOperationList::bind(m);
41044196
PyOpAttributeMap::bind(m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
10591059
unwrap(block)->print(stream);
10601060
}
10611061

1062+
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
1063+
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
1064+
}
1065+
1066+
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
1067+
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
1068+
}
1069+
1070+
intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
1071+
Block *b = unwrap(block);
1072+
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
1073+
}
1074+
1075+
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
1076+
Block *b = unwrap(block);
1077+
Block::pred_iterator it = b->pred_begin();
1078+
std::advance(it, pos);
1079+
return wrap(*it);
1080+
}
1081+
10621082
//===----------------------------------------------------------------------===//
10631083
// Value API.
10641084
//===----------------------------------------------------------------------===//

mlir/test/python/ir/blocks.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4-
import io
5-
import itertools
6-
from mlir.ir import *
4+
75
from mlir.dialects import builtin
86
from mlir.dialects import cf
97
from mlir.dialects import func
8+
from mlir.ir import *
109

1110

1211
def run(f):
@@ -54,10 +53,25 @@ def testBlockCreation():
5453
with InsertionPoint(middle_block) as middle_ip:
5554
assert middle_ip.block == middle_block
5655
cf.BranchOp([i32_arg], dest=successor_block)
56+
5757
module.print(enable_debug_info=True)
5858
# Ensure region back references are coherent.
5959
assert entry_block.region == middle_block.region == successor_block.region
6060

61+
assert len(entry_block.predecessors) == 0
62+
63+
assert len(entry_block.successors) == 1
64+
assert middle_block == entry_block.successors[0]
65+
assert len(middle_block.predecessors) == 1
66+
assert entry_block == middle_block.predecessors[0]
67+
68+
assert len(middle_block.successors) == 1
69+
assert successor_block == middle_block.successors[0]
70+
assert len(successor_block.predecessors) == 1
71+
assert middle_block == successor_block.predecessors[0]
72+
73+
assert len(successor_block.successors) == 0
74+
6175

6276
# CHECK-LABEL: TEST: testBlockCreationArgLocs
6377
@run

0 commit comments

Comments
 (0)