Skip to content

Commit 796e9d1

Browse files
authored
feat: add support for CTK 13 (#410)
* fix: make cuda 13 compile * feat: add if_cuda_toolkit_version_ge * fix: address split of culibos * fix: address cccl subdirectory * fix: update toolchain configuration for standalone nvvm repo * fix: add npp uber target which depends on nppc, nppi and npps * fix: depends on crt after 13 * fix: make cicc_label and libdevice_label optional * fix: make version attr mandatory for nvcc Since cuda 13, nvcc deliverable does not contain any header file that have a version string hardcoded. `repository_ctx.execute` is not for our case as the binary may not be compatible with the host. * ci: add cuda 13 in build tests * test: stop specifying components as it is optional * fix: make cuda_toolkit_info cicc and libdevice optional for local ctk * fix: also add CICC_PATH and NVVMIR_LIBRARY_DIR env for nvcc msvc toolchain configuration * ci: add crt and nvvm starts with ctk 13 * style: format * ci: disable clang test in ci for ctk 13 due to bug upstream
1 parent 27e407d commit 796e9d1

File tree

24 files changed

+252
-53
lines changed

24 files changed

+252
-53
lines changed

.github/actions/set-build-env/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ runs:
6868
if: ${{ startsWith(inputs.os, 'windows') }}
6969
with:
7070
cuda: ${{ inputs.cuda-version }}
71-
sub-packages: '["nvcc", "cudart"]'
71+
sub-packages: ${{ (startsWith(inputs.cuda-version, '11') || startsWith(inputs.cuda-version, '12')) && '["nvcc", "cudart"]' || '["nvcc", "cudart", "crt", "nvvm"]' }}
7272
method: network
7373
- name: Show bin, include, lib64 (Windows)
7474
if: ${{ startsWith(inputs.os, 'windows') }}

.github/workflows/build-tests.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jobs:
1717
cases:
1818
- { os: "ubuntu-22.04", cuda-version: "11.7.0", source: "nvidia" }
1919
- { os: "ubuntu-24.04", cuda-version: "12.6.3", source: "nvidia" }
20+
- { os: "ubuntu-24.04", cuda-version: "13.0.2", source: "nvidia" }
2021
- {
2122
os: "ubuntu-22.04",
2223
cuda-version: "11.7.0",
@@ -31,13 +32,21 @@ jobs:
3132
toolchain: "llvm_host_device", # clang as cuda compiler driver
3233
toolchain-version: "19",
3334
}
35+
# - { # FIXME: enable once llvm 22 is released
36+
# os: "ubuntu-24.04",
37+
# cuda-version: "13.0.2",
38+
# source: "nvidia",
39+
# toolchain: "llvm_host_only", # clang as host compiler
40+
# toolchain-version: "21",
41+
# }
3442
- {
3543
os: "ubuntu-22.04",
3644
cuda-version: "11.5.1-1ubuntu1",
3745
source: "ubuntu",
3846
}
3947
- { os: "windows-2022", cuda-version: "11.5.2", source: "nvidia" }
4048
- { os: "windows-2025", cuda-version: "12.6.3", source: "nvidia" }
49+
- { os: "windows-2025", cuda-version: "13.0.2", source: "nvidia" }
4150
steps:
4251
- uses: actions/checkout@v5
4352

cuda/private/providers.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ CudaToolkitInfo = provider(
107107
"link_stub": "File to the link.stub file",
108108
"bin2c": "File to the bin2c executable",
109109
"fatbinary": "File to the fatbinary executable",
110+
"cicc": "File to the cicc executable",
111+
"libdevice": "File to the libdevice LLVM bitcode library (libdevice.10.bc)",
110112
},
111113
)
112114

cuda/private/repositories.bzl

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,26 @@ def _detect_local_cuda_toolkit(repository_ctx):
8484
link_stub_label = link_stub,
8585
bin2c_label = bin2c,
8686
fatbinary_label = fatbinary,
87+
cicc_label = None, # local CTK do not need this
88+
libdevice_label = None, # local CTK do not need this
8789
)
8890

8991
def _detect_deliverable_cuda_toolkit(repository_ctx):
92+
cuda_version_str = repository_ctx.attr.version
93+
if cuda_version_str == None or cuda_version_str == "":
94+
fail("attr version is required.")
95+
96+
nvcc_version_str = repository_ctx.attr.nvcc_version
97+
if nvcc_version_str == None or nvcc_version_str == "":
98+
nvcc_version_str = cuda_version_str
99+
100+
cuda_version_major, cuda_version_minor = cuda_version_str.split(".")[:2]
101+
nvcc_version_major, nvcc_version_minor = nvcc_version_str.split(".")[:2]
102+
90103
# NOTE: component nvcc contains some headers that will be used.
91104
required_components = ["cccl", "cudart", "nvcc"]
105+
if int(cuda_version_major) >= 13:
106+
required_components.extend(["crt", "nvvm"])
92107
for rc in required_components:
93108
if rc not in repository_ctx.attr.components_mapping:
94109
fail('component "{}" is required.'.format(rc))
@@ -102,16 +117,12 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
102117
bin2c = "{}//:nvcc/bin/bin2c{}".format(nvcc_repo, bin_ext)
103118
fatbinary = "{}//:nvcc/bin/fatbinary{}".format(nvcc_repo, bin_ext)
104119

