|
7 | 7 | #include <torch/csrc/jit/ir/ir.h>
|
8 | 8 | #include <torch/csrc/jit/ir/ir_views.h>
|
9 | 9 | #include <torch/csrc/jit/ir/irparser.h>
|
| 10 | +#include <torch/csrc/jit/passes/constant_propagation.h> |
10 | 11 | #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
11 | 12 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
12 | 13 | #include <torch/csrc/jit/runtime/graph_iterator.h>
|
@@ -246,5 +247,46 @@ TEST(ShapeAnalysisTest, DynamicShapesFusion) {
|
246 | 247 | }
|
247 | 248 | }
|
248 | 249 |
|
| 250 | +TEST(ShapeAnalysisTest, MovingConstantOutOfFusionGroups) { |
| 251 | + std::shared_ptr<Graph> subgraph = std::make_shared<Graph>(); |
| 252 | + const auto graph_string = R"IR( |
| 253 | + graph(%x.1 : Tensor): |
| 254 | + %none : NoneType = prim::Constant() |
| 255 | + %size1 : int = prim::Constant[value=1]() |
| 256 | + %size10 : int = prim::Constant[value=10]() |
| 257 | + %sizes : int[] = prim::ListConstruct(%size10, %size1) |
| 258 | + %device : Device = prim::Constant[value="cpu"]() |
| 259 | + %10 : Tensor = aten::ones(%sizes, %none, %none, %device, %none) |
| 260 | + %3 : Tensor = aten::tanh(%x.1) |
| 261 | + %29 : Tensor = aten::mul(%3, %10) |
| 262 | + return (%29))IR"; |
| 263 | + torch::jit::parseIR(graph_string, subgraph.get()); |
| 264 | + ConstantPropagation(subgraph); |
| 265 | + |
| 266 | + std::shared_ptr<Graph> g = std::make_shared<Graph>(); |
| 267 | + auto x_inp = g->addInput("x_inp"); |
| 268 | + auto x_type = TensorType::create(at::rand({10, 5})); |
| 269 | + x_inp->setType(x_type); |
| 270 | + subgraph->inputs().at(0)->setType(x_type); |
| 271 | + auto output = g->insertNode(g->create(prim::TensorExprGroup))->output(); |
| 272 | + output->node()->addInput(x_inp); |
| 273 | + output->node()->g_(attr::Subgraph, subgraph); |
| 274 | + |
| 275 | + auto success = GenerateGuard(output->node()); |
| 276 | + TORCH_INTERNAL_ASSERT(success); |
| 277 | + |
| 278 | + // Check that the constants have been moved out of the fused graph. |
| 279 | + // This should result in not have any conditionals other than the one |
| 280 | + // checking the result of TensorExprDynamicGuard. |
| 281 | + testing::FileCheck() |
| 282 | + .check("TensorExprDynamicGuard") |
| 283 | + ->check_next("prim::If") |
| 284 | + ->check_not("prim::If") // no other IFs due to constants. |
| 285 | + ->check("TensorExprGroup") |
| 286 | + ->check("block1") |
| 287 | + ->check("FallbackGraph") |
| 288 | + ->run(*g); |
| 289 | +} |
| 290 | + |
249 | 291 | } // namespace jit
|
250 | 292 | } // namespace torch
|
0 commit comments