Skip to content

Clone appctx for adjoint replay solvers (release)#5171

Open
sghelichkhani wants to merge 2 commits into
releasefrom
sghelichkhani/fix-ad-clone-kwargs-release
Open

Clone appctx for adjoint replay solvers (release)#5171
sghelichkhani wants to merge 2 commits into
releasefrom
sghelichkhani/fix-ad-clone-kwargs-release

Conversation

@sghelichkhani

@sghelichkhani sghelichkhani commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

During tape replay of an annotated NonlinearVariationalSolver, the appctx is passed to the clone solvers still pointing at the user's original coefficients (forward) or dropped (adjoint), so preconditioners reading UFL expressions from it, e.g. MassInvPC, see stale values instead of the tape-synced clones. This clones the appctx with the same coefficient replace maps used for the forms. Same fix as #5170, which addresses this for main on top of the #4638 refactor; this is the equivalent for the current release. Carrying appctx into the adjoint solver is safe since the adjoint equation uses J^T (https://firedrakeproject.slack.com/archives/C1Q0Y6H8A/p1776171359074139).

Disclosure: parts of this change were prepared with the help of Claude Code (Opus 4.8).

…ation

When a NonlinearVariationalSolver is created with an appctx dict, the
adjoint machinery deep-copies the form coefficients (F, J) for the
clone solvers used during tape replay, but the appctx was passed
through unchanged (forward clone) or dropped outright (adjoint clone).
Preconditioners that read from appctx (e.g. MassInvPC for Schur
complement Stokes) would see stale end-of-forward-run values instead
of the tape-replayed values at each step.

_ad_problem_clone and _ad_adj_lvs_problem now return their coefficient
replace maps alongside the cloned problems. A new _ad_clone_kwargs
static method applies ufl.replace to every UFL expression in appctx
using the relevant map, so each clone solver's appctx entries point at
the cloned Function objects that _ad_solver_replace_forms keeps in
sync with the tape. This happens once at clone creation time, not at
every replay solve.

The appctx pop in solve_init_params is kept: that function is shared
with GenericSolveBlock, whose adjoint solve goes through
firedrake.solve with an assembled matrix, which rejects an appctx
kwarg. Instead, _ad_clone_kwargs takes a default_appctx fallback and
the adjoint LinearVariationalSolver reinstates the forward appctx,
cloned with the adjoint replace map.

Tests cover structural identity (the clone solvers' appctx mu is the
same object as the mu inside their forms, and the forward and adjoint
clones are independent), behavioural forward and adjoint replay with a
recording preconditioner, a Taylor sanity check, and a regression
guard that the legacy solve(F == 0, ...) path with appctx still
computes derivatives.
# The legacy adjoint solve goes through ``firedrake.solve``
# with an assembled matrix, which does not accept appctx.
# The variational solver mixin reinstates it (suitably
# cloned) for the adjoint LinearVariationalSolver.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? A LinearSolver is now just a LinearVariationalSolver and passes through kwargs: https://github.com/firedrakeproject/firedrake/blob/main/firedrake/linear_solver.py#L12

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that comment was garbage, sorry! LinearSolver does take appctx as a LinearVariationalSolver subclass. The thing that actually bites is one layer up: the adjoint solve here goes through solve(A, x, b), which dispatches to _la_solve and then _extract_linear_solver_args, and that validates kwargs against a fixed list that doesn't include appctx, so a top-level appctx raises Illegal keyword argument 'appctx'. On that path appctx only gets picked up out of solver_parameters, never as a top-level kwarg. So the pop is still needed, I'd just written down the wrong reason for it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if the linear solver path could be adapted to accept an appctx argument. Then we wouldn't need any special handling here.

In general I think that this PR makes a lot of sense. I just don't want to bake in any unneeded technical debt.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Connor is correct, the right fix here is just to make _extract_linear_solver_args in solving.py accept an appctx kwarg.
We already allow using an appctx with the LinearSolver code path in the solve free function, but for some reason we expect it to be stuffed into the solver_parameters dictionary. This obviously doesn't match the pattern of any other solver/solve call.

Should just be a few lines to change this to expecting appctx in the kwargs like normal.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should allow this now and then deprecate passing the appctx the old way on main. That permanently fixes our API as well as sorting this out in the short term.

The previous comment claimed the assembled adjoint solve does not accept
appctx, which is misleading now that LinearSolver is a LinearVariationalSolver
subclass and forwards appctx. The actual rejection is one layer up: the
solve(A, x, b) dispatch validates kwargs in _extract_linear_solver_args and
refuses a top-level appctx, reading it only from solver_parameters.
@sghelichkhani sghelichkhani force-pushed the sghelichkhani/fix-ad-clone-kwargs-release branch from b89e6af to fcb5021 Compare June 15, 2026 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants