Skip to content

Commit d43b328

Browse files
committed
fix mpi solution
1 parent 4892fdf commit d43b328

9 files changed

+154
-16
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ Manifest.toml
2929
activate.sh
3030
deactivate.sh
3131
ext/GrayScott.jl
32+
33+
34+
*.jld2

Manifest.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.4"
44
manifest_format = "2.0"
5-
project_hash = "fc814033df7bb2362d0b2cbe216991051dd7475f"
5+
project_hash = "f95ad848b2db86117d9a28403e2805e9e12ad741"
66

77
[[deps.AbstractFFTs]]
88
deps = ["LinearAlgebra"]
@@ -839,6 +839,12 @@ git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
839839
uuid = "82899510-4779-5014-852e-03e436cf321d"
840840
version = "1.0.0"
841841

842+
[[deps.JLD2]]
843+
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
844+
git-tree-sha1 = "ce5737c0d4490b0e0040b5dc77fbb6a351ddf188"
845+
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
846+
version = "0.5.8"
847+
842848
[[deps.JLFzf]]
843849
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
844850
git-tree-sha1 = "39d64b09147620f5ffbf6b2d3255be3c901bec63"

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
66
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
77
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
88
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
9+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
910
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1011
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1112
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"

parts/mpi/diffusion_2d_mpi.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using Printf
33
using JLD2
44
using MPI
5-
include(joinpath(@__DIR__, "../shared.jl"))
5+
include(joinpath(@__DIR__, "shared.jl"))
66

77
# convenience macros simply to avoid writing nested finite-difference expression
88
macro qx(ix, iy) esc(:(-D * (C[$ix+1, $iy] - C[$ix, $iy]) / dx)) end

