Skip to content

Commit e44cff1

Browse files
committed
feat: 添加mla_cache
1 parent 6989b92 commit e44cff1

File tree

9 files changed

+320
-7
lines changed

9 files changed

+320
-7
lines changed

operators/src/attention_mla/operator.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::{args::Meta, Args, AttentionMLA};
22
use crate::{
3-
dyn_, fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
4-
SchemeError, TensorLayout, Workspace, WorkspaceCollector,
3+
fuesd_softmax, get_static, mat_mul, rearrange, ByteOf, Hardware, LaunchError, QueueAlloc,
4+
SchemeError, TensorLayout, Workspace,
55
};
66
use ndarray_layout::ArrayLayout;
77
use std::marker::PhantomData;
@@ -94,16 +94,13 @@ where
9494
let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else {
9595
unreachable!()
9696
};
97-
let &[nh_so, seq_so, dv_so] = o_layout.strides() else {
98-
unreachable!()
99-
};
97+
10098
let ele = dt.nbytes();
10199
get_static! {
102100
nh seq dkv dr
103101
nh_skv att_skv dkv_skv
104102
nh_skr att_skr dr_skr
105103
nh_sa dv_sa dkv_sa
106-
nh_so seq_so dv_so
107104
dv att
108105
};
109106

@@ -141,7 +138,7 @@ where
141138
workspace,
142139
queue_alloc,
143140
)?;
144-
141+
145142
self.mat_mul.launch(
146143
&mat_mul::Args {
147144
c_layout: att_w_layout.clone(),
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
use crate::{
2+
fuesd_softmax::AttnMask,
3+
utils::{dim_distinct, rank_error, type_distinct},
4+
ConstPtr, Hardware, MaybeDyn, MutPtr, SchemeError, TensorLayout,
5+
};
6+
use digit_layout::DigitLayout;
7+
8+
pub struct Args<H: Hardware> {
9+
// q传入的是是吸收后的
10+
pub q_layout: TensorLayout,
11+
pub q_base: MutPtr<H>,
12+
13+
pub kv_layout: TensorLayout,
14+
pub kv_base: ConstPtr<H>,
15+
16+
pub absorb_layout: TensorLayout,
17+
pub absorb_base: ConstPtr<H>,
18+
19+
pub qr_layout: TensorLayout,
20+
pub qr_base: ConstPtr<H>,
21+
22+
pub kr_layout: TensorLayout,
23+
pub kr_base: ConstPtr<H>,
24+
25+
pub o_layout: TensorLayout,
26+
pub o_base: MutPtr<H>,
27+
pub kv_cache_layout: TensorLayout,
28+
pub kv_cache_base: MutPtr<H>,
29+
30+
pub kr_cache_layout: TensorLayout,
31+
pub kr_cache_base: MutPtr<H>,
32+
33+
pub mask: AttnMask,
34+
pub pos: MaybeDyn<usize>,
35+
}
36+
37+
pub(super) struct Meta {
38+
pub dt: DigitLayout,
39+
pub nh: MaybeDyn<usize>,
40+
pub seq: MaybeDyn<usize>,
41+
pub att: MaybeDyn<usize>,
42+
pub dkv: MaybeDyn<usize>,
43+
pub dv: MaybeDyn<usize>,
44+
pub dr: MaybeDyn<usize>,
45+
}
46+
47+
impl<H: Hardware> Args<H> {
48+
pub(super) fn meta(&self) -> Result<Meta, SchemeError> {
49+
let Self {
50+
q_layout,
51+
kv_layout,
52+
absorb_layout,
53+
qr_layout,
54+
kr_layout,
55+
o_layout,
56+
kv_cache_layout,
57+
kr_cache_layout,
58+
..
59+
} = self;
60+
61+
let &[nh_q, seq_q, dkv_q] = q_layout.shape() else {
62+
return Err(rank_error("q", 3, q_layout.ndim()));
63+
};
64+
65+
let &[nh_kv, attn_kv, dkv_kv] = kv_layout.shape() else {
66+
return Err(rank_error("kv", 3, kv_layout.ndim()));
67+
};
68+
let &[nh_a, dv_a, dkv_a] = absorb_layout.shape() else {
69+
return Err(rank_error("absorb", 3, absorb_layout.ndim()));
70+
};
71+
let &[nh_qr, seq_qr, dr_qr] = qr_layout.shape() else {
72+
return Err(rank_error("qr", 3, qr_layout.ndim()));
73+
};
74+
let &[nh_kr, att_kr, dr_kr] = kr_layout.shape() else {
75+
return Err(rank_error("kr", 3, kr_layout.ndim()));
76+
};
77+
let &[nh_o, seq_o, dv_o] = o_layout.shape() else {
78+
return Err(rank_error("o", 3, o_layout.ndim()));
79+
};
80+
let &[nh_kvc, _buf, dkv_kvc] = kv_cache_layout.shape() else {
81+
return Err(rank_error("k_cache", 3, kv_cache_layout.ndim()));
82+
};
83+
let &[nh_krc, _buf, dr_krc] = kr_cache_layout.shape() else {
84+
return Err(rank_error("v_cache", 3, kr_cache_layout.ndim()));
85+
};
86+
87+
Ok(Meta {
88+
dt: type_distinct(&[
89+
q_layout.dt(),
90+
kv_layout.dt(),
91+
qr_layout.dt(),
92+
kr_layout.dt(),
93+
o_layout.dt(),
94+
])?,
95+
nh: dim_distinct(&[nh_q, nh_kv, nh_a, nh_qr, nh_kr, nh_o, nh_krc, nh_kvc])?,
96+
seq: dim_distinct(&[seq_q, seq_o, seq_qr])?,
97+
att: dim_distinct(&[attn_kv, att_kr])?,
98+
dkv: dim_distinct(&[dkv_a, dkv_kv, dkv_q, dkv_kvc])?,
99+
dv: dim_distinct(&[dv_a, dv_o])?,
100+
dr: dim_distinct(&[dr_kr, dr_qr, dr_krc])?,
101+
})
102+
}
103+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(common_cpu, Cpu);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(cuda, Gpu);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(infini, Device);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
mod args;
2+
mod operator;
3+
4+
pub use args::Args;
5+
6+
crate::op_trait!(AttentionMLACached);
7+
8+
macro_rules! impl_op {
9+
($dev:ident, $proc:ident) => {
10+
pub type Operator = super::operator::Operator<
11+
crate::$dev::$proc,
12+
crate::rearrange::$dev::Operator,
13+
crate::attention::$dev::Operator,
14+
>;
15+
};
16+
}
17+
18+
#[cfg(any(use_cpu, test))]
19+
pub mod common_cpu;
20+
#[cfg(use_cuda)]
21+
pub mod cuda;
22+
#[cfg(use_infini)]
23+
pub mod infini;
24+
#[cfg(use_cl)]
25+
pub mod opencl;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
impl_op!(opencl, ClDevice);
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
use crate::attention_mla_cached::args::Meta;
2+
use crate::attention_mla_cached::{Args, AttentionMLACached};
3+
use crate::{
4+
attention_mla, get_static, rearrange, shape_mismatch, ByteOf, Hardware, LaunchError,
5+
QueueAlloc, SchemeError, TensorLayout,
6+
};
7+
use ndarray_layout::ArrayLayout;
8+
use std::marker::PhantomData;
9+
10+
pub struct Operator<Hardware, Rearrange, Attention> {
11+
rearrange: Rearrange,
12+
attention: Attention,
13+
_phantom: PhantomData<Hardware>,
14+
}
15+
16+
impl<H, R, A> AttentionMLACached<H> for Operator<H, R, A>
17+
where
18+
H: Hardware,
19+
R: rearrange::Rearrange<H>,
20+
A: attention_mla::AttentionMLA<H>,
21+
{
22+
}
23+
24+
impl<H, R, A> crate::Operator for Operator<H, R, A>
25+
where
26+
H: Hardware,
27+
R: rearrange::Rearrange<H>,
28+
A: attention_mla::AttentionMLA<H>,
29+
{
30+
type Hardware = H;
31+
type TopoNode = H;
32+
type Args = crate::attention_mla_cached::Args<H>;
33+
fn new(node: &Self::TopoNode) -> Self {
34+
Self {
35+
rearrange: R::new(node),
36+
attention: A::new(node),
37+
_phantom: PhantomData,
38+
}
39+
}
40+
41+
fn scheme(
42+
&mut self,
43+
args: &Self::Args,
44+
max_workspace_size: usize,
45+
) -> Result<usize, SchemeError> {
46+
// TODO
47+
Ok(0)
48+
}
49+
50+
fn launch<QA>(
51+
&self,
52+
args: &Self::Args,
53+
workspace: &mut [ByteOf<Self::Hardware>],
54+
queue_alloc: &QA,
55+
) -> Result<(), LaunchError>
56+
where
57+
QA: QueueAlloc<Hardware = Self::Hardware>,
58+
{
59+
let Meta {
60+
dt,
61+
nh,
62+
seq,
63+
att,
64+
dkv,
65+
dv,
66+
dr,
67+
} = args.meta()?;
68+
let Args {
69+
q_layout,
70+
q_base,
71+
kv_layout,
72+
kv_base,
73+
absorb_layout,
74+
absorb_base,
75+
qr_layout,
76+
qr_base,
77+
kr_layout,
78+
kr_base,
79+
o_layout,
80+
o_base,
81+
kv_cache_layout,
82+
kv_cache_base,
83+
kr_cache_layout,
84+
kr_cache_base,
85+
mask,
86+
pos,
87+
} = args;
88+
let &[nh_skv, att_skv, dkv_skv] = kv_layout.strides() else {
89+
unreachable!()
90+
};
91+
let &[nh_skr, att_skr, dr_skr] = kr_layout.strides() else {
92+
unreachable!()
93+
};
94+
let &[nh_sa, dv_sa, dkv_sa] = absorb_layout.strides() else {
95+
unreachable!()
96+
};
97+
98+
let &[_, buf_kv, _] = kv_cache_layout.shape() else {
99+
unreachable!()
100+
};
101+
let &[_, buf_kr, _] = kr_cache_layout.shape() else {
102+
unreachable!()
103+
};
104+
let &[nh_skvc, buf_skvc, dh_skvc] = kv_cache_layout.strides() else {
105+
unreachable!()
106+
};
107+
let &[nh_skrc, buf_skrc, dh_skrc] = kr_cache_layout.strides() else {
108+
unreachable!()
109+
};
110+
let ele = dt.nbytes();
111+
get_static! {
112+
nh seq dkv dr
113+
pos
114+
buf_kv buf_kr
115+
nh_skvc buf_skvc dh_skvc
116+
nh_skrc buf_skrc dh_skrc
117+
118+
};
119+
120+
// 检查 cache 容量
121+
let att = pos + seq;
122+
if buf_kr < att || buf_kv < att {
123+
return Err(shape_mismatch("Out of cache buffer").into());
124+
}
125+
// 连接 kv cache
126+
#[inline(always)]
127+
fn layout(shape: [usize; 3], strides: [isize; 3]) -> ArrayLayout<3> {
128+
ArrayLayout::new(&shape, &strides, 0)
129+
}
130+
131+
let kvc_layout = layout([nh, buf_kv, dkv], [nh_skvc, buf_skvc, dh_skvc]);
132+
let krc_layout = layout([nh, buf_kr, dr], [nh_skrc, buf_skrc, dh_skrc]);
133+
134+
let kv_cat = kvc_layout.slice(1, pos, 1, seq);
135+
let kr_cat = krc_layout.slice(1, pos, 1, seq);
136+
137+
self.rearrange.launch(
138+
&rearrange::Args {
139+
dst_layout: TensorLayout::new(dt, kv_cat.shape(), kv_cat.strides()),
140+
dst_base: unsafe { kv_cache_base.byte_add(kv_cat.offset() as _) },
141+
src_layout: kv_layout.clone(),
142+
src_base: *kv_base,
143+
},
144+
workspace,
145+
queue_alloc,
146+
)?;
147+
self.rearrange.launch(
148+
&rearrange::Args {
149+
dst_layout: TensorLayout::new(dt, kr_cat.shape(), kr_cat.strides()),
150+
dst_base: unsafe { kr_cache_base.byte_add(kr_cat.offset() as _) },
151+
src_layout: kr_layout.clone(),
152+
src_base: *kr_base,
153+
},
154+
workspace,
155+
queue_alloc,
156+
)?;
157+
// attention
158+
let kv_layout = kvc_layout.slice(1, 0, 1, att);
159+
let kr_layout = krc_layout.slice(1, 0, 1, att);
160+
assert_eq!(kv_layout.offset(), 0);
161+
assert_eq!(kr_layout.offset(), 0);
162+
self.attention.launch(
163+
&attention_mla::Args {
164+
mask: *mask,
165+
q_layout: q_layout.clone(),
166+
q_base: *q_base,
167+
kv_layout: TensorLayout::new(dt, kv_layout.shape(), kv_layout.strides()),
168+
kv_base: *kv_cache_base,
169+
kr_layout: TensorLayout::new(dt, kr_layout.shape(), kr_layout.strides()),
170+
kr_base: *kr_cache_base,
171+
absorb_layout: absorb_layout.clone(),
172+
absorb_base: *absorb_base,
173+
qr_layout: qr_layout.clone(),
174+
qr_base: *qr_base,
175+
o_layout: o_layout.clone(),
176+
o_base: *o_base,
177+
},
178+
workspace,
179+
queue_alloc,
180+
)?;
181+
Ok(())
182+
}
183+
}

operators/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod all_reduce;
99
pub mod attention;
1010
pub mod attention_kv_cached;
1111
pub mod attention_mla;
12+
pub mod attention_mla_cached;
1213
pub mod broadcast;
1314
pub mod conv;
1415
pub mod fuesd_softmax;

0 commit comments

Comments
 (0)