Skip to content

dolfinx.fem.petsc.LinearProblem for nest and block systems #3684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
196ea85
Proposal for supporting blocked problems (not nest) in `LinearProblem`.
jorgensd Apr 3, 2025
22ef528
Add asserts for typehinting to demos
jorgensd Apr 3, 2025
c08bc72
Add kind as input argument and allow for nest
jorgensd Apr 3, 2025
c7b1fca
Use same logic for both blocked and non-blocked problems
jorgensd Apr 3, 2025
7b926b8
Test all use-cases
jorgensd Apr 3, 2025
f4636c3
Centralize ghostupdate function
jorgensd Apr 3, 2025
ece7d96
Simplify vector creation
jorgensd Apr 3, 2025
ba209bf
add element dtypes
jorgensd Apr 3, 2025
4313473
Ruff formatting
jorgensd Apr 3, 2025
1b13c23
Use realtype
jorgensd Apr 3, 2025
ff5e96a
Use inner
jorgensd Apr 3, 2025
83a6548
Add kind which is nested list of types
jorgensd Apr 3, 2025
1196ea2
Make ghostupdate private
jorgensd Apr 4, 2025
e7310fd
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd Apr 4, 2025
366253b
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd Apr 8, 2025
7fb4875
Ruff format comment and add Preconditioner input to `LinearProblem`
jorgensd Apr 9, 2025
6ae380d
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd Apr 9, 2025
22918c6
Ruff format
jorgensd Apr 9, 2025
f383ca9
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd Apr 10, 2025
77b8a39
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd Apr 30, 2025
1da4825
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd May 2, 2025
446a23a
Merge branch 'main' into dokken/blocked_linear_solver
jorgensd May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/demo/demo_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def create_eps_mu(pml, rho, eps_bkg, mu_bkg):
},
)
Esh_m = problem.solve()
assert isinstance(Esh_m, fem.Function)
assert problem.solver.getConvergedReason() > 0, "Solver did not converge!"

# Scattered magnetic field
Expand Down
1 change: 1 addition & 0 deletions python/demo/demo_biharmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@

problem = LinearProblem(a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"})
uh = problem.solve()
assert isinstance(uh, fem.Function)

# The solution can be written to a {py:class}`XDMFFile
# <dolfinx.io.XDMFFile>` file visualization with ParaView or VisIt
Expand Down
1 change: 1 addition & 0 deletions python/demo/demo_pml.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def create_eps_mu(
},
)
Esh = problem.solve()
assert isinstance(Esh, fem.Function)
assert problem.solver.getConvergedReason() > 0, "Solver did not converge!"
# -

Expand Down
1 change: 1 addition & 0 deletions python/demo/demo_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
# +
problem = LinearProblem(a, L, bcs=[bc], petsc_options={"ksp_type": "preonly", "pc_type": "lu"})
uh = problem.solve()
assert isinstance(uh, fem.Function)
# -

# The solution can be written to a {py:class}`XDMFFile
Expand Down
1 change: 1 addition & 0 deletions python/demo/demo_scattering_boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def curl_2d(f: fem.Function):
gdim = mesh_data.mesh.geometry.dim
V_dg = fem.functionspace(mesh_data.mesh, ("Discontinuous Lagrange", degree, (gdim,)))
Esh_dg = fem.Function(V_dg)
assert isinstance(Esh, fem.Function)
Esh_dg.interpolate(Esh)

with io.VTXWriter(mesh_data.mesh.comm, "Esh.bp", Esh_dg) as vtx:
Expand Down
108 changes: 81 additions & 27 deletions python/dolfinx/fem/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,20 +742,22 @@ class LinearProblem:

