Skip to content

Cross compile guard #3486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pybind11==2.6.2
torch>=2.8.0.dev,<2.9.0
torchvision>=0.22.0.dev,<0.23.0
--extra-index-url https://pypi.ngc.nvidia.com
pyyaml
pyyaml
dllist
3 changes: 2 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.fx
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
Expand Down Expand Up @@ -301,6 +301,7 @@ def compile(
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


@needs_cross_compile
def cross_compile_for_windows(
module: torch.nn.Module,
file_path: str,
Expand Down
30 changes: 28 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar

from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt._utils import (
check_cross_compile_trt_win_lib,
sanitized_torch_version,
)

from packaging import version

Expand All @@ -15,6 +18,7 @@
"dynamo_frontend",
"fx_frontend",
"refit",
"windows_cross_compile",
],
)

Expand All @@ -38,9 +42,15 @@
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
_FX_FE_AVAIL = True
_REFIT_AVAIL = True
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
_TS_FE_AVAIL,
_TORCHTRT_RT_AVAIL,
_DYNAMO_FE_AVAIL,
_FX_FE_AVAIL,
_REFIT_AVAIL,
_WINDOWS_CROSS_COMPILE,
)


Expand Down Expand Up @@ -80,6 +90,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return wrapper


def needs_cross_compile(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.windows_cross_compile:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"Windows cross compilation feature is not available"
)

return not_implemented(*args, **kwargs)

return wrapper


T = TypeVar("T")


Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from torch_tensorrt._enums import Platform


def sanitized_torch_version() -> Any:
Expand All @@ -9,3 +10,18 @@ def sanitized_torch_version() -> Any:
if ".nv" not in torch.__version__
else torch.__version__.split(".nv")[0]
)


def check_cross_compile_trt_win_lib() -> bool:
# cross compile feature is only available on linux
# build engine on linux and run on windows
import dllist

platform = Platform.current_platform()
platform = str(platform).lower()
if platform.startswith("linux"):
loaded_libs = dllist.dllist()
target_lib = "libnvinfer_builder_resource_win.so.*"
if target_lib in loaded_libs:
return True
return False
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults, partitioning
from torch_tensorrt.dynamo._DryRunTracker import (
Expand Down Expand Up @@ -49,6 +50,7 @@
logger = logging.getLogger(__name__)


@needs_cross_compile
def cross_compile_for_windows(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
Expand Down Expand Up @@ -1190,6 +1192,7 @@ def convert_exported_program_to_serialized_trt_engine(
return serialized_engine


@needs_cross_compile
def save_cross_compiled_exported_program(
gm: torch.fx.GraphModule,
file_path: str,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ requires = [
"pybind11==2.6.2",
"numpy",
"sympy",
"dllist",
]
build-backend = "setuptools.build_meta"

Expand Down Expand Up @@ -63,6 +64,7 @@ dependencies = [
"packaging>=23",
"numpy",
"typing-extensions>=4.7.0",
"dllist",
]

dynamic = ["version"]
Expand Down
13 changes: 13 additions & 0 deletions tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt._utils import check_cross_compile_trt_win_lib

from ..testing_utilities import DECIMALS_OF_AGREEMENT

Expand All @@ -16,6 +17,10 @@ class TestCrossCompileSaveForWindows(TestCase):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_cross_compile_for_windows(self):
class Add(torch.nn.Module):
Expand All @@ -40,6 +45,10 @@ def forward(self, a, b):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_dynamo_cross_compile_for_windows(self):
class Add(torch.nn.Module):
Expand Down Expand Up @@ -68,6 +77,10 @@ def forward(self, a, b):
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux x86-64 platform",
)
@unittest.skipIf(
not (check_cross_compile_trt_win_lib()),
"TRT windows lib for cross compile not found",
)
@pytest.mark.unit
def test_dynamo_cross_compile_for_windows_multiple_output(self):
class Add(torch.nn.Module):
Expand Down
Loading