Skip to content

Commit 61c93b2

Browse files
committed
Addressing review comments- tmp dir for wheel download and wheel extraction, variable for py_version
1 parent 66f0c88 commit 61c93b2

File tree

2 files changed

+136
-59
lines changed

2 files changed

+136
-59
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 135 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@
22

33
import ctypes
44
import gc
5+
import getpass
56
import logging
67
import os
8+
import tempfile
79
import urllib.request
810
import warnings
11+
from contextlib import contextmanager
912
from dataclasses import fields, replace
1013
from enum import Enum
11-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
14+
from pathlib import Path
15+
from typing import (
16+
Any,
17+
Callable,
18+
Dict,
19+
Iterator,
20+
List,
21+
Optional,
22+
Sequence,
23+
Tuple,
24+
Union,
25+
)
1226

1327
import numpy as np
1428
import sympy
@@ -37,6 +51,7 @@
3751
RTOL = 5e-3
3852
ATOL = 5e-3
3953
CPU_DEVICE = "cpu"
54+
_WHL_CPYTHON_VERSION = "cp310"
4055

4156

4257
class Frameworks(Enum):
@@ -240,6 +255,19 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
240255
"""
241256
if parent_logger:
242257
parent_logger.setLevel(level)
258+
print("Handlers for parent_logger:", parent_logger.handlers)
259+
print("bool check--", parent_logger.hasHandlers())
260+
if parent_logger.hasHandlers():
261+
ch = logging.StreamHandler()
262+
ch.setLevel(logging.DEBUG) # Allow debug messages on handler
263+
formatter = logging.Formatter(
264+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
265+
)
266+
ch.setFormatter(formatter)
267+
parent_logger.addHandler(ch)
268+
print("Logger level:", parent_logger.level)
269+
# print("Parent logger level:", logger.parent.level)
270+
print("Root logger level:", logging.getLogger().level)
243271

244272
if ENABLED_FEATURES.torch_tensorrt_runtime:
245273
if level == logging.DEBUG:
@@ -826,17 +854,41 @@ def is_tegra_platform() -> bool:
826854
return False
827855

828856

829-
def download_plugin_lib_path(py_version: str, platform: str) -> str:
830-
plugin_lib_path = None
857+
@contextmanager
858+
def download_plugin_lib_path(platform: str) -> Iterator[str]:
859+
"""
860+
Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
861+
then yields the path to the extracted shared library (.so or .dll).
831862
832-
# Downloading TRT-LLM lib
833-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
834-
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl"
835-
download_url = base_url + file_name
836-
if not (os.path.exists(file_name)):
863+
The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
864+
Extraction happens in a temporary directory that is cleaned up after use.
865+
866+
Args:
867+
platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
868+
869+
Yields:
870+
str: The full path to the extracted TensorRT-LLM shared library file.
871+
872+
Raises:
873+
ImportError: If the 'zipfile' module is not available.
874+
RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
875+
"""
876+
plugin_lib_path = None
877+
username = getpass.getuser()
878+
torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
879+
torchtrt_cache_dir.mkdir(parents=True, exist_ok=True)
880+
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl"
881+
torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
882+
downloaded_file_path = torchtrt_cache_trtllm_whl
883+
884+
if not torchtrt_cache_trtllm_whl.exists():
885+
# Downloading TRT-LLM lib
886+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
887+
download_url = base_url + file_name
888+
print("Downloading TRT-LLM wheel")
837889
try:
838890
logger.debug(f"Downloading {download_url} ...")
839-
urllib.request.urlretrieve(download_url, file_name)
891+
urllib.request.urlretrieve(download_url, downloaded_file_path)
840892
logger.debug("Download succeeded and TRT-LLM wheel is now present")
841893
except urllib.error.HTTPError as e:
842894
logger.error(
@@ -849,60 +901,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
849901
except OSError as e:
850902
logger.error(f"Local file write error: {e}")
851903

852-
# Proceeding with the unzip of the wheel file
853-
# This will exist if the filename was already downloaded
904+
# Proceeding with the unzip of the wheel file in tmpdir
854905
if "linux" in platform:
855906
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
856907
else:
857908
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
858-
plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename)
859-
if os.path.exists(plugin_lib_path):
860-
return plugin_lib_path
861-
try:
862-
import zipfile
863-
except ImportError as e:
864-
raise ImportError(
865-
"zipfile module is required but not found. Please install zipfile"
866-
)
867-
with zipfile.ZipFile(file_name, "r") as zip_ref:
868-
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
869-
plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
870-
return plugin_lib_path
871-
872909

873-
def load_tensorrt_llm() -> bool:
910+
with tempfile.TemporaryDirectory() as tmpdir:
911+
try:
912+
import zipfile
913+
except ImportError:
914+
raise ImportError(
915+
"zipfile module is required but not found. Please install zipfile"
916+
)
917+
try:
918+
with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref:
919+
zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm'
920+
except FileNotFoundError as e:
921+
# This should capture the errors in the download failure above
922+
logger.error(f"Wheel file not found at {downloaded_file_path}: {e}")
923+
raise RuntimeError(
924+
f"Failed to find downloaded wheel file at {downloaded_file_path}"
925+
) from e
926+
except zipfile.BadZipFile as e:
927+
logger.error(f"Invalid or corrupted wheel file: {e}")
928+
raise RuntimeError(
929+
"Downloaded wheel file is corrupted or not a valid zip archive"
930+
) from e
931+
except Exception as e:
932+
logger.error(f"Unexpected error while extracting wheel: {e}")
933+
raise RuntimeError(
934+
"Unexpected error during extraction of TensorRT-LLM wheel"
935+
) from e
936+
plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename)
937+
yield plugin_lib_path
938+
939+
940+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
874941
"""
875-
Attempts to load the TensorRT-LLM plugin and initialize it.
876-
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
877-
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
942+
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
943+
944+
Args:
945+
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
878946
879947
Returns:
880-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
948+
bool: True if successful, False otherwise.
881949
"""
882-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
883-
if not plugin_lib_path:
884-
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
885-
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
886-
"1",
887-
"true",
888-
"yes",
889-
"on",
890-
)
891-
if not use_trtllm_plugin:
892-
logger.warning(
893-
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
894-
)
895-
return False
896-
else:
897-
# this is used as the default py version
898-
py_version = "cp310"
899-
platform = Platform.current_platform()
900-
901-
platform = str(platform).lower()
902-
plugin_lib_path = download_plugin_lib_path(py_version, platform)
903-
904950
try:
905-
# Load the shared TRT-LLM file
906951
handle = ctypes.CDLL(plugin_lib_path)
907952
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
908953
except OSError as e_os_error:
@@ -915,14 +960,13 @@ def load_tensorrt_llm() -> bool:
915960
)
916961
else:
917962
logger.warning(
918-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
919-
f"Ensure the path is correct and the library is compatible",
963+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
964+
f"Ensure the path is correct and the library is compatible.",
920965
exc_info=e_os_error,
921966
)
922967
return False
923968

924969
try:
925-
# Configure plugin initialization arguments
926970
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
927971
handle.initTrtLlmPlugins.restype = ctypes.c_bool
928972
except AttributeError as e_plugin_unavailable:
@@ -933,9 +977,7 @@ def load_tensorrt_llm() -> bool:
933977
return False
934978

935979
try:
936-
# Initialize the plugin
937-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
938-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
980+
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
939981
logger.info("TensorRT-LLM plugin successfully initialized")
940982
return True
941983
else:
@@ -948,3 +990,37 @@ def load_tensorrt_llm() -> bool:
948990
)
949991
return False
950992
return False
993+
994+
995+
def load_tensorrt_llm() -> bool:
996+
"""
997+
Attempts to load the TensorRT-LLM plugin and initialize it.
998+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
999+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
1000+
1001+
Returns:
1002+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
1003+
"""
1004+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
1005+
if plugin_lib_path:
1006+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1007+
else:
1008+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1009+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
1010+
"1",
1011+
"true",
1012+
"yes",
1013+
"on",
1014+
)
1015+
if not use_trtllm_plugin:
1016+
logger.warning(
1017+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1018+
)
1019+
return False
1020+
else:
1021+
platform = Platform.current_platform()
1022+
platform = str(platform).lower()
1023+
1024+
with download_plugin_lib_path(platform) as plugin_lib_path:
1025+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1026+
return False

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def forward(self, x):
7070
use_dynamo_tracer=True,
7171
enable_passes=True,
7272
)
73+
dist.destroy_process_group()
7374

7475

7576
if __name__ == "__main__":

0 commit comments

Comments
 (0)