Skip to content
Draft
52 changes: 52 additions & 0 deletions components/segmenter/src/complex/lstm/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> {
}
}
}

#[inline]
pub(super) fn reshape<const M: usize>(&mut self, dims: [usize; M]) -> MatrixBorrowedMut<'_, M> {
MatrixBorrowedMut {
data: self.data,
dims,
}
}
}

impl<'a> MatrixBorrowed<'a, 1> {
Expand Down Expand Up @@ -317,12 +325,47 @@ impl<'a> MatrixBorrowedMut<'a, 1> {
}
}
}

/// Calculate the dot product of a and b, adding the result to self.
///
/// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N;
/// this is the opposite of standard practice.
pub(super) fn add_dot_2d_1(&mut self, a: MatrixZero<1>, b: MatrixZero<2>) {
let m = a.dim();
let n = self.as_borrowed().dim();
debug_assert_eq!(
m,
b.dim().1,
"dims: {:?}/{:?}/{:?}",
self.as_borrowed().dim(),
a.dim(),
b.dim()
);
debug_assert_eq!(
n,
b.dim().0,
"dims: {:?}/{:?}/{:?}",
self.as_borrowed().dim(),
a.dim(),
b.dim()
);
let lhs = a.as_slice();
for i in 0..n {
if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
{
*dest += unrolled_dot_2(lhs, b_sub.data);
} else {
debug_assert!(false, "unreachable: dims checked above");
}
}
}
}

impl<'a> MatrixBorrowedMut<'a, 2> {
/// Calculate the dot product of a and b, adding the result to self.
///
/// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
#[allow(dead_code)] // could be useful
pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) {
let m = a.dim();
let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
Expand Down Expand Up @@ -363,6 +406,7 @@ impl<'a> MatrixBorrowedMut<'a, 2> {
/// Calculate the dot product of a and b, adding the result to self.
///
/// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
#[allow(dead_code)] // could be useful
pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) {
let m = a.dim();
let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
Expand Down Expand Up @@ -475,6 +519,14 @@ impl<'a, const D: usize> MatrixZero<'a, D> {
let n = sub_dims.iter().product::<usize>();
(n * index..n * (index + 1), sub_dims)
}

#[inline]
pub(super) fn reshape<const M: usize>(self, dims: [usize; M]) -> MatrixZero<'a, M> {
MatrixZero {
data: self.data,
dims,
}
}
}

macro_rules! f32c {
Expand Down
45 changes: 30 additions & 15 deletions components/segmenter/src/complex/lstm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ impl Iterator for LstmSegmenterIteratorUtf16<'_> {
pub(super) struct LstmSegmenter<'l> {
dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>,
embedding: MatrixZero<'l, 2>,
fw_w: MatrixZero<'l, 3>,
fw_u: MatrixZero<'l, 3>,
fw_w: MatrixZero<'l, 2>,
fw_u: MatrixZero<'l, 2>,
fw_b: MatrixZero<'l, 2>,
bw_w: MatrixZero<'l, 3>,
bw_u: MatrixZero<'l, 3>,
bw_w: MatrixZero<'l, 2>,
bw_u: MatrixZero<'l, 2>,
bw_b: MatrixZero<'l, 2>,
timew_fw: MatrixZero<'l, 2>,
timew_bw: MatrixZero<'l, 2>,
Expand All @@ -71,19 +71,32 @@ impl<'l> LstmSegmenter<'l> {
/// Returns `Err` if grapheme data is required but not present
pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self {
let LstmDataV1::Float32(lstm) = lstm;
let fw_w = MatrixZero::from(&lstm.fw_w);
let fw_u = MatrixZero::from(&lstm.fw_u);
let bw_w = MatrixZero::from(&lstm.bw_w);
let bw_u = MatrixZero::from(&lstm.bw_u);
let time_w = MatrixZero::from(&lstm.time_w);

let hunits = fw_w.dim().1;
let embedd_dim = fw_w.dim().2;

let fw_w = fw_w.reshape([4 * hunits, embedd_dim]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: instead of a general reshape, define a collapse_1_2 to collapse the first two dimensions. Then you don't have to compute the sizes here.

let fw_u = fw_u.reshape([4 * hunits, hunits]);
let bw_w = bw_w.reshape([4 * hunits, embedd_dim]);
let bw_u = bw_u.reshape([4 * hunits, hunits]);

#[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
let timew_fw = time_w.submatrix(0).unwrap();
#[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
let timew_bw = time_w.submatrix(1).unwrap();
Self {
dic: lstm.dic.as_borrowed(),
embedding: MatrixZero::from(&lstm.embedding),
fw_w: MatrixZero::from(&lstm.fw_w),
fw_u: MatrixZero::from(&lstm.fw_u),
fw_w,
fw_u,
fw_b: MatrixZero::from(&lstm.fw_b),
bw_w: MatrixZero::from(&lstm.bw_w),
bw_u: MatrixZero::from(&lstm.bw_u),
bw_w,
bw_u,
bw_b: MatrixZero::from(&lstm.bw_b),
timew_fw,
timew_bw,
Expand Down Expand Up @@ -278,24 +291,26 @@ fn compute_hc<'a>(
x_t: MatrixZero<'a, 1>,
mut h_tm1: MatrixBorrowedMut<'a, 1>,
mut c_tm1: MatrixBorrowedMut<'a, 1>,
w: MatrixZero<'a, 3>,
u: MatrixZero<'a, 3>,
w: MatrixZero<'a, 2>,
u: MatrixZero<'a, 2>,
b: MatrixZero<'a, 2>,
) {
let hunits = h_tm1.dim();
#[cfg(debug_assertions)]
{
let hunits = h_tm1.dim();
let embedd_dim = x_t.dim();
c_tm1.as_borrowed().debug_assert_dims([hunits]);
w.debug_assert_dims([4, hunits, embedd_dim]);
u.debug_assert_dims([4, hunits, hunits]);
w.debug_assert_dims([4 * hunits, embedd_dim]);
u.debug_assert_dims([4 * hunits, hunits]);
b.debug_assert_dims([4, hunits]);
}

let mut s_t = b.to_owned();

s_t.as_mut().add_dot_3d_2(x_t, w);
s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u);
s_t.as_mut().reshape([4 * hunits]).add_dot_2d_1(x_t, w);
s_t.as_mut()
.reshape([4 * hunits])
.add_dot_2d(h_tm1.as_borrowed(), u);

#[allow(clippy::unwrap_used)] // first dimension is 4
s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform();
Expand Down