Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 1154509

Browse files
Merge #636
636: Add BLAS.axpby! r=amontoison a=amontoison Co-authored-by: Alexis Montoison <[email protected]>
2 parents bc03269 + b51ed34 commit 1154509

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

src/blas/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ function LinearAlgebra.axpy!(alpha::Number, x::CuArray{T}, y::CuArray{T}) where
4646
axpy!(length(x), convert(T,alpha), x, 1, y, 1)
4747
end
4848

49+
function LinearAlgebra.axpby!(alpha::Number, x::CuArray{T}, beta::Number, y::CuArray{T}) where T<:CublasFloat
50+
length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))"))
51+
axpby!(length(x), convert(T,alpha), x, 1, convert(T,beta), y, 1)
52+
end
4953

5054

5155
#

src/blas/wrappers.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,33 @@ function axpy!(alpha::Ta,
200200
y
201201
end
202202

203+
function axpby!(n::Integer,
204+
alpha::T,
205+
dx::CuArray{T},
206+
incx::Integer,
207+
beta::T,
208+
dy::CuArray{T},
209+
incy::Integer) where T <: CublasFloat
210+
scal!(n, beta, dy, incy)
211+
axpy!(n, alpha, dx, incx, dy, incy)
212+
dy
213+
end
214+
215+
function axpby!(alpha::Ta,
216+
x::CuArray{T},
217+
rx::Union{UnitRange{Ti},AbstractRange{Ti}},
218+
beta::Tb,
219+
y::CuArray{T},
220+
ry::Union{UnitRange{Ti},AbstractRange{Ti}}) where {T<:CublasFloat,Ta<:Number,Tb<:Number,Ti<:Integer}
221+
length(rx)==length(ry) || throw(DimensionMismatch(""))
222+
if minimum(rx) < 1 || maximum(rx) > length(x) || minimum(ry) < 1 || maximum(ry) > length(y)
223+
throw(BoundsError())
224+
end
225+
axpby!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T),
226+
step(rx), convert(T, beta), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
227+
y
228+
end
229+
203230
## iamax
204231
# TODO: fix iamax in julia base
205232
for (fname, elty) in ((:cublasIdamax_v2,:Float64),

test/blas.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ k = 13
3232
@test testf(norm, rand(T, m))
3333
@test testf(BLAS.asum, rand(T, m))
3434
@test testf(BLAS.axpy!, Ref(rand()), rand(T, m), rand(T, m))
35+
@test testf(BLAS.axpby!, Ref(rand()), rand(T, m), Ref(rand()), rand(T, m))
3536

3637
if T <: Complex
3738
@test testf(BLAS.dotu, rand(T, m), rand(T, m))

0 commit comments

Comments
 (0)