105-
cuda_version_str = repository_ctx.attr.version
106-
if cuda_version_str == None or cuda_version_str == "":
107-
fail("attr version is required.")
108-
109-
nvcc_version_str = repository_ctx.attr.nvcc_version
110-
if nvcc_version_str == None or nvcc_version_str == "":
111-
nvcc_version_str = cuda_version_str
112-
113-
cuda_version_major, cuda_version_minor = cuda_version_str.split(".")[:2]
114-
nvcc_version_major, nvcc_version_minor = nvcc_version_str.split(".")[:2]
120+
cicc = None
121+
libdevice = None
122+
if int(cuda_version_major) >= 13:
123+
nvvm_repo = repository_ctx.attr.components_mapping["nvvm"]
124+
cicc = "{}//:nvvm/nvvm/bin/cicc{}".format(nvvm_repo, bin_ext) # TODO: can we use @cuda//:cicc?
125+
libdevice = "{}//:nvvm/nvvm/libdevice/libdevice.10.bc".format(nvvm_repo) # TODO: can we use @cuda//:libdevice?
115126

116127
return struct(
117128
path = None, # scattered components
@@ -124,6 +135,8 @@ def _detect_deliverable_cuda_toolkit(repository_ctx):
124135
link_stub_label = link_stub,
125136
bin2c_label = bin2c,
126137
fatbinary_label = fatbinary,
138+
cicc_label = cicc,
139+
libdevice_label = libdevice,
127140
)
128141

129142
def detect_cuda_toolkit(repository_ctx):
@@ -188,7 +201,7 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
188201
)
189202

190203
# Generate @cuda//defs.bzl
191-
template_helper.generate_defs_bzl(repository_ctx, is_local_ctk == True)
204+
template_helper.generate_defs_bzl(repository_ctx, cuda.version_major, cuda.version_minor, is_local_ctk == True)
192205

193206
# Generate @cuda//toolchain/BUILD
194207
template_helper.generate_toolchain_build(repository_ctx, cuda)
@@ -283,6 +296,44 @@ cuda_toolkit = repository_rule(
283296
# remotable = True,
284297
)
285298

299+
def _patch_nvcc_profile_pre(repository_ctx, component_name):
300+
"""nvcc after 13 needs to adjust nvcc.profile to support distributed components"""
301+
302+
patch_nvcc_profile = False
303+
rename_files = {}
304+
305+
if component_name != "nvcc":
306+
return patch_nvcc_profile, rename_files
307+
if getattr(repository_ctx.attr, "version") == None or repository_ctx.attr.version == "":
308+
fail("attribute `version` must be filled for 'nvcc' component")
309+
310+
nvcc_major_version = int(repository_ctx.attr.version.split(".")[0])
311+
if nvcc_major_version >= 13:
312+
patch_nvcc_profile = True
313+
314+
if patch_nvcc_profile:
315+
nvcc_profile = repository_ctx.attr.strip_prefix + "/bin/nvcc.profile"
316+
rename_files[nvcc_profile] = nvcc_profile + ".renamed_by_rules_cuda"
317+
318+
return patch_nvcc_profile, rename_files
319+
320+
def _patch_nvcc_profile_post(repository_ctx, patch_nvcc_profile):
321+
if not patch_nvcc_profile:
322+
return
323+
324+
nvcc_profile_content = repository_ctx.read("nvcc/bin/nvcc.profile.renamed_by_rules_cuda")
325+
lines = nvcc_profile_content.split("\n")
326+
for i, line in enumerate(lines):
327+
key_to_replace = ["CICC_PATH", "NVVMIR_LIBRARY_DIR"]
328+
for key in key_to_replace:
329+
s = line.find(key)
330+
if s == 0 or (s > 0 and line[s-1] != "(" and line[s-1] != "%"): # ensure it is a env key assignment, not a reference
331+
# we will then pass the env from outside to
332+
new_line = key + " ?= " + key + "/in/nvcc.profile/replaced/by/rules_cuda/but/not/set/at/runtime"
333+
lines[i] = new_line
334+
nvcc_profile_content = "\n".join(lines)
335+
repository_ctx.file("nvcc/bin/nvcc.profile", nvcc_profile_content)
336+
286337
def _cuda_component_impl(repository_ctx):
287338
component_name = None
288339
if repository_ctx.attr.component_name:
@@ -300,14 +351,19 @@ def _cuda_component_impl(repository_ctx):
300351
if repository_ctx.attr.url and repository_ctx.attr.urls:
301352
fail("attributes `url` and `urls` cannot be used at the same time")
302353

