Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 229 additions & 4 deletions docsrc/contributors/complex_number_support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,244 @@ runtime modules handle the conversion:
Key Implementation Invariants
-------------------------------

* **``originally_complex`` set** — the set of nodes that were complex-dtype
*before* any rewrites. After ``replace_input_node``, complex placeholders become
``float32`` so ``is_complex_dtype()`` returns ``False``. The ``originally_complex``
set is used to decide which ``mul.Tensor`` nodes need the complex mul rewrite.
* **``node.meta["is_complex_layout"]``** — every node that represents a complex
quantity (either originally complex-dtype, or a real ``(..., 2)`` tensor produced
by the rewriter) is annotated with ``node.meta["is_complex_layout"] = True``.
This annotation is set during the detection phase (before any rewrites begin) and
propagated by every rewrite handler as it emits new nodes. It survives dtype
changes: after ``replace_input_node`` converts a ``placeholder`` from complex to
``float32``, the dtype-based check ``is_complex_dtype()`` would return ``False``,
but the metadata flag remains. ``_is_complex_layout_node(n)`` is simply
``n.meta.get("is_complex_layout", False)`` — no shape heuristics or recursion.
* **FakeTensorMode reuse** — ``propagate_metadata`` must use the ``FakeTensorMode``
from existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch
errors under ``torch.compile`` and to preserve SymInt for dynamic shapes.
* **Dotted buffer names** — ``register_buffer`` rejects names containing ``.``.
Nested submodule parameter names (e.g. ``layers.0.weight``) must have ``.``
replaced with ``__`` before registration.

The Decomposition System — How It Is Built
-------------------------------------------

The rewriter is split across two classes and wired together by a lightweight
dispatch mechanism. This section walks through each piece and explains the
design decisions.

ComplexOpDetector — Subgraph Discovery
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``ComplexOpDetector`` walks the graph to find the set of nodes that participate
in complex arithmetic.

``node_include_in_subgraph``
""""""""""""""""""""""""""""

A node is included in a complex subgraph if:

1. Its output dtype is ``complex64`` or ``complex128`` (``is_complex_dtype``), **or**
2. Any of its inputs are complex (``has_complex_input``).

The second condition is necessary to catch real-output ops — ``abs``, ``angle``,
``real``, ``imag`` — whose inputs are complex. These must be rewritten alongside
the rest of the subgraph even though their outputs are real.

``subgraph_from_anchor``
""""""""""""""""""""""""

For ``view_as_real``-bounded subgraphs, detection starts at a ``view_as_real``
*anchor* node and performs a backward BFS:

.. code-block:: text

view_as_real ← mul (complex) ← reshape ← placeholder (complex)
↑ anchor ↑ subgraph ↑ subgraph ↑ input

At each step, if an upstream node satisfies ``node_include_in_subgraph`` it is
added to the subgraph; otherwise it becomes an *input node* (the boundary). The
result is a ``ComplexSubGraphInfo`` containing anchor nodes, subgraph nodes, and
input nodes.

After collection the subgraph is **sorted in topological order** (by position in
the graph's node list). This is critical: without it a ``mul`` node could be
processed before its ``sin`` or ``cos`` operands, causing the rewriter to see the
original complex node instead of the already-rewritten real node.

``find_complex_op_subgraphs`` and subgraph merging
"""""""""""""""""""""""""""""""""""""""""""""""""""

When a model has multiple ``view_as_real`` anchors that share upstream nodes
(e.g. ``xq_out`` and ``xk_out`` in a RoPE layer both descend from the same
``freqs_cis`` placeholder), their subgraphs would otherwise be detected
separately. ``find_complex_op_subgraphs`` merges overlapping subgraphs by
set intersection so each node is rewritten exactly once.

``find_all_complex_subgraphs`` — unbounded complex ops
"""""""""""""""""""""""""""""""""""""""""""""""""""""""

Some models produce a complex tensor as a graph *output* without passing it
through ``view_as_real``. ``find_all_complex_subgraphs`` is a forward scan that
collects every ``call_function`` node with a complex output, regardless of
anchoring. The resulting subgraph is processed the same way as an
anchor-bounded one.

ComplexGraphRewriter — Dispatch-Based Rewriting
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``ComplexGraphRewriter`` is decorated with ``@_register_unpackers``, which at
class-definition time scans every method for the ``@_complex_unpacker(op, ...)``
decorator and builds a ``cls._DISPATCH`` dictionary mapping aten ops to rewrite
methods.

