Skip to content

Commit 7b4151b

Browse files
committed
fix: 更新 cuda driver,改正异常和警告
Signed-off-by: YdrMaster <[email protected]>
1 parent 3c7d1da commit 7b4151b

File tree

4 files changed

+14
-41
lines changed

4 files changed

+14
-41
lines changed

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ infini-op = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e83
1212
infini-ccl = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" }
1313
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-toolkit", rev = "e8362c3" }
1414

15-
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a815e41" }
16-
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a815e41" }
17-
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a815e41" }
18-
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a815e41" }
19-
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a815e41" }
15+
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "db4692c" }
16+
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "db4692c" }
17+
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "db4692c" }
18+
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "db4692c" }
19+
search-corex-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "db4692c" }

operators/src/all_reduce/nccl.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::{AllReduce, Args, ReduceOp, args::Meta};
22
use crate::{
3-
ByteOf, LaunchError, LaunchError, QueueAlloc,
3+
ByteOf, LaunchError, QueueAlloc,
44
cuda::{Gpu, NcclNode},
55
rearrange,
66
};
@@ -26,14 +26,6 @@ impl crate::Operator for Operator {
2626
}
2727
}
2828

29-
fn scheme(
30-
&mut self,
31-
_args: &Self::Args,
32-
_max_workspace_size: usize,
33-
) -> Result<usize, LaunchError> {
34-
Ok(0)
35-
}
36-
3729
fn launch<QA>(
3830
&self,
3931
args: &Self::Args,

operators/src/broadcast/nccl/mod.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::{Args, Broadcast, args::Meta};
22
use crate::{
3-
ByteOf, LaunchError, LaunchError, QueueAlloc,
3+
ByteOf, LaunchError, QueueAlloc,
44
cuda::{Gpu, NcclNode},
55
rearrange,
66
};
@@ -26,14 +26,6 @@ impl crate::Operator for Operator {
2626
}
2727
}
2828

29-
fn scheme(
30-
&mut self,
31-
_args: &Self::Args,
32-
_max_workspace_size: usize,
33-
) -> Result<usize, LaunchError> {
34-
Ok(0)
35-
}
36-
3729
fn launch<QA>(
3830
&self,
3931
args: &Self::Args,

operators/src/rearrange/cuda/mod.rs

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use crate::{
77
use itertools::Itertools;
88
use lru::LruCache;
99
use std::cmp::max;
10-
use std::iter::repeat;
1110
use std::slice::{from_raw_parts, from_raw_parts_mut};
1211
use std::{
1312
ffi::CString,
@@ -71,21 +70,11 @@ impl ArrayStruct {
7170
}
7271

7372
fn try_into_array<const N: usize>(self) -> Result<[ArrayType; N], LaunchError> {
74-
if self.0.len() > N {
75-
Err(rank_not_support("ArrayStruct::try_into_array"))
73+
let ArrayStruct(vec) = self;
74+
if vec.len() <= N {
75+
Ok(std::array::from_fn(|i| vec.get(i).copied().unwrap_or(0)))
7676
} else {
77-
let ArrayStruct(vec) = self;
78-
if vec.len() == N {
79-
Ok(vec.try_into().unwrap())
80-
} else {
81-
let vec_len = vec.len();
82-
Ok(vec
83-
.into_iter()
84-
.chain(repeat(0).take(N - vec_len))
85-
.collect::<Vec<_>>()
86-
.try_into()
87-
.unwrap())
88-
}
77+
Err(rank_not_support("over length"))
8978
}
9079
}
9180
}
@@ -158,7 +147,7 @@ impl crate::Operator for Operator {
158147
.collect::<Vec<_>>();
159148

160149
//分离维度,分成grid处理的维度和block处理的维度,与dst的维度相对应
161-
let mut block_dim_choose = repeat(false).take(ndim).collect::<Vec<_>>();
150+
let mut block_dim_choose = vec![false; ndim];
162151

163152
//TODO 需要优化
164153
let max_block_size = 256;
@@ -596,9 +585,9 @@ mod test {
596585
let dt = ty::U64;
597586

598587
let cpu_op = RefOp::new(&Cpu);
599-
let gpu_op = Operator::new(&gpu);
588+
let gpu_op = Operator::new(gpu);
600589

601-
let mut r_shape: [usize; N] = shape.clone();
590+
let mut r_shape = shape;
602591
r_shape[0..TRANS_N].reverse();
603592

604593
let trans_param: [usize; TRANS_N] =

0 commit comments

Comments
 (0)