Skip to content

Commit 1f3c5c1

Browse files
committed
added function for computing jacobian of interpolation between finite element spaces
1 parent f6677ee commit 1f3c5c1

File tree

9 files changed

+377
-12
lines changed

9 files changed

+377
-12
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "ExtendableFEMBase"
22
uuid = "12fb9182-3d4c-4424-8fd1-727a0899810c"
3-
authors = ["Christian Merdon <[email protected]>", "Patrick Jaap <[email protected]>"]
43
version = "1.4.0"
4+
authors = ["Christian Merdon <[email protected]>", "Patrick Jaap <[email protected]>"]
55

66
[deps]
77
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
8+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8"
1011
ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
@@ -13,6 +14,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1415
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1516
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
17+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
18+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
1619
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
1720

1821
[weakdeps]
@@ -23,6 +26,7 @@ ExtendableFEMBaseUnicodePlotsExt = ["UnicodePlots"]
2326

2427
[compat]
2528
DiffResults = "1"
29+
DifferentiationInterface = "0.7.10"
2630
DocStringExtensions = "0.8,0.9"
2731
ExtendableGrids = "1.13.0"
2832
ExtendableSparse = "1.5.1"
@@ -31,6 +35,8 @@ LinearAlgebra = "1.9"
3135
Polynomials = "2.0.21, 3, 4"
3236
Printf = "1.9"
3337
SparseArrays = "1.9"
38+
SparseConnectivityTracer = "1.1.2"
39+
SparseMatrixColorings = "0.4.22"
3440
SpecialPolynomials = "0.4.9, 0.5"
3541
UnicodePlots = "3.6"
3642
julia = "1.9"

src/ExtendableFEMBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,16 @@ using ExtendableGrids: ExtendableGrids, AT_NODES, AbstractElementGeometry,
4646
using ExtendableSparse: ExtendableSparse, ExtendableSparseMatrix, flush!,
4747
AbstractExtendableSparseMatrixCSC, ExtendableSparseMatrixCSC, MTExtendableSparseMatrixCSC,
4848
rawupdateindex!
49+
using DifferentiationInterface: AutoForwardDiff, AutoSparse, jacobian
4950
using ForwardDiff: ForwardDiff, DiffResults
5051
using LinearAlgebra: LinearAlgebra, convert, det, diagm, dot, eigen, ldiv!, lu,
5152
mul!, norm, transpose
5253
using Polynomials: Polynomials, Polynomial, coeffs
5354
using Printf: Printf, @printf
5455
using SparseArrays: SparseArrays, AbstractSparseArray, AbstractSparseMatrix,
5556
SparseMatrixCSC, nzrange, rowvals
57+
using SparseConnectivityTracer: TracerSparsityDetector
58+
using SparseMatrixColorings: GreedyColoringAlgorithm
5659
using SpecialPolynomials: SpecialPolynomials, ShiftedLegendre, basis
5760

5861
include("functionoperators.jl")
@@ -114,6 +117,7 @@ export interpolate! # must be defined separately by each FEdefinition
114117
export nodevalues, continuify
115118
export nodevalues!, nodevalues_subset!
116119
export nodevalues_view
120+
export compute_interpolation_jacobian
117121

118122
export interpolator_matrix
119123

src/interpolations.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,62 @@ function ExtendableGrids.interpolate!(target::FEVectorBlock, source; kwargs...)
9191
return interpolate!(target, ON_CELLS, source; kwargs...)
9292
end
9393

