Skip to content

Commit 950825b

Browse files
authored
Merge pull request #83 from SymbolicML/fixed-promotion
Fix promotion rules in `+, -, *, /`
2 parents 29454c1 + c291f5f commit 950825b

File tree

3 files changed

+84
-17
lines changed

3 files changed

+84
-17
lines changed

src/math.jl

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
11
for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES
22
@eval begin
3-
Base.:*(l::$type, r::$type) = new_quantity(typeof(l), ustrip(l) * ustrip(r), dimension(l) * dimension(r))
4-
Base.:/(l::$type, r::$type) = new_quantity(typeof(l), ustrip(l) / ustrip(r), dimension(l) / dimension(r))
5-
Base.div(x::$type, y::$type, r::RoundingMode=RoundToZero) = new_quantity(typeof(x), div(ustrip(x), ustrip(y), r), dimension(x) / dimension(y))
3+
function Base.:*(l::$type, r::$type)
4+
l, r = promote_except_value(l, r)
5+
new_quantity(typeof(l), ustrip(l) * ustrip(r), dimension(l) * dimension(r))
6+
end
7+
function Base.:/(l::$type, r::$type)
8+
l, r = promote_except_value(l, r)
9+
new_quantity(typeof(l), ustrip(l) / ustrip(r), dimension(l) / dimension(r))
10+
end
11+
function Base.div(x::$type, y::$type, r::RoundingMode=RoundToZero)
12+
x, y = promote_except_value(x, y)
13+
new_quantity(typeof(x), div(ustrip(x), ustrip(y), r), dimension(x) / dimension(y))
14+
end
615

7-
Base.:*(l::$type, r::$base_type) = new_quantity(typeof(l), ustrip(l) * r, dimension(l))
8-
Base.:/(l::$type, r::$base_type) = new_quantity(typeof(l), ustrip(l) / r, dimension(l))
9-
Base.div(x::$type, y::Number, r::RoundingMode=RoundToZero) = new_quantity(typeof(x), div(ustrip(x), y, r), dimension(x))
16+
# The rest of the functions are unchanged because they do not operate on two variables of the custom type
17+
function Base.:*(l::$type, r::$base_type)
18+
new_quantity(typeof(l), ustrip(l) * r, dimension(l))
19+
end
20+
function Base.:/(l::$type, r::$base_type)
21+
new_quantity(typeof(l), ustrip(l) / r, dimension(l))
22+
end
23+
function Base.div(x::$type, y::Number, r::RoundingMode=RoundToZero)
24+
new_quantity(typeof(x), div(ustrip(x), y, r), dimension(x))
25+
end
1026

11-
Base.:*(l::$base_type, r::$type) = new_quantity(typeof(r), l * ustrip(r), dimension(r))
12-
Base.:/(l::$base_type, r::$type) = new_quantity(typeof(r), l / ustrip(r), inv(dimension(r)))
13-
Base.div(x::Number, y::$type, r::RoundingMode=RoundToZero) = new_quantity(typeof(y), div(x, ustrip(y), r), inv(dimension(y)))
27+
function Base.:*(l::$base_type, r::$type)
28+
new_quantity(typeof(r), l * ustrip(r), dimension(r))
29+
end
30+
function Base.:/(l::$base_type, r::$type)
31+
new_quantity(typeof(r), l / ustrip(r), inv(dimension(r)))
32+
end
33+
function Base.div(x::Number, y::$type, r::RoundingMode=RoundToZero)
34+
new_quantity(typeof(y), div(x, ustrip(y), r), inv(dimension(y)))
35+
end
1436

15-
Base.:*(l::$type, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) * r)
16-
Base.:/(l::$type, r::AbstractDimensions) = new_quantity(typeof(l), ustrip(l), dimension(l) / r)
37+
function Base.:*(l::$type, r::AbstractDimensions)
38+
new_quantity(typeof(l), ustrip(l), dimension(l) * r)
39+
end
40+
function Base.:/(l::$type, r::AbstractDimensions)
41+
new_quantity(typeof(l), ustrip(l), dimension(l) / r)
42+
end
1743

18-
Base.:*(l::AbstractDimensions, r::$type) = new_quantity(typeof(r), ustrip(r), l * dimension(r))
19-
Base.:/(l::AbstractDimensions, r::$type) = new_quantity(typeof(r), inv(ustrip(r)), l / dimension(r))
44+
function Base.:*(l::AbstractDimensions, r::$type)
45+
new_quantity(typeof(r), ustrip(r), l * dimension(r))
46+
end
47+
function Base.:/(l::AbstractDimensions, r::$type)
48+
new_quantity(typeof(r), inv(ustrip(r)), l / dimension(r))
49+
end
2050
end
2151
end
2252

@@ -27,6 +57,7 @@ Base.:/(l::AbstractDimensions, r::AbstractDimensions) = map_dimensions(-, l, r)
2757
for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, op in (:+, :-)
2858
@eval begin
2959
function Base.$op(l::$type, r::$type)
60+
l, r = promote_except_value(l, r)
3061
dimension(l) == dimension(r) || throw(DimensionError(l, r))
3162
return new_quantity(typeof(l), $op(ustrip(l), ustrip(r)), dimension(l))
3263
end
@@ -50,7 +81,7 @@ for op in (:*, :/, :+, :-, :div, :atan, :atand, :copysign, :flipsign, :mod),
5081

5182
t1 == t2 && continue
5283

53-
@eval Base.$op(l::$t1, r::$t2) = $op(promote(l, r)...)
84+
@eval Base.$op(l::$t1, r::$t2) = $op(promote_except_value(l, r)...)
5485
end
5586