def __init__(
self,
a: ufl.Form,
L: ufl.Form,
a: typing.Union[ufl.Form, Iterable[Iterable[ufl.Form]]],
L: typing.Union[ufl.Form, Iterable[ufl.Form]],
bcs: typing.Optional[Iterable[DirichletBC]] = None,
u: typing.Optional[_Function] = None,
u: typing.Optional[typing.Union[_Function, Iterable[_Function]]] = None,
petsc_options: typing.Optional[dict] = None,
form_compiler_options: typing.Optional[dict] = None,
jit_options: typing.Optional[dict] = None,
):
kind: typing.Optional[typing.Union[str, Iterable[Iterable[str]]]] = None,
P: typing.Optional[typing.Union[ufl.Form, Iterable[Iterable[ufl.Form]]]] = None,
) -> None:
"""Initialize solver for a linear variational problem.

Args:
a: Bilinear UFL form, the left hand side of the
a: Bilinear UFL form or a rectangular array of bilinear forms, the left hand side of the
variational problem.
L: Linear UFL form, the right hand side of the variational
L: Linear UFL form or a sequence of lienar forms, the right hand side of the variational
problem.
bcs: Sequecne of Dirichlet boundary conditions.
u: Solution function. It is be created if not
Expand All @@ -771,42 +773,67 @@ def __init__(
code generated by FFCx. See `python/dolfinx/jit.py` for
all available options. Takes priority over all other
option values.
kind: The PETSc matrix and vector type. See :func:`create_matrix` for options.
P: UFL form or a rectangular array of bilinear forms used as a preconditioner.

Example::

problem = LinearProblem(a, L, [bc0, bc1], petsc_options={"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type":
"mumps"})

problem = LinearProblem([[a00, a01], [None, a11]], [L0, L1], bcs=[bc0, bc1],
u=[uh0, uh1])
"""
self._a = _create_form(
a,
dtype=PETSc.ScalarType,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
)
self._A = create_matrix(self._a)
self._L = _create_form(
L,
dtype=PETSc.ScalarType,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
)
self._b = create_vector(self._L)
self._A = create_matrix(self._a, kind=kind)
self._preconditioner = _create_form(
P,
dtype=PETSc.ScalarType,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
)
self._P = (
create_matrix(self._preconditioner, kind=kind)
if self._preconditioner is not None
else None
)

# For nest matrices kind can be a nested list.
kind = "nest" if self._A.getType() == PETSc.Mat.Type.NEST else kind
self._b = create_vector(self._L, kind=kind)
self._x = create_vector(self._L, kind=kind)

if u is None:
# Extract function space from TrialFunction (which is at the
# end of the argument list as it is numbered as 1, while the
# Test function is numbered as 0)
self.u = _Function(a.arguments()[-1].ufl_function_space())
try:
# Extract function space for unknown from the right hand side of the equation.
self._u = _Function(L.arguments()[0].ufl_function_space())
except AttributeError:
self._u = [_Function(Li.arguments()[0].ufl_function_space()) for Li in L]
else:
self.u = u
self._u = u

self._x = dolfinx.la.petsc.create_vector_wrap(self.u.x)
self.bcs = bcs

self._solver = PETSc.KSP().create(self.u.function_space.mesh.comm)
self._solver.setOperators(self._A)
try:
comm = self._u.function_space.mesh.comm
except AttributeError:
comm = self._u[0].function_space.mesh.comm

self._solver = PETSc.KSP().create(comm)
self._solver.setOperators(self._A, self._P)

# Give PETSc solver options a unique prefix
problem_prefix = f"dolfinx_solve_{id(self)}"
Expand All @@ -833,33 +860,55 @@ def __del__(self):
self._b.destroy()
self._x.destroy()

def solve(self) -> _Function:
def solve(self) -> typing.Union[_Function, Iterable[_Function]]:
"""Solve the problem."""

# Assemble lhs
self._A.zeroEntries()
assemble_matrix(self._A, self._a, bcs=self.bcs)
self._A.assemble()

# Assemble preconditioner
if self._P is not None:
self._P.zeroEntries()
assemble_matrix(self._P, self._preconditioner, bcs=self.bcs)
self._P.assemble()

