Skip to content

Commit 31192f2

Browse files
[mps] revamp torchao mps ops build (#3477)
1 parent acc612d commit 31192f2

File tree

9 files changed

+150
-56
lines changed

9 files changed

+150
-56
lines changed

.github/workflows/metal_test.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: Run TorchAO Experimental MPS Tests
2+
on:
3+
push:
4+
branches:
5+
- main
6+
- 'gh/**'
7+
pull_request:
8+
branches:
9+
- main
10+
- 'gh/**'
11+
12+
jobs:
13+
test-mps-ops:
14+
name: test-mps-ops
15+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
16+
with:
17+
runner: macos-m1-stable
18+
python-version: '3.11'
19+
submodules: 'recursive'
20+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
21+
timeout: 90
22+
script: |
23+
set -eux
24+
25+
echo "::group::Install Torch"
26+
${CONDA_RUN} pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
27+
echo "::endgroup::"
28+
29+
echo "::group::Install requirements"
30+
${CONDA_RUN} pip install -r dev-requirements.txt
31+
echo "::endgroup::"
32+
33+
echo "::group::Install experimental MPS ops"
34+
${CONDA_RUN} USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation
35+
echo "::endgroup::"
36+
37+
echo "::group::Run lowbit tests"
38+
${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_lowbit.py
39+
echo "::endgroup::"
40+
41+
echo "::group::Run quantizer tests"
42+
${CONDA_RUN} python -m pytest torchao/experimental/ops/mps/test/test_quantizer.py
43+
echo "::endgroup::"

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ transformers
77
hypothesis # Avoid test derandomization warning
88
sentencepiece # for gpt-fast tokenizer
99
expecttest
10+
pyyaml
1011

1112
# For prototype features and benchmarks
1213
bitsandbytes # needed for testing triton quant / dequant ops for 8-bit optimizers

setup.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ def get_cutlass_build_flags():
329329
)
330330

331331

332+
def bool_to_on_off(value):
333+
return "ON" if value else "OFF"
334+
335+
332336
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
333337
class TorchAOBuildExt(BuildExtension):
334338
def __init__(self, *args, **kwargs) -> None:
@@ -353,16 +357,19 @@ def build_extensions(self):
353357
def build_cmake(self, ext):
354358
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
355359

356-
if not os.path.exists(self.build_temp):
357-
os.makedirs(self.build_temp)
360+
# Use a unique build directory per CMake extension to avoid cache conflicts
361+
# when multiple extensions use different CMakeLists.txt source directories
362+
ext_build_temp = os.path.join(self.build_temp, ext.name.replace(".", "_"))
363+
if not os.path.exists(ext_build_temp):
364+
os.makedirs(ext_build_temp)
358365

359366
# Get the expected extension file name that Python will look for
360367
# We force CMake to use this library name
361368
ext_filename = os.path.basename(self.get_ext_filename(ext.name))
362369
ext_basename = os.path.splitext(ext_filename)[0]
363370

364371
print(
365-
"CMAKE COMMANG",
372+
"CMAKE COMMAND",
366373
[
367374
"cmake",
368375
ext.cmake_lists_dir,
@@ -384,9 +391,9 @@ def build_cmake(self, ext):
384391
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
385392
"-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename,
386393
],
387-
cwd=self.build_temp,
394+
cwd=ext_build_temp,
388395
)
389-
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)
396+
subprocess.check_call(["cmake", "--build", "."], cwd=ext_build_temp)
390397

391398

392399
class CMakeExtension(Extension):
@@ -772,9 +779,6 @@ def get_extensions():
772779
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
773780
build_options = BuildOptions()
774781

775-
def bool_to_on_off(value):
776-
return "ON" if value else "OFF"
777-
778782
from distutils.sysconfig import get_python_lib
779783

780784
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"
@@ -799,6 +803,21 @@ def bool_to_on_off(value):
799803
)
800804
)
801805

806+
if build_options.build_experimental_mps:
807+
ext_modules.append(
808+
CMakeExtension(
809+
"torchao._C_mps",
810+
cmake_lists_dir="torchao/experimental/ops/mps",
811+
cmake_args=(
812+
[
813+
f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}",
814+
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
815+
"-DTorch_DIR=" + torch_dir,
816+
]
817+
),
818+
)
819+
)
820+
802821
return ext_modules
803822

804823

torchao/experimental/ops/__init__.py

Whitespace-only changes.

