Clone appctx for adjoint replay solvers (release)#5171
Conversation
…ation When a NonlinearVariationalSolver is created with an appctx dict, the adjoint machinery deep-copies the form coefficients (F, J) for the clone solvers used during tape replay, but the appctx was passed through unchanged (forward clone) or dropped outright (adjoint clone). Preconditioners that read from appctx (e.g. MassInvPC for Schur complement Stokes) would see stale end-of-forward-run values instead of the tape-replayed values at each step. _ad_problem_clone and _ad_adj_lvs_problem now return their coefficient replace maps alongside the cloned problems. A new _ad_clone_kwargs static method applies ufl.replace to every UFL expression in appctx using the relevant map, so each clone solver's appctx entries point at the cloned Function objects that _ad_solver_replace_forms keeps in sync with the tape. This happens once at clone creation time, not at every replay solve. The appctx pop in solve_init_params is kept: that function is shared with GenericSolveBlock, whose adjoint solve goes through firedrake.solve with an assembled matrix, which rejects an appctx kwarg. Instead, _ad_clone_kwargs takes a default_appctx fallback and the adjoint LinearVariationalSolver reinstates the forward appctx, cloned with the adjoint replace map. Tests cover structural identity (the clone solvers' appctx mu is the same object as the mu inside their forms, and the forward and adjoint clones are independent), behavioural forward and adjoint replay with a recording preconditioner, a Taylor sanity check, and a regression guard that the legacy solve(F == 0, ...) path with appctx still computes derivatives.
| # The legacy adjoint solve goes through ``firedrake.solve`` | ||
| # with an assembled matrix, which does not accept appctx. | ||
| # The variational solver mixin reinstates it (suitably | ||
| # cloned) for the adjoint LinearVariationalSolver. |
There was a problem hiding this comment.
Is this true? A LinearSolver is now just a LinearVariationalSolver and passes through kwargs: https://github.com/firedrakeproject/firedrake/blob/main/firedrake/linear_solver.py#L12
There was a problem hiding this comment.
Yes that comment was garbage, sorry! LinearSolver does take appctx as a LinearVariationalSolver subclass. The thing that actually bites is one layer up: the adjoint solve here goes through solve(A, x, b), which dispatches to _la_solve and then _extract_linear_solver_args, and that validates kwargs against a fixed list that doesn't include appctx, so a top-level appctx raises Illegal keyword argument 'appctx'. On that path appctx only gets picked up out of solver_parameters, never as a top-level kwarg. So the pop is still needed, I'd just written down the wrong reason for it.
There was a problem hiding this comment.
I am wondering if the linear solver path could be adapted to accept an appctx argument. Then we wouldn't need any special handling here.
In general I think that this PR makes a lot of sense. I just don't want to bake in any unneeded technical debt.
There was a problem hiding this comment.
Connor is correct, the right fix here is just to make _extract_linear_solver_args in solving.py accept an appctx kwarg.
We already allow using an appctx with the LinearSolver code path in the solve free function, but for some reason we expect it to be stuffed into the solver_parameters dictionary. This obviously doesn't match the pattern of any other solver/solve call.
Should just be a few lines to change this to expecting appctx in the kwargs like normal.
There was a problem hiding this comment.
Yes, we should allow this now and then deprecate passing the appctx the old way on main. That permanently fixes our API as well as sorting this out in the short term.
The previous comment claimed the assembled adjoint solve does not accept appctx, which is misleading now that LinearSolver is a LinearVariationalSolver subclass and forwards appctx. The actual rejection is one layer up: the solve(A, x, b) dispatch validates kwargs in _extract_linear_solver_args and refuses a top-level appctx, reading it only from solver_parameters.
b89e6af to
fcb5021
Compare
During tape replay of an annotated
NonlinearVariationalSolver, the appctx is passed to the clone solvers still pointing at the user's original coefficients (forward) or dropped (adjoint), so preconditioners reading UFL expressions from it, e.g. MassInvPC, see stale values instead of the tape-synced clones. This clones the appctx with the same coefficient replace maps used for the forms. Same fix as #5170, which addresses this formainon top of the #4638 refactor; this is the equivalent for the current release. Carrying appctx into the adjoint solver is safe since the adjoint equation uses J^T (https://firedrakeproject.slack.com/archives/C1Q0Y6H8A/p1776171359074139).Disclosure: parts of this change were prepared with the help of Claude Code (Opus 4.8).