Skip to content

Commit c4b0b83

Browse files
devmotionsethaxen
andauthored
Add derivatives for besselix, besseljx, and besselyx (#350)
* Add derivatives for `besselix`, `besseljx`, and `besselyx` * Bump version * Apply suggestions from @sethaxen's review (muladd + optimizations) * Improve `frule`s * Simplify `frule`s * Apply suggestions from code review Co-authored-by: Seth Axen <[email protected]> Co-authored-by: Seth Axen <[email protected]>
1 parent 1b29b0b commit c4b0b83

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/chainrules.jl

+90
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,93 @@ ChainRulesCore.@scalar_rule(
193193
ChainRulesCore.@scalar_rule(expinti(x), exp(x) / x)
194194
ChainRulesCore.@scalar_rule(sinint(x), sinc(invπ * x))
195195
ChainRulesCore.@scalar_rule(cosint(x), cos(x) / x)
196+
197+
# non-holomorphic functions
198+
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselix), ν::Number, x::Number)
199+
# primal
200+
Ω = besselix(ν, x)
201+
202+
# derivative
203+
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
204+
a = (besselix- 1, x) + besselix+ 1, x)) / 2
205+
ΔΩ = if Δx isa Real
206+
muladd(muladd(-sign(real(x)), Ω, a), Δx, ∂Ω_∂ν * Δν)
207+
else
208+
muladd(a, Δx, muladd(-sign(real(x)) * real(Δx), Ω, ∂Ω_∂ν * Δν))
209+
end
210+
211+
return Ω, ΔΩ
212+
end
213+
function ChainRulesCore.rrule(::typeof(besselix), ν::Number, x::Number)
214+
Ω = besselix(ν, x)
215+
project_x = ChainRulesCore.ProjectTo(x)
216+
function besselix_pullback(ΔΩ)
217+
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
218+
a = (besselix- 1, x) + besselix+ 1, x)) / 2
219+
= project_x(muladd(conj(a), ΔΩ, - sign(real(x)) * real(conj(Ω) * ΔΩ)))
220+
return ChainRulesCore.NoTangent(), ν̄, x̄
221+
end
222+
return Ω, besselix_pullback
223+
end
224+
225+
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besseljx), ν::Number, x::Number)
226+
# primal
227+
Ω = besseljx(ν, x)
228+
229+
# derivative
230+
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
231+
a = (besseljx- 1, x) - besseljx+ 1, x)) / 2
232+
ΔΩ = if Δx isa Real
233+
muladd(a, Δx, ∂Ω_∂ν * Δν)
234+
else
235+
muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν))
236+
end
237+
238+
return Ω, ΔΩ
239+
end
240+
function ChainRulesCore.rrule(::typeof(besseljx), ν::Number, x::Number)
241+
Ω = besseljx(ν, x)
242+
project_x = ChainRulesCore.ProjectTo(x)
243+
function besseljx_pullback(ΔΩ)
244+
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
245+
a = (besseljx- 1, x) - besseljx+ 1, x)) / 2
246+
= if x isa Real
247+
project_x(a * ΔΩ)
248+
else
249+
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im))
250+
end
251+
return ChainRulesCore.NoTangent(), ν̄, x̄
252+
end
253+
return Ω, besseljx_pullback
254+
end
255+
256+
function ChainRulesCore.frule((_, Δν, Δx), ::typeof(besselyx), ν::Number, x::Number)
257+
# primal
258+
Ω = besselyx(ν, x)
259+
260+
# derivative
261+
∂Ω_∂ν = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
262+
a = (besselyx- 1, x) - besselyx+ 1, x)) / 2
263+
ΔΩ = if Δx isa Real
264+
muladd(a, Δx, ∂Ω_∂ν * Δν)
265+
else
266+
muladd(a, Δx, muladd(-sign(imag(x)) * imag(Δx), Ω, ∂Ω_∂ν * Δν))
267+
end
268+
269+
return Ω, ΔΩ
270+
end
271+
function ChainRulesCore.rrule(::typeof(besselyx), ν::Number, x::Number)
272+
Ω = besselyx(ν, x)
273+
project_x = ChainRulesCore.ProjectTo(x)
274+
function besselyx_pullback(ΔΩ)
275+
ν̄ = ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO)
276+
a = (besselyx- 1, x) - besselyx+ 1, x)) / 2
277+
= if x isa Real
278+
project_x(a * ΔΩ)
279+
else
280+
project_x(muladd(conj(a), ΔΩ, - sign(imag(x)) * real(conj(Ω) * ΔΩ) * im))
281+
end
282+
return ChainRulesCore.NoTangent(), ν̄, x̄
283+
end
284+
return Ω, besselyx_pullback
285+
end

test/chainrules.jl

+9
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@
5353
for nu in (-1.5, 2.2, 4.0)
5454
test_frule(besseli, nu, x)
5555
test_rrule(besseli, nu, x)
56+
test_frule(besselix, nu, x) # derivative is `NotImplemented`
57+
test_frule(besselix, nu NoTangent(), x) # derivative is a number
58+
test_rrule(besselix, nu, x)
5659

5760
test_frule(besselj, nu, x)
5861
test_rrule(besselj, nu, x)
62+
test_frule(besseljx, nu, x) # derivative is `NotImplemented`
63+
test_frule(besseljx, nu NoTangent(), x) # derivative is a number
64+
test_rrule(besseljx, nu, x)
5965

6066
test_frule(besselk, nu, x)
6167
test_rrule(besselk, nu, x)
@@ -64,6 +70,9 @@
6470

6571
test_frule(bessely, nu, x)
6672
test_rrule(bessely, nu, x)
73+
test_frule(besselyx, nu, x) # derivative is `NotImplemented`
74+
test_frule(besselyx, nu NoTangent(), x) # derivative is a number
75+
test_rrule(besselyx, nu, x)
6776

6877
test_frule(hankelh1, nu, x)
6978
test_rrule(hankelh1, nu, x)

0 commit comments

Comments
 (0)