2
2
3
3
import ctypes
4
4
import gc
5
+ import getpass
5
6
import logging
6
7
import os
8
+ import tempfile
7
9
import urllib .request
8
10
import warnings
11
+ from contextlib import contextmanager
9
12
from dataclasses import fields , replace
10
13
from enum import Enum
11
- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
14
+ from pathlib import Path
15
+ from typing import (
16
+ Any ,
17
+ Callable ,
18
+ Dict ,
19
+ Iterator ,
20
+ List ,
21
+ Optional ,
22
+ Sequence ,
23
+ Tuple ,
24
+ Union ,
25
+ )
12
26
13
27
import numpy as np
14
28
import sympy
37
51
RTOL = 5e-3
38
52
ATOL = 5e-3
39
53
CPU_DEVICE = "cpu"
54
+ _WHL_CPYTHON_VERSION = "cp310"
40
55
41
56
42
57
class Frameworks (Enum ):
@@ -240,6 +255,19 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
240
255
"""
241
256
if parent_logger :
242
257
parent_logger .setLevel (level )
258
+ print ("Handlers for parent_logger:" , parent_logger .handlers )
259
+ print ("bool check--" , parent_logger .hasHandlers ())
260
+ if parent_logger .hasHandlers ():
261
+ ch = logging .StreamHandler ()
262
+ ch .setLevel (logging .DEBUG ) # Allow debug messages on handler
263
+ formatter = logging .Formatter (
264
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
265
+ )
266
+ ch .setFormatter (formatter )
267
+ parent_logger .addHandler (ch )
268
+ print ("Logger level:" , parent_logger .level )
269
+ # print("Parent logger level:", logger.parent.level)
270
+ print ("Root logger level:" , logging .getLogger ().level )
243
271
244
272
if ENABLED_FEATURES .torch_tensorrt_runtime :
245
273
if level == logging .DEBUG :
@@ -826,17 +854,41 @@ def is_tegra_platform() -> bool:
826
854
return False
827
855
828
856
829
- def download_plugin_lib_path (py_version : str , platform : str ) -> str :
830
- plugin_lib_path = None
857
+ @contextmanager
858
+ def download_plugin_lib_path (platform : str ) -> Iterator [str ]:
859
+ """
860
+ Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
861
+ then yields the path to the extracted shared library (.so or .dll).
831
862
832
- # Downloading TRT-LLM lib
833
- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
834
- file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ py_version } -{ py_version } -{ platform } .whl"
835
- download_url = base_url + file_name
836
- if not (os .path .exists (file_name )):
863
+ The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
864
+ Extraction happens in a temporary directory that is cleaned up after use.
865
+
866
+ Args:
867
+ platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
868
+
869
+ Yields:
870
+ str: The full path to the extracted TensorRT-LLM shared library file.
871
+
872
+ Raises:
873
+ ImportError: If the 'zipfile' module is not available.
874
+ RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
875
+ """
876
+ plugin_lib_path = None
877
+ username = getpass .getuser ()
878
+ torchtrt_cache_dir = Path (tempfile .gettempdir ()) / f"torch_tensorrt_{ username } "
879
+ torchtrt_cache_dir .mkdir (parents = True , exist_ok = True )
880
+ file_name = f"tensorrt_llm-{ __tensorrt_llm_version__ } -{ _WHL_CPYTHON_VERSION } -{ _WHL_CPYTHON_VERSION } -{ platform } .whl"
881
+ torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
882
+ downloaded_file_path = torchtrt_cache_trtllm_whl
883
+
884
+ if not torchtrt_cache_trtllm_whl .exists ():
885
+ # Downloading TRT-LLM lib
886
+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
887
+ download_url = base_url + file_name
888
+ print ("Downloading TRT-LLM wheel" )
837
889
try :
838
890
logger .debug (f"Downloading { download_url } ..." )
839
- urllib .request .urlretrieve (download_url , file_name )
891
+ urllib .request .urlretrieve (download_url , downloaded_file_path )
840
892
logger .debug ("Download succeeded and TRT-LLM wheel is now present" )
841
893
except urllib .error .HTTPError as e :
842
894
logger .error (
@@ -849,60 +901,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
849
901
except OSError as e :
850
902
logger .error (f"Local file write error: { e } " )
851
903
852
- # Proceeding with the unzip of the wheel file
853
- # This will exist if the filename was already downloaded
904
+ # Proceeding with the unzip of the wheel file in tmpdir
854
905
if "linux" in platform :
855
906
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
856
907
else :
857
908
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
858
- plugin_lib_path = os .path .join ("./tensorrt_llm/libs" , lib_filename )
859
- if os .path .exists (plugin_lib_path ):
860
- return plugin_lib_path
861
- try :
862
- import zipfile
863
- except ImportError as e :
864
- raise ImportError (
865
- "zipfile module is required but not found. Please install zipfile"
866
- )
867
- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
868
- zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
869
- plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
870
- return plugin_lib_path
871
-
872
909
873
- def load_tensorrt_llm () -> bool :
910
+ with tempfile .TemporaryDirectory () as tmpdir :
911
+ try :
912
+ import zipfile
913
+ except ImportError :
914
+ raise ImportError (
915
+ "zipfile module is required but not found. Please install zipfile"
916
+ )
917
+ try :
918
+ with zipfile .ZipFile (downloaded_file_path , "r" ) as zip_ref :
919
+ zip_ref .extractall (tmpdir ) # Extract to a folder named 'tensorrt_llm'
920
+ except FileNotFoundError as e :
921
+ # This should capture the errors in the download failure above
922
+ logger .error (f"Wheel file not found at { downloaded_file_path } : { e } " )
923
+ raise RuntimeError (
924
+ f"Failed to find downloaded wheel file at { downloaded_file_path } "
925
+ ) from e
926
+ except zipfile .BadZipFile as e :
927
+ logger .error (f"Invalid or corrupted wheel file: { e } " )
928
+ raise RuntimeError (
929
+ "Downloaded wheel file is corrupted or not a valid zip archive"
930
+ ) from e
931
+ except Exception as e :
932
+ logger .error (f"Unexpected error while extracting wheel: { e } " )
933
+ raise RuntimeError (
934
+ "Unexpected error during extraction of TensorRT-LLM wheel"
935
+ ) from e
936
+ plugin_lib_path = os .path .join (tmpdir , "tensorrt_llm/libs" , lib_filename )
937
+ yield plugin_lib_path
938
+
939
+
940
+ def load_and_initialize_trtllm_plugin (plugin_lib_path : str ) -> bool :
874
941
"""
875
- Attempts to load the TensorRT-LLM plugin and initialize it.
876
- Either the env variable TRTLLM_PLUGINS_PATH can specify the path
877
- Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
942
+ Loads and initializes the TensorRT-LLM plugin from the given shared library path.
943
+
944
+ Args:
945
+ plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
878
946
879
947
Returns:
880
- bool: True if the plugin was successfully loaded and initialized , False otherwise.
948
+ bool: True if successful , False otherwise.
881
949
"""
882
- plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
883
- if not plugin_lib_path :
884
- # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
885
- use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
886
- "1" ,
887
- "true" ,
888
- "yes" ,
889
- "on" ,
890
- )
891
- if not use_trtllm_plugin :
892
- logger .warning (
893
- "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
894
- )
895
- return False
896
- else :
897
- # this is used as the default py version
898
- py_version = "cp310"
899
- platform = Platform .current_platform ()
900
-
901
- platform = str (platform ).lower ()
902
- plugin_lib_path = download_plugin_lib_path (py_version , platform )
903
-
904
950
try :
905
- # Load the shared TRT-LLM file
906
951
handle = ctypes .CDLL (plugin_lib_path )
907
952
logger .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
908
953
except OSError as e_os_error :
@@ -915,14 +960,13 @@ def load_tensorrt_llm() -> bool:
915
960
)
916
961
else :
917
962
logger .warning (
918
- f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } "
919
- f"Ensure the path is correct and the library is compatible" ,
963
+ f"Failed to load libnvinfer_plugin_tensorrt_llm.so from { plugin_lib_path } . "
964
+ f"Ensure the path is correct and the library is compatible. " ,
920
965
exc_info = e_os_error ,
921
966
)
922
967
return False
923
968
924
969
try :
925
- # Configure plugin initialization arguments
926
970
handle .initTrtLlmPlugins .argtypes = [ctypes .c_void_p , ctypes .c_char_p ]
927
971
handle .initTrtLlmPlugins .restype = ctypes .c_bool
928
972
except AttributeError as e_plugin_unavailable :
@@ -933,9 +977,7 @@ def load_tensorrt_llm() -> bool:
933
977
return False
934
978
935
979
try :
936
- # Initialize the plugin
937
- TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
938
- if handle .initTrtLlmPlugins (None , TRT_LLM_PLUGIN_NAMESPACE .encode ("utf-8" )):
980
+ if handle .initTrtLlmPlugins (None , b"tensorrt_llm" ):
939
981
logger .info ("TensorRT-LLM plugin successfully initialized" )
940
982
return True
941
983
else :
@@ -948,3 +990,37 @@ def load_tensorrt_llm() -> bool:
948
990
)
949
991
return False
950
992
return False
993
+
994
+
995
+ def load_tensorrt_llm () -> bool :
996
+ """
997
+ Attempts to load the TensorRT-LLM plugin and initialize it.
998
+ Either the env variable TRTLLM_PLUGINS_PATH can specify the path
999
+ Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
1000
+
1001
+ Returns:
1002
+ bool: True if the plugin was successfully loaded and initialized, False otherwise.
1003
+ """
1004
+ plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
1005
+ if plugin_lib_path :
1006
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1007
+ else :
1008
+ # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1009
+ use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
1010
+ "1" ,
1011
+ "true" ,
1012
+ "yes" ,
1013
+ "on" ,
1014
+ )
1015
+ if not use_trtllm_plugin :
1016
+ logger .warning (
1017
+ "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1018
+ )
1019
+ return False
1020
+ else :
1021
+ platform = Platform .current_platform ()
1022
+ platform = str (platform ).lower ()
1023
+
1024
+ with download_plugin_lib_path (platform ) as plugin_lib_path :
1025
+ return load_and_initialize_trtllm_plugin (plugin_lib_path )
1026
+ return False
0 commit comments