Skip to content

Commit 1e2148d

Browse files
committed
checks for windows where NCCL backend is not supported
1 parent 61c93b2 commit 1e2148d

File tree

2 files changed

+14
-20
lines changed

2 files changed

+14
-20
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -255,19 +255,6 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
255255
"""
256256
if parent_logger:
257257
parent_logger.setLevel(level)
258-
print("Handlers for parent_logger:", parent_logger.handlers)
259-
print("bool check--", parent_logger.hasHandlers())
260-
if parent_logger.hasHandlers():
261-
ch = logging.StreamHandler()
262-
ch.setLevel(logging.DEBUG) # Allow debug messages on handler
263-
formatter = logging.Formatter(
264-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
265-
)
266-
ch.setFormatter(formatter)
267-
parent_logger.addHandler(ch)
268-
print("Logger level:", parent_logger.level)
269-
# print("Parent logger level:", logger.parent.level)
270-
print("Root logger level:", logging.getLogger().level)
271258

272259
if ENABLED_FEATURES.torch_tensorrt_runtime:
273260
if level == logging.DEBUG:
@@ -885,7 +872,6 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
885872
# Downloading TRT-LLM lib
886873
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
887874
download_url = base_url + file_name
888-
print("Downloading TRT-LLM wheel")
889875
try:
890876
logger.debug(f"Downloading {download_url} ...")
891877
urllib.request.urlretrieve(download_url, downloaded_file_path)
@@ -937,7 +923,7 @@ def download_plugin_lib_path(platform: str) -> Iterator[str]:
937923
yield plugin_lib_path
938924

939925

940-
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
926+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str, platform: str) -> bool:
941927
"""
942928
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
943929
@@ -947,6 +933,9 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
947933
Returns:
948934
bool: True if successful, False otherwise.
949935
"""
936+
if "windows" in platform:
937+
logger.info("NCCL backend is not supported on Windows")
938+
return False
950939
try:
951940
handle = ctypes.CDLL(plugin_lib_path)
952941
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
@@ -1002,8 +991,10 @@ def load_tensorrt_llm() -> bool:
1002991
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1003992
"""
1004993
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
994+
platform = Platform.current_platform()
995+
platform = str(platform).lower()
1005996
if plugin_lib_path:
1006-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
997+
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
1007998
else:
1008999
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
10091000
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
@@ -1017,10 +1008,7 @@ def load_tensorrt_llm() -> bool:
10171008
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
10181009
)
10191010
return False
1020-
else:
1021-
platform = Platform.current_platform()
1022-
platform = str(platform).lower()
10231011

10241012
with download_plugin_lib_path(platform) as plugin_lib_path:
1025-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1013+
return load_and_initialize_trtllm_plugin(plugin_lib_path, platform)
10261014
return False

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from distributed_utils import set_environment_variables_pytest
77
from parameterized import parameterized
88
from torch.testing._internal.common_utils import run_tests
9+
from torch_tensorrt._enums import Platform
910

1011
set_environment_variables_pytest()
1112
dist.init_process_group(backend="nccl", init_method="env://")
@@ -15,7 +16,12 @@
1516

1617
from conversion.harness import DispatchTestCase
1718

19+
platform_str = str(Platform.current_platform()).lower()
1820

21+
22+
@unittest.skipIf(
23+
"win" in platform_str, "Skipped on Windows: NCCL backend is not supported."
24+
)
1925
class TestGatherNcclOpsConverter(DispatchTestCase):
2026
@parameterized.expand([8])
2127
def test_nccl_ops(self, linear_layer_dim):

0 commit comments

Comments
 (0)