5687
# We don't promote on the dimension types:
@@ -125,6 +156,7 @@ for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, f in (:atan, :atand)
125156
return $f(ustrip(x))
126157
end
127158
function Base.$f(y::$type, x::$type)
159+
y, x = promote_except_value(y, x)
128160
dimension(y) == dimension(x) || throw(DimensionError(y, x))
129161
return $f(ustrip(y), ustrip(x))
130162
end
@@ -154,6 +186,7 @@ for (type, base_type, _) in ABSTRACT_QUANTITY_TYPES, f in (:copysign, :flipsign,
154186
# and ignore any dimensions on y, since those will cancel out.
155187
@eval begin
156188
function Base.$f(x::$type, y::$type)
189+
x, y = promote_except_value(x, y)
157190
return new_quantity(typeof(x), $f(ustrip(x), ustrip(y)), dimension(x))
158191
end
159192
function Base.$f(x::$type, y::$base_type)

src/utils.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ function Base.promote_rule(::Type{<:AbstractQuantity}, ::Type{<:Number})
6969
return Number
7070
end
7171

72+
"""
73+
promote_except_value(q1::UnionAbstractQuantity, q2::UnionAbstractQuantity)
74+
75+
This applies a promotion to the quantity type, and the dimension type,
76+
but *not* the value type. This is necessary because sometimes we would
77+
want to multiply a quantity array with a scalar quantity, and wish to use
78+
promotion on the quantity type itself, but don't want to promote to a
79+
single value type.
80+
"""
81+
@inline function promote_except_value(q1::Q1, q2::Q2) where {T1,D1,T2,D2,Q1<:UnionAbstractQuantity{T1,D1},Q2<:UnionAbstractQuantity{T2,D2}}
82+
Q = promote_type(Q1, Q2)
83+
D = promote_type(D1, D2)
84+
85+
Q1_out = with_type_parameters(Q, T1, D)
86+
Q2_out = with_type_parameters(Q, T2, D)
87+
return convert(Q1_out, q1), convert(Q2_out, q2)
88+
end
89+
@inline promote_except_value(q1::Q, q2::Q) where {Q<:UnionAbstractQuantity} = (q1, q2)
90+
7291
Base.keys(d::AbstractDimensions) = dimension_names(typeof(d))
7392
Base.getindex(d::AbstractDimensions, k::Symbol) = getfield(d, k)
7493

@@ -100,6 +119,7 @@ Base.keys(q::UnionAbstractQuantity) = keys(ustrip(q))
100119

101120
# Numeric checks
102121
function Base.isapprox(l::UnionAbstractQuantity, r::UnionAbstractQuantity; kws...)
122+
l, r = promote_except_value(l, r)
103123
return isapprox(ustrip(l), ustrip(r); kws...) && dimension(l) == dimension(r)
104124
end
105125
function Base.isapprox(l::Number, r::UnionAbstractQuantity; kws...)
@@ -111,11 +131,15 @@ function Base.isapprox(l::UnionAbstractQuantity, r::Number; kws...)
111131
return isapprox(ustrip(l), r; kws...)
112132
end
113133
Base.iszero(d::AbstractDimensions) = all_dimensions(iszero, d)
114-
Base.:(==)(l::AbstractDimensions, r::AbstractDimensions) = all_dimensions(==, l, r)
115-
Base.:(==)(l::UnionAbstractQuantity, r::UnionAbstractQuantity) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
134+
function Base.:(==)(l::UnionAbstractQuantity, r::UnionAbstractQuantity)
135+
l, r = promote_except_value(l, r)
136+
ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
137+
end
116138
Base.:(==)(l::Number, r::UnionAbstractQuantity) = ustrip(l) == ustrip(r) && iszero(dimension(r))
117139
Base.:(==)(l::UnionAbstractQuantity, r::Number) = ustrip(l) == ustrip(r) && iszero(dimension(l))
140+
Base.:(==)(l::AbstractDimensions, r::AbstractDimensions) = all_dimensions(==, l, r)
118141
function Base.isless(l::UnionAbstractQuantity, r::UnionAbstractQuantity)
142+
l, r = promote_except_value(l, r)
119143
dimension(l) == dimension(r) || throw(DimensionError(l, r))
120144
return isless(ustrip(l), ustrip(r))
121145
end

test/unittests.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,16 @@ end
663663
qa = [x, y]
664664
@test qa isa Vector{Quantity{Float64,SymbolicDimensions{Rational{Int}}}}
665665
DynamicQuantities.with_type_parameters(SymbolicDimensions{Float64}, Rational{Int}) == SymbolicDimensions{Rational{Int}}
666+
667+
@testset "Promotion with Dimensions" begin
668+
x = 0.5u"cm"
669+
y = -0.03u"m"
670+
x_s = 0.5us"cm"
671+
for op in (+, -, *, /, atan, atand, copysign, flipsign, mod)
672+
@test op(x, y) == op(x_s, y)
673+
@test op(y, x) == op(y, x_s)
674+
end
675+
end
666676
end
667677

668678
@testset "uconvert" begin
@@ -696,7 +706,7 @@ end
696706
q = convert(Q{Float16}, 1.5u"g")
697707
qs = uconvert(convert(Q{Float16}, us"g"), 5 * q)
698708
@test typeof(qs) <: Q{Float16,<:SymbolicDimensions{<:Any}}
699-
@test qs 7.5us"g"
709+
@test isapprox(qs, 7.5us"g"; atol=0.01)
700710

701711
# Arrays
702712
x = [1.0, 2.0, 3.0] .* Q(u"kg")

0 commit comments

Comments
 (0)