.. code-block:: python

@_complex_unpacker(torch.ops.aten.mul.Tensor)
def _rewrite_mul(self, node: Node, b: SubgraphBuilder, ...):
...

The entry point ``rewrite_subgraph_nodes`` iterates over the (topologically
ordered) subgraph nodes and for each node:

1. Looks up ``node.target`` in ``_DISPATCH``.
2. If found, calls the corresponding rewrite method.
3. If not found but the op is in ``_ELEMENTWISE_SAFE``, skips it (the op applies
independently to every scalar, so the ``(..., 2)`` real layout is already
correct).
4. Otherwise logs a warning and leaves the node unchanged.

``_ELEMENTWISE_SAFE``
"""""""""""""""""""""

The ``_ELEMENTWISE_SAFE`` set contains ops that apply to every element of the
tensor independently — ``add.Tensor``, ``sub.Tensor``, ``neg``, ``mul.Scalar``,
``clone``, ``where``, etc. On the ``(..., 2)`` real layout these are already
correct: adding two complex tensors element-wise is the same as adding their
real and imaginary parts independently.

Notably **excluded** from this set:

* ``permute.default`` — must append the trailing real/imag dim index.
* ``add.Scalar`` / ``sub.Scalar`` — a scalar added to a complex number only
shifts the real part; on the ``(..., 2)`` layout both parts would be shifted.
* ``reshape`` / ``view`` — shape arguments need updating for the extra ``2`` dim.

Complex Multiply Decomposition
"""""""""""""""""""""""""""""""

The most important rewrite is ``mul.Tensor`` between two complex operands.
The rewriter calls ``complex_mul_replacement``:

.. code-block:: python

# inputs a, b have shape (..., 2) — last dim is [real, imag]
re_a = select(a, -1, 0); im_a = select(a, -1, 1)
re_b = select(b, -1, 0); im_b = select(b, -1, 1)
real_out = re_a * re_b - im_a * im_b # ac - bd
imag_out = re_a * im_b + im_a * re_b # ad + bc
result = stack([real_out, imag_out], dim=-1)

Each step is inserted via a ``SubgraphBuilder`` anchored at the ``mul`` node,
so all six new nodes appear immediately after it in topological order. The
original ``mul`` node is then replaced and erased.

See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages
cursor-based insertion.

The ``is_complex_layout`` Metadata Invariant
"""""""""""""""""""""""""""""""""""""""""""""

Input replacement (Stage 2) converts complex ``placeholder`` nodes to
``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those
nodes even though they logically represent complex quantities.

To avoid missed rewrites, every node that represents a complex quantity is
annotated with ``node.meta["is_complex_layout"] = True`` during the detection
phase (lines in ``rewrite_subgraph_nodes`` before any rewrites begin). The
annotation is then propagated forward by every rewrite handler:

* ``replace_input_node`` stamps it on the new placeholder and ``get_attr`` nodes.
* ``_inline_cat_re_im`` stamps it on every ``[re_u, im_u]`` concatenation node,
covering all math handlers (``exp``, ``log``, ``sin``, ``mul``, etc.) at once.
* Each shape-manipulation handler (``reshape``, ``permute``, ``unsqueeze``,
``cat``, ``stack``, etc.) stamps it on its output node explicitly.

``_is_complex_layout_node(n)`` is therefore a direct metadata lookup — no shape
heuristics (``val.shape[-1] == 2``), no recursive ``_SHAPE_TRANSPARENT_OPS``
propagation. This also eliminates false-positives on real parameters that
coincidentally have a trailing dimension of size 2.