94+
"""
95+
````
96+
function compute_interpolation_jacobian(
97+
target_space::FESpace,
98+
source_space::FESpace;
99+
use_cellparents::Bool = false,
100+
kwargs...
101+
)
102+
````
103+
104+
Compute the Jacobian of the [`lazy_interpolate!`](@ref) call with respect to the
105+
`source_space` degrees of freedom, i.e. for functions ``v = \\sum_j \\alpha_j \\, \\varphi_j`` of the
106+
`source_space` and the interpolation operator ``I(v) = \\sum_k L_k(v)\\,\\phi_k = \\sum_k L_k\\left(\\sum_j \\alpha_j \\varphi_j\\right) \\, \\phi_k``
107+
into the `target_space`, this function computes the jacobian ``\\left[\\frac{\\partial L_k}{\\partial \\alpha_j}\\right]_{k,\\,j}``
108+
and returns its sparse matrix representation.
109+
110+
# Arguments
111+
- `target_space::FESpace`: Finite element space into which the interpolation ``I(v)`` is directed.
112+
- `source_space::FESpace`: Finite element space from which ``v`` is taken.
113+
114+
# Keyword Arguments
115+
- `use_cellparents`: Use parent cell information if available (can speed up the calculation if the `target_space` is defined on a subgrid of `source_space`).
116+
- `kwargs...`: Additional keyword arguments passed to lower-level `lazy_interpolate!` call.
117+
118+
# Notes
119+
- This function can be used for computing prolongation or restriction operators if the `FESpace`s are defined on coarser/finer grids, respectively.
120+
121+
"""
122+
function compute_interpolation_jacobian(target_space::FESpace, source_space::FESpace; use_cellparents::Bool = false, kwargs...)
123+
# DifferentiationInterface.jacobian needs a function of signature
124+
# AbstractVector -> AbstractVector
125+
function do_interpolation(source_vector::AbstractVector; use_cellparents = use_cellparents)
126+
T = valtype(source_vector)
127+
target_vector = FEVector{T}(target_space)[1]
128+
source_FE_Vector = FEVector{T}(source_space)
129+
source_FE_Vector.entries .= source_vector
130+
131+
lazy_interpolate!(target_vector, source_FE_Vector, [(1, Identity)]; use_cellparents, kwargs...)
132+
133+
return target_vector.entries
134+
end
135+
136+
n = ndofs(source_space)
137+
138+
dense_forward_backend = AutoForwardDiff()
139+
sparse_forward_backend = AutoSparse(
140+
dense_forward_backend;
141+
sparsity_detector = TracerSparsityDetector(),
142+
coloring_algorithm = GreedyColoringAlgorithm(),
143+
)
144+
145+
M = jacobian(x -> do_interpolation(x; use_cellparents), sparse_forward_backend, ones(n))
146+
147+
return M
148+
end
149+
94150