# Assemble rhs
with self._b.localForm() as b_loc:
b_loc.set(0)
if self._b.getType() == PETSc.Vec.Type.NEST:
for b_sub in self._b.getNestSubVecs():
with b_sub.localForm() as b_local:
b_local.set(0.0)
else:
with self._b.localForm() as b_loc:
b_loc.set(0)
assemble_vector(self._b, self._L)

# Apply boundary conditions to the rhs
if self.bcs is not None:
apply_lifting(self._b, [self._a], bcs=[self.bcs])
self._b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
for bc in self.bcs:
bc.set(self._b.array_w)
try:
apply_lifting(self._b, [self._a], bcs=[self.bcs])
dolfinx.la.petsc._ghost_update(
self._b, PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE
)
for bc in self.bcs:
bc.set(self._b.array_w)
except RuntimeError:
bcs1 = _bcs_by_block(_extract_spaces(self._a, 1), self.bcs) # type: ignore
apply_lifting(self._b, self._a, bcs=bcs1) # type: ignore
dolfinx.la.petsc._ghost_update(
self._b, PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE
)
bcs0 = _bcs_by_block(_extract_spaces(self._L), self.bcs) # type: ignore
dolfinx.fem.petsc.set_bc(self._b, bcs0)
else:
self._b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
dolfinx.la.petsc._ghost_update(self._b, PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)

# Solve linear system and update ghost values in the solution
self._solver.solve(self._b, self._x)
self.u.x.scatter_forward()

return self.u
dolfinx.la.petsc._ghost_update(self._x, PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)
dolfinx.fem.petsc.assign(self._x, self._u)
return self._u

@property
def L(self) -> Form:
Expand All @@ -886,6 +935,11 @@ def solver(self) -> PETSc.KSP:
"""Linear solver object"""
return self._solver

@property
def u(self) -> typing.Union[_Function, list[_Function]]:
"""Solution function"""
return self._u


class NonlinearProblem:
"""Nonlinear problem class for solving the non-linear problems.
Expand Down
9 changes: 9 additions & 0 deletions python/dolfinx/la/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
__all__ = ["assign", "create_vector", "create_vector_wrap"]


def _ghost_update(x: PETSc.Vec, insert_mode: PETSc.InsertMode, scatter_mode: PETSc.ScatterMode): # type: ignore
"""Helper function for ghost updating PETSc vectors"""
if x.getType() == PETSc.Vec.Type.NEST: # type: ignore[attr-defined]
for x_sub in x.getNestSubVecs():
x_sub.ghostUpdate(addv=insert_mode, mode=scatter_mode)
else:
x.ghostUpdate(addv=insert_mode, mode=scatter_mode)


def create_vector(index_map: IndexMap, bs: int) -> PETSc.Vec: # type: ignore[name-defined]
"""Create a distributed PETSc vector.

Expand Down
98 changes: 97 additions & 1 deletion python/test/unit/fem/test_petsc_solver_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Jørgen S. Dokken
# Copyright (C) 2024-2025 Jørgen S. Dokken
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
Expand All @@ -10,6 +10,7 @@
import numpy as np
import pytest

import basix.ufl
import dolfinx
import ufl

Expand Down Expand Up @@ -70,3 +71,98 @@ def test_compare_solvers(self, mode):
solver.rtol = eps
solver.solve(uh)
assert np.allclose(u_lin.x.array, uh.x.array, atol=eps, rtol=eps)

@pytest.mark.parametrize(
"mode", [dolfinx.mesh.GhostMode.none, dolfinx.mesh.GhostMode.shared_facet]
)
@pytest.mark.parametrize("kind", [None, "mpi", "nest", [["aij", None], [None, "baij"]]])
def test_mixed_system(self, mode, kind):
from petsc4py import PETSc

msh = dolfinx.mesh.create_unit_square(
MPI.COMM_WORLD, 12, 12, ghost_mode=mode, dtype=PETSc.RealType
)

