-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathvis.py
More file actions
109 lines (94 loc) · 4.22 KB
/
vis.py
File metadata and controls
109 lines (94 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
# Load colormap for velocity map visualization
rainbow_cmap = ListedColormap(np.load('rainbow256.npy'))
def plot_velocity(output, target, path, vmin=None, vmax=None):
fig, ax = plt.subplots(1, 2, figsize=(11, 5))
if vmin is None or vmax is None:
vmax, vmin = np.max(target), np.min(target)
im = ax[0].matshow(output, cmap=rainbow_cmap, vmin=vmin, vmax=vmax)
ax[0].set_title('Prediction', y=1.08)
ax[1].matshow(target, cmap=rainbow_cmap, vmin=vmin, vmax=vmax)
ax[1].set_title('Ground Truth', y=1.08)
for axis in ax:
# axis.set_xticks(range(0, 70, 10))
# axis.set_xticklabels(range(0, 1050, 150))
# axis.set_yticks(range(0, 70, 10))
# axis.set_yticklabels(range(0, 1050, 150))
axis.set_xticks(range(0, 70, 10))
axis.set_xticklabels(range(0, 700, 100))
axis.set_yticks(range(0, 70, 10))
axis.set_yticklabels(range(0, 700, 100))
axis.set_ylabel('Depth (m)', fontsize=12)
axis.set_xlabel('Offset (m)', fontsize=12)
fig.colorbar(im, ax=ax, shrink=0.75, label='Velocity(m/s)')
plt.savefig(path)
plt.close('all')
def plot_single_velocity(label, path):
plt.rcParams.update({'font.size': 16})
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
vmax, vmin = np.max(label), np.min(label)
im = ax.matshow(label, cmap=rainbow_cmap, vmin=vmin, vmax=vmax)
# im = ax.matshow(label, cmap="gist_rainbow", vmin=vmin, vmax=vmax)
# nx = label.shape[0]
# ax.set_aspect(aspect=1)
# ax.set_xticks(range(0, nx, int(150//(1050/nx)))[:7])
# ax.set_xticklabels(range(0, 1050, 150))
# ax.set_yticks(range(0, nx, int(150//(1050/nx)))[:7])
# ax.set_yticklabels(range(0, 1050, 150))
# ax.set_title('Offset (m)', y=1.08)
# ax.set_ylabel('Depth (m)', fontsize=18)
fig.colorbar(im, ax=ax, shrink=1.0, label='Velocity(m/s)')
plt.savefig(path)
plt.close('all')
# def plot_seismic(output, target, path, vmin=-1e-5, vmax=1e-5):
# fig, ax = plt.subplots(1, 3, figsize=(15, 6))
# im = ax[0].matshow(output, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax)
# ax[0].set_title('Prediction')
# ax[1].matshow(target, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax)
# ax[1].set_title('Ground Truth')
# ax[2].matshow(output - target, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax)
# ax[2].set_title('Difference')
# fig.colorbar(im, ax=ax, format='%.1e')
# plt.savefig(path)
# plt.close('all')
def plot_seismic(output, target, path, vmin=-1e-5, vmax=1e-5):
fig, ax = plt.subplots(1, 3, figsize=(20, 5))
# fig, ax = plt.subplots(1, 2, figsize=(11, 5))
aspect = output.shape[1]/output.shape[0]
im = ax[0].matshow(target, aspect=aspect, cmap='gray', vmin=vmin, vmax=vmax)
ax[0].set_title('Ground Truth')
ax[1].matshow(output, aspect=aspect, cmap='gray', vmin=vmin, vmax=vmax)
ax[1].set_title('Prediction')
ax[2].matshow(output - target, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax)
ax[2].set_title('Difference')
# for axis in ax:
# axis.set_xticks(range(0, 70, 10))
# axis.set_xticklabels(range(0, 1050, 150))
# axis.set_title('Offset (m)', y=1.1)
# axis.set_ylabel('Time (ms)', fontsize=12)
# fig.colorbar(im, ax=ax, shrink=1.0, pad=0.01, label='Amplitude')
fig.colorbar(im, ax=ax, shrink=0.75, label='Amplitude')
plt.savefig(path)
plt.close('all')
def plot_single_seismic(data, path):
nz, nx = data.shape
plt.rcParams.update({'font.size': 18})
vmin, vmax = np.min(data), np.max(data)
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
im = ax.matshow(data, aspect='auto', cmap='gray', vmin=vmin * 0.01, vmax=vmax * 0.01)
ax.set_aspect(aspect=nx/nz)
ax.set_xticks(range(0, nx, int(300//(1050/nx)))[:5])
ax.set_xticklabels(range(0, 1050, 300))
ax.set_title('Offset (m)', y=1.08)
ax.set_yticks(range(0, nz, int(200//(1000/nz)))[:5])
ax.set_yticklabels(range(0, 1000, 200))
ax.set_ylabel('Time (ms)', fontsize=18)
fig.colorbar(im, ax=ax, shrink=1.0, pad=0.01, label='Amplitude')
plt.savefig(path)
plt.close('all')