FakeTensorMode Reuse for Dynamic Shapes
"""""""""""""""""""""""""""""""""""""""""

When inserting a new ``placeholder`` for a complex input, the pass must populate
``meta["val"]`` with a ``FakeTensor`` of the new real shape. Using a fresh
``FakeTensorMode()`` would create a *new* ``ShapeEnv``, which is incompatible
with the one that ``torch.export`` used to encode dynamic shape constraints
(SymInt ranges).

The fix is to extract the ``FakeTensorMode`` from the *original* placeholder's
``meta["val"].fake_mode`` and reuse it. The new fake tensor is then constructed
by appending a concrete ``2`` to the symbolic shape list:

.. code-block:: python

orig_fake = input_node.meta["val"]
sym_shape = list(orig_fake.shape) + [2]
with orig_fake.fake_mode:
fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device)

This preserves all SymInt identity across the graph and keeps
dynamic-shape exports working correctly.

Entry Point: ``complex_graph_detection``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The public entry point called by the lowering pipeline is
``complex_graph_detection(gm, settings)``. It:

1. Instantiates ``ComplexOpDetector`` and ``ComplexGraphRewriter``.
2. Calls ``find_complex_op_subgraphs`` anchored on ``view_as_real`` to find
bounded complex subgraphs.
3. Calls ``find_all_complex_subgraphs`` for any remaining complex nodes that
are not ``view_as_real``-bounded.
4. For each subgraph:

a. Calls ``replace_input_node`` on every boundary input node (Stage 2).
b. Calls ``rewrite_subgraph_nodes`` on the ordered subgraph (Stage 3).
c. Calls ``clean_up_graph_after_modifications`` to remove dead nodes.

5. Returns the modified ``GraphModule``.

Adding New Op Rewrites
^^^^^^^^^^^^^^^^^^^^^^^

To teach the rewriter about a new complex op, add a method to
``ComplexGraphRewriter`` tagged with ``@_complex_unpacker``:

.. code-block:: python

@_complex_unpacker(torch.ops.aten.my_new_op.default)
def _rewrite_my_new_op(self, node: Node) -> bool:
inp = node.args[0]
with SubgraphBuilder(self.gm.graph, node) as b:
re = b(torch.ops.aten.select.int, inp, -1, 0)
im = b(torch.ops.aten.select.int, inp, -1, 1)
out = b(my_real_impl, re, im)
# If the output is still a complex-layout [..., 2] tensor, annotate it.
# (Not needed if using _inline_cat_re_im, which sets the flag automatically.)
out.meta["is_complex_layout"] = True
node.replace_all_uses_with(out)
self.gm.graph.erase_node(node)
return True

``@_register_unpackers`` (applied to the class) picks up the new entry
automatically at import time — no other registration is required.

If the new op is elementwise-safe on the ``(..., 2)`` layout (i.e. it acts
independently on every scalar), add it to ``_ELEMENTWISE_SAFE`` instead.

Related
-------

* :ref:`lowering` — the complex rewrite is a lowering pass.
* :ref:`subgraph_builder` — the ``SubgraphBuilder`` helper used in every rewrite method.
* :ref:`lowering_passes_catalog` — pass ordering and management.
3 changes: 2 additions & 1 deletion docsrc/tutorials/advanced_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Advanced Usage
==============

Step-by-step tutorials covering engine caching, quantization, custom kernels,
dynamic shapes, weight streaming, debugging, and more.
dynamic shapes, weight streaming, debugging, complex numerics, and more.

.. toctree::
:maxdepth: 2
Expand All @@ -14,5 +14,6 @@ dynamic shapes, weight streaming, debugging, and more.
weight_refit/index
runtime_opt/index
deployment/index
complex_numerics/index
Example: Distributed Inference <_rendered_examples/distributed_inference/index>
../indices/supported_ops
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ compilation.
This page explains what the rewriter does, which patterns are supported, and what
limitations to be aware of when compiling models with complex inputs.

.. seealso::

:doc:`../_rendered_examples/dynamo/torch_export_3d_rope` — a runnable
end-to-end example compiling a video-transformer 3D RoPE attention block
(CogVideoX / Wan / HunyuanVideo style) with dynamic T×H×W shapes.

----

How the Rewriter Works
Expand Down
10 changes: 10 additions & 0 deletions docsrc/tutorials/complex_numerics/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Complex Numerics
===================

Compatiblity support for numerical datatypes like complex numerics which are not natively supported by TensorRT

.. toctree::
:maxdepth: 1

complex_tensors
Example: 3D RoPE with Complex Numerics <../_rendered_examples/dynamo/torch_export_3d_rope>
1 change: 0 additions & 1 deletion docsrc/tutorials/deployment/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ complex-valued model support.
cross_compile_windows
Example: Cross-runtime Compilation for Windows <../_rendered_examples/dynamo/cross_runtime_compilation_for_windows>
distributed_inference
complex_tensors
1 change: 1 addition & 0 deletions docsrc/tutorials/extensibility/lowering/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ rewrite ATen ops before TensorRT compilation.
:maxdepth: 1

writing_dynamo_aten_lowering_passes
subgraph_builder
Loading