-
Notifications
You must be signed in to change notification settings - Fork 214
LSTM forward pass using matrix muls #3219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great from my end, but I'll defer final approval to Shane
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great that you got more perf out of this! Although I'm puzzled about where the perf is coming from.
sb.view(0..hunits), | ||
sb.view(2 * hunits..3 * hunits), | ||
sb.view(hunits..2 * hunits), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I would really like to keep the 4
as a dimension in the matrix. If reshaping the matrix from [hunits, 4, embedd_dim] to [4, hunits, embedd_dim] gives you some perf, that's great, but then let's express it that way, not just flattening the two numbers together, because then it's not clear whether the arrangement is [A1, A2, A3, A4, B1, B2, B3, B4, ...] or [A1, B1, ..., A2, B2, ...].
For example, instead of
s_t.as_mut().sigmoid(3 * hunits..4 * hunits);
I would much prefer if we can keep it as
s_t.submatrix_mut(3)?.sigmoid_transform();
for idx in 0..self.data.len() { | ||
*self.data.get_mut(idx)? = | ||
i.data.get(idx)? * c.data.get(idx)? + self.data.get(idx)? * f.data.get(idx)? | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I'm surprised this is faster than the current memory layout. You need to jump all around to get the data at the four indices. The previous memory layout put i, c, and f all right next to each other. Actually, the memory layout before I started hacking at it was in that memory order, and I got a perf boost when changing it to the current order. So I'm a bit suspect of where the perf is actually coming from in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After compiling the crate with RUSTFLAGS='-C opt-level=2 -Cllvm-args=--pass-remarks=.*vector.* -Cllvm-args=--pass-remarks-analysis=.*vector.*'
and extracting the remarks related to "math_helper.rs"
llvm opt remarks for math_helper.rs at CL
remark: experimental/segmenter/src/math_helper.rs:520:9: loop not vectorized: cannot prove it is safe to reorder floating-point operations
remark: experimental/segmenter/src/math_helper.rs:557:9: loop not vectorized: cannot prove it is safe to reorder floating-point operations
remark: experimental/segmenter/src/math_helper.rs:339:36: loop not vectorized: value that could not be identified as reduction is used outside the loop
remark: experimental/segmenter/src/math_helper.rs:339:36: loop not vectorized: could not determine number of loop iterations
remark: experimental/segmenter/src/math_helper.rs:346:65: loop not vectorized: instruction cannot be vectorized
remark: experimental/segmenter/src/math_helper.rs:339:17: the cost-model indicates that interleaving is not beneficial
remark: experimental/segmenter/src/math_helper.rs:339:17: vectorized loop (vectorization width: 4, interleaved count: 1)
remark: experimental/segmenter/src/math_helper.rs:346:40: loop not vectorized: could not determine number of loop iterations
math_helper.rs:520 -> inner loop of unrolled_dot_1 (not vectorized)
math_helper.rs:557 -> inner loop of unrolled_dot_2 (not vectorized)
math_helper.rs:339 -> convolve loop (vectorized) as the loop has no data dependencies across indices
math_helper.rs:346 -> mul_tanh loop (not vectorized) tanh is an unvectorizable function
I think the speedup over c4152c5 comes from a semi-vectorized version of https://github.com/unicode-org/icu4x/blob/main/experimental/segmenter/src/lstm.rs#L175-L191.
After removing the loops and only computing the add_dot product at both HEAD and CL, we see near similar performance.
Benchmark HEAD/CL only computing add_dot
Line Break/UTF8/Th/lstm time: [246.39 µs 246.60 µs 246.85 µs]
change: [+3.1495% +3.3542% +3.5639%] (p = 0.00 < 0.05)
Performance has regressed.
Found 18 outliers among 100 measurements (18.00%)
3 (3.00%) high mild
15 (15.00%) high severe
Line Break/UTF16/Th/lstm
time: [248.21 µs 248.50 µs 248.82 µs]
change: [+3.6773% +3.9000% +4.1699%] (p = 0.00 < 0.05)
Performance has regressed.
Found 18 outliers among 100 measurements (18.00%)
1 (1.00%) low severe
3 (3.00%) high mild
14 (14.00%) high severe
Additionally if we use
Benchmark with $approxtanh(x)$
Line Break/UTF8/Th/lstm time: [313.47 µs 313.75 µs 314.08 µs]
change: [-34.471% -34.320% -34.183%] (p = 0.00 < 0.05)
Performance has improved.
Found 10 outliers among 100 measurements (10.00%)
4 (4.00%) high mild
6 (6.00%) high severe
Line Break/UTF16/Th/lstm
time: [313.34 µs 313.47 µs 313.61 µs]
change: [-34.703% -34.552% -34.410%] (p = 0.00 < 0.05)
Performance has improved.
Found 15 outliers among 100 measurements (15.00%)
5 (5.00%) high mild
10 (10.00%) high severe
This passes the segmenter test suite with cargo test --all-features
, but I'm not certain if we should use this approximation as the absolute error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so as far as this PR is concerned, we believe that the perf comes primarily from the vectorization of the convolve
function?
Hmm, I wonder if the most efficient would be a memory layout like the one we currently have but where we carry the matrix in chunks of 4 so that we can vectorize those operations together; we get the benefits of both memory locality (which I found was faster) and vectorization (which you found was faster)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words, a dimension of [hunits / 4, 4, 4]
, which doesn't work all the time since hunits is 27 (not divisible by 4)
I just want to clarify . |
Notice: the branch changed across the force-push!
~ Your Friendly Jira-GitHub PR Checker Bot |
@pdogr If you can resolve the remaining open questions on this PR in the next ~day, it would be nice to land it so that we can ship it in the release. |
Current benches on M2 MacBook Air:
|
Please note that the memory layout of |
This PR is empty, are you still working on it? |
I'd be happy to see whether @pdogr can get back to the 203.37 µs. It should be achievable without touching the data model. |
let hunits = fw_w.dim().1; | ||
let embedd_dim = fw_w.dim().2; | ||
|
||
let fw_w = fw_w.reshape([4 * hunits, embedd_dim]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: instead of a general reshape
, define a collapse_1_2
to collapse the first two dimensions. Then you don't have to compute the sizes here.
Eager to see how much performance you can get out of this. I will reiterate my position though that I don't currently see a logical reason why a 4*HxD should behave any differently than a 4xHxD. The underlying data is identical; it is only an abstraction in the type system. I would prefer a solution where you keep the 4xHxD but address potential issues that make it less performant than the 2d matrix. I'd like to merge the 2d matrix only if you can both prove it is faster and explain why those performance improvements cannot be mapped onto the 3d matrix. |
We're seeing a 5% performance improvement on M1, but nothing on x86. I do see an advantage of using a 4*HxD matrix instead of a 4xHxD tensor: we can use matrix multiplication libraries like |
Why can't we use |
Apparently it's a 5% abstraction on ARM |
This PR changes the shape of
warr
from(hunits, 4, embedd_dim)
->(4 * hunits, embedd_dim)
anduarr
from(hunits, 4, hunits)
->(4 * hunits, hunits)
.Now the forward pass can be expressed in the form of matrix multiplication. and elementwise dot product (which might be vectorized).
Ran the benchmark using
cargo bench --all-features -- "lstm"