Skip to content

Commit c130e6f

Browse files
committed
Add call_at_end and save_positions to callbacks
1 parent 95e6109 commit c130e6f

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

src/Callbacks.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,13 @@ Trigger `f!(integrator)` every `Δt` simulation time.
8888
8989
If `atinit=true`, then `f!` will additionally be triggered at initialization. Otherwise
9090
the first trigger will be after `Δt` simulation time.
91+
92+
If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise
93+
there is no call to `f!` at the end of the time span.
94+
95+
The tuple `save_positions` determines whether to save before or after `f!`.
9196
"""
92-
function EveryXSimulationTime(f!, Δt; atinit = false)
97+
function EveryXSimulationTime(f!, Δt; atinit = false, call_at_end = false, save_positions = (true, true))
9398
t_next = zero(Δt)
9499

95100
function _initialize(c, u, t, integrator)
@@ -111,14 +116,22 @@ function EveryXSimulationTime(f!, Δt; atinit = false)
111116
t_next += Δt
112117
end
113118
return true
119+
elseif (call_at_end && t == integrator.sol.prob.tspan[2])
120+
return true
114121
else
115122
return false
116123
end
117124
end
118125
if isdefined(DiffEqBase, :finalize!)
119-
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
126+
SciMLBase.DiscreteCallback(
127+
condition,
128+
f!;
129+
initialize = _initialize,
130+
finalize = _finalize,
131+
save_positions = save_positions,
132+
)
120133
else
121-
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
134+
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions)
122135
end
123136
end
124137

@@ -131,8 +144,13 @@ Trigger `f!(integrator)` every `Δsteps` simulation steps.
131144
132145
If `atinit==true`, then `f!` will additionally be triggered at initialization. Otherwise
133146
the first trigger will be after `Δsteps`.
147+
148+
If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise
149+
there is no call to `f!` at the end of the time span.
150+
151+
The tuple `save_positions` determines whether to save before or after `f!`.
134152
"""
135-
function EveryXSimulationSteps(f!, Δsteps; atinit = false)
153+
function EveryXSimulationSteps(f!, Δsteps; atinit = false, call_at_end = false, save_positions = (true, true))
136154
steps = 0
137155
steps_next = 0
138156

@@ -154,15 +172,23 @@ function EveryXSimulationSteps(f!, Δsteps; atinit = false)
154172
if steps >= steps_next
155173
steps_next += Δsteps
156174
return true
175+
elseif (call_at_end && t == integrator.sol.prob.tspan[2])
176+
return true
157177
else
158178
return false
159179
end
160180
end
161181

162182
if isdefined(DiffEqBase, :finalize!)
163-
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
183+
SciMLBase.DiscreteCallback(
184+
condition,
185+
f!;
186+
initialize = _initialize,
187+
finalize = _finalize,
188+
save_positions = save_positions,
189+
)
164190
else
165-
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
191+
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions)
166192
end
167193
end
168194

test/callbacks.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ mutable struct MyCallback
1414
initialized::Bool
1515
calls::Int
1616
finalized::Bool
17+
last_t::Real
1718
end
18-
MyCallback() = MyCallback(false, 0, false)
19+
MyCallback() = MyCallback(false, 0, false, -1.0)
1920

2021
function Callbacks.initialize!(cb::MyCallback, integrator)
2122
cb.initialized = true
@@ -25,13 +26,18 @@ function Callbacks.finalize!(cb::MyCallback, integrator)
2526
end
2627
function (cb::MyCallback)(integrator)
2728
cb.calls += 1
29+
cb.last_t = integrator.t
2830
end
2931

3032
cb1 = MyCallback()
3133
cb2 = MyCallback()
3234
cb3 = MyCallback()
3335
cb4 = MyCallback()
3436
cb5 = MyCallback()
37+
cb6 = MyCallback()
38+
cb7 = MyCallback()
39+
cb8 = MyCallback()
40+
cb9 = MyCallback()
3541

3642
cbs = CallbackSet(
3743
EveryXSimulationTime(cb1, 1 / 4),
@@ -40,6 +46,10 @@ cbs = CallbackSet(
4046
EveryXSimulationSteps(cb4, 4, atinit = true),
4147
EveryXSimulationSteps(_ -> sleep(1 / 32), 1),
4248
EveryXWallTimeSeconds(cb5, 0.49, comm_ctx),
49+
EveryXSimulationTime(cb6, 0.49, call_at_end = true),
50+
EveryXSimulationSteps(cb7, 3, call_at_end = true),
51+
EveryXSimulationTime(cb8, 0.3, call_at_end = false),
52+
EveryXSimulationSteps(cb9, 3, call_at_end = false),
4353
)
4454

4555
const_prob_inc = ODEProblem(
@@ -63,6 +73,11 @@ solve(const_prob_inc, LSRKEulerMethod(), dt = 1 / 32, callback = cbs)
6373
@test cb4.calls == 9
6474
@test cb5.calls >= 2
6575

76+
@test cb6.last_t == 1.0
77+
@test cb7.last_t == 1.0
78+
@test cb8.last_t == (1 / 32) * 29
79+
@test cb9.last_t == (1 / 32) * 30
80+
6681
if isdefined(DiffEqBase, :finalize!)
6782

6883
@test cb1.finalized

0 commit comments

Comments
 (0)