@@ -4,8 +4,8 @@ use std::os::raw::c_char;
4
4
use std:: ptr;
5
5
6
6
use cust:: stream:: Stream ;
7
- use cust_raw:: cublas_sys ;
8
- use cust_raw:: driver_sys ;
7
+ use cust_raw:: cublas ;
8
+ use cust_raw:: driver ;
9
9
10
10
use super :: error:: DropResult ;
11
11
use super :: error:: ToResult as _;
@@ -73,7 +73,7 @@ bitflags::bitflags! {
73
73
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
74
74
#[ derive( Debug ) ]
75
75
pub struct CublasContext {
76
- pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
76
+ pub ( crate ) raw : cublas :: cublasHandle_t ,
77
77
}
78
78
79
79
impl CublasContext {
@@ -92,10 +92,10 @@ impl CublasContext {
92
92
pub fn new ( ) -> Result < Self > {
93
93
let mut raw = MaybeUninit :: uninit ( ) ;
94
94
unsafe {
95
- cublas_sys :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96
- cublas_sys :: cublasSetPointerMode (
95
+ cublas :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96
+ cublas :: cublasSetPointerMode (
97
97
raw. assume_init ( ) ,
98
- cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98
+ cublas :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
99
99
)
100
100
. to_result ( ) ?;
101
101
Ok ( Self {
@@ -112,7 +112,7 @@ impl CublasContext {
112
112
113
113
unsafe {
114
114
let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
115
- match cublas_sys :: cublasDestroy ( inner) . to_result ( ) {
115
+ match cublas :: cublasDestroy ( inner) . to_result ( ) {
116
116
Ok ( ( ) ) => {
117
117
mem:: forget ( ctx) ;
118
118
Ok ( ( ) )
@@ -127,7 +127,7 @@ impl CublasContext {
127
127
let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
128
128
unsafe {
129
129
// getVersion can't fail
130
- cublas_sys :: cublasGetVersion ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130
+ cublas :: cublasGetVersion ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
131
131
. to_result ( )
132
132
. unwrap ( ) ;
133
133
@@ -145,17 +145,15 @@ impl CublasContext {
145
145
) -> Result < T > {
146
146
unsafe {
147
147
// cudaStream_t is the same as CUstream
148
- cublas_sys :: cublasSetStream (
148
+ cublas :: cublasSetStream (
149
149
self . raw ,
150
- mem:: transmute :: < * mut driver_sys:: CUstream_st , * mut cublas_sys:: CUstream_st > (
151
- stream. as_inner ( ) ,
152
- ) ,
150
+ mem:: transmute :: < driver:: CUstream , cublas:: cudaStream_t > ( stream. as_inner ( ) ) ,
153
151
)
154
152
. to_result ( ) ?;
155
153
let res = func ( self ) ?;
156
154
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
157
155
// execute a raw sys function with the context's handle.
158
- cublas_sys :: cublasSetStream ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
156
+ cublas :: cublasSetStream ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
159
157
Ok ( res)
160
158
}
161
159
}
@@ -185,12 +183,12 @@ impl CublasContext {
185
183
/// ```
186
184
pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
187
185
unsafe {
188
- Ok ( cublas_sys :: cublasSetAtomicsMode (
186
+ Ok ( cublas :: cublasSetAtomicsMode (
189
187
self . raw ,
190
188
if allowed {
191
- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
189
+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
192
190
} else {
193
- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
191
+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
194
192
} ,
195
193
)
196
194
. to_result ( ) ?)
@@ -215,10 +213,10 @@ impl CublasContext {
215
213
pub fn get_atomics_mode ( & self ) -> Result < bool > {
216
214
let mut mode = MaybeUninit :: uninit ( ) ;
217
215
unsafe {
218
- cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
216
+ cublas :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
219
217
Ok ( match mode. assume_init ( ) {
220
- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221
- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
218
+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
219
+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
222
220
} )
223
221
}
224
222
}
@@ -238,9 +236,9 @@ impl CublasContext {
238
236
/// ```
239
237
pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
240
238
unsafe {
241
- Ok ( cublas_sys :: cublasSetMathMode (
239
+ Ok ( cublas :: cublasSetMathMode (
242
240
self . raw ,
243
- mem:: transmute :: < u32 , cublas_sys :: cublasMath_t > ( math_mode. bits ( ) ) ,
241
+ mem:: transmute :: < u32 , cublas :: cublasMath_t > ( math_mode. bits ( ) ) ,
244
242
)
245
243
. to_result ( ) ?)
246
244
}
@@ -263,7 +261,7 @@ impl CublasContext {
263
261
pub fn get_math_mode ( & self ) -> Result < MathMode > {
264
262
let mut mode = MaybeUninit :: uninit ( ) ;
265
263
unsafe {
266
- cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
264
+ cublas :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
267
265
Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
268
266
. expect ( "Invalid MathMode from cuBLAS" ) )
269
267
}
@@ -303,7 +301,7 @@ impl CublasContext {
303
301
let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
304
302
let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
305
303
306
- cublas_sys :: cublasLoggerConfigure (
304
+ cublas :: cublasLoggerConfigure (
307
305
enable as i32 ,
308
306
log_to_stdout as i32 ,
309
307
log_to_stderr as i32 ,
@@ -320,7 +318,7 @@ impl CublasContext {
320
318
///
321
319
/// The callback must not panic and unwind.
322
320
pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
323
- cublas_sys :: cublasSetLoggerCallback ( callback)
321
+ cublas :: cublasSetLoggerCallback ( callback)
324
322
. to_result ( )
325
323
. unwrap ( ) ;
326
324
}
@@ -329,7 +327,7 @@ impl CublasContext {
329
327
pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
330
328
let mut cb = MaybeUninit :: uninit ( ) ;
331
329
unsafe {
332
- cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
330
+ cublas :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
333
331
. to_result ( )
334
332
. unwrap ( ) ;
335
333
cb. assume_init ( )
@@ -340,7 +338,7 @@ impl CublasContext {
340
338
impl Drop for CublasContext {
341
339
fn drop ( & mut self ) {
342
340
unsafe {
343
- let _ = cublas_sys :: cublasDestroy ( self . raw ) ;
341
+ let _ = cublas :: cublasDestroy ( self . raw ) ;
344
342
}
345
343
}
346
344
}
0 commit comments