95151
"""
96152
````

src/interpolators.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ function NodalInterpolator(
6565
xCellDofs = FES[CellDofs]
6666
ncells = num_cells(grid)
6767
function evaluate_broken!(target, exact_function!, items; time = 0, params = [], kwargs...)
68+
if !(eltype(target) <: T)
69+
result = zeros(eltype(target), ncomponents)
70+
end
6871
QP.time = time
6972
QP.params = params === nothing ? [] : params
7073
if isempty(items)
@@ -96,6 +99,9 @@ function NodalInterpolator(
9699
nnodes = num_nodes(grid)
97100
xNodeCells = atranspose(grid[CellNodes])
98101
function evaluate!(target, exact_function!, items; time = 0, params = [], kwargs...)
102+
if !(eltype(target) <: T)
103+
result = zeros(eltype(target), ncomponents)
104+
end
99105
QP.time = time
100106
QP.params = params === nothing ? [] : params
101107
if isempty(items)
@@ -370,6 +376,10 @@ function MomentInterpolator(
370376
interiordofs = zeros(Int, length(idofs))
371377

372378
function assembly_loop!(target, f_moments, items, exact_function!, QF, L2G, FEB, FEB_moments)
379+
if !(eltype(target) <: Tv)
380+
result_f = zeros(eltype(target), ncomponents)
381+
f_moments = zeros(eltype(target), nmoments)
382+
end
373383
weights, xref = QF.w, QF.xref
374384
nweights = length(weights)
375385
for item::Int in items
@@ -582,6 +592,10 @@ function FunctionalInterpolator(
582592
FEB = FEEvaluator(FE, operator, QF; AT = AT, T = Tv, L2G = L2G)
583593

584594
function assembly_loop!(target, f_fluxes, items, exact_function!, QF, L2G, FEB)
595+
if !(eltype(target) <: Tv)
596+
result_f = zeros(eltype(target), ncomponents)
597+
f_fluxes = zeros(eltype(target), nfluxes)
598+
end
585599
weights, xref = QF.w, QF.xref
586600
nweights = length(weights)
587601
for item::Int in items

src/lazy_interpolate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function lazy_interpolate!(
7979
if xdim_source != xdim_target
8080
@assert xtrafo !== nothing "grids have different coordinate dimensions, need xtrafo!"
8181
end
82-
PE = PointEvaluator(postprocess, operators, source)
82+
PE = PointEvaluator(postprocess, operators, source; TCoeff = T1)
8383
xref = zeros(Tv, xdim_source)
8484
x_source = zeros(Tv, xdim_source)
8585
cell::Int = start_cell

src/point_evaluator.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct PointEvaluator{Tv <: Real, UT, KFT <: Function}
1+
mutable struct PointEvaluator{Tv <: Real, TCoeff <: Real, UT, KFT <: Function}
22
u_args::Array{UT, 1}
33
ops_args::Array{DataType, 1}
44
kernel::KFT
@@ -61,11 +61,11 @@ $(_myprint(default_peval_kwargs()))
6161
After construction, call `initialize!` to prepare the evaluator for a given solution, then use `evaluate!` or `evaluate_bary!` to perform point evaluations.
6262
6363
"""
64-
function PointEvaluator(kernel, u_args, ops_args, sol = nothing; Tv = Float64, kwargs...)
64+
function PointEvaluator(kernel, u_args, ops_args, sol = nothing; Tv = Float64, TCoeff = Float64, kwargs...)
6565
parameters = Dict{Symbol, Any}(k => v[1] for (k, v) in default_peval_kwargs())
6666
_update_params!(parameters, kwargs)
6767
@assert length(u_args) == length(ops_args)
68-
PE = PointEvaluator{Tv, typeof(u_args[1]), typeof(kernel)}(u_args, ops_args, kernel, nothing, nothing, nothing, 1, nothing, nothing, nothing, zeros(Tv, 2), parameters)
68+
PE = PointEvaluator{Tv, TCoeff, typeof(u_args[1]), typeof(kernel)}(u_args, ops_args, kernel, nothing, nothing, nothing, 1, nothing, nothing, nothing, zeros(Tv, 2), parameters)
6969
if sol !== nothing
7070
initialize!(PE, sol)
7171
end
@@ -121,7 +121,7 @@ $(_myprint(default_peval_kwargs()))
121121
# Notes
122122
- This function must be called before using `evaluate!` or `evaluate_bary!` with the `PointEvaluator`.
123123
"""
124-
function initialize!(O::PointEvaluator{T, UT}, sol; time = 0, kwargs...) where {T, UT}
124+
function initialize!(O::PointEvaluator{T, TCoeff, UT}, sol; time = 0, kwargs...) where {T, TCoeff, UT}
125125
_update_params!(O.parameters, kwargs)
126126
if UT <: Integer
127127
ind_args = O.u_args
@@ -159,7 +159,7 @@ function initialize!(O::PointEvaluator{T, UT}, sol; time = 0, kwargs...) where {
159159
op_lengths_args = [size(O.BE_args[1][j].cvals, 1) for j in 1:nargs]
160160
op_offsets_args = [0]
161161
append!(op_offsets_args, cumsum(op_lengths_args))
162-
input_args = zeros(T, op_offsets_args[end])
162+
input_args = zeros(TCoeff, op_offsets_args[end])
163163

164164
FEATs_args = [EffAT4AssemblyType(get_AT(FES_args[j]), AT) for j in 1:nargs]
165165
itemdofs_args::Array{Union{Adjacency{Ti}, SerialVariableTargetAdjacency{Ti}}, 1} = [FES_args[j][Dofmap4AssemblyType(FEATs_args[j])] for j in 1:nargs]

test/Project.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
[deps]
2-
ExtendableFEMBase = "12fb9182-3d4c-4424-8fd1-727a0899810c"
3-
ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
4-
ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8"
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
53
ExampleJuggler = "3bbe58f8-ed81-4c4e-a134-03e85fcf4a1a"
64
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
7-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5+
ExtendableGrids = "cfc395e8-590f-11e8-1f13-43a2532b2fa8"
6+
ExtendableSparse = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
87
GridVisualize = "5eed8a63-0fb0-45eb-886d-8d5a387d12b8"
8+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1110
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1211

1312
[compat]

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using ExtendableGrids
33
using ExtendableFEMBase
4+
using ExtendableSparse
45
using ExplicitImports
56
using ExampleJuggler
67
using SparseArrays
@@ -28,6 +29,7 @@ end
2829

2930
include("test_quadrature.jl")
3031
include("test_interpolators.jl")
32+
include("test_interpolation_matrix.jl")
3133
include("test_operators.jl")
3234
include("test_febasis.jl")
3335
include("test_segmentintegrator.jl")
@@ -150,6 +152,8 @@ function run_all_tests()
150152
run_operator_tests()
151153
run_quadrature_tests()
152154
run_interpolator_tests()
155+
run_grid_interpolation_matrix_tests()
156+
run_space_interpolation_matrix_tests()
153157
run_segmentintegrator_tests()
154158
run_pointevaluator_tests()
155159
end

0 commit comments

Comments
 (0)