-
-
Notifications
You must be signed in to change notification settings - Fork 8
add missing mul! implementation #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overloading mul!
seems fine.
Re deleting *
, note that @less *(rand(2,2), rand(2,2))
shows it calls similar(B)
, and that will give the wrong type here... although that could be fixed.
julia> using JLArrays
julia> onehotbatch([1,3,2], 0:5) |> jl
6×3 OneHotMatrix(::JLArray{UInt32, 1}) with eltype Bool:
⋅ ⋅ ⋅
1 ⋅ ⋅
⋅ ⋅ 1
⋅ 1 ⋅
⋅ ⋅ ⋅
⋅ ⋅ ⋅
julia> similar(ans, Float32)
6×3 Matrix{Float32}:
3.0f-44 3.66694f29 3.66699f29
0.0 1.0f-45 1.0f-45
2.7f-44 3.66696f29 3.66701f29
0.0 1.0f-45 1.0f-45
2.8f-44 3.66697f29 3.0f-45
0.0 1.0f-45 0.0
test/gpu.jl
Outdated
end | ||
|
||
# some specialized implementations call only mul! and not *, so we must ensure this works | ||
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) isa CuArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tests only one case. I think that it should test vector & matrix output, each with a OneHotArray, and some reshaped array... and maybe something for which _isonehot(B) === false
to test the invoke
path?
And it should check that the results are correct, not just the type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the tests in the GPU suite only check type, but I now have this comparing to *
for the time being.
The branching here opens up a huge can of worms which it is a non-goal of mine to fix, see #48.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. But new code really has to test that it produces the correct value. (And enough cases that obvious method ambiguities would be found.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
The overall situation with matrix operations in this package is likely still incredibly wonky, ideally somebody would carefully think through exactly how these should be done and replace all of |
Also, am I correct in assuming that 1.6 is no longer relevant as it is no longer LTS? If so I can remove the 1.6 tests in favor of 1.10. |
Alright, I've swapped out 1.6 for 1.10. That's as much as I was planning to fix here, it should no longer immediately fail for someone trying to use it with Lux. |
Bumping
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #47 +/- ##
==========================================
- Coverage 96.26% 95.07% -1.20%
==========================================
Files 3 3
Lines 134 142 +8
==========================================
+ Hits 129 135 +6
- Misses 5 7 +2 ☔ View full report in Codecov by Sentry. |
I've admittedly stopped paying attention to this myself, but can we either merge this or set up an extended discussion about what's the right thing to do here? This PR definitely doesn't fix the whole situation, but current latest tag is very broken on GPU right now, and (for me at least, probably everyone?) renders OneHotArrays.jl totally broken on what I believe was its originally intended use case (deep learning). |
end | ||
|
||
# some specialized implementations call only mul! and not *, so we must ensure this works | ||
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) ≈ gA*y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to check a vector case too. And also a non-GPU case.... not 100% sure what happens when tests are run without CUDA, but the new code is LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLike)
which certainly includes A::Matrix
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've added a test for vector.
Here's an example where this PR changes an error into a silently wrong answer:
I realise it's annoying to write more tests than your use case, especially when other people adding features have not bothered to do so. But adding more ways to get wrong answers is something we should try pretty hard to avoid. |
Alright, sorry it took so me so long to get to that especially after I had been complaining about it. Ok, I think everything should be about wrapped up now. There are still a lot of issues floating around because it hits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed some Vector/Matrix tests like I keep suggesting, re-using some of the existing cases for *
, and they fail with this PR (but pass without).
There are still a lot of issues floating around because it hits getindex, but they will require a more comprehensive fix (and some cases are probably just not possible to handle elegantly with GPU arrays).
If you have other cases which fail, please make issues. Otherwise it's just vague feelings & hard for anyone to act on.
Co-authored-by: Michael Abbott <[email protected]>
Prior to this PR, this package does not define a method for
LinearAlgebra.mul!
. This low-level method is used by some procedures instead of*
. Without the new method included in this PR, such low-level calls could break in some cases. In particular, some of the optimizedmatmul
calls used by Lux.jl would error out with a method ambiguity or scalar indexing of GPU array error when used with GPU arrays.Ultimately the
*
methods should probably be removed, but I have left them alone for now to get this in without worrying too much about it breaking anything.GPU tests pass for me on CUDA.
PR Checklist