diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 7aff2db324..9152229d2f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -88,7 +88,7 @@ def setup_pytorch_extension( # Libraries library_dirs = [] libraries = [] - if os.getenv("NVTE_UB_WITH_MPI"): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"