Skip to content

Commit 6bb40ca

Browse files
authored
Format source with new rules (google-ai-edge#118)
* update scripts and configs * fmt
1 parent 4c40530 commit 6bb40ca

File tree

95 files changed

+1529
-900
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+1529
-900
lines changed

.isort.cfg

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[settings]
2+
profile=google
3+
multi_line_output=7
4+
line_length=200
5+
skip=.downloads,venv,bazel
6+
known_third_party=ai_edge_torch
7+
known_internal=tensorflow.python.platform,tensorflow.compiler.tf2xla.python
8+
default_section=THIRDPARTY
9+
sections=FUTURE,STDLIB,LOCALFOLDER,THIRDPARTY,INTERNAL

ai_edge_torch/convert/conversion.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
import os
1919
from typing import Optional
2020

21-
import torch
22-
from torch.export import ExportedProgram
23-
from torch_xla import stablehlo
24-
2521
from ai_edge_torch import model
2622
from ai_edge_torch.convert import conversion_utils as cutils
2723
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
@@ -32,6 +28,9 @@
3228
from ai_edge_torch.convert.fx_passes import run_passes
3329
from ai_edge_torch.generative.fx_passes import run_generative_passes
3430
from ai_edge_torch.quantize import quant_config as qcfg
31+
import torch
32+
from torch.export import ExportedProgram
33+
from torch_xla import stablehlo
3534

3635
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
3736

@@ -61,8 +60,9 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
6160
continue
6261

6362
message = (
64-
"Your model {sig_name}is converted in training mode. "
65-
"Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
63+
"Your model {sig_name}is converted in training mode. Please set the"
64+
" module in evaluation mode with `module.eval()` for better on-device"
65+
" performance and compatibility."
6666
)
6767
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
6868
# User does not specify any signature names explicitly.
@@ -88,7 +88,9 @@ def convert_signatures(
8888
_warn_training_modules(signatures)
8989

9090
exported_programs: torch.export.ExportedProgram = [
91-
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
91+
torch.export.export(
92+
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
93+
)
9294
for sig in signatures
9395
]
9496

@@ -100,7 +102,9 @@ def convert_signatures(
100102
]
101103

102104
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
103-
cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
105+
cutils.merge_stablehlo_bundles(
106+
shlo_bundles, signatures, exported_programs
107+
)
104108
)
105109
del exported_programs
106110
del shlo_bundles

ai_edge_torch/convert/conversion_utils.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
import tempfile
2323
from typing import Any, Dict, List, Optional, Tuple, Union
2424

25+
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
26+
from ai_edge_torch.quantize import quant_config as qcfg
2527
import torch
2628
import torch.utils._pytree as pytree
2729
from torch_xla import stablehlo
2830

29-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
30-
from ai_edge_torch.quantize import quant_config as qcfg
31-
3231
try:
3332
import tensorflow as tf
33+
3434
from tensorflow.compiler.tf2xla.python import xla as tfxla
3535

3636
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
@@ -90,18 +90,20 @@ def _flat_kwarg_names(self, specs, context) -> List[str]:
9090
if context is None:
9191
for i, spec in enumerate(specs):
9292
if spec.children_specs:
93-
flat_names.extend(
94-
[
95-
f"{i}_{name}"
96-
for name in self._flat_kwarg_names(spec.children_specs, spec.context)
97-
]
98-
)
93+
flat_names.extend([
94+
f"{i}_{name}"
95+
for name in self._flat_kwarg_names(
96+
spec.children_specs, spec.context
97+
)
98+
])
9999
else:
100100
flat_names.append(f"{i}")
101101
else:
102102
flat_ctx = self._flatten_list(context)
103103
for prefix, spec in zip(flat_ctx, specs):
104-
leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
104+
leaf_flat_names = self._flat_kwarg_names(
105+
spec.children_specs, spec.context
106+
)
105107
if leaf_flat_names:
106108
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
107109
else:
@@ -125,7 +127,8 @@ def flat_args(self) -> tuple[Any]:
125127

126128

127129
def exported_program_to_stablehlo_bundle(
128-
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
130+
exported_program: torch.export.ExportedProgram,
131+
sample_args: tuple[torch.Tensor],
129132
) -> stablehlo.StableHLOModelBundle:
130133
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
131134
# to a numpy array which would lead to memory bloat. This means that the state_dict
@@ -146,15 +149,18 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
146149
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
147150
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
148151
except Exception:
149-
logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
152+
logging.info(
153+
"Can not use dlpack to convert torch tensors. Falling back to numpy."
154+
)
150155
nparray = torch_tensor.cpu().detach().numpy()
151156
tf_tensor = tf.convert_to_tensor(nparray)
152157

153158
return tf_tensor
154159

155160

156161
def _get_states(
157-
exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
162+
exported_programs: list[torch.export.ExportedProgram],
163+
signatures: list[Signature],
158164
):
159165
for exported_program, signature in zip(exported_programs, signatures):
160166
args, _ = exported_program.example_inputs
@@ -166,7 +172,8 @@ def _get_states(
166172
# Only interested in Tensors that are part of the state (and not user input).
167173
if (
168174
not isinstance(tensor, torch.Tensor)
169-
or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
175+
or input_spec.kind
176+
== torch.export.graph_signature.InputKind.USER_INPUT
170177
):
171178
continue
172179
yield signature, tensor, input_spec
@@ -192,9 +199,13 @@ def _gather_state_dict(
192199
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
193200

194201
state_dict = {}
195-
for signature, tensor, input_spec in _get_states(exported_programs, signatures):
202+
for signature, tensor, input_spec in _get_states(
203+
exported_programs, signatures
204+
):
196205
unique_id = _tensor_unique_id(tensor)
197-
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
206+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
207+
unique_id
208+
]
198209

199210
return state_dict
200211

@@ -236,7 +247,9 @@ def _wrap_as_tf_func(
236247
):
237248
def inner(*args):
238249
type_info = [sig.dtype for sig in func.meta.output_signature]
239-
shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
250+
shape_info = [
251+
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
252+
]
240253
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
241254
return tfxla.call_module(
242255
tuple(call_args),
@@ -369,7 +382,9 @@ def convert_stablehlo_to_tflite(
369382
)
370383
)
371384

372-
tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
385+
tf_module._variables = (
386+
list(bundle.state_dict.values()) + bundle.additional_constants
387+
)
373388
del bundle
374389
gc.collect()
375390

@@ -385,7 +400,8 @@ def convert_stablehlo_to_tflite(
385400
tf_module,
386401
temp_dir_path,
387402
signatures={
388-
sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
403+
sig.name: tf_concrete_funcs[idx]
404+
for idx, sig in enumerate(signatures)
389405
},
390406
)
391407
# Clean up intermediate memory early.
@@ -416,6 +432,8 @@ def convert_stablehlo_to_tflite(
416432
and quant_config._quantizer_mode
417433
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
418434
):
419-
tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
435+
tflite_model = translate_recipe.quantize_model(
436+
tflite_model, translated_recipe
437+
)
420438

421439
return tflite_model

ai_edge_torch/convert/converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
from typing import Any, Dict, Optional, Tuple, Union
1919

20-
import torch
21-
2220
from ai_edge_torch import model
2321
from ai_edge_torch.convert import conversion
2422
from ai_edge_torch.convert import conversion_utils as cutils
2523
from ai_edge_torch.quantize import quant_config as qcfg
24+
import torch
2625

2726

2827
class Converter:
@@ -68,14 +67,20 @@ def add_signature(
6867
"""
6968

7069
if name in [sig.name for sig in self._signatures]:
71-
raise ValueError(f"A signature with the provided name ({name}) is already added.")
70+
raise ValueError(
71+
f"A signature with the provided name ({name}) is already added."
72+
)
7273

7374
if sample_args is None and sample_kwargs is None:
7475
raise ValueError("sample_args or sample_kwargs must be provided.")
7576

7677
self._signatures.append(
7778
cutils.Signature(
78-
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
79+
name,
80+
module,
81+
sample_args,
82+
sample_kwargs,
83+
dynamic_shapes=dynamic_shapes,
7984
)
8085
)
8186
return self
@@ -128,7 +133,8 @@ def convert(
128133
)
129134
else: # module is provided but not args
130135
raise ValueError(
131-
"sample_args or sample_kwargs must be provided if a module is specified."
136+
"sample_args or sample_kwargs must be provided if a module is"
137+
" specified."
132138
)
133139

134140
return conversion.convert_signatures(

ai_edge_torch/convert/fx_passes/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515

1616
from typing import Sequence, Union
1717

18-
from torch.export import ExportedProgram
19-
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
20-
import torch.utils._pytree as pytree
21-
2218
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
2319
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
2420
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
@@ -28,6 +24,9 @@
2824
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
2925
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
3026
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27+
from torch.export import ExportedProgram
28+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29+
import torch.utils._pytree as pytree
3130

3231

3332
# TODO(cnchan): make a PassManager class.

ai_edge_torch/convert/fx_passes/_pass_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@ def __new__(cls, exported_program, modified):
3232

3333
class ExportedProgramPassBase(abc.ABC):
3434

35-
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
35+
def __call__(
36+
self, exported_program: ExportedProgram
37+
) -> ExportedProgramPassResult:
3638
self.requires(exported_program)
3739
res = self.call(exported_program)
3840
self.ensures(exported_program)
3941
return res
4042

4143
@abc.abstractmethod
42-
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
44+
def call(
45+
self, exported_program: ExportedProgram
46+
) -> ExportedProgramPassResult:
4347
pass
4448

4549
def requires(self, exported_program: ExportedProgram) -> None:

0 commit comments

Comments
 (0)