parts/mpi/shared.jl

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
## PARAMETER INITIALIZATION
2+
function init_params(; ns=64, nt=100, kwargs...)
3+
L = 10.0 # physical domain length
4+
D = 1.0 # diffusion coefficient
5+
ds = L / ns # grid spacing
6+
dt = ds^2 / D / 8.2 # time step
7+
cs = range(start=ds / 2, stop=L - ds / 2, length=ns) .- 0.5 * L # vector of coord points
8+
nout = floor(Int, nt / 5) # plotting frequency
9+
return (; L, D, ns, nt, ds, dt, cs, nout, kwargs...)
10+
end
11+
12+
function init_params_mpi(; dims, coords, ns=64, nt=100, kwargs...)
13+
L = 10.0 # physical domain length
14+
D = 1.0 # diffusion coefficient
15+
nx_g = dims[1] * (ns - 2) + 2 # global number of grid points along dim 1
16+
ny_g = dims[2] * (ns - 2) + 2 # global number of grid points along dim 2
17+
dx = L / nx_g # grid spacing
18+
dy = L / ny_g # grid spacing
19+
dt = min(dx, dy)^2 / D / 8.2 # time step
20+
x0 = coords[1] * (ns - 2) * dx # coords shift to get global coords on local process
21+
y0 = coords[2] * (ns - 2) * dy # coords shift to get global coords on local process
22+
xcs = LinRange(x0 + dx / 2, x0 + ns * dx - dx / 2, ns) .- 0.5 .* L # local vector of global coord points
23+
ycs = LinRange(y0 + dy / 2, y0 + ns * dx - dy / 2, ns) .- 0.5 .* L # local vector of global coord points
24+
return (; L, D, ns, nt, dx, dy, dt, xcs, ycs, kwargs...)
25+
end
26+
27+
function init_params_gpu(; ns=64, nt=100, kwargs...)
28+
L = 10.0 # physical domain length
29+
D = 1.0 # diffusion coefficient
30+
ds = L / ns # grid spacing
31+
dt = ds^2 / D / 8.2 # time step
32+
cs = range(start=ds / 2, stop=L - ds / 2, length=ns) .- 0.5 * L # vector of coord points
33+
nout = floor(Int, nt / 5) # plotting frequency
34+
nthreads = 32, 8 # number of threads per block
35+
nblocks = cld.(ns, nthreads) # number of blocks
36+
return (; L, D, ns, nt, ds, dt, cs, nout, nthreads, nblocks, kwargs...)
37+
end
38+
39+
function init_params_gpu_mpi(; dims, coords, ns=64, nt=100, kwargs...)
40+
L = 10.0 # physical domain length
41+
D = 1.0 # diffusion coefficient
42+
nx_g = dims[1] * (ns - 2) + 2 # global number of grid points along dim 1
43+
ny_g = dims[2] * (ns - 2) + 2 # global number of grid points along dim 2
44+
dx = L / nx_g # grid spacing
45+
dy = L / ny_g # grid spacing
46+
dt = min(dx, dy)^2 / D / 8.2 # time step
47+
x0 = coords[1] * (ns - 2) * dx # coords shift to get global coords on local process
48+
y0 = coords[2] * (ns - 2) * dy # coords shift to get global coords on local process
49+
xcs = LinRange(x0 + dx / 2, x0 + ns * dx - dx / 2, ns) .- 0.5 * L # local vector of global coord points
50+
ycs = LinRange(y0 + dy / 2, y0 + ns * dy - dy / 2, ns) .- 0.5 * L # local vector of global coord points
51+
nthreads = 32, 8 # number of threads per block
52+
nblocks = cld.(ns, nthreads) # number of blocks
53+
return (; L, D, ns, nt, dx, dy, dt, xcs, ycs, nthreads, nblocks, kwargs...)
54+
end
55+
56+
## ARRAY INITIALIZATION
57+
function init_arrays_with_flux(params)
58+
(; cs, ns) = params
59+
C = @. exp(-cs^2 - (cs')^2)
60+
qx = zeros(ns - 1, ns - 2)
61+
qy = zeros(ns - 2, ns - 1)
62+
return C, qx, qy
63+
end
64+
65+
function init_arrays(params)
66+
(; cs) = params
67+
C = @. exp(-cs^2 - (cs')^2)
68+
C2 = copy(C)
69+
return C, C2
70+
end
71+
72+
function init_arrays_mpi(params)
73+
(; xcs, ycs) = params
74+
C = @. exp(-xcs^2 - (ycs')^2)
75+
C2 = copy(C)
76+
return C, C2
77+
end
78+
79+
function init_arrays_gpu(params)
80+
(; cs) = params
81+
C = CuArray(@. exp(-cs^2 - (cs')^2))
82+
C2 = copy(C)
83+
return C, C2
84+
end
85+
86+
function init_arrays_gpu_mpi(params)
87+
(; xcs, ycs) = params
88+
C = CuArray(@. exp(-xcs^2 - (ycs')^2))
89+
C2 = copy(C)
90+
return C, C2
91+
end
92+
93+
## VISUALIZATION & PRINTING
94+
function maybe_init_visualization(params, C)
95+
if params.do_visualize
96+
fig = Figure(; size=(500, 400), fontsize=14)
97+
ax = Axis(fig[1, 1][1, 1]; aspect=DataAspect(), title="C")
98+
plt = heatmap!(ax, params.cs, params.cs, Array(C); colormap=:turbo, colorrange=(0, 1))
99+
cb = Colorbar(fig[1, 1][1, 2], plt)
100+
display(fig)
101+
return fig, plt
102+
end
103+
return nothing, nothing
104+
end
105+
106+
function maybe_update_visualization(params, fig, plt, C, it)
107+
if params.do_visualize && (it % params.nout == 0)
108+
plt[3] = Array(C)
109+
display(fig)
110+
end
111+
return nothing
112+
end
113+
114+
function print_perf(params, t_toc)
115+
(; ns, nt) = params
116+
@printf("Time = %1.4e s, T_eff = %1.2f GB/s \n", t_toc, round((2 / 1e9 * ns^2 * sizeof(Float64)) / (t_toc / (nt - 10)), sigdigits=6))
117+
return nothing
118+
end

parts/mpi/solution/diffusion_2d_mpi.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using Printf
33
using JLD2
44
using MPI
5-
include(joinpath(@__DIR__, "../../shared.jl"))
5+
include(joinpath(@__DIR__, "../shared.jl"))
66

77
# convenience macros simply to avoid writing nested finite-difference expression
88
macro qx(ix, iy) esc(:(-D * (C[$ix+1, $iy] - C[$ix, $iy]) / dx)) end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
nprocs = 16, dims = (4, 4)
2+
Time = 1.3349e-02 s, T_eff = 7.07 GB/s
+10-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
1-
nprocs = 4, dims = [2, 2]
2-
Time = 1.2309e-02 s, T_eff = 7.67 GB/s
1+
nprocs = 4, dims = (2, 2)
2+
Time = 1.2562e-02 s, T_eff = 7.51 GB/s
3+
┌ Warning: Opening file with JLD2.MmapIO failed, falling back to IOStream
4+
└ @ JLD2 /pscratch/sd/b/blaschke/depot/packages/JLD2/KyKLQ/src/JLD2.jl:153
5+
┌ Warning: Opening file with JLD2.MmapIO failed, falling back to IOStream
6+
└ @ JLD2 /pscratch/sd/b/blaschke/depot/packages/JLD2/KyKLQ/src/JLD2.jl:153
7+
┌ Warning: Opening file with JLD2.MmapIO failed, falling back to IOStream
8+
└ @ JLD2 /pscratch/sd/b/blaschke/depot/packages/JLD2/KyKLQ/src/JLD2.jl:153
9+
┌ Warning: Opening file with JLD2.MmapIO failed, falling back to IOStream
10+
└ @ JLD2 /pscratch/sd/b/blaschke/depot/packages/JLD2/KyKLQ/src/JLD2.jl:153

parts/mpi/visualize_mpi.ipynb

+11-11
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)