def top_bc(x):
return np.isclose(x[1], 1.0)

msh.topology.create_connectivity(msh.topology.dim - 1, msh.topology.dim)
bndry_facets = dolfinx.mesh.locate_entities_boundary(msh, msh.topology.dim - 1, top_bc)

el_0 = basix.ufl.element("Lagrange", msh.basix_cell(), 1, dtype=PETSc.RealType)
el_1 = basix.ufl.element("Lagrange", msh.basix_cell(), 2, dtype=PETSc.RealType)

if kind is None:
me = basix.ufl.mixed_element([el_0, el_1])
W = dolfinx.fem.functionspace(msh, me)
V, _ = W.sub(0).collapse()
Q, _ = W.sub(1).collapse()
else:
V = dolfinx.fem.functionspace(msh, el_0)
Q = dolfinx.fem.functionspace(msh, el_1)
W = ufl.MixedFunctionSpace(V, Q)

u, p = ufl.TrialFunctions(W)
v, q = ufl.TestFunctions(W)

a00 = ufl.inner(u, v) * ufl.dx
a11 = ufl.inner(p, q) * ufl.dx
x = ufl.SpatialCoordinate(msh)
f = x[0] + 3 * x[1]
g = -(x[1] ** 2) + x[0]
L0 = ufl.inner(f, v) * ufl.dx
L1 = ufl.inner(g, q) * ufl.dx

f_expr = dolfinx.fem.Expression(f, V.element.interpolation_points)
g_expr = dolfinx.fem.Expression(g, Q.element.interpolation_points)
u_bc = dolfinx.fem.Function(V)
u_bc.interpolate(f_expr)
p_bc = dolfinx.fem.Function(Q)
p_bc.interpolate(g_expr)

if kind is None:
a = a00 + a11
L = L0 + L1
dofs_V = dolfinx.fem.locate_dofs_topological(
(W.sub(0), V), msh.topology.dim - 1, bndry_facets
)
dofs_Q = dolfinx.fem.locate_dofs_topological(
(W.sub(1), Q), msh.topology.dim - 1, bndry_facets
)
bcs = [
dolfinx.fem.dirichletbc(u_bc, dofs_V, W.sub(0)),
dolfinx.fem.dirichletbc(p_bc, dofs_Q, W.sub(1)),
]
else:
a = [[a00, None], [None, a11]]
L = [L0, L1]
dofs_V = dolfinx.fem.locate_dofs_topological(V, msh.topology.dim - 1, bndry_facets)
dofs_Q = dolfinx.fem.locate_dofs_topological(Q, msh.topology.dim - 1, bndry_facets)
bcs = [
dolfinx.fem.dirichletbc(u_bc, dofs_V),
dolfinx.fem.dirichletbc(p_bc, dofs_Q),
]

petsc_options = {
"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type": "mumps",
"ksp_error_if_not_converged": True,
}

problem = dolfinx.fem.petsc.LinearProblem(
a, L, bcs=bcs, petsc_options=petsc_options, kind=kind
)
wh = problem.solve()
if kind is None:
uh, ph = wh.split()
else:
uh, ph = wh
error_uh = dolfinx.fem.form(ufl.inner(uh - f, uh - f) * ufl.dx)
error_ph = dolfinx.fem.form(ufl.inner(ph - g, ph - g) * ufl.dx)
local_uh_L2 = dolfinx.fem.assemble_scalar(error_uh)
local_ph_L2 = dolfinx.fem.assemble_scalar(error_ph)
global_uh_L2 = np.sqrt(msh.comm.allreduce(local_uh_L2, op=MPI.SUM))
global_ph_L2 = np.sqrt(msh.comm.allreduce(local_ph_L2, op=MPI.SUM))
tol = 500 * np.finfo(dolfinx.default_scalar_type).eps
assert global_uh_L2 < tol and global_ph_L2 < tol
Loading