Skip to content

Add ndims type parameter to AbstractArrayInterface #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 10, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DerivableInterfaces"
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.5"
version = "0.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[compat]
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
Documenter = "1"
Literate = "2"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

[compat]
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
20 changes: 12 additions & 8 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# TODO: Add `ndims` type parameter.
abstract type AbstractArrayInterface <: AbstractInterface end
abstract type AbstractArrayInterface{N} <: AbstractInterface end

function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N}
return DefaultArrayInterface{N}()
end
function interface(::Type{<:Broadcast.AbstractArrayStyle})
return DefaultArrayInterface()
end

function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface()
function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}})
return DefaultArrayInterface{ndims(BC)}()
end

function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
return interface(Style)
end

# TODO: Define as `Array{T}`.
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")
# TODO: Define as `similar(Array{T}, ax)`.
function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple)
return error("Not implemented.")
end

using ArrayLayouts: ArrayLayouts

Expand Down Expand Up @@ -85,7 +90,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
)
return similar(arraytype(interface, T), size)
return similar(interface, T, size)
end

@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
Expand All @@ -105,8 +110,7 @@ end
@interface interface::AbstractArrayInterface function Base.similar(
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
)
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
return similar(arraytype(interface, T), axes)
return similar(interface, T, axes)
end

using MapBroadcast: Mapped
Expand Down
27 changes: 19 additions & 8 deletions src/concatenate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ export concatenate
@compat public Concatenated, cat, cat!, concatenated

using Base: promote_eltypeof
using ..DerivableInterfaces:
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero!

unval(x) = x
unval(::Val{x}) where {x} = x
Expand All @@ -53,13 +52,17 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
end
end

function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
function Concatenated(
interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple
)
return _Concatenated(interface, dims, args)
end
function Concatenated(dims::Val, args::Tuple)
return Concatenated(interface(args...), dims, args)
return Concatenated(cat_interface(dims, args...), dims, args)
end
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
function Concatenated{Interface}(
dims::Val, args::Tuple
) where {Interface<:Union{AbstractArrayInterface,Nothing}}
return Concatenated(Interface(), dims, args)
end

Expand All @@ -81,8 +84,11 @@ end
# ------------------------------------
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
return similar(arraytype(interface(concat), T), ax)
function Base.similar(concat::Concatenated, ax::Tuple)
return similar(interface(concat), eltype(concat), ax)
end
function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T}
return similar(interface(concat), T, ax)
end

function cat_axis(
Expand All @@ -108,10 +114,15 @@ function cat_axes(dims::Val, as::AbstractArray...)
return cat_axes(unval(dims), as...)
end

function cat_interface(dims, as::AbstractArray...)
N = cat_ndims(dims, as...)
return typeof(interface(as...))(Val(N))
end

Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
Base.size(concat::Concatenated) = length.(axes(concat))
Base.ndims(concat::Concatenated) = length(axes(concat))
Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...)

# Main logic
# ----------
Expand Down
26 changes: 22 additions & 4 deletions src/defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
# TODO: Add `ndims` type parameter.
struct DefaultArrayInterface <: AbstractArrayInterface end
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end

DefaultArrayInterface() = DefaultArrayInterface{Any}()
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()

using TypeParameterAccessors: parenttype
function interface(a::Type{<:AbstractArray})
parenttype(a) === a && return DefaultArrayInterface()
return interface(parenttype(a))
end
function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N}
parenttype(a) === a && return DefaultArrayInterface{N}()
return interface(parenttype(a))
end

function combine_interface_rule(
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
) where {N}
return DefaultArrayInterface{N}()
end
function combine_interface_rule(
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
)
return DefaultArrayInterface{Any}()
end

@interface ::DefaultArrayInterface function Base.getindex(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
Expand All @@ -31,6 +49,6 @@ end
return Base.mapreduce(f, op, as...; kwargs...)
end

function arraytype(::DefaultArrayInterface, T::Type)
return Array{T}
function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple)
return similar(Array{T}, ax)
end
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8"
ArrayLayouts = "1"
DerivableInterfaces = "0.4"
DerivableInterfaces = "0.5"
LinearAlgebra = "1"
SafeTestsets = "0.1"
Suppressor = "0.2"
LinearAlgebra = "1"
Test = "1"
TestExtras = "0.3"
16 changes: 12 additions & 4 deletions test/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using DerivableInterfaces:
using LinearAlgebra: LinearAlgebra

# Define an interface.
struct SparseArrayInterface <: AbstractArrayInterface end
struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end
SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}()
SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}()

# Define interface functions.
@interface ::SparseArrayInterface function Base.getindex(
Expand All @@ -66,11 +68,15 @@ end
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()

DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()
function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N}
return SparseArrayInterface{N}()
end

@derive SparseArrayStyle AbstractArrayStyleOps

DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}
function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple)
return similar(SparseArrayDOK{T}, ax)
end

# Interface functions.
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
Expand Down Expand Up @@ -260,7 +266,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK)
end

# Specify the interface the type adheres to.
DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
SparseArrayInterface{ndims(arrayt)}()
end

# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
@array_aliases SparseArrayDOK
Expand Down
35 changes: 29 additions & 6 deletions test/test_defaultarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test: @inferred, @testset, @test
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
using Test: @testset, @test
using TestExtras: @constinferred

# function wrappers to test type-stability
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
Expand All @@ -11,28 +12,50 @@ end

@testset "indexing" begin
for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3)))
a = @inferred _getindex(A, i...)
a = @constinferred _getindex(A, i...)
@test a == A[i...]
v = 1.1
A′ = @inferred _setindex!(A, v, i...)
A′ = @constinferred _setindex!(A, v, i...)
@test A′ == (A[i...] = v)
end
end

@testset "map!" begin
A = zeros(3)
a = @inferred _map!(Returns(2), copy(A), A)
a = @constinferred _map!(Returns(2), copy(A), A)
@test a == map!(Returns(2), copy(A), A)
end

@testset "mapreduce" begin
A = zeros(3)
a = @inferred _mapreduce(Returns(2), +, A)
a = @constinferred _mapreduce(Returns(2), +, A)
@test a == mapreduce(Returns(2), +, A)
end

@testset "DefaultArrayInterface" begin
@test interface(Array) === DefaultArrayInterface{Any}()
@test interface(Array{Float32}) === DefaultArrayInterface{Any}()
@test interface(Matrix) === DefaultArrayInterface{2}()
@test interface(Matrix{Float32}) === DefaultArrayInterface{2}()
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
end

@testset "similar(::DefaultArrayInterface, ...)" begin
a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)

a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2))
@test typeof(a) === Matrix{Float32}
@test size(a) == (2, 2)
end

@testset "Broadcast.DefaultArrayStyle" begin
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
DefaultArrayInterface()
DefaultArrayInterface{1}()
end
Loading