Add DecomposeGruPass for ARM backend (#17137)#19463
Conversation
Summary:
Adds quantizable versions of GRU and RNN modules that can be used with
PyTorch quantization-aware training (QAT) for the ARM backend.
The standard nn.GRU and nn.RNN are opaque composite ops that the quantizer
cannot annotate. These modules decompose the RNN operations into
nn.Linear + FloatFunctional so that QAT observers can be inserted at
each arithmetic boundary.
## New modules:
- `GRUCell`, `_GRUSingleLayer`, `_GRULayer`, `GRU`
- `RNNCell`, `_RNNSingleLayer`, `_RNNLayer`, `RNN`
## Features:
- `from_float()` class method to convert from nn.GRU/nn.RNN
- Multi-layer support
- Bidirectional support
- Both tanh and relu nonlinearities (for RNN)
## Usage:
```python
from executorch.backends.arm.quantizable import GRU, RNN
# Create quantizable GRU
model = GRU(input_size=10, hidden_size=20, num_layers=2)
# Or convert from existing nn.GRU
eager_model = torch.nn.GRU(10, 20, 2)
eager_model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")
quantizable_model = GRU.from_float(eager_model)
```
Differential Revision: D92059608
Summary:
Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).
GRU cell equations per timestep:
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
h_t = n_t + z_t * (h_{t-1} - n_t)
Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
Differential Revision: D92058313
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19463
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 1 Unrelated FailureAs of commit 60d56eb with merge base 126507c ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
This PR needs a
|
Summary:
Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).
GRU cell equations per timestep:
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
h_t = n_t + z_t * (h_{t-1} - n_t)
Features:
Differential Revision: D92058313
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani