Skip to content

LSTM forward pass using matrix muls #3219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

Conversation

pdogr
Copy link
Contributor

@pdogr pdogr commented Mar 24, 2023

This PR changes the shape of warr from (hunits, 4, embedd_dim) -> (4 * hunits, embedd_dim) and uarr from (hunits, 4, hunits) -> (4 * hunits, hunits).
Now the forward pass can be expressed in the form of matrix multiplication. and elementwise dot product (which might be vectorized).

Ran the benchmark using cargo bench --all-features -- "lstm"

Line Break/UTF8/Th/lstm time:   [423.87 µs 426.95 µs 430.10 µs]
                        change: [-13.213% -12.386% -11.444%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 6 outliers among 100 measurements (6.00%)
  3 (3.00%) high mild
  3 (3.00%) high severe

Line Break/UTF16/Th/lstm
                        time:   [415.31 µs 417.42 µs 419.68 µs]
                        change: [-14.506% -13.635% -12.525%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
  3 (3.00%) high mild
  4 (4.00%) high severe

@pdogr pdogr requested review from sffc and robertbastian March 24, 2023 15:30
@pdogr pdogr marked this pull request as ready for review March 24, 2023 15:30
@pdogr pdogr requested review from Manishearth, aethanyc, makotokato and a team as code owners March 24, 2023 15:30
robertbastian
robertbastian previously approved these changes Mar 24, 2023
Copy link
Member

@robertbastian robertbastian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great from my end, but I'll defer final approval to Shane

Copy link
Member

@sffc sffc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that you got more perf out of this! Although I'm puzzled about where the perf is coming from.

Comment on lines 175 to 177
sb.view(0..hunits),
sb.view(2 * hunits..3 * hunits),
sb.view(hunits..2 * hunits),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would really like to keep the 4 as a dimension in the matrix. If reshaping the matrix from [hunits, 4, embedd_dim] to [4, hunits, embedd_dim] gives you some perf, that's great, but then let's express it that way, not just flattening the two numbers together, because then it's not clear whether the arrangement is [A1, A2, A3, A4, B1, B2, B3, B4, ...] or [A1, B1, ..., A2, B2, ...].

For example, instead of

s_t.as_mut().sigmoid(3 * hunits..4 * hunits);

I would much prefer if we can keep it as

s_t.submatrix_mut(3)?.sigmoid_transform();

Comment on lines 298 to 336
for idx in 0..self.data.len() {
*self.data.get_mut(idx)? =
i.data.get(idx)? * c.data.get(idx)? + self.data.get(idx)? * f.data.get(idx)?
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I'm surprised this is faster than the current memory layout. You need to jump all around to get the data at the four indices. The previous memory layout put i, c, and f all right next to each other. Actually, the memory layout before I started hacking at it was in that memory order, and I got a perf boost when changing it to the current order. So I'm a bit suspect of where the perf is actually coming from in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After compiling the crate with RUSTFLAGS='-C opt-level=2 -Cllvm-args=--pass-remarks=.*vector.* -Cllvm-args=--pass-remarks-analysis=.*vector.*' and extracting the remarks related to "math_helper.rs"

llvm opt remarks for math_helper.rs at CL
remark: experimental/segmenter/src/math_helper.rs:520:9: loop not vectorized: cannot prove it is safe to reorder floating-point operations
remark: experimental/segmenter/src/math_helper.rs:557:9: loop not vectorized: cannot prove it is safe to reorder floating-point operations
remark: experimental/segmenter/src/math_helper.rs:339:36: loop not vectorized: value that could not be identified as reduction is used outside the loop
remark: experimental/segmenter/src/math_helper.rs:339:36: loop not vectorized: could not determine number of loop iterations
remark: experimental/segmenter/src/math_helper.rs:346:65: loop not vectorized: instruction cannot be vectorized
remark: experimental/segmenter/src/math_helper.rs:339:17: the cost-model indicates that interleaving is not beneficial
remark: experimental/segmenter/src/math_helper.rs:339:17: vectorized loop (vectorization width: 4, interleaved count: 1)
remark: experimental/segmenter/src/math_helper.rs:346:40: loop not vectorized: could not determine number of loop iterations

math_helper.rs:520 -> inner loop of unrolled_dot_1 (not vectorized)
math_helper.rs:557 -> inner loop of unrolled_dot_2 (not vectorized)
math_helper.rs:339 -> convolve loop (vectorized) as the loop has no data dependencies across indices
math_helper.rs:346 -> mul_tanh loop (not vectorized) tanh is an unvectorizable function

I think the speedup over c4152c5 comes from a semi-vectorized version of https://github.com/unicode-org/icu4x/blob/main/experimental/segmenter/src/lstm.rs#L175-L191.

After removing the loops and only computing the add_dot product at both HEAD and CL, we see near similar performance.

Benchmark HEAD/CL only computing add_dot
Line Break/UTF8/Th/lstm time:   [246.39 µs 246.60 µs 246.85 µs]
                        change: [+3.1495% +3.3542% +3.5639%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 18 outliers among 100 measurements (18.00%)
  3 (3.00%) high mild
  15 (15.00%) high severe

Line Break/UTF16/Th/lstm
                        time:   [248.21 µs 248.50 µs 248.82 µs]
                        change: [+3.6773% +3.9000% +4.1699%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 18 outliers among 100 measurements (18.00%)
  1 (1.00%) low severe
  3 (3.00%) high mild
  14 (14.00%) high severe

Additionally if we use $approxtanh(x) = \frac{2}{1+e^-2x} -1 $, mul_tanh vectorizes.

Benchmark with $approxtanh(x)$
Line Break/UTF8/Th/lstm time:   [313.47 µs 313.75 µs 314.08 µs]
                        change: [-34.471% -34.320% -34.183%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 10 outliers among 100 measurements (10.00%)
  4 (4.00%) high mild
  6 (6.00%) high severe

Line Break/UTF16/Th/lstm
                        time:   [313.34 µs 313.47 µs 313.61 µs]
                        change: [-34.703% -34.552% -34.410%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 15 outliers among 100 measurements (15.00%)
  5 (5.00%) high mild
  10 (10.00%) high severe

This passes the segmenter test suite with cargo test --all-features, but I'm not certain if we should use this approximation as the absolute error $|tanh(x) - approxtanh(x)|$ is as big as $1e^-3$ for the test suite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so as far as this PR is concerned, we believe that the perf comes primarily from the vectorization of the convolve function?

Hmm, I wonder if the most efficient would be a memory layout like the one we currently have but where we carry the matrix in chunks of 4 so that we can vectorize those operations together; we get the benefits of both memory locality (which I found was faster) and vectorization (which you found was faster)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, a dimension of [hunits / 4, 4, 4], which doesn't work all the time since hunits is 27 (not divisible by 4)

@sffc
Copy link
Member

sffc commented Mar 24, 2023

Now the forward pass can be expressed in the form of matrix multiplication. and elementwise dot product (which might be vectorized).

I just want to clarify . dot_3d is exactly equivalent to dot_2d except that it is more explicit about the memory order of the second dimension. In other words, this calculation is already being "expressed in the form of matrix multiplication" and "elementwise dot product".

@jira-pull-request-webhook
Copy link

Notice: the branch changed across the force-push!

  • experimental/segmenter/src/lstm_bies.rs is no longer changed in the branch
  • experimental/segmenter/src/lstm.rs is now changed in the branch
  • experimental/segmenter/src/math_helper.rs is different
  • experimental/segmenter/src/provider/lstm.rs is now changed in the branch
  • experimental/segmenter/tests/testdata/provider/segmenter/lstm/wl_auto@1/km.postcard is now changed in the branch
  • experimental/segmenter/tests/testdata/provider/segmenter/lstm/wl_auto@1/lo.postcard is now changed in the branch
  • experimental/segmenter/tests/testdata/provider/segmenter/lstm/wl_auto@1/my.postcard is now changed in the branch
  • experimental/segmenter/tests/testdata/provider/segmenter/lstm/wl_auto@1/th.postcard is now changed in the branch
  • provider/datagen/src/transform/segmenter/lstm.rs is different
  • provider/repodata/data/json/fingerprints.csv is different
  • provider/repodata/data/json/segmenter/lstm@1/th.json is no longer changed in the branch
  • provider/repodata/data/json/segmenter/lstm/wl_auto@1/th.json is now changed in the branch
  • provider/testdata/data/baked/segmenter/lstm_v1/th.rs.data is no longer changed in the branch
  • provider/testdata/data/baked/segmenter/lstm/wl_auto_v1/th.rs.data is now changed in the branch
  • provider/testdata/data/postcard/fingerprints.csv is different
  • provider/testdata/data/testdata.postcard is different

View Diff Across Force-Push

~ Your Friendly Jira-GitHub PR Checker Bot

@sffc
Copy link
Member

sffc commented Apr 10, 2023

@pdogr If you can resolve the remaining open questions on this PR in the next ~day, it would be nice to land it so that we can ship it in the release.

@sffc
Copy link
Member

sffc commented Apr 12, 2023

Current benches on M2 MacBook Air:

# Main
Line Break/UTF8/Th/lstm time:   [231.19 µs 231.31 µs 231.43 µs]

# sffc/vectorize-convolve
Line Break/UTF8/Th/lstm time:   [215.82 µs 216.38 µs 217.30 µs]

# sffc/lstm-unroll
Line Break/UTF8/Th/lstm time:   [220.93 µs 220.97 µs 221.03 µs]

# pdogr/2d-mat
Line Break/UTF8/Th/lstm time:   [203.32 µs 203.37 µs 203.42 µs]

@sffc
Copy link
Member

sffc commented Apr 12, 2023

Please note that the memory layout of sffc/vectorize-convolve and pdogr/2d-mat is the same, so careful coding should be able to bring the performance together without touching the data model:

https://raw.githubusercontent.com/unicode-org/icu4x/532a71bc79a0c3d4bab0761a1ee3499256528bf0/provider/repodata/data/json/segmenter/lstm/wl_auto%401/th.json

https://raw.githubusercontent.com/unicode-org/icu4x/8a247fc7f0e2500789a58fa750b8580a55eec3c7/provider/repodata/data/json/segmenter/lstm/wl_auto%401/th.json

@sffc sffc added the waiting-on-author PRs waiting for action from the author for >7 days label Apr 13, 2023
@robertbastian robertbastian added the C-segmentation Component: Segmentation label Apr 13, 2023
@robertbastian
Copy link
Member

This PR is empty, are you still working on it?

@sffc
Copy link
Member

sffc commented May 3, 2023

I'd be happy to see whether @pdogr can get back to the 203.37 µs. It should be achievable without touching the data model.

let hunits = fw_w.dim().1;
let embedd_dim = fw_w.dim().2;

let fw_w = fw_w.reshape([4 * hunits, embedd_dim]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: instead of a general reshape, define a collapse_1_2 to collapse the first two dimensions. Then you don't have to compute the sizes here.

@sffc
Copy link
Member

sffc commented May 8, 2023

Eager to see how much performance you can get out of this. I will reiterate my position though that I don't currently see a logical reason why a 4*HxD should behave any differently than a 4xHxD. The underlying data is identical; it is only an abstraction in the type system. I would prefer a solution where you keep the 4xHxD but address potential issues that make it less performant than the 2d matrix. I'd like to merge the 2d matrix only if you can both prove it is faster and explain why those performance improvements cannot be mapped onto the 3d matrix.

@robertbastian
Copy link
Member

We're seeing a 5% performance improvement on M1, but nothing on x86.

I do see an advantage of using a 4*HxD matrix instead of a 4xHxD tensor: we can use matrix multiplication libraries like blis to improve performance further.

@sffc
Copy link
Member

sffc commented May 9, 2023

Why can't we use blis with the upgraded 3-D type? You can just drop the third dimension when you call into the library or perform any other operation. What I'm saying is that the 3-D nature of the matrix/tensor is a zero-cost abstraction.

@robertbastian
Copy link
Member

What I'm saying is that the 3-D nature of the matrix/tensor is a zero-cost abstraction.

Apparently it's a 5% abstraction on ARM

@robertbastian robertbastian removed the waiting-on-author PRs waiting for action from the author for >7 days label May 6, 2025
@robertbastian robertbastian marked this pull request as draft May 6, 2025 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C-segmentation Component: Segmentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants