Skip to content

Commit 8acfeca

Browse files
Jessica Linmalfet
andauthored
[1.6] Add optimizer_for_mobile doc into python api root doc (pytorch#41491)
* Add optimizer_for_mobile doc into python api root doc * Apply suggestions from code review Remove all references to `optimization_blacklist` as it's missing in 1.6 Co-authored-by: Nikita Shulga <[email protected]>
1 parent 860e18a commit 8acfeca

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
5656
torch.utils.cpp_extension <cpp_extension>
5757
torch.utils.data <data>
5858
torch.utils.dlpack <dlpack>
59+
torch.utils.mobile_optimizer <mobile_optimizer>
5960
torch.utils.model_zoo <model_zoo>
6061
torch.utils.tensorboard <tensorboard>
6162
type_info

docs/source/mobile_optimizer.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
torch.utils.mobile_optimizer
2+
===================================
3+
4+
.. warning::
5+
This API is in beta and may change in the near future.
6+
7+
Torch mobile supports ``torch.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode.
8+
The method takes the following parameters: a torch.jit.ScriptModule object, a blacklisting optimization set and a preserved method list
9+
10+
By default, if optimization blacklist is None or empty, ``optimize_for_mobile`` will run the following optimizations:
11+
- **Conv2D + BatchNorm fusion** (blacklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated.
12+
- **Insert and Fold prepacked ops** (blacklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops.
13+
- **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together.
14+
- **Dropout removal** (blacklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false.
15+
16+
``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method.
17+
18+
19+
.. currentmodule:: torch.utils.mobile_optimizer
20+
.. autofunction:: optimize_for_mobile

torch/utils/mobile_optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ class LintCode(Enum):
1616
def optimize_for_mobile(script_module, optimization_blacklist: Set[MobileOptimizerType] = None):
1717
"""
1818
Args:
19-
script_module: An instance of torch script module with type of ScriptModule
20-
optimization_blacklist: A set with type of MobileOptimizerType.
21-
When set is not passed, optimization method will run all the optimizer pass; otherwise, optimizer
22-
method will run the optimization pass that is not included inside optimization_blacklist.
19+
script_module: An instance of torch script module with type of ScriptModule.
20+
optimization_blacklist: A set with type of MobileOptimizerType. When set is not passed,
21+
optimization method will run all the optimizer pass; otherwise, optimizer
22+
method will run the optimization pass that is not included inside optimization_blacklist.
2323
Returns:
24-
script_module: A new optimized torch script module
24+
A new optimized torch script module
2525
"""
2626
if not isinstance(script_module, torch.jit.ScriptModule):
2727
raise TypeError(

0 commit comments

Comments
 (0)