Skip to content

Commit 7a6dc8f

Browse files
authored
Merge pull request #108 from JuliaGaussianProcesses/tgf/fix-rrule-fill
Fix `rrule` for more than 2 Fill
2 parents fa8b4af + e94f668 commit 7a6dc8f

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TemporalGPs"
22
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
33
authors = ["willtebbutt <[email protected]> and contributors"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

src/util/chainrules.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,20 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_ma
181181
return Fill(y_el, axes(x)), _map_Fill_rrule
182182
end
183183

184-
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, x::Fill, y::Fill)
185-
z_el, back = ChainRulesCore.rrule_via_ad(config, f, x.value, y.value)
184+
# Somehow needed to avoid the _map -> map indirection
185+
function _map(f, xs::Fill...)
186+
all(==(axes(first(xs))), axes.(xs)) || error("All axes should be the same")
187+
Fill(f(FillArrays.getindex_value.(xs)...), axes(first(xs)))
188+
end
189+
190+
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_map), f, xs::Fill...)
191+
z_el, back = ChainRulesCore.rrule_via_ad(config, f, FillArrays.getindex_value.(xs)...)
186192
function _map_Fill_rrule(Δ)
187-
Δf, Δx_el, Δy_el = back(unthunk(Δ).value)
188-
return NoTangent(), Δf, Fill(Δx_el, axes(x)), Fill(Δy_el, axes(x))
193+
Δf, Δxs_el... = back(unthunk(Δ).value)
194+
return NoTangent(), Δf, Fill.(Δxs_el, axes.(xs))...
189195
end
190-
return Fill(z_el, axes(x)), _map_Fill_rrule
196+
return Fill(z_el, axes(first(xs))), _map_Fill_rrule
191197
end
192-
193198
### Same thing for `StructArray`
194199

195200

test/util/chainrules.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,17 @@ include("../test_util.jl")
8080
test_rrule(_map, x -> 2.0 * x, x; check_inferred=false)
8181
test_rrule(ZygoteRuleConfig(), (x,a)-> _map(x -> x * a, x), x, 2.0; check_inferred=false, rrule_f=rrule_via_ad)
8282
end
83-
@testset "_map(f, x1::Fill, x2::Fill)" begin
83+
@testset "_map(f, x::Fill....)" begin
8484
x1 = Fill(randn(3, 4), 3)
8585
x2 = Fill(randn(3, 4), 3)
86+
x3 = Fill(randn(3, 4), 3)
8687

8788
@test _map(+, x1, x2) == _map(+, collect(x1), collect(x2))
8889
test_rrule(_map, +, x1, x2; check_inferred=true)
8990

91+
@test _map(+, x1, x2, x3) == _map(+, collect(x1), collect(x2), collect(x3))
92+
test_rrule(_map, +, x1, x2, x3; check_inferred=true)
93+
9094
fsin(x, y) = sin.(x .* y)
9195
test_rrule(_map, fsin, x1, x2; check_inferred=false)
9296

0 commit comments

Comments
 (0)