@@ -181,15 +181,20 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(_ma
181
181
return Fill (y_el, axes (x)), _map_Fill_rrule
182
182
end
183
183
184
- function ChainRulesCore. rrule (config:: RuleConfig{>:HasReverseMode} , :: typeof (_map), f, x:: Fill , y:: Fill )
185
- z_el, back = ChainRulesCore. rrule_via_ad (config, f, x. value, y. value)
184
+ # Somehow needed to avoid the _map -> map indirection
185
+ function _map (f, xs:: Fill... )
186
+ all (== (axes (first (xs))), axes .(xs)) || error (" All axes should be the same" )
187
+ Fill (f (FillArrays. getindex_value .(xs)... ), axes (first (xs)))
188
+ end
189
+
190
+ function ChainRulesCore. rrule (config:: RuleConfig{>:HasReverseMode} , :: typeof (_map), f, xs:: Fill... )
191
+ z_el, back = ChainRulesCore. rrule_via_ad (config, f, FillArrays. getindex_value .(xs)... )
186
192
function _map_Fill_rrule (Δ)
187
- Δf, Δx_el, Δy_el = back (unthunk (Δ). value)
188
- return NoTangent (), Δf, Fill (Δx_el , axes (x)), Fill (Δy_el, axes (x))
193
+ Δf, Δxs_el ... = back (unthunk (Δ). value)
194
+ return NoTangent (), Δf, Fill .(Δxs_el , axes .(xs)) ...
189
195
end
190
- return Fill (z_el, axes (x )), _map_Fill_rrule
196
+ return Fill (z_el, axes (first (xs) )), _map_Fill_rrule
191
197
end
192
-
193
198
# ## Same thing for `StructArray`
194
199
195
200
0 commit comments