torchao/experimental/ops/mps/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
3030
file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal)
3131
set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml)
3232
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
33-
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
33+
# Use the build directory for generated files to avoid permission issues during pip install
34+
set(GENERATED_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated_include)
35+
set(GENERATED_METAL_SHADER_LIB ${GENERATED_INCLUDE_DIR}/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
3436
add_custom_command(
3537
OUTPUT ${GENERATED_METAL_SHADER_LIB}
3638
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
@@ -45,7 +47,7 @@ endif()
4547
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
4648

4749
include_directories(${TORCHAO_INCLUDE_DIRS})
48-
include_directories(${CMAKE_INSTALL_PREFIX}/include)
50+
include_directories(${GENERATED_INCLUDE_DIR})
4951
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm)
5052
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)
5153

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.experimental.ops.mps.utils import _load_torchao_mps_lib
2+
3+
_load_torchao_mps_lib()

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,13 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
87
import unittest
98

109
import torch
1110
from parameterized import parameterized
1211

1312
# Need to import to load the ops
14-
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer # noqa: F401
15-
16-
try:
17-
for nbit in range(1, 8):
18-
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
19-
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
20-
except AttributeError:
21-
try:
22-
libname = "libtorchao_ops_mps_aten.dylib"
23-
libpath = os.path.abspath(
24-
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
25-
)
26-
torch.ops.load_library(libpath)
27-
except:
28-
raise RuntimeError(f"Failed to load library {libpath}")
29-
else:
30-
try:
31-
for nbit in range(1, 8):
32-
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
33-
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
34-
except AttributeError as e:
35-
raise e
13+
import torchao.experimental.ops.mps # noqa: F401
3614

3715

3816
class TestLowBitQuantWeightsLinear(unittest.TestCase):

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,15 @@
66

77
import copy
88
import itertools
9-
import os
109
import unittest
1110

1211
import torch
1312
from parameterized import parameterized
1413

15-
import torchao # noqa: F401
14+
# Need to import to load the ops
15+
import torchao.experimental.ops.mps # noqa: F401
1616
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize
1717

18-
try:
19-
for nbit in range(1, 8):
20-
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
21-
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
22-
except AttributeError:
23-
try:
24-
libname = "libtorchao_ops_mps_aten.dylib"
25-
libpath = os.path.abspath(
26-
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
27-
)
28-
torch.ops.load_library(libpath)
29-
except:
30-
raise RuntimeError(f"Failed to load library {libpath}")
31-
else:
32-
try:
33-
for nbit in range(1, 8):
34-
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
35-
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
36-
except AttributeError as e:
37-
raise e
38-
3918

4019
class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
4120
BITWIDTHS = range(1, 8)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import glob
8+
import os
9+
10+
import torch
11+
12+
13+
def _get_torchao_mps_lib_path():
14+
"""Get the path to the MPS ops library.
15+
16+
Searches in the following locations:
17+
1. The torchao package directory (for pip-installed packages)
18+
2. The build directory (for development installs from source)
19+
3. The cmake-out directory relative to this file (for standalone CMake builds)
20+
"""
21+
import torchao
22+
23+
libname = "libtorchao_ops_mps_aten.dylib"
24+
25+
# Try the torchao package directory first (pip install location)
26+
torchao_dir = os.path.dirname(torchao.__file__)
27+
pip_libpath = os.path.join(torchao_dir, libname)
28+
if os.path.exists(pip_libpath):
29+
return pip_libpath
30+
31+
# Try the build directory (for editable/development installs)
32+
# The build directory is typically at the repo root level
33+
repo_root = os.path.dirname(torchao_dir)
34+
build_pattern = os.path.join(repo_root, "build", "lib.*", "torchao", libname)
35+
build_matches = glob.glob(build_pattern)
36+
if build_matches:
37+
return build_matches[0]
38+
39+
# Fall back to cmake-out directory (standalone CMake build)
40+
cmake_libpath = os.path.abspath(
41+
os.path.join(os.path.dirname(__file__), "cmake-out/lib/", libname)
42+
)
43+
if os.path.exists(cmake_libpath):
44+
return cmake_libpath
45+
46+
return None
47+
48+
49+
def _load_torchao_mps_lib():
50+
"""Load the MPS ops library."""
51+
try:
52+
for nbit in range(1, 8):
53+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
54+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
55+
except AttributeError:
56+
libpath = _get_torchao_mps_lib_path()
57+
if libpath is None:
58+
raise RuntimeError(
59+
"Could not find libtorchao_ops_mps_aten.dylib. "
60+
"Please build with TORCHAO_BUILD_EXPERIMENTAL_MPS=1"
61+
)
62+
try:
63+
torch.ops.load_library(libpath)
64+
except Exception as e:
65+
raise RuntimeError(f"Failed to load library {libpath}: {e}")
66+
67+
for nbit in range(1, 8):
68+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
69+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")

0 commit comments

Comments
 (0)