354+
patch_nvcc_profile, rename_files = _patch_nvcc_profile_pre(repository_ctx, component_name)
355+
303356
repository_ctx.download_and_extract(
304357
url = repository_ctx.attr.url or repository_ctx.attr.urls,
305358
output = component_name,
306359
integrity = repository_ctx.attr.integrity,
307360
sha256 = repository_ctx.attr.sha256,
308361
stripPrefix = repository_ctx.attr.strip_prefix,
362+
rename_files = rename_files,
309363
)
310364

365+
_patch_nvcc_profile_post(repository_ctx, patch_nvcc_profile)
366+
311367
template_helper.generate_build(
312368
repository_ctx,
313369
libpath = "lib",

cuda/private/rules/cuda_toolkit_info.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def _impl(ctx):
1212
link_stub = ctx.file.link_stub,
1313
bin2c = ctx.file.bin2c,
1414
fatbinary = ctx.file.fatbinary,
15+
cicc = ctx.file.cicc,
16+
libdevice = ctx.file.libdevice,
1517
)
1618

1719
cuda_toolkit_info = rule(
@@ -25,6 +27,8 @@ cuda_toolkit_info = rule(
2527
"link_stub": attr.label(allow_single_file = True, doc = "The link.stub text file."),
2628
"bin2c": attr.label(allow_single_file = True, doc = "The bin2c executable."),
2729
"fatbinary": attr.label(allow_single_file = True, doc = "The fatbinary executable."),
30+
"cicc": attr.label(default = None, allow_single_file = True, doc = "The cicc executable."),
31+
"libdevice": attr.label(default = None, allow_single_file = True, doc = "The libdevice LLVM bitcode library."),
2832
},
2933
provides = [CudaToolkitInfo],
3034
)

cuda/private/template_helper.bzl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ def _generate_build(repository_ctx, libpath, components = None, is_cuda_repo = T
127127

128128
_generate_build_impl(repository_ctx, libpath, components, is_cuda_repo, is_deliverable)
129129

130-
def _generate_defs_bzl(repository_ctx, is_local_ctk):
130+
def _generate_defs_bzl(repository_ctx, version_major, version_minor, is_local_ctk):
131131
tpl_label = Label("//cuda/private:templates/defs.bzl.tpl")
132132
substitutions = {
133+
"%{version_major}": str(version_major),
134+
"%{version_minor}": str(version_minor),
133135
"%{is_local_ctk}": str(is_local_ctk),
134136
}
135137
repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False)
@@ -185,7 +187,16 @@ def _generate_toolchain_build(repository_ctx, cuda):
185187
"//cuda/private:templates/BUILD.toolchain_" +
186188
("nvcc" if _is_linux(repository_ctx) else "nvcc_msvc"),
187189
)
190+
compiler_files = ["@cuda//:compiler_deps"]
191+
if int(cuda.version_major) >= 13:
192+
if cuda.cicc_label != None:
193+
compiler_files.append(cuda.cicc_label)
194+
if cuda.libdevice_label != None:
195+
compiler_files.append(cuda.libdevice_label)
196+
compiler_files_line = "compiler_files = " + repr(compiler_files) + ","
197+
188198
substitutions = {
199+
"# %{compiler_files_line}": compiler_files_line,
189200
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found",
190201
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
191202
"%{nvcc_version_major}": str(cuda.nvcc_version_major),
@@ -196,6 +207,11 @@ def _generate_toolchain_build(repository_ctx, cuda):
196207
"%{bin2c_label}": cuda.bin2c_label,
197208
"%{fatbinary_label}": cuda.fatbinary_label,
198209
}
210+
if cuda.cicc_label:
211+
substitutions["# %{cicc_line}"] = "cicc = " + repr(cuda.cicc_label)
212+
if cuda.libdevice_label:
213+
substitutions["# %{libdevice_line}"] = "libdevice = " + repr(cuda.libdevice_label)
214+
199215
env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None))
200216
if env_tmp != None:
201217
substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp)
@@ -256,6 +272,10 @@ def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path_or_label):
256272
"%{bin2c_label}": cuda.bin2c_label,
257273
"%{fatbinary_label}": cuda.fatbinary_label,
258274
}
275+
if cuda.cicc_label:
276+
substitutions["# %{cicc_line}"] = "cicc = " + repr(cuda.cicc_label)
277+
if cuda.libdevice_label:
278+
substitutions["# %{libdevice_line}"] = "libdevice = " + repr(cuda.libdevice_label)
259279

