Skip to content

Commit 3ee02bb

Browse files
committed
reactor(cust_raw): Generate bindings for cuda types and restructure cust_raw's crates
- Allow list files instead of types/var/functions. - Split out type headers from runtime/cublas to their own crates.
1 parent 290d711 commit 3ee02bb

20 files changed

+230
-57
lines changed

crates/blastoff/src/context.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ impl CublasContext {
147147
// cudaStream_t is the same as CUstream
148148
cublas_sys::cublasSetStream(
149149
self.raw,
150-
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
151-
stream.as_inner(),
152-
),
150+
mem::transmute::<driver_sys::CUstream, cublas_sys::cudaStream_t>(stream.as_inner()),
153151
)
154152
.to_result()?;
155153
let res = func(self)?;

crates/cust_raw/Cargo.toml

+16-9
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ name = "cust_raw"
33
version = "0.11.3"
44
edition = "2021"
55
license = "MIT OR Apache-2.0"
6-
description = "Low level bindings to the CUDA Driver API"
6+
description = "Low level bindings to the CUDA Toolkit SDK"
77
repository = "https://github.com/Rust-GPU/Rust-CUDA"
88
readme = "../../README.md"
99
links = "cuda"
1010
build = "build/main.rs"
1111

12+
[dependencies]
13+
libc = { version = "0.2", optional = true }
14+
1215
[build-dependencies]
1316
bindgen = "0.71.1"
1417
bimap = "0.6.3"
@@ -19,20 +22,24 @@ features = [
1922
"driver",
2023
"runtime",
2124
"cublas",
22-
"cublaslt",
23-
"cublasxt",
24-
"cudnn",
25+
"cublasLt",
26+
"cublasXt",
2527
"nvptx-compiler",
2628
"nvvm",
2729
]
2830

2931
[features]
3032
default = ["driver"]
3133
driver = []
32-
runtime = []
33-
cublas = []
34-
cublaslt = []
35-
cublasxt = []
36-
cudnn = []
34+
runtime = ["driver_types", "vector_types", "texture_types", "surface_types"]
35+
cuComplex = ["vector_types"]
36+
driver_types = []
37+
library_types = []
38+
surface_types = []
39+
texture_types = []
40+
vector_types = []
41+
cublas = ["runtime", "cuComplex", "library_types"]
42+
cublasLt = ["cublas", "libc"]
43+
cublasXt = ["cublas"]
3744
nvptx-compiler = []
3845
nvvm = []
+1-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
#include "cuComplex.h"
21
#include "cuda.h"
3-
#include "cudaProfiler.h"
4-
#include "vector_types.h"
2+
#include "cudaProfiler.h"

crates/cust_raw/build/main.rs

+111-32
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use std::env;
2828
use std::fs;
2929
use std::path;
3030

31-
pub mod callbacks;
32-
pub mod cuda_sdk;
31+
mod callbacks;
32+
mod cuda_sdk;
3333

3434
fn main() {
3535
let outdir = path::PathBuf::from(
@@ -63,8 +63,9 @@ fn main() {
6363
println!("cargo::rerun-if-env-changed={}", e);
6464
}
6565

66-
create_cuda_driver_bindings(&sdk, outdir.as_path());
67-
create_cuda_runtime_bindings(&sdk, outdir.as_path());
66+
create_driver_bindings(&sdk, outdir.as_path());
67+
create_runtime_bindings(&sdk, outdir.as_path());
68+
create_runtime_types_bindings(&sdk, outdir.as_path());
6869
create_cublas_bindings(&sdk, outdir.as_path());
6970
create_nptx_compiler_bindings(&sdk, outdir.as_path());
7071
create_nvvm_bindings(&sdk, outdir.as_path());
@@ -73,8 +74,8 @@ fn main() {
7374
feature = "driver",
7475
feature = "runtime",
7576
feature = "cublas",
76-
feature = "cublaslt",
77-
feature = "cublasxt"
77+
feature = "cublasLt",
78+
feature = "cublasXt"
7879
)) {
7980
for libdir in sdk.cuda_library_paths() {
8081
println!("cargo::rustc-link-search=native={}", libdir.display());
@@ -84,11 +85,11 @@ fn main() {
8485
if cfg!(feature = "runtime") {
8586
println!("cargo::rustc-link-lib=dylib=cudart");
8687
}
87-
if cfg!(feature = "cublas") || cfg!(feature = "cublasxt") {
88+
if cfg!(feature = "cublas") || cfg!(feature = "cublasXt") {
8889
println!("cargo::rustc-link-lib=dylib=cublas");
8990
}
90-
if cfg!(feature = "cublaslt") {
91-
println!("cargo::rustc-link-lib=dylib=cublaslt");
91+
if cfg!(feature = "cublasLt") {
92+
println!("cargo::rustc-link-lib=dylib=cublasLt");
9293
}
9394
if cfg!(feature = "nvvm") {
9495
for libdir in sdk.nvvm_library_paths() {
@@ -101,7 +102,53 @@ fn main() {
101102
}
102103
}
103104

104-
fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
105+
fn create_runtime_types_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
106+
let params = &[
107+
(cfg!(feature = "driver_types"), "driver_types"),
108+
(cfg!(feature = "library_types"), "library_types"),
109+
(cfg!(feature = "vector_types"), "vector_types"),
110+
(cfg!(feature = "texture_types"), "texture_types"),
111+
(cfg!(feature = "surface_types"), "surface_types"),
112+
(cfg!(feature = "cuComplex"), "cuComplex"),
113+
];
114+
for (should_generate, pkg) in params {
115+
if !should_generate {
116+
continue;
117+
}
118+
let bindgen_path = path::PathBuf::from(format!("{}/{}_sys.rs", outdir.display(), pkg));
119+
let header = sdk
120+
.cuda_root()
121+
.join(format!("include/{}.h", pkg))
122+
.display()
123+
.to_string();
124+
let bindings = bindgen::Builder::default()
125+
.header(&header)
126+
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
127+
.clang_args(
128+
sdk.cuda_include_paths()
129+
.iter()
130+
.map(|p| format!("-I{}", p.display())),
131+
)
132+
.allowlist_file(format!(r".*{pkg}\.h"))
133+
.allowlist_recursively(false)
134+
.default_enum_style(bindgen::EnumVariation::Rust {
135+
non_exhaustive: false,
136+
})
137+
.derive_default(true)
138+
.derive_eq(true)
139+
.derive_hash(true)
140+
.derive_ord(true)
141+
.size_t_is_usize(true)
142+
.layout_tests(true)
143+
.generate()
144+
.unwrap_or_else(|e| panic!("Unable to generate {pkg} bindings: {e}"));
145+
bindings
146+
.write_to_file(bindgen_path.as_path())
147+
.unwrap_or_else(|e| panic!("Cannot write {pkg} bindgen output to file: {e}"));
148+
}
149+
}
150+
151+
fn create_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
105152
if !cfg!(feature = "driver") {
106153
return;
107154
}
@@ -121,13 +168,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
121168
.iter()
122169
.map(|p| format!("-I{}", p.display())),
123170
)
124-
.allowlist_type("^CU.*")
125-
.allowlist_type("^cuuint(32|64)_t")
126-
.allowlist_type("^cudaError_enum")
127-
.allowlist_type("^cu.*Complex$")
128-
.allowlist_type("^cuda.*")
129-
.allowlist_var("^CU.*")
130-
.allowlist_function("^cu.*")
171+
.allowlist_file(r".*cuda[^/\\]*\.h")
131172
.default_enum_style(bindgen::EnumVariation::Rust {
132173
non_exhaustive: false,
133174
})
@@ -145,7 +186,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
145186
.expect("Cannot write CUDA driver bindgen output to file.");
146187
}
147188

148-
fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
189+
fn create_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
149190
if !cfg!(feature = "runtime") {
150191
return;
151192
}
@@ -165,14 +206,13 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
165206
.iter()
166207
.map(|p| format!("-I{}", p.display())),
167208
)
168-
.allowlist_type("^CU.*")
169-
.allowlist_type("^cuda.*")
170-
.allowlist_type("^libraryPropertyType.*")
171-
.allowlist_var("^CU.*")
172-
.allowlist_function("^cu.*")
209+
.allowlist_file(r".*cuda[^/\\]*\.h")
210+
.allowlist_file(r".*cuComplex\.h")
211+
.allowlist_recursively(false)
173212
.default_enum_style(bindgen::EnumVariation::Rust {
174213
non_exhaustive: false,
175214
})
215+
.disable_nested_struct_naming()
176216
.derive_default(true)
177217
.derive_eq(true)
178218
.derive_hash(true)
@@ -188,19 +228,51 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
188228
}
189229

190230
fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
191-
#[rustfmt::skip]
192231
let params = &[
193-
(cfg!(feature = "cublas"), "cublas", "^cublas.*", "^CUBLAS.*"),
194-
(cfg!(feature = "cublaslt"), "cublasLt", "^cublasLt.*", "^CUBLASLT.*"),
195-
(cfg!(feature = "cublasxt"), "cublasXt", "^cublasXt.*", "^CUBLASXT.*"),
232+
(
233+
cfg!(feature = "cublas"),
234+
"cublas",
235+
vec![r".*cublas(_api|_v2)\.h"],
236+
vec![
237+
r".*cuComplex\.h",
238+
r".*driver_types\.h",
239+
r".*library_types\.h",
240+
r".*vector_types\.h",
241+
],
242+
),
243+
(
244+
cfg!(feature = "cublasLt"),
245+
"cublasLt",
246+
vec![r".*cublasLt\.h"],
247+
vec![
248+
r".*cublas(_api|_v2)*\.h",
249+
r".*cuComplex\.h",
250+
r".*driver_types\.h",
251+
r".*library_types\.h",
252+
r".*vector_types\.h",
253+
r".*std\w+\.h",
254+
],
255+
),
256+
(
257+
cfg!(feature = "cublasXt"),
258+
"cublasXt",
259+
vec![r".*cublasXt\.h"],
260+
vec![
261+
r".*cublas(_api|_v2)*\.h",
262+
r".*cuComplex\.h",
263+
r".*driver_types\.h",
264+
r".*library_types\.h",
265+
r".*vector_types\.h",
266+
],
267+
),
196268
];
197-
for (should_generate, pkg, tf, var) in params {
269+
for (should_generate, pkg, allowed, blocked) in params {
198270
if !should_generate {
199271
continue;
200272
}
201273
let bindgen_path = path::PathBuf::from(format!("{}/{pkg}_sys.rs", outdir.display()));
202274
let header = format!("build/{pkg}_wrapper.h");
203-
let bindings = bindgen::Builder::default()
275+
let mut bindings = bindgen::Builder::default()
204276
.header(&header)
205277
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
206278
pkg,
@@ -214,9 +286,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
214286
.iter()
215287
.map(|p| format!("-I{}", p.display())),
216288
)
217-
.allowlist_type(tf)
218-
.allowlist_function(tf)
219-
.allowlist_var(var)
289+
.allowlist_recursively(false);
290+
291+
for file in allowed {
292+
bindings = bindings.allowlist_file(file);
293+
}
294+
for file in blocked {
295+
bindings = bindings.blocklist_file(file);
296+
}
297+
298+
let bindings = bindings
220299
.default_enum_style(bindgen::EnumVariation::Rust {
221300
non_exhaustive: false,
222301
})

crates/cust_raw/src/cublas_sys.rs

-5
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
#![allow(non_upper_case_globals)]
21
#![allow(non_camel_case_types)]
32
#![allow(non_snake_case)]
43

4+
use libc::FILE;
5+
6+
use super::*;
7+
use crate::types::driver::*;
8+
use crate::types::library::*;
9+
510
include!(concat!(env!("OUT_DIR"), "/cublasLt_sys.rs"));

crates/cust_raw/src/cublas_sys/mod.rs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//! Bindings to the CUDA Basic Linear Algebra Subprograms (cuBLAS) library.
2+
#![allow(non_upper_case_globals)]
3+
#![allow(non_camel_case_types)]
4+
#![allow(non_snake_case)]
5+
6+
use crate::types::library::*;
7+
8+
pub use crate::runtime_sys::cudaStream_t;
9+
pub use crate::types::complex::*;
10+
pub use crate::types::library::cudaDataType;
11+
12+
include!(concat!(env!("OUT_DIR"), "/cublas_sys.rs"));
13+
14+
#[cfg(feature = "cublasLt")]
15+
pub mod lt;
16+
17+
#[cfg(feature = "cublasXt")]
18+
pub mod xt;

crates/cust_raw/src/cublasxt_sys.rs renamed to crates/cust_raw/src/cublas_sys/xt.rs

+2
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
#![allow(non_camel_case_types)]
33
#![allow(non_snake_case)]
44

5+
use super::*;
6+
57
include!(concat!(env!("OUT_DIR"), "/cublasXt_sys.rs"));

crates/cust_raw/src/driver_sys.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Bindings to the CUDA Driver API
2+
13
#![allow(non_upper_case_globals)]
24
#![allow(non_camel_case_types)]
35
#![allow(non_snake_case)]

crates/cust_raw/src/lib.rs

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
1+
//! # `cust_raw`: Bindings to the CUDA Toolkit SDK
2+
//!
13
#[cfg(feature = "driver")]
24
pub mod driver_sys;
5+
36
#[cfg(feature = "runtime")]
47
pub mod runtime_sys;
58

9+
#[cfg(any(
10+
feature = "driver_types",
11+
feature = "vector_types",
12+
feature = "texture_types",
13+
feature = "surface_types",
14+
feature = "cuComplex",
15+
feature = "library_types"
16+
))]
17+
pub mod types;
18+
619
#[cfg(feature = "cublas")]
720
pub mod cublas_sys;
8-
#[cfg(feature = "cublaslt")]
9-
pub mod cublaslt_sys;
10-
#[cfg(feature = "cublasxt")]
11-
pub mod cublasxt_sys;
1221

1322
#[cfg(feature = "nvptx-compiler")]
1423
pub mod nvptx_compiler_sys;
24+
1525
#[cfg(feature = "nvvm")]
1626
pub mod nvvm_sys;

crates/cust_raw/src/nvptx_compiler_sys.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Bindings to the NVPTX Compiler library.
2+
13
#![allow(non_upper_case_globals)]
24
#![allow(non_camel_case_types)]
35
#![allow(non_snake_case)]

crates/cust_raw/src/nvvm_sys.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
//! Bindings to the libNVVM API, an interface for generating PTX code from both
2+
//! binary and text NVVM IR inputs.
3+
14
#![allow(non_upper_case_globals)]
25
#![allow(non_camel_case_types)]
36
#![allow(non_snake_case)]

crates/cust_raw/src/runtime_sys.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
//! Bindings to the CUDA Runtime API
12
#![allow(non_upper_case_globals)]
23
#![allow(non_camel_case_types)]
34
#![allow(non_snake_case)]
45

6+
pub use crate::types::driver::*;
7+
pub use crate::types::surface::*;
8+
pub use crate::types::texture::*;
9+
pub use crate::types::vector::dim3;
10+
511
include!(concat!(env!("OUT_DIR"), "/runtime_sys.rs"));

0 commit comments

Comments
 (0)