Skip to content

Commit f2b2f62

Browse files
authored
Merge pull request #156 from tharittk/mdd-lsm-nccl-tutorial
tutorials for LSM and MDD using NCCL
2 parents 19e873a + 40ddb34 commit f2b2f62

File tree

2 files changed

+446
-0
lines changed

2 files changed

+446
-0
lines changed

tutorials_nccl/lsm_nccl.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
r"""
2+
Least-squares Migration with NCCL
3+
=================================
4+
This tutorial is an extension of the :ref:`sphx_glr_tutorials_lsm.py`
5+
tutorial where PyLops-MPI is run in multi-GPU setting with GPUs communicating
6+
via NCCL.
7+
"""
8+
9+
import warnings
10+
warnings.filterwarnings('ignore')
11+
12+
import numpy as np
13+
import cupy as cp
14+
from matplotlib import pyplot as plt
15+
from mpi4py import MPI
16+
17+
from pylops.utils.wavelets import ricker
18+
from pylops.waveeqprocessing.lsm import LSM
19+
20+
import pylops_mpi
21+
22+
###############################################################################
23+
# NCCL communication can be easily initialized with
24+
# :py:func:`pylops_mpi.utils._nccl.initialize_nccl_comm` operator.
25+
# One can think of this as GPU-counterpart of :code:`MPI.COMM_WORLD`
26+
27+
np.random.seed(42)
28+
plt.close("all")
29+
nccl_comm = pylops_mpi.utils._nccl.initialize_nccl_comm()
30+
rank = MPI.COMM_WORLD.Get_rank()
31+
32+
###############################################################################
33+
# Let's start by defining all the parameters required by the
34+
# :py:class:`pylops.waveeqprocessing.LSM` operator.
35+
# Note that this section is exactly the same as the one in the MPI example
36+
# as we will keep using MPI for transfering metadata (i.e., shapes, dims, etc.)
37+
38+
# Velocity Model
39+
nx, nz = 81, 60
40+
dx, dz = 4, 4
41+
x, z = np.arange(nx) * dx, np.arange(nz) * dz
42+
v0 = 1000 # initial velocity
43+
kv = 0.0 # gradient
44+
vel = np.outer(np.ones(nx), v0 + kv * z)
45+
46+
# Reflectivity Model
47+
refl = np.zeros((nx, nz))
48+
refl[:, 30] = -1
49+
refl[:, 50] = 0.5
50+
51+
# Receivers
52+
nr = 11
53+
rx = np.linspace(10 * dx, (nx - 10) * dx, nr)
54+
rz = 20 * np.ones(nr)
55+
recs = np.vstack((rx, rz))
56+
57+
# Sources
58+
ns = 10
59+
# Total number of sources at all ranks
60+
nstot = MPI.COMM_WORLD.allreduce(ns, op=MPI.SUM)
61+
sxtot = np.linspace(dx * 10, (nx - 10) * dx, nstot)
62+
sx = sxtot[rank * ns: (rank + 1) * ns]
63+
sztot = 10 * np.ones(nstot)
64+
sz = 10 * np.ones(ns)
65+
sources = np.vstack((sx, sz))
66+
sources_tot = np.vstack((sxtot, sztot))
67+
68+
if rank == 0:
69+
plt.figure(figsize=(10, 5))
70+
im = plt.imshow(vel.T, cmap="summer", extent=(x[0], x[-1], z[-1], z[0]))
71+
plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
72+
plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k")
73+
cb = plt.colorbar(im)
74+
cb.set_label("[m/s]")
75+
plt.axis("tight")
76+
plt.xlabel("x [m]"), plt.ylabel("z [m]")
77+
plt.title("Velocity")
78+
plt.xlim(x[0], x[-1])
79+
plt.tight_layout()
80+
81+
plt.figure(figsize=(10, 5))
82+
im = plt.imshow(refl.T, cmap="gray", extent=(x[0], x[-1], z[-1], z[0]))
83+
plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
84+
plt.scatter(sources_tot[0], sources_tot[1], marker="*", s=150, c="r", edgecolors="k")
85+
plt.colorbar(im)
86+
plt.axis("tight")
87+
plt.xlabel("x [m]"), plt.ylabel("z [m]")
88+
plt.title("Reflectivity")
89+
plt.xlim(x[0], x[-1])
90+
plt.tight_layout()
91+
92+
###############################################################################
93+
# We create a :py:class:`pylops.waveeqprocessing.LSM` at each rank and then push them
94+
# into a :py:class:`pylops_mpi.basicoperators.MPIVStack` to perform a matrix-vector
95+
# product with the broadcasted reflectivity at every location on the subsurface.
96+
# Note that we must use :code:`engine="cuda"` and move the wavelet wav to the GPU prior to creating the operator.
97+
# Moreover, we allocate the traveltime tables (:code:`lsm.Demop.trav_srcs`, and :code:`lsm.Demop.trav_recs`)
98+
# to the GPU prior to applying the operator to avoid incurring in the penalty of performing
99+
# host-to-device memory copies every time the operator is applied. Moreover, we must pass :code:`nccl_comm`
100+
# to the DistributedArray constructor used to create :code:`refl_dist` in order to use NCCL for communications.
101+
102+
# Wavelet
103+
nt = 651
104+
dt = 0.004
105+
t = np.arange(nt) * dt
106+
wav, wavt, wavc = ricker(t[:41], f0=20)
107+
108+
lsm = LSM(
109+
z,
110+
x,
111+
t,
112+
sources,
113+
recs,
114+
v0,
115+
cp.asarray(wav.astype(np.float32)),
116+
wavc,
117+
mode="analytic",
118+
engine="cuda",
119+
dtype=np.float32
120+
)
121+
lsm.Demop.trav_srcs = cp.asarray(lsm.Demop.trav_srcs.astype(np.float32))
122+
lsm.Demop.trav_recs = cp.asarray(lsm.Demop.trav_recs.astype(np.float32))
123+
124+
VStack = pylops_mpi.MPIVStack(ops=[lsm.Demop, ])
125+
refl_dist = pylops_mpi.DistributedArray(global_shape=nx * nz,
126+
partition=pylops_mpi.Partition.BROADCAST,
127+
base_comm_nccl=nccl_comm,
128+
engine="cupy")
129+
refl_dist[:] = cp.asarray(refl.flatten())
130+
d_dist = VStack @ refl_dist
131+
d = d_dist.asarray().reshape((nstot, nr, nt))
132+
133+
###############################################################################
134+
# We calculate now the adjoint and model the data using the adjoint reflectivity
135+
# as input.
136+
madj_dist = VStack.H @ d_dist
137+
madj = madj_dist.asarray().reshape((nx, nz))
138+
d_adj_dist = VStack @ madj_dist
139+
d_adj = d_adj_dist.asarray().reshape((nstot, nr, nt))
140+
141+
###############################################################################
142+
# We calculate the inverse using the :py:func:`pylops_mpi.optimization.basic.cgls`
143+
# solver. Here, we pass the :code:`nccl_comm` to :code:`x0` to use NCCL as a communicator.
144+
# In this particular case, the local computation will be done in GPU.
145+
# Collective communication calls will be carried through NCCL GPU-to-GPU.
146+
147+
# Inverse
148+
# Initializing x0 to zeroes
149+
x0 = pylops_mpi.DistributedArray(VStack.shape[1],
150+
partition=pylops_mpi.Partition.BROADCAST,
151+
base_comm_nccl=nccl_comm,
152+
engine="cupy")
153+
x0[:] = 0
154+
minv_dist = pylops_mpi.cgls(VStack, d_dist, x0=x0, niter=100, show=True)[0]
155+
minv = minv_dist.asarray().reshape((nx, nz))
156+
d_inv_dist = VStack @ minv_dist
157+
d_inv = d_inv_dist.asarray().reshape(nstot, nr, nt)
158+
159+
##############################################################################
160+
# Finally we visualize the results. Note that the array must be copied back
161+
# to the CPU by calling the :code:`get()` method on the CuPy arrays.
162+
163+
if rank == 0:
164+
# Visualize
165+
fig1, axs = plt.subplots(1, 3, figsize=(10, 3))
166+
axs[0].imshow(refl.T, cmap="gray", vmin=-1, vmax=1)
167+
axs[0].axis("tight")
168+
axs[0].set_title(r"$m$")
169+
axs[1].imshow(madj.T.get(), cmap="gray", vmin=-madj.max(), vmax=madj.max())
170+
axs[1].set_title(r"$m_{adj}$")
171+
axs[1].axis("tight")
172+
axs[2].imshow(minv.T.get(), cmap="gray", vmin=-1, vmax=1)
173+
axs[2].axis("tight")
174+
axs[2].set_title(r"$m_{inv}$")
175+
plt.tight_layout()
176+
fig1.savefig("model.png")
177+
178+
fig2, axs = plt.subplots(1, 3, figsize=(10, 3))
179+
axs[0].imshow(d[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
180+
axs[0].set_title(r"$d$")
181+
axs[0].axis("tight")
182+
axs[1].imshow(d_adj[0, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max())
183+
axs[1].set_title(r"$d_{adj}$")
184+
axs[1].axis("tight")
185+
axs[2].imshow(d_inv[0, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
186+
axs[2].set_title(r"$d_{inv}$")
187+
axs[2].axis("tight")
188+
fig2.savefig("data1.png")
189+
190+
fig3, axs = plt.subplots(1, 3, figsize=(10, 3))
191+
axs[0].imshow(d[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
192+
axs[0].set_title(r"$d$")
193+
axs[0].axis("tight")
194+
axs[1].imshow(d_adj[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d_adj.max(), vmax=d_adj.max())
195+
axs[1].set_title(r"$d_{adj}$")
196+
axs[1].axis("tight")
197+
axs[2].imshow(d_inv[nstot // 2, :, :300].T.get(), cmap="gray", vmin=-d.max(), vmax=d.max())
198+
axs[2].set_title(r"$d_{inv}$")
199+
axs[2].axis("tight")
200+
plt.tight_layout()
201+
fig3.savefig("data2.png")

0 commit comments

Comments
 (0)