260280
if clang_label_for_subst:
261281
substitutions.pop("%{clang_path}")

cuda/private/templates/BUILD.cccl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,57 @@
11
cc_library(
22
name = "libcudacxx",
33
hdrs = glob(
4-
[
5-
"%{component_name}/include/cuda/**",
6-
"%{component_name}/include/nv/**",
7-
],
4+
if_cuda_toolkit_version_ge(
5+
(13, 0),
6+
[
7+
"%{component_name}/include/cccl/cuda/**",
8+
"%{component_name}/include/cccl/nv/**",
9+
],
10+
[
11+
"%{component_name}/include/cuda/**",
12+
"%{component_name}/include/nv/**",
13+
],
14+
),
815
allow_empty = True,
916
),
10-
includes = [
11-
"%{component_name}/include",
12-
],
17+
includes = if_cuda_toolkit_version_ge(
18+
(13, 0),
19+
["%{component_name}/include/cccl"],
20+
["%{component_name}/include"],
21+
),
1322
)
1423

1524
cc_library(
1625
name = "cub",
1726
hdrs = glob(
18-
["%{component_name}/include/cub/**"],
27+
if_cuda_toolkit_version_ge(
28+
(13, 0),
29+
["%{component_name}/include/cccl/cub/**"],
30+
["%{component_name}/include/cub/**"],
31+
),
1932
allow_empty = True,
2033
),
21-
includes = [
22-
"%{component_name}/include",
23-
],
34+
includes = if_cuda_toolkit_version_ge(
35+
(13, 0),
36+
["%{component_name}/include/cccl"],
37+
["%{component_name}/include"],
38+
),
2439
)
2540

2641
cc_library(
2742
name = "thrust",
2843
hdrs = glob(
29-
["%{component_name}/include/thrust/**"],
44+
if_cuda_toolkit_version_ge(
45+
(13, 0),
46+
["%{component_name}/include/cccl/thrust/**"],
47+
["%{component_name}/include/thrust/**"],
48+
),
3049
allow_empty = True,
3150
),
32-
includes = [
33-
"%{component_name}/include",
34-
],
51+
includes = if_cuda_toolkit_version_ge(
52+
(13, 0),
53+
["%{component_name}/include/cccl"],
54+
["%{component_name}/include"],
55+
),
3556
deps = [":cub"],
3657
)

cuda/private/templates/BUILD.crt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
cc_library(
2+
name = "crt",
3+
hdrs = glob(
4+
["%{component_name}/include/crt/**"],
5+
allow_empty = True,
6+
),
7+
includes = [
8+
"%{component_name}/include",
9+
],
10+
)

cuda/private/templates/BUILD.cuda_shared

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@bazel_skylib//rules:common_settings.bzl", "bool_setting") # @unused
2-
load("@cuda//:defs.bzl", "additional_header_deps", "if_local_cuda_toolkit") # @unused
2+
load("@cuda//:defs.bzl", "additional_header_deps", "if_cuda_toolkit_version_ge", "if_local_cuda_toolkit") # @unused
33
load("@rules_cuda//cuda:defs.bzl", "cc_import_versioned_sos", "if_linux", "if_windows") # @unused
44

55
package(

cuda/private/templates/BUILD.cudart

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@ cc_library(
99
target_compatible_with = ["@platforms//os:linux"],
1010
)
1111

12-
cc_library(
13-
name = "culibos_a",
14-
srcs = ["%{component_name}/%{libpath}/libculibos.a"],
15-
target_compatible_with = ["@platforms//os:linux"],
12+
# NOTE: for cuda toolkit older than 13.0, culibos_a is defined here, BUILD.culibos otherwise.
13+
[
14+
cc_library(
15+
name = "culibos_a",
16+
srcs = ["%{component_name}/%{libpath}/libculibos.a"],
17+
target_compatible_with = ["@platforms//os:linux"],
18+
)
19+
for name in if_cuda_toolkit_version_ge((13,0), [], ["culibos"])
20+
]
21+
22+
conditional_culibos_a = if_cuda_toolkit_version_ge(
23+
(13,0),
24+
if_local_cuda_toolkit([":culibos_a"], ["@cuda//:culibos_a"]),
25+
[":culibos_a"],
1626
)
1727

1828
cc_import(
@@ -43,9 +53,8 @@ cc_library(
4353
] + if_linux([
4454
# devrt is required for jit linking when rdc is enabled
4555
":cudadevrt_a",
46-
":culibos_a",
4756
":cudart_so",
48-
]) + if_windows([
57+
] + conditional_culibos_a) + if_windows([
4958
# devrt is required for jit linking when rdc is enabled
5059
":cudadevrt_lib",
5160
":cudart_lib",

0 commit comments

Comments
 (0)