|
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tuple...) where {F} |
|
length_y = minimum(length, xs) |
|
hobbits = ntuple(length_y) do i |
|
args = getindex.(xs, i) |
|
rrule_via_ad(config, f, args...) |
|
end |
|
y = map(first, hobbits) |
|
num_xs = Val(length(xs)) |
|
paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs) |
|
all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths! |
|
But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed.""" |
|
function map_pullback(dy_raw) |
|
dy = unthunk(dy_raw) |
|
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass: |
|
backevals = ntuple(length_y) do i |
|
rev_i = length_y - i + 1 |
|
last(hobbits[rev_i])(dy[rev_i]) |
|
end |> reverse |
|
# This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem. |
|
df = ProjectTo(f)(sum(first, backevals)) |
|
# Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding. |
|
# Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216 |
|
dxs = ntuple(num_xs) do k |
|
dx_short = map(bv -> bv[k+1], backevals) |
|
ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us |
|
end |
|
return (NoTangent(), df, dxs...) |
|
end |
|
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) |
|
return y, map_pullback |
|
end |
ChainRules.jl/src/rulesets/Base/base.jl
Lines 243 to 273 in 9dd39bd
The referenced Julia issue is resolved: JuliaLang/julia#42216