@@ -84,11 +84,26 @@ def _detect_local_cuda_toolkit(repository_ctx):
8484 link_stub_label = link_stub ,
8585 bin2c_label = bin2c ,
8686 fatbinary_label = fatbinary ,
87+ cicc_label = None , # local CTK do not need this
88+ libdevice_label = None , # local CTK do not need this
8789 )
8890
8991def _detect_deliverable_cuda_toolkit (repository_ctx ):
92+ cuda_version_str = repository_ctx .attr .version
93+ if cuda_version_str == None or cuda_version_str == "" :
94+ fail ("attr version is required." )
95+
96+ nvcc_version_str = repository_ctx .attr .nvcc_version
97+ if nvcc_version_str == None or nvcc_version_str == "" :
98+ nvcc_version_str = cuda_version_str
99+
100+ cuda_version_major , cuda_version_minor = cuda_version_str .split ("." )[:2 ]
101+ nvcc_version_major , nvcc_version_minor = nvcc_version_str .split ("." )[:2 ]
102+
90103 # NOTE: component nvcc contains some headers that will be used.
91104 required_components = ["cccl" , "cudart" , "nvcc" ]
105+ if int (cuda_version_major ) >= 13 :
106+ required_components .extend (["crt" , "nvvm" ])
92107 for rc in required_components :
93108 if rc not in repository_ctx .attr .components_mapping :
94109 fail ('component "{}" is required.' .format (rc ))
@@ -102,16 +117,12 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
102117 bin2c = "{}//:nvcc/bin/bin2c{}" .format (nvcc_repo , bin_ext )
103118 fatbinary = "{}//:nvcc/bin/fatbinary{}" .format (nvcc_repo , bin_ext )
104119
105- cuda_version_str = repository_ctx .attr .version
106- if cuda_version_str == None or cuda_version_str == "" :
107- fail ("attr version is required." )
108-
109- nvcc_version_str = repository_ctx .attr .nvcc_version
110- if nvcc_version_str == None or nvcc_version_str == "" :
111- nvcc_version_str = cuda_version_str
112-
113- cuda_version_major , cuda_version_minor = cuda_version_str .split ("." )[:2 ]
114- nvcc_version_major , nvcc_version_minor = nvcc_version_str .split ("." )[:2 ]
120+ cicc = None
121+ libdevice = None
122+ if int (cuda_version_major ) >= 13 :
123+ nvvm_repo = repository_ctx .attr .components_mapping ["nvvm" ]
124+ cicc = "{}//:nvvm/nvvm/bin/cicc{}" .format (nvvm_repo , bin_ext ) # TODO: can we use @cuda//:cicc?
125+ libdevice = "{}//:nvvm/nvvm/libdevice/libdevice.10.bc" .format (nvvm_repo ) # TODO: can we use @cuda//:libdevice?
115126
116127 return struct (
117128 path = None , # scattered components
@@ -124,6 +135,8 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
124135 link_stub_label = link_stub ,
125136 bin2c_label = bin2c ,
126137 fatbinary_label = fatbinary ,
138+ cicc_label = cicc ,
139+ libdevice_label = libdevice ,
127140 )
128141
129142def detect_cuda_toolkit (repository_ctx ):
@@ -188,7 +201,7 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
188201 )
189202
190203 # Generate @cuda//defs.bzl
191- template_helper .generate_defs_bzl (repository_ctx , is_local_ctk == True )
204+ template_helper .generate_defs_bzl (repository_ctx , cuda . version_major , cuda . version_minor , is_local_ctk == True )
192205
193206 # Generate @cuda//toolchain/BUILD
194207 template_helper .generate_toolchain_build (repository_ctx , cuda )
@@ -283,6 +296,44 @@ cuda_toolkit = repository_rule(
283296 # remotable = True,
284297)
285298
299+ def _patch_nvcc_profile_pre (repository_ctx , component_name ):
300+ """nvcc after 13 needs to adjust nvcc.profile to support distributed components"""
301+
302+ patch_nvcc_profile = False
303+ rename_files = {}
304+
305+ if component_name != "nvcc" :
306+ return patch_nvcc_profile , rename_files
307+ if getattr (repository_ctx .attr , "version" ) == None or repository_ctx .attr .version == "" :
308+ fail ("attribute `version` must be filled for 'nvcc' component" )
309+
310+ nvcc_major_version = int (repository_ctx .attr .version .split ("." )[0 ])
311+ if nvcc_major_version >= 13 :
312+ patch_nvcc_profile = True
313+
314+ if patch_nvcc_profile :
315+ nvcc_profile = repository_ctx .attr .strip_prefix + "/bin/nvcc.profile"
316+ rename_files [nvcc_profile ] = nvcc_profile + ".renamed_by_rules_cuda"
317+
318+ return patch_nvcc_profile , rename_files
319+
320+ def _patch_nvcc_profile_post (repository_ctx , patch_nvcc_profile ):
321+ if not patch_nvcc_profile :
322+ return
323+
324+ nvcc_profile_content = repository_ctx .read ("nvcc/bin/nvcc.profile.renamed_by_rules_cuda" )
325+ lines = nvcc_profile_content .split ("\n " )
326+ for i , line in enumerate (lines ):
327+ key_to_replace = ["CICC_PATH" , "NVVMIR_LIBRARY_DIR" ]
328+ for key in key_to_replace :
329+ s = line .find (key )
330+ if s == 0 or (s > 0 and line [s - 1 ] != "(" and line [s - 1 ] != "%" ): # ensure it is a env key assignment, not a reference
331+ # we will then pass the env from outside to
332+ new_line = key + " ?= " + key + "/in/nvcc.profile/replaced/by/rules_cuda/but/not/set/at/runtime"
333+ lines [i ] = new_line
334+ nvcc_profile_content = "\n " .join (lines )
335+ repository_ctx .file ("nvcc/bin/nvcc.profile" , nvcc_profile_content )
336+
286337def _cuda_component_impl (repository_ctx ):
287338 component_name = None
288339 if repository_ctx .attr .component_name :
@@ -300,14 +351,19 @@ def _cuda_component_impl(repository_ctx):
300351 if repository_ctx .attr .url and repository_ctx .attr .urls :
301352 fail ("attributes `url` and `urls` cannot be used at the same time" )
302353
354+ patch_nvcc_profile , rename_files = _patch_nvcc_profile_pre (repository_ctx , component_name )
355+
303356 repository_ctx .download_and_extract (
304357 url = repository_ctx .attr .url or repository_ctx .attr .urls ,
305358 output = component_name ,
306359 integrity = repository_ctx .attr .integrity ,
307360 sha256 = repository_ctx .attr .sha256 ,
308361 stripPrefix = repository_ctx .attr .strip_prefix ,
362+ rename_files = rename_files ,
309363 )
310364
365+ _patch_nvcc_profile_post (repository_ctx , patch_nvcc_profile )
366+
311367 template_helper .generate_build (
312368 repository_ctx ,
313369 libpath = "lib" ,
0 commit comments