Skip to content

Commit 608a845

Browse files
committed
reactor(cust_raw): Generate bindings for cuda types and restructure cust_raw's crates
- Allow list files instead of types/var/functions. - Split out type headers from runtime/cublas to their own crates. - Drop sys prefix from internal crates.
1 parent b0a4b7d commit 608a845

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+629
-511
lines changed

crates/blastoff/src/context.rs

+24-26
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use std::os::raw::c_char;
44
use std::ptr;
55

66
use cust::stream::Stream;
7-
use cust_raw::cublas_sys;
8-
use cust_raw::driver_sys;
7+
use cust_raw::cublas;
8+
use cust_raw::driver;
99

1010
use super::error::DropResult;
1111
use super::error::ToResult as _;
@@ -73,7 +73,7 @@ bitflags::bitflags! {
7373
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
7474
#[derive(Debug)]
7575
pub struct CublasContext {
76-
pub(crate) raw: cublas_sys::cublasHandle_t,
76+
pub(crate) raw: cublas::cublasHandle_t,
7777
}
7878

7979
impl CublasContext {
@@ -92,10 +92,10 @@ impl CublasContext {
9292
pub fn new() -> Result<Self> {
9393
let mut raw = MaybeUninit::uninit();
9494
unsafe {
95-
cublas_sys::cublasCreate(raw.as_mut_ptr()).to_result()?;
96-
cublas_sys::cublasSetPointerMode(
95+
cublas::cublasCreate(raw.as_mut_ptr()).to_result()?;
96+
cublas::cublasSetPointerMode(
9797
raw.assume_init(),
98-
cublas_sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
98+
cublas::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
9999
)
100100
.to_result()?;
101101
Ok(Self {
@@ -112,7 +112,7 @@ impl CublasContext {
112112

113113
unsafe {
114114
let inner = mem::replace(&mut ctx.raw, ptr::null_mut());
115-
match cublas_sys::cublasDestroy(inner).to_result() {
115+
match cublas::cublasDestroy(inner).to_result() {
116116
Ok(()) => {
117117
mem::forget(ctx);
118118
Ok(())
@@ -127,7 +127,7 @@ impl CublasContext {
127127
let mut raw = MaybeUninit::<u32>::uninit();
128128
unsafe {
129129
// getVersion can't fail
130-
cublas_sys::cublasGetVersion(self.raw, raw.as_mut_ptr().cast())
130+
cublas::cublasGetVersion(self.raw, raw.as_mut_ptr().cast())
131131
.to_result()
132132
.unwrap();
133133

@@ -145,17 +145,15 @@ impl CublasContext {
145145
) -> Result<T> {
146146
unsafe {
147147
// cudaStream_t is the same as CUstream
148-
cublas_sys::cublasSetStream(
148+
cublas::cublasSetStream(
149149
self.raw,
150-
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
151-
stream.as_inner(),
152-
),
150+
mem::transmute::<driver::CUstream, cublas::cudaStream_t>(stream.as_inner()),
153151
)
154152
.to_result()?;
155153
let res = func(self)?;
156154
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
157155
// execute a raw sys function with the context's handle.
158-
cublas_sys::cublasSetStream(self.raw, ptr::null_mut()).to_result()?;
156+
cublas::cublasSetStream(self.raw, ptr::null_mut()).to_result()?;
159157
Ok(res)
160158
}
161159
}
@@ -185,12 +183,12 @@ impl CublasContext {
185183
/// ```
186184
pub fn set_atomics_mode(&self, allowed: bool) -> Result<()> {
187185
unsafe {
188-
Ok(cublas_sys::cublasSetAtomicsMode(
186+
Ok(cublas::cublasSetAtomicsMode(
189187
self.raw,
190188
if allowed {
191-
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
189+
cublas::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED
192190
} else {
193-
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
191+
cublas::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED
194192
},
195193
)
196194
.to_result()?)
@@ -215,10 +213,10 @@ impl CublasContext {
215213
pub fn get_atomics_mode(&self) -> Result<bool> {
216214
let mut mode = MaybeUninit::uninit();
217215
unsafe {
218-
cublas_sys::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
216+
cublas::cublasGetAtomicsMode(self.raw, mode.as_mut_ptr()).to_result()?;
219217
Ok(match mode.assume_init() {
220-
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
221-
cublas_sys::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
218+
cublas::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => true,
219+
cublas::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => false,
222220
})
223221
}
224222
}
@@ -238,9 +236,9 @@ impl CublasContext {
238236
/// ```
239237
pub fn set_math_mode(&self, math_mode: MathMode) -> Result<()> {
240238
unsafe {
241-
Ok(cublas_sys::cublasSetMathMode(
239+
Ok(cublas::cublasSetMathMode(
242240
self.raw,
243-
mem::transmute::<u32, cublas_sys::cublasMath_t>(math_mode.bits()),
241+
mem::transmute::<u32, cublas::cublasMath_t>(math_mode.bits()),
244242
)
245243
.to_result()?)
246244
}
@@ -263,7 +261,7 @@ impl CublasContext {
263261
pub fn get_math_mode(&self) -> Result<MathMode> {
264262
let mut mode = MaybeUninit::uninit();
265263
unsafe {
266-
cublas_sys::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
264+
cublas::cublasGetMathMode(self.raw, mode.as_mut_ptr()).to_result()?;
267265
Ok(MathMode::from_bits(mode.assume_init() as u32)
268266
.expect("Invalid MathMode from cuBLAS"))
269267
}
@@ -303,7 +301,7 @@ impl CublasContext {
303301
let path = log_file_name.map(|p| CString::new(p).expect("nul in log_file_name"));
304302
let path_ptr = path.map_or(ptr::null(), |s| s.as_ptr());
305303

306-
cublas_sys::cublasLoggerConfigure(
304+
cublas::cublasLoggerConfigure(
307305
enable as i32,
308306
log_to_stdout as i32,
309307
log_to_stderr as i32,
@@ -320,7 +318,7 @@ impl CublasContext {
320318
///
321319
/// The callback must not panic and unwind.
322320
pub unsafe fn set_logger_callback(callback: Option<unsafe extern "C" fn(*const c_char)>) {
323-
cublas_sys::cublasSetLoggerCallback(callback)
321+
cublas::cublasSetLoggerCallback(callback)
324322
.to_result()
325323
.unwrap();
326324
}
@@ -329,7 +327,7 @@ impl CublasContext {
329327
pub fn get_logger_callback() -> Option<unsafe extern "C" fn(*const c_char)> {
330328
let mut cb = MaybeUninit::uninit();
331329
unsafe {
332-
cublas_sys::cublasGetLoggerCallback(cb.as_mut_ptr())
330+
cublas::cublasGetLoggerCallback(cb.as_mut_ptr())
333331
.to_result()
334332
.unwrap();
335333
cb.assume_init()
@@ -340,7 +338,7 @@ impl CublasContext {
340338
impl Drop for CublasContext {
341339
fn drop(&mut self) {
342340
unsafe {
343-
let _ = cublas_sys::cublasDestroy(self.raw);
341+
let _ = cublas::cublasDestroy(self.raw);
344342
}
345343
}
346344
}

crates/blastoff/src/error.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::{ffi::CStr, fmt::Display};
22

33
use cust::error::CudaError;
4-
use cust_raw::cublas_sys;
4+
use cust_raw::cublas;
55

66
/// Result that contains the un-dropped value on error.
77
pub type DropResult<T> = std::result::Result<(), (CublasError, T)>;
@@ -25,7 +25,7 @@ impl std::error::Error for CublasError {}
2525
impl Display for CublasError {
2626
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2727
unsafe {
28-
let ptr = cublas_sys::cublasGetStatusString(self.into_raw());
28+
let ptr = cublas::cublasGetStatusString(self.into_raw());
2929
let cow = CStr::from_ptr(ptr).to_string_lossy();
3030
f.write_str(cow.as_ref())
3131
}
@@ -36,9 +36,9 @@ pub trait ToResult {
3636
fn to_result(self) -> Result<(), CublasError>;
3737
}
3838

39-
impl ToResult for cublas_sys::cublasStatus_t {
39+
impl ToResult for cublas::cublasStatus_t {
4040
fn to_result(self) -> Result<(), CublasError> {
41-
use cust_raw::cublas_sys::cublasStatus_t::*;
41+
use cust_raw::cublas::cublasStatus_t::*;
4242
use CublasError::*;
4343

4444
Err(match self {
@@ -57,8 +57,8 @@ impl ToResult for cublas_sys::cublasStatus_t {
5757
}
5858

5959
impl CublasError {
60-
pub fn into_raw(self) -> cublas_sys::cublasStatus_t {
61-
use cust_raw::cublas_sys::cublasStatus_t::*;
60+
pub fn into_raw(self) -> cublas::cublasStatus_t {
61+
use cust_raw::cublas::cublasStatus_t::*;
6262
use CublasError::*;
6363

6464
match self {

crates/blastoff/src/lib.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#![allow(clippy::too_many_arguments)]
1111
#![cfg_attr(docsrs, feature(doc_cfg))]
1212

13-
pub use cust_raw::cublas_sys;
13+
pub use cust_raw::cublas;
1414
use num_complex::{Complex32, Complex64};
1515

1616
pub use context::*;
@@ -39,34 +39,34 @@ pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
3939
/// The corresponding float type. For complex numbers this means their backing
4040
/// precision, and for floats it is just themselves.
4141
type FloatTy: Float;
42-
fn to_raw(&self) -> cublas_sys::cudaDataType;
42+
fn to_raw(&self) -> cublas::cudaDataType;
4343
}
4444

4545
impl BlasDatatype for f32 {
4646
type FloatTy = f32;
47-
fn to_raw(&self) -> cublas_sys::cudaDataType {
48-
cublas_sys::cudaDataType::CUDA_R_32F
47+
fn to_raw(&self) -> cublas::cudaDataType {
48+
cublas::cudaDataType::CUDA_R_32F
4949
}
5050
}
5151

5252
impl BlasDatatype for f64 {
5353
type FloatTy = f64;
54-
fn to_raw(&self) -> cublas_sys::cudaDataType {
55-
cublas_sys::cudaDataType::CUDA_R_64F
54+
fn to_raw(&self) -> cublas::cudaDataType {
55+
cublas::cudaDataType::CUDA_R_64F
5656
}
5757
}
5858

5959
impl BlasDatatype for Complex32 {
6060
type FloatTy = f32;
61-
fn to_raw(&self) -> cublas_sys::cudaDataType {
62-
cublas_sys::cudaDataType::CUDA_C_32F
61+
fn to_raw(&self) -> cublas::cudaDataType {
62+
cublas::cudaDataType::CUDA_C_32F
6363
}
6464
}
6565

6666
impl BlasDatatype for Complex64 {
6767
type FloatTy = f64;
68-
fn to_raw(&self) -> cublas_sys::cudaDataType {
69-
cublas_sys::cudaDataType::CUDA_C_64F
68+
fn to_raw(&self) -> cublas::cudaDataType {
69+
cublas::cudaDataType::CUDA_C_64F
7070
}
7171
}
7272

@@ -106,11 +106,11 @@ pub enum MatrixOp {
106106

107107
impl MatrixOp {
108108
/// Returns the corresponding `cublasOperation_t` for this operation.
109-
pub fn to_raw(self) -> cublas_sys::cublasOperation_t {
109+
pub fn to_raw(self) -> cublas::cublasOperation_t {
110110
match self {
111-
MatrixOp::None => cublas_sys::cublasOperation_t::CUBLAS_OP_N,
112-
MatrixOp::Transpose => cublas_sys::cublasOperation_t::CUBLAS_OP_T,
113-
MatrixOp::ConjugateTranspose => cublas_sys::cublasOperation_t::CUBLAS_OP_C,
111+
MatrixOp::None => cublas::cublasOperation_t::CUBLAS_OP_N,
112+
MatrixOp::Transpose => cublas::cublasOperation_t::CUBLAS_OP_T,
113+
MatrixOp::ConjugateTranspose => cublas::cublasOperation_t::CUBLAS_OP_C,
114114
}
115115
}
116116
}

crates/blastoff/src/raw/level1.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::os::raw::c_int;
22

3-
use cust_raw::cublas_sys::*;
3+
use cust_raw::cublas::*;
44
use num_complex::{Complex32, Complex64};
55

66
use crate::BlasDatatype;

crates/blastoff/src/raw/level3.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::os::raw::c_int;
22

3-
use cust_raw::cublas_sys::*;
3+
use cust_raw::cublas::*;
44
use num_complex::{Complex32, Complex64};
55

66
use crate::GemmDatatype;

0 commit comments

Comments
 (0)