Skip to content

Commit 992ad4d

Browse files
Add a shim layer to allow usage of odml_torch (google-ai-edge#121)
Co-authored-by: Google AI Edge <[email protected]>
1 parent 7300dd1 commit 992ad4d

File tree

129 files changed

+2471
-34913
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+2471
-34913
lines changed

.github/workflows/formatting.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
ref: ${{ inputs.trigger-sha }}
3030
- name: Install dependencies
3131
run: |
32-
pip3 install pyink isort
32+
pip3 install pyink
3333
- name: Check code style
3434
run: |
3535
ci/test_code_style.sh

WORKSPACE

+4-4
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ http_archive(
4646
urls = [
4747
"https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"
4848
],
49-
build_file = "@//third_party:sentencepiece.BUILD",
50-
patches = ["@//third_party:com_google_sentencepiece.diff"],
49+
build_file = "@//bazel:sentencepiece.BUILD",
50+
patches = ["@//bazel:com_google_sentencepiece.diff"],
5151
patch_args = ["-p1"],
5252
)
5353

5454
http_archive(
5555
name = "darts_clone",
56-
build_file = "@//third_party:darts_clone.BUILD",
56+
build_file = "@//bazel:darts_clone.BUILD",
5757
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
5858
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
5959
urls = [
@@ -92,7 +92,7 @@ http_archive(
9292
"https://github.com/tensorflow/tensorflow/archive/%s.tar.gz" % _TENSORFLOW_GIT_COMMIT,
9393
],
9494
patches = [
95-
"@//third_party:org_tensorflow_system_python.diff",
95+
"@//bazel:org_tensorflow_system_python.diff",
9696
],
9797
patch_args = [
9898
"-p1",

ai_edge_torch/README.md

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11

22
* Documentation of the [PyTorch Converter](../docs/pytorch_converter/README.md)
33
* Documentation of the [Generative API](generative/)
4-

ai_edge_torch/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from .convert.converter import convert
17-
from .convert.converter import signature
18-
from .convert.to_channel_last_io import to_channel_last_io
19-
from .model import Model
20-
from .version import __version__
16+
from ai_edge_torch._convert.converter import convert
17+
from ai_edge_torch._convert.converter import signature
18+
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19+
from ai_edge_torch.model import Model
20+
from ai_edge_torch.version import __version__
2121

2222

2323
def load(path: str) -> Model:

ai_edge_torch/convert/conversion.py renamed to ai_edge_torch/_convert/conversion.py

+40-50
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,44 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
import gc
1716
import logging
1817
import os
19-
from typing import Optional
18+
from typing import Any, Optional
2019

20+
from ai_edge_torch import lowertools
2121
from ai_edge_torch import model
22-
from ai_edge_torch.convert import conversion_utils as cutils
23-
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
24-
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
25-
from ai_edge_torch.convert.fx_passes import CanonicalizePass
26-
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
27-
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
28-
from ai_edge_torch.convert.fx_passes import run_passes
29-
from ai_edge_torch.generative.fx_passes import run_generative_passes
22+
from ai_edge_torch._convert import fx_passes
23+
from ai_edge_torch._convert import signature
24+
from ai_edge_torch.generative import fx_passes as generative_fx_passes
3025
from ai_edge_torch.quantize import quant_config as qcfg
3126
import torch
32-
from torch.export import ExportedProgram
33-
from torch_xla import stablehlo
3427

3528
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
3629

3730

3831
def _run_convert_passes(
39-
exported_program: ExportedProgram,
40-
) -> ExportedProgram:
41-
exported_program = run_generative_passes(exported_program)
42-
return run_passes(
32+
exported_program: torch.export.ExportedProgram,
33+
) -> torch.export.ExportedProgram:
34+
exported_program = generative_fx_passes.run_generative_passes(
35+
exported_program
36+
)
37+
return fx_passes.run_passes(
4338
exported_program,
4439
[
45-
BuildInterpolateCompositePass(),
46-
CanonicalizePass(),
47-
OptimizeLayoutTransposesPass(),
48-
CanonicalizePass(),
49-
BuildAtenCompositePass(),
50-
CanonicalizePass(),
51-
InjectMlirDebuginfoPass(),
52-
CanonicalizePass(),
40+
fx_passes.BuildInterpolateCompositePass(),
41+
fx_passes.CanonicalizePass(),
42+
fx_passes.OptimizeLayoutTransposesPass(),
43+
fx_passes.CanonicalizePass(),
44+
fx_passes.BuildAtenCompositePass(),
45+
fx_passes.CanonicalizePass(),
46+
fx_passes.InjectMlirDebuginfoPass(),
47+
fx_passes.CanonicalizePass(),
5348
],
5449
)
5550

5651

57-
def _warn_training_modules(signatures: list[cutils.Signature]):
52+
def _warn_training_modules(signatures: list[signature.Signature]):
53+
"""Warns the user if the module is in training mode (.eval not called)."""
5854
for sig in signatures:
5955
if not sig.module.training:
6056
continue
@@ -64,30 +60,39 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
6460
" module in evaluation mode with `module.eval()` for better on-device"
6561
" performance and compatibility."
6662
)
67-
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
63+
if len(signatures) == 1 and sig.name == model.DEFAULT_SIGNATURE_NAME:
6864
# User does not specify any signature names explicitly.
6965
message = message.format(sig_name="")
7066
else:
7167
message = message.format(sig_name=f'"{sig.name}" ')
7268

73-
logging.warn(message)
69+
logging.warning(message)
7470

7571

7672
def convert_signatures(
77-
signatures: list[cutils.Signature],
73+
signatures: list[signature.Signature],
7874
*,
7975
quant_config: Optional[qcfg.QuantConfig] = None,
80-
_tfl_converter_flags: dict = {},
76+
_tfl_converter_flags: Optional[dict[str, Any]],
8177
) -> model.TfLiteModel:
82-
"""Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
78+
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
79+
8380
Args:
84-
signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
81+
signatures: The list of 'signature.Signature' objects containing PyTorch
82+
modules to be converted.
8583
quant_config: User-defined quantization method and scheme of the model.
86-
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
84+
_tfl_converter_flags: A nested dictionary allowing setting flags for the
85+
underlying tflite converter.
86+
87+
Returns:
88+
The converted `model.TfLiteModel` object.
8789
"""
90+
if _tfl_converter_flags is None:
91+
_tfl_converter_flags = {}
92+
8893
_warn_training_modules(signatures)
8994

90-
exported_programs: torch.export.ExportedProgram = [
95+
exported_programs: torch.export.torch.export.ExportedProgram = [
9196
torch.export.export(
9297
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
9398
)
@@ -96,23 +101,8 @@ def convert_signatures(
96101

97102
# Apply default fx passes
98103
exported_programs = list(map(_run_convert_passes, exported_programs))
99-
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
100-
cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
101-
for exported, sig in zip(exported_programs, signatures)
102-
]
103-
104-
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
105-
cutils.merge_stablehlo_bundles(
106-
shlo_bundles, signatures, exported_programs
107-
)
108-
)
109-
del exported_programs
110-
del shlo_bundles
111-
112-
gc.collect()
113-
114-
tflite_model = cutils.convert_stablehlo_to_tflite(
115-
merged_shlo_graph_module,
104+
tflite_model = lowertools.exported_programs_to_tflite(
105+
exported_programs,
116106
signatures,
117107
quant_config=quant_config,
118108
_tfl_converter_flags=_tfl_converter_flags,
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from typing import Any
17+
18+
from ai_edge_torch.quantize import quant_config as qcfg
19+
import tensorflow as tf
20+
21+
22+
def apply_tfl_converter_flags(
23+
converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict[str, Any]
24+
):
25+
"""Applies TFLite converter flags to the converter.
26+
27+
Args:
28+
converter: TFLite converter.
29+
tfl_converter_flags: TFLite converter flags.
30+
"""
31+
32+
def _set_converter_flag(path: list[Any]):
33+
if len(path) < 2:
34+
raise ValueError("Expecting at least two values in the path.")
35+
36+
target_obj = converter
37+
for idx in range(len(path) - 2):
38+
target_obj = getattr(target_obj, path[idx])
39+
40+
setattr(target_obj, path[-2], path[-1])
41+
42+
def _iterate_dict_tree(flags_dict: dict[str, Any], path: list[Any]):
43+
for key, value in flags_dict.items():
44+
path.append(key)
45+
if isinstance(value, dict):
46+
_iterate_dict_tree(value, path)
47+
else:
48+
path.append(value)
49+
_set_converter_flag(path)
50+
path.pop()
51+
path.pop()
52+
53+
_iterate_dict_tree(tfl_converter_flags, [])
54+
55+
56+
def set_tfl_converter_quant_flags(
57+
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
58+
):
59+
if quant_config is not None:
60+
quantizer_mode = quant_config._quantizer_mode
61+
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
62+
converter._experimental_qdq_conversion_mode = "DYNAMIC"
63+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
64+
converter._experimental_qdq_conversion_mode = "STATIC"

0 commit comments

Comments
 (0)