Skip to content

Commit 3823b9e

Browse files
chunnienccopybara-github
authored andcommitted
enable odml-torch as default
PiperOrigin-RevId: 705931084
1 parent 0ee967e commit 3823b9e

24 files changed

+170
-139
lines changed

.github/workflows/model_coverage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
test-model-coverage:
2020
strategy:
2121
matrix:
22-
python-version: ["3.9", "3.10", "3.11"]
22+
python-version: ["3.10", "3.11"]
2323
runs-on:
2424
labels: Linux_runner_8_core
2525
steps:

.github/workflows/nightly_release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
needs: build-and-release-nightly
6565
strategy:
6666
matrix:
67-
python-version: ["3.9", "3.10", "3.11"]
67+
python-version: ["3.10", "3.11"]
6868

6969
name: Test Install and Import ai-edge-torch
7070
runs-on: ubuntu-latest

.github/workflows/unittests_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
test:
2020
strategy:
2121
matrix:
22-
python-version: ["3.9", "3.10", "3.11"]
22+
python-version: ["3.10", "3.11"]
2323
runs-on:
2424
labels: Linux_runner_8_core
2525
steps:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Nightly Release | [![](https://github.com/google-ai-edge/ai-edge-torch/action
7777

7878
### Requirements and Dependencies
7979

80-
* Python versions: 3.9, 3.10, 3.11
80+
* Python versions: >=3.10
8181
* Operating system: Linux
8282
* PyTorch: [![torch](https://img.shields.io/badge/torch->=2.4.0-blue)](https://pypi.org/project/torch/)
8383
* TensorFlow: [![tf-nightly](https://img.shields.io/badge/tf--nightly-latest-blue)](https://pypi.org/project/tf-nightly/)

ai_edge_torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from ai_edge_torch._config import config
1617
from ai_edge_torch._convert.converter import convert
1718
from ai_edge_torch._convert.converter import signature
1819
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
1920
from ai_edge_torch.model import Model
2021
from ai_edge_torch.version import __version__
2122

22-
2323
def load(path: str) -> Model:
2424
"""Imports an ai_edge_torch model from disk.
2525

ai_edge_torch/_config.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Provides a configuration for the ai-edge-torch."""
17+
18+
import functools
19+
import logging
20+
import os
21+
22+
__all__ = ["config"]
23+
24+
25+
class _Config:
26+
"""ai-edge-torch global configs."""
27+
28+
@property
29+
@functools.cache # pylint: disable=method-cache-max-size-none
30+
def use_torch_xla(self) -> bool:
31+
"""True if using torch_xla to lower torch ops to StableHLO.
32+
33+
To use torch_xla as the lowering backend, set environment variable
34+
`USE_TORCH_XLA` to "true".
35+
"""
36+
var = os.environ.get("USE_TORCH_XLA", "false")
37+
var = var.lower().strip()
38+
if var in ("y", "yes", "t", "true", "on", "1"):
39+
return True
40+
elif var in ("n", "no", "f", "false", "off", "0"):
41+
return False
42+
else:
43+
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44+
return False
45+
46+
@property
47+
def in_oss(self) -> bool:
48+
"""True if the code is not running in google internal environment."""
49+
return True
50+
51+
52+
config = _Config()

ai_edge_torch/_convert/test/test_convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Tuple
2020

2121
import ai_edge_torch
22-
from ai_edge_torch import config
2322
from ai_edge_torch._convert import conversion_utils
2423
from ai_edge_torch.quantize import pt2e_quantizer
2524
from ai_edge_torch.testing import model_coverage
@@ -292,7 +291,7 @@ def test_convert_conv_transpose_batch_norm(self):
292291
self.assertTrue(result)
293292

294293
@googletest.skipIf(
295-
not config.Config.use_torch_xla,
294+
not ai_edge_torch.config.use_torch_xla,
296295
reason="Shape polymorphism is not yet support with odml_torch.",
297296
)
298297
def test_convert_model_with_dynamic_batch(self):

ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import re
1616
from typing import Callable, Union
1717

18-
from ai_edge_torch import config
18+
import ai_edge_torch
1919
from ai_edge_torch import fx_pass_base
2020
from ai_edge_torch import lowertools
2121
from ai_edge_torch.generative.fx_passes import CanonicalizePass
@@ -112,7 +112,7 @@ def get_model_config() -> unet_cfg.AttentionBlock2DConfig:
112112
(torch.rand(1, 512, 64, 64),),
113113
)
114114

115-
if config.Config.use_torch_xla:
115+
if ai_edge_torch.config.use_torch_xla:
116116
self.assertTrue(
117117
re.search(
118118
'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+,'

ai_edge_torch/generative/test/test_model_conversion.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Testing model conversion for a few gen-ai models."""
1717

1818
import ai_edge_torch
19-
from ai_edge_torch import config as ai_edge_config
2019
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
2120
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
2221
from ai_edge_torch.generative.layers import kv_cache
@@ -83,22 +82,22 @@ def _test_model_with_kv_cache(self, enable_hlfb: bool):
8382
)
8483

8584
@googletest.skipIf(
86-
ai_edge_config.Config.use_torch_xla,
87-
reason="tests with custom ops are not supported on oss",
85+
ai_edge_torch.config.in_oss,
86+
reason="tests with custom ops are not supported in oss",
8887
)
8988
def test_toy_model_with_kv_cache(self):
9089
self._test_model_with_kv_cache(enable_hlfb=False)
9190

9291
@googletest.skipIf(
93-
ai_edge_config.Config.use_torch_xla,
94-
reason="tests with custom ops are not supported on oss",
92+
ai_edge_torch.config.in_oss,
93+
reason="tests with custom ops are not supported in oss",
9594
)
9695
def test_toy_model_with_kv_cache_with_hlfb(self):
9796
self._test_model_with_kv_cache(enable_hlfb=True)
9897

9998
@googletest.skipIf(
100-
ai_edge_config.Config.use_torch_xla,
101-
reason="tests with custom ops are not supported on oss",
99+
ai_edge_torch.config.in_oss,
100+
reason="tests with custom ops are not supported in oss",
102101
)
103102
def test_toy_model_has_dus_op(self):
104103
"""Tests that the model has the dynamic update slice op."""
@@ -179,8 +178,8 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
179178
)
180179

181180
@googletest.skipIf(
182-
ai_edge_config.Config.use_torch_xla,
183-
reason="tests with custom ops are not supported on oss",
181+
ai_edge_torch.config.in_oss,
182+
reason="tests with custom ops are not supported in oss",
184183
)
185184
def test_tiny_llama_multisig(self):
186185
config = tiny_llama.get_fake_model_config()

ai_edge_torch/generative/test/test_model_conversion_large.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""Testing model conversion for a few gen-ai models."""
1717

1818
import ai_edge_torch
19-
from ai_edge_torch import config as ai_edge_config
2019
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2120
from ai_edge_torch.generative.examples.gemma import gemma1
2221
from ai_edge_torch.generative.examples.gemma import gemma2
@@ -91,35 +90,35 @@ def _test_model(self, config, model, signature_name, atol, rtol):
9190
)
9291

9392
@googletest.skipIf(
94-
ai_edge_config.Config.use_torch_xla,
95-
reason="tests with custom ops are not supported on oss",
93+
ai_edge_torch.config.in_oss,
94+
reason="tests with custom ops are not supported in oss",
9695
)
9796
def test_gemma1(self):
9897
config = gemma1.get_fake_model_config()
9998
pytorch_model = gemma1.Gemma1(config).eval()
10099
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
101100

102101
@googletest.skipIf(
103-
ai_edge_config.Config.use_torch_xla,
104-
reason="tests with custom ops are not supported on oss",
102+
ai_edge_torch.config.in_oss,
103+
reason="tests with custom ops are not supported in oss",
105104
)
106105
def test_gemma2(self):
107106
config = gemma2.get_fake_model_config()
108107
pytorch_model = gemma2.Gemma2(config).eval()
109108
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
110109

111110
@googletest.skipIf(
112-
ai_edge_config.Config.use_torch_xla,
113-
reason="tests with custom ops are not supported on oss",
111+
ai_edge_torch.config.in_oss,
112+
reason="tests with custom ops are not supported in oss",
114113
)
115114
def test_llama(self):
116115
config = llama.get_fake_model_config()
117116
pytorch_model = llama.Llama(config).eval()
118117
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
119118

120119
@googletest.skipIf(
121-
ai_edge_config.Config.use_torch_xla,
122-
reason="tests with custom ops are not supported on oss",
120+
ai_edge_torch.config.in_oss,
121+
reason="tests with custom ops are not supported in oss",
123122
)
124123
def test_phi2(self):
125124
config = phi2.get_fake_model_config()
@@ -128,53 +127,53 @@ def test_phi2(self):
128127
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129128

130129
@googletest.skipIf(
131-
ai_edge_config.Config.use_torch_xla,
132-
reason="tests with custom ops are not supported on oss",
130+
ai_edge_torch.config.in_oss,
131+
reason="tests with custom ops are not supported in oss",
133132
)
134133
def test_phi3(self):
135134
config = phi3.get_fake_model_config()
136135
pytorch_model = phi3.Phi3_5Mini(config).eval()
137136
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
138137

139138
@googletest.skipIf(
140-
ai_edge_config.Config.use_torch_xla,
141-
reason="tests with custom ops are not supported on oss",
139+
ai_edge_torch.config.in_oss,
140+
reason="tests with custom ops are not supported in oss",
142141
)
143142
def test_smollm(self):
144143
config = smollm.get_fake_model_config()
145144
pytorch_model = smollm.SmolLM(config).eval()
146145
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147146

148147
@googletest.skipIf(
149-
ai_edge_config.Config.use_torch_xla,
150-
reason="tests with custom ops are not supported on oss",
148+
ai_edge_torch.config.in_oss,
149+
reason="tests with custom ops are not supported in oss",
151150
)
152151
def test_openelm(self):
153152
config = openelm.get_fake_model_config()
154153
pytorch_model = openelm.OpenELM(config).eval()
155154
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156155

157156
@googletest.skipIf(
158-
ai_edge_config.Config.use_torch_xla,
159-
reason="tests with custom ops are not supported on oss",
157+
ai_edge_torch.config.in_oss,
158+
reason="tests with custom ops are not supported in oss",
160159
)
161160
def test_qwen(self):
162161
config = qwen.get_fake_model_config()
163162
pytorch_model = qwen.Qwen(config).eval()
164163
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165164

166165
@googletest.skipIf(
167-
ai_edge_config.Config.use_torch_xla,
168-
reason="tests with custom ops are not supported on oss",
166+
ai_edge_torch.config.in_oss,
167+
reason="tests with custom ops are not supported in oss",
169168
)
170169
def test_amd_llama_135m(self):
171170
config = amd_llama_135m.get_fake_model_config()
172171
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
173172
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174173

175174
@googletest.skipIf(
176-
ai_edge_config.Config.use_torch_xla,
177-
reason="tests with custom ops are not supported on oss",
175+
ai_edge_torch.config.in_oss,
176+
reason="tests with custom ops are not supported in oss",
178177
)
179178
def disabled_test_paligemma(self):
180179
config = paligemma.get_fake_model_config()
@@ -222,8 +221,8 @@ def disabled_test_paligemma(self):
222221
)
223222

224223
@googletest.skipIf(
225-
ai_edge_config.Config.use_torch_xla,
226-
reason="tests with custom ops are not supported on oss",
224+
ai_edge_torch.config.in_oss,
225+
reason="tests with custom ops are not supported in oss",
227226
)
228227
def test_stable_diffusion_clip(self):
229228
config = sd_clip.get_fake_model_config()
@@ -254,8 +253,8 @@ def test_stable_diffusion_clip(self):
254253
)
255254

256255
@googletest.skipIf(
257-
ai_edge_config.Config.use_torch_xla,
258-
reason="tests with custom ops are not supported on oss",
256+
ai_edge_torch.config.in_oss,
257+
reason="tests with custom ops are not supported in oss",
259258
)
260259
def test_stable_diffusion_diffusion(self):
261260
config = sd_diffusion.get_fake_model_config(2)
@@ -296,8 +295,8 @@ def test_stable_diffusion_diffusion(self):
296295
)
297296

298297
@googletest.skipIf(
299-
ai_edge_config.Config.use_torch_xla,
300-
reason="tests with custom ops are not supported on oss",
298+
ai_edge_torch.config.in_oss,
299+
reason="tests with custom ops are not supported in oss",
301300
)
302301
def test_stable_diffusion_decoder(self):
303302
config = sd_decoder.get_fake_model_config()

ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import math
1818

19-
from ai_edge_torch import config
19+
import ai_edge_torch
20+
from ai_edge_torch import hlfb
2021
from ai_edge_torch import lowertools
2122
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
2223
import torch
@@ -29,9 +30,11 @@ def _export_stablehlo_mlir(model, args):
2930
ep = torch.export.export(model, args)
3031
return lowertools.exported_program_to_mlir_text(ep)
3132

33+
StableHLOCompositeBuilder = hlfb.StableHLOCompositeBuilder
34+
3235

3336
@googletest.skipIf(
34-
not config.Config.use_torch_xla,
37+
not ai_edge_torch.config.use_torch_xla,
3538
reason="The odml_torch counter part is in odml_torch.",
3639
)
3740
class TestStableHLOCompositeBuilder(googletest.TestCase):

ai_edge_torch/lowertools/_shim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
from typing import Any, Optional
1717

18-
from ai_edge_torch import config
18+
from ai_edge_torch import _config
1919
from ai_edge_torch._convert import signature
2020
from ai_edge_torch.quantize import quant_config as qcfg
2121
import torch
2222

23+
config = _config.config
24+
2325
# isort: off
24-
if config.Config.use_torch_xla:
26+
if config.use_torch_xla:
2527
from ai_edge_torch.lowertools import torch_xla_utils as utils
2628
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
2729
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder

0 commit comments

Comments
 (0)