Skip to content

Add DecomposeGruPass for ARM backend (#17137)#19463

Open
apullin wants to merge 2 commits into
pytorch:mainfrom
apullin:export-D92058313
Open

Add DecomposeGruPass for ARM backend (#17137)#19463
apullin wants to merge 2 commits into
pytorch:mainfrom
apullin:export-D92058313

Conversation

@apullin
Copy link
Copy Markdown
Contributor

@apullin apullin commented May 11, 2026

Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
h_t = n_t + z_t * (h_{t-1} - n_t)

Features:

  • Multi-layer GRU support
  • Bidirectional GRU support
  • With/without bias
  • batch_first support
  • Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani

Andrew Pullin added 2 commits May 11, 2026 11:49
Summary:

Adds quantizable versions of GRU and RNN modules that can be used with
PyTorch quantization-aware training (QAT) for the ARM backend.

The standard nn.GRU and nn.RNN are opaque composite ops that the quantizer
cannot annotate. These modules decompose the RNN operations into
nn.Linear + FloatFunctional so that QAT observers can be inserted at
each arithmetic boundary.

## New modules:
- `GRUCell`, `_GRUSingleLayer`, `_GRULayer`, `GRU`
- `RNNCell`, `_RNNSingleLayer`, `_RNNLayer`, `RNN`

## Features:
- `from_float()` class method to convert from nn.GRU/nn.RNN
- Multi-layer support
- Bidirectional support
- Both tanh and relu nonlinearities (for RNN)

## Usage:
```python
from executorch.backends.arm.quantizable import GRU, RNN

# Create quantizable GRU
model = GRU(input_size=10, hidden_size=20, num_layers=2)

# Or convert from existing nn.GRU
eager_model = torch.nn.GRU(10, 20, 2)
eager_model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
quantizable_model = GRU.from_float(eager_model)
```

Differential Revision: D92059608
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
@apullin apullin force-pushed the export-D92058313 branch from ac16b1c to 60d56eb Compare May 11, 2026 18:49
@apullin apullin requested a review from digantdesai as a code owner May 11, 2026 18:49
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19463

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 1 Unrelated Failure

As of commit 60d56eb with merge base 126507c (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 11, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 11, 2026

@apullin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92058313.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 11, 2026

Workflows were awaiting approval. CI has now been triggered for the ciflow labels on this PR.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported module: arm Issues related to arm backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant