@@ -28,8 +28,8 @@ use std::env;
28
28
use std:: fs;
29
29
use std:: path;
30
30
31
- pub mod callbacks;
32
- pub mod cuda_sdk;
31
+ mod callbacks;
32
+ mod cuda_sdk;
33
33
34
34
fn main ( ) {
35
35
let outdir = path:: PathBuf :: from (
@@ -63,8 +63,9 @@ fn main() {
63
63
println ! ( "cargo::rerun-if-env-changed={}" , e) ;
64
64
}
65
65
66
- create_cuda_driver_bindings ( & sdk, outdir. as_path ( ) ) ;
67
- create_cuda_runtime_bindings ( & sdk, outdir. as_path ( ) ) ;
66
+ create_driver_bindings ( & sdk, outdir. as_path ( ) ) ;
67
+ create_runtime_bindings ( & sdk, outdir. as_path ( ) ) ;
68
+ create_runtime_types_bindings ( & sdk, outdir. as_path ( ) ) ;
68
69
create_cublas_bindings ( & sdk, outdir. as_path ( ) ) ;
69
70
create_nptx_compiler_bindings ( & sdk, outdir. as_path ( ) ) ;
70
71
create_nvvm_bindings ( & sdk, outdir. as_path ( ) ) ;
@@ -73,8 +74,8 @@ fn main() {
73
74
feature = "driver" ,
74
75
feature = "runtime" ,
75
76
feature = "cublas" ,
76
- feature = "cublaslt " ,
77
- feature = "cublasxt "
77
+ feature = "cublasLt " ,
78
+ feature = "cublasXt "
78
79
) ) {
79
80
for libdir in sdk. cuda_library_paths ( ) {
80
81
println ! ( "cargo::rustc-link-search=native={}" , libdir. display( ) ) ;
@@ -84,11 +85,11 @@ fn main() {
84
85
if cfg ! ( feature = "runtime" ) {
85
86
println ! ( "cargo::rustc-link-lib=dylib=cudart" ) ;
86
87
}
87
- if cfg ! ( feature = "cublas" ) || cfg ! ( feature = "cublasxt " ) {
88
+ if cfg ! ( feature = "cublas" ) || cfg ! ( feature = "cublasXt " ) {
88
89
println ! ( "cargo::rustc-link-lib=dylib=cublas" ) ;
89
90
}
90
- if cfg ! ( feature = "cublaslt " ) {
91
- println ! ( "cargo::rustc-link-lib=dylib=cublaslt " ) ;
91
+ if cfg ! ( feature = "cublasLt " ) {
92
+ println ! ( "cargo::rustc-link-lib=dylib=cublasLt " ) ;
92
93
}
93
94
if cfg ! ( feature = "nvvm" ) {
94
95
for libdir in sdk. nvvm_library_paths ( ) {
@@ -101,7 +102,53 @@ fn main() {
101
102
}
102
103
}
103
104
104
- fn create_cuda_driver_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
105
+ fn create_runtime_types_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
106
+ let params = & [
107
+ ( cfg ! ( feature = "driver_types" ) , "driver_types" ) ,
108
+ ( cfg ! ( feature = "library_types" ) , "library_types" ) ,
109
+ ( cfg ! ( feature = "vector_types" ) , "vector_types" ) ,
110
+ ( cfg ! ( feature = "texture_types" ) , "texture_types" ) ,
111
+ ( cfg ! ( feature = "surface_types" ) , "surface_types" ) ,
112
+ ( cfg ! ( feature = "cuComplex" ) , "cuComplex" ) ,
113
+ ] ;
114
+ for ( should_generate, pkg) in params {
115
+ if !should_generate {
116
+ continue ;
117
+ }
118
+ let bindgen_path = path:: PathBuf :: from ( format ! ( "{}/{}_sys.rs" , outdir. display( ) , pkg) ) ;
119
+ let header = sdk
120
+ . cuda_root ( )
121
+ . join ( format ! ( "include/{}.h" , pkg) )
122
+ . display ( )
123
+ . to_string ( ) ;
124
+ let bindings = bindgen:: Builder :: default ( )
125
+ . header ( & header)
126
+ . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
127
+ . clang_args (
128
+ sdk. cuda_include_paths ( )
129
+ . iter ( )
130
+ . map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
131
+ )
132
+ . allowlist_file ( format ! ( r".*{pkg}\.h" ) )
133
+ . allowlist_recursively ( false )
134
+ . default_enum_style ( bindgen:: EnumVariation :: Rust {
135
+ non_exhaustive : false ,
136
+ } )
137
+ . derive_default ( true )
138
+ . derive_eq ( true )
139
+ . derive_hash ( true )
140
+ . derive_ord ( true )
141
+ . size_t_is_usize ( true )
142
+ . layout_tests ( true )
143
+ . generate ( )
144
+ . unwrap_or_else ( |e| panic ! ( "Unable to generate {pkg} bindings: {e}" ) ) ;
145
+ bindings
146
+ . write_to_file ( bindgen_path. as_path ( ) )
147
+ . unwrap_or_else ( |e| panic ! ( "Cannot write {pkg} bindgen output to file: {e}" ) ) ;
148
+ }
149
+ }
150
+
151
+ fn create_driver_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
105
152
if !cfg ! ( feature = "driver" ) {
106
153
return ;
107
154
}
@@ -121,13 +168,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
121
168
. iter ( )
122
169
. map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
123
170
)
124
- . allowlist_type ( "^CU.*" )
125
- . allowlist_type ( "^cuuint(32|64)_t" )
126
- . allowlist_type ( "^cudaError_enum" )
127
- . allowlist_type ( "^cu.*Complex$" )
128
- . allowlist_type ( "^cuda.*" )
129
- . allowlist_var ( "^CU.*" )
130
- . allowlist_function ( "^cu.*" )
171
+ . allowlist_file ( r".*cuda[^/\\]*\.h" )
131
172
. default_enum_style ( bindgen:: EnumVariation :: Rust {
132
173
non_exhaustive : false ,
133
174
} )
@@ -145,7 +186,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
145
186
. expect ( "Cannot write CUDA driver bindgen output to file." ) ;
146
187
}
147
188
148
- fn create_cuda_runtime_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
189
+ fn create_runtime_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
149
190
if !cfg ! ( feature = "runtime" ) {
150
191
return ;
151
192
}
@@ -165,14 +206,13 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
165
206
. iter ( )
166
207
. map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
167
208
)
168
- . allowlist_type ( "^CU.*" )
169
- . allowlist_type ( "^cuda.*" )
170
- . allowlist_type ( "^libraryPropertyType.*" )
171
- . allowlist_var ( "^CU.*" )
172
- . allowlist_function ( "^cu.*" )
209
+ . allowlist_file ( r".*cuda[^/\\]*\.h" )
210
+ . allowlist_file ( r".*cuComplex\.h" )
211
+ . allowlist_recursively ( false )
173
212
. default_enum_style ( bindgen:: EnumVariation :: Rust {
174
213
non_exhaustive : false ,
175
214
} )
215
+ . disable_nested_struct_naming ( )
176
216
. derive_default ( true )
177
217
. derive_eq ( true )
178
218
. derive_hash ( true )
@@ -188,19 +228,51 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
188
228
}
189
229
190
230
fn create_cublas_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
191
- #[ rustfmt:: skip]
192
231
let params = & [
193
- ( cfg ! ( feature = "cublas" ) , "cublas" , "^cublas.*" , "^CUBLAS.*" ) ,
194
- ( cfg ! ( feature = "cublaslt" ) , "cublasLt" , "^cublasLt.*" , "^CUBLASLT.*" ) ,
195
- ( cfg ! ( feature = "cublasxt" ) , "cublasXt" , "^cublasXt.*" , "^CUBLASXT.*" ) ,
232
+ (
233
+ cfg ! ( feature = "cublas" ) ,
234
+ "cublas" ,
235
+ vec ! [ r".*cublas(_api|_v2)\.h" ] ,
236
+ vec ! [
237
+ r".*cuComplex\.h" ,
238
+ r".*driver_types\.h" ,
239
+ r".*library_types\.h" ,
240
+ r".*vector_types\.h" ,
241
+ ] ,
242
+ ) ,
243
+ (
244
+ cfg ! ( feature = "cublasLt" ) ,
245
+ "cublasLt" ,
246
+ vec ! [ r".*cublasLt\.h" ] ,
247
+ vec ! [
248
+ r".*cublas(_api|_v2)*\.h" ,
249
+ r".*cuComplex\.h" ,
250
+ r".*driver_types\.h" ,
251
+ r".*library_types\.h" ,
252
+ r".*vector_types\.h" ,
253
+ r".*std\w+\.h" ,
254
+ ] ,
255
+ ) ,
256
+ (
257
+ cfg ! ( feature = "cublasXt" ) ,
258
+ "cublasXt" ,
259
+ vec ! [ r".*cublasXt\.h" ] ,
260
+ vec ! [
261
+ r".*cublas(_api|_v2)*\.h" ,
262
+ r".*cuComplex\.h" ,
263
+ r".*driver_types\.h" ,
264
+ r".*library_types\.h" ,
265
+ r".*vector_types\.h" ,
266
+ ] ,
267
+ ) ,
196
268
] ;
197
- for ( should_generate, pkg, tf , var ) in params {
269
+ for ( should_generate, pkg, allowed , blocked ) in params {
198
270
if !should_generate {
199
271
continue ;
200
272
}
201
273
let bindgen_path = path:: PathBuf :: from ( format ! ( "{}/{pkg}_sys.rs" , outdir. display( ) ) ) ;
202
274
let header = format ! ( "build/{pkg}_wrapper.h" ) ;
203
- let bindings = bindgen:: Builder :: default ( )
275
+ let mut bindings = bindgen:: Builder :: default ( )
204
276
. header ( & header)
205
277
. parse_callbacks ( Box :: new ( callbacks:: FunctionRenames :: new (
206
278
pkg,
@@ -214,9 +286,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
214
286
. iter ( )
215
287
. map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
216
288
)
217
- . allowlist_type ( tf)
218
- . allowlist_function ( tf)
219
- . allowlist_var ( var)
289
+ . allowlist_recursively ( false ) ;
290
+
291
+ for file in allowed {
292
+ bindings = bindings. allowlist_file ( file) ;
293
+ }
294
+ for file in blocked {
295
+ bindings = bindings. blocklist_file ( file) ;
296
+ }
297
+
298
+ let bindings = bindings
220
299
. default_enum_style ( bindgen:: EnumVariation :: Rust {
221
300
non_exhaustive : false ,
222
301
} )
0 commit comments