|
1 | 1 | import copy
|
| 2 | +import logging |
2 | 3 | import operator
|
3 | 4 | import warnings
|
4 |
| -from typing import Any |
| 5 | +from typing import Any, Optional |
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 | import torch.fx
|
| 9 | +import torch.fx as fx |
| 10 | +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils |
8 | 11 | from torch.fx.experimental.const_fold import split_const_subgraphs
|
9 | 12 |
|
10 | 13 | from ..observer import observable
|
|
13 | 16 | from ..tracer.acc_tracer.acc_utils import get_attr
|
14 | 17 | from .pass_utils import log_before_after, validate_inference
|
15 | 18 |
|
| 19 | +_LOGGER = logging.getLogger(__name__) |
| 20 | + |
16 | 21 | # Create an alias for module input type to avoid littering pyre-ignore for Any
|
17 | 22 | # throughout the file.
|
18 | 23 | Input = Any
|
@@ -460,3 +465,146 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
|
460 | 465 | gm.graph.lint()
|
461 | 466 | gm.recompile()
|
462 | 467 | return gm
|
| 468 | + |
| 469 | + |
| 470 | +def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule: |
| 471 | + """\ |
| 472 | + TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256), |
| 473 | + since the dynamic shape of the reshape comes from the dynamic shape of |
| 474 | + another node (y). The compilation will fail with various memory related |
| 475 | + errors, depending on the size of the input tensor. |
| 476 | +
|
| 477 | + This pass fixes the issue by finding this reshape pattern, checking that: |
| 478 | +
|
| 479 | + x.size(0) == y.size(0) |
| 480 | +
|
| 481 | + And then replaces reshape's batch size from y.size(0) to x.size(0). |
| 482 | + """ |
| 483 | + |
| 484 | + def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: |
| 485 | + """\ |
| 486 | + Try to find the reshape op's batch size as an input node. |
| 487 | +
|
| 488 | + Match below graph structure and return `node_y`: |
| 489 | + node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}}) |
| 490 | + """ |
| 491 | + if ( |
| 492 | + maybe_reshape.op != "call_function" |
| 493 | + or maybe_reshape.target != acc_ops.reshape |
| 494 | + ): |
| 495 | + return None |
| 496 | + shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None) |
| 497 | + if not shape: |
| 498 | + return None |
| 499 | + batch_size = shape[0] |
| 500 | + if isinstance(batch_size, fx.Node): |
| 501 | + return batch_size |
| 502 | + return None |
| 503 | + |
| 504 | + def get_reshape_batch_size_inferred_source( |
| 505 | + batch_size_node: fx.Node, |
| 506 | + ) -> Optional[fx.Node]: |
| 507 | + """\ |
| 508 | + Given a node representing the batch size used for reshape op, we want |
| 509 | + to know if it is coming from below pattern: |
| 510 | +
|
| 511 | + batch_size_node = src.size()[0] |
| 512 | +
|
| 513 | + or in IR graph: |
| 514 | +
|
| 515 | + src -> size(input=_) -> getitem(input=_, idx=0) |
| 516 | + ^ ~~~ batch_size_node |
| 517 | +
|
| 518 | + If so, return `src`. Otherwise, return `None`. |
| 519 | + """ |
| 520 | + if ( |
| 521 | + batch_size_node.op != "call_function" |
| 522 | + or batch_size_node.target != acc_ops.getitem |
| 523 | + or batch_size_node.kwargs["idx"] != 0 |
| 524 | + ): |
| 525 | + return None |
| 526 | + maybe_size: fx.Node = batch_size_node.all_input_nodes[0] |
| 527 | + if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size: |
| 528 | + return None |
| 529 | + return maybe_size.all_input_nodes[0] |
| 530 | + |
| 531 | + maybe_reshape: fx.Node |
| 532 | + for maybe_reshape in mod.graph.nodes: |
| 533 | + reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( |
| 534 | + maybe_reshape |
| 535 | + ) |
| 536 | + if not reshape_batch_size: |
| 537 | + continue |
| 538 | + reshape_batch_size_inferred_source: Optional[ |
| 539 | + fx.Node |
| 540 | + ] = get_reshape_batch_size_inferred_source(reshape_batch_size) |
| 541 | + if not reshape_batch_size_inferred_source: |
| 542 | + continue |
| 543 | + |
| 544 | + reshape_input: fx.Node = maybe_reshape.kwargs["input"] |
| 545 | + if reshape_input == reshape_batch_size_inferred_source: |
| 546 | + continue |
| 547 | + |
| 548 | + if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source): |
| 549 | + continue |
| 550 | + |
| 551 | + _LOGGER.info( |
| 552 | + f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}" |
| 553 | + ) |
| 554 | + |
| 555 | + # Step 1: create a node to compute batch size, using the tensor which |
| 556 | + # is being reshaped: reshape_input.size()[0]. This batch size is now |
| 557 | + # derived from reshape_input, the same node as the reshape op's input. |
| 558 | + with mod.graph.inserting_before(maybe_reshape): |
| 559 | + reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function( |
| 560 | + acc_ops.getitem, |
| 561 | + kwargs={ |
| 562 | + "idx": 0, |
| 563 | + "input": maybe_reshape.graph.call_function( |
| 564 | + acc_ops.size, |
| 565 | + kwargs={ |
| 566 | + "input": reshape_input, |
| 567 | + }, |
| 568 | + ), |
| 569 | + }, |
| 570 | + ) |
| 571 | + |
| 572 | + # Step 2: update `maybe_reshape`'s shape argument to be |
| 573 | + # (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER) |
| 574 | + maybe_reshape.kwargs = { |
| 575 | + **maybe_reshape.kwargs, |
| 576 | + "acc_out_ty": acc_utils.build_raw_tensor_meta( |
| 577 | + shape=( |
| 578 | + reshape_batch_size_2, |
| 579 | + *(maybe_reshape.kwargs["acc_out_ty"].shape[1:]), |
| 580 | + ) |
| 581 | + ), |
| 582 | + } |
| 583 | + |
| 584 | + mod.graph.eliminate_dead_code() |
| 585 | + mod.recompile() |
| 586 | + return mod |
| 587 | + |
| 588 | + |
| 589 | +def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool: |
| 590 | + """\ |
| 591 | + Check that x.size(0) == y.size(0) |
| 592 | + """ |
| 593 | + x_size, y_size = _get_shape(x), _get_shape(y) |
| 594 | + return ( |
| 595 | + x_size |
| 596 | + and y_size |
| 597 | + # now both are non-empty |
| 598 | + and x_size[0] == y_size[0] |
| 599 | + ) |
| 600 | + |
| 601 | + |
| 602 | +def _get_shape(node: fx.Node) -> Optional[torch.Size]: |
| 603 | + if ( |
| 604 | + not getattr(node, "meta", None) |
| 605 | + or not node.meta.get("tensor_meta", None) |
| 606 | + or not getattr(node.meta["tensor_meta"], "shape", None) |
| 607 | + ): |
| 608 | + # shape info not available |
| 609 | + return None |
| 610 | + return node.meta["tensor_meta"].shape |
0 commit comments