Skip to content

Commit 040aec9

Browse files
[lib/conversion] Create seed only if needed in `convert-torch-convers… (#3926)
…ion-to-mlprogram` pass This PR changes `convert-torch-conversion-to-mlprogram` pass implementation by moving seed generation inside `ConvertGetNextSeedOp` pattern. Previously, global seed was being created by this pass, even when its only consumer `torch_c.get_next_seed` op is not present in the IR. This pass is part of Torch->Linalg conversion pipeline. Always creating global seed created an issue for the case when downstream compiler doesn't expect/support `ml_program` dialect in linalg on tensor IR format. However, when starting torch IR has `torch_c.get_next_seed` op, `ml_program` will still be present and will need to be handled by downstream compilers.
1 parent 62eb38b commit 040aec9

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
5959
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
6060
ConversionPatternRewriter &rewriter) const override {
6161
Location loc = op.getLoc();
62+
63+
// Check for global seed and create if it doesn't exist.
64+
auto module = op->getParentOfType<ModuleOp>();
65+
OpBuilder b(module.getBodyRegion());
66+
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
67+
return failure();
68+
6269
// Generate sequence for getting the next seed with LCG step:
6370
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
6471
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
@@ -115,11 +122,6 @@ class ConvertTorchConversionToMLProgram
115122
typeConverter.addConversion([](Type type) { return type; });
116123
TorchConversion::setupBackendTypeConversion(target, typeConverter);
117124

118-
auto module = getOperation();
119-
OpBuilder b(module.getBodyRegion());
120-
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
121-
signalPassFailure();
122-
123125
RewritePatternSet patterns(context);
124126
target.addIllegalOp<GetNextSeedOp>();
125127
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);

test/Conversion/TorchConversionToMLProgram/basic.mlir

+13
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,16 @@ module {
1717
return %seed : i64
1818
}
1919
}
20+
21+
// -----
22+
23+
module {
24+
func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> {
25+
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
26+
return %0 : !torch.vtensor<[2,3],f32>
27+
}
28+
}
29+
30+
// CHECK-NOT: ml_program.global
31+
// CHECK-LABEL: @no_seed_needed
32+
// CHECK-NEXT: torch_c.from_builtin_tensor

test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ module {
1111
func.func private @f7() -> i64
1212
}
1313

14-
// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
14+
// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
1515
// CHECK-NOT: @global_seed

0 commit comments

Comments
 (0)