Skip to content

Commit b9725d2

Browse files
committed
Refactor plotting code
1 parent 7c79e4a commit b9725d2

File tree

4 files changed

+44
-58
lines changed

4 files changed

+44
-58
lines changed

src/gemdat/plots/_shared.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from scipy.stats import skewnorm
1212

1313
if TYPE_CHECKING:
14+
from gemdat.jumps import Jumps
1415
from gemdat.orientations import Orientations
1516
from gemdat.trajectory import Trajectory
1617

@@ -127,7 +128,7 @@ def dataframe(self):
127128
def _get_vibrational_amplitudes_hist(
128129
*, trajectories: list[Trajectory], bins: int
129130
) -> VibrationalAmplitudeHist:
130-
"""Calculate vabrational amplitudes histogram.
131+
"""Calculate vibrational amplitudes histogram.
131132
132133
Helper for `vibrational_amplitudes`.
133134
"""
@@ -150,3 +151,38 @@ def _get_vibrational_amplitudes_hist(
150151
std = np.std(data, axis=0)
151152

152153
return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std)
154+
155+
156+
def _jumps_vs_distance(jumps: Jumps, *, resolution: float, n_parts: int) -> pd.DataFrame:
157+
"""Calculate jumps vs distance histogram.
158+
159+
Helper for `jumps_vs_distance`.
160+
"""
161+
sites = jumps.sites
162+
trajectory = jumps.trajectory
163+
lattice = trajectory.get_lattice()
164+
165+
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
166+
167+
bin_max = (1 + pdist.max() // resolution) * resolution
168+
n_bins = int(bin_max / resolution) + 1
169+
x = np.linspace(0, bin_max, n_bins)
170+
171+
bin_idx = np.digitize(pdist, bins=x)
172+
data = []
173+
for transitions_part in jumps.split(n_parts=n_parts):
174+
counts = np.zeros_like(x)
175+
for idx, n in zip(bin_idx.flatten(), transitions_part.matrix().flatten()):
176+
counts[idx] += n
177+
for idx in range(n_bins):
178+
if counts[idx] > 0:
179+
data.append((x[idx], counts[idx]))
180+
181+
df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
182+
183+
grouped = df.groupby(['Displacement'])
184+
mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
185+
std = grouped.std().reset_index().rename(columns={'count': 'std'})
186+
df = mean.merge(std, how='inner')
187+
188+
return df

src/gemdat/plots/matplotlib/_jumps_vs_distance.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from typing import TYPE_CHECKING
44

55
import matplotlib.pyplot as plt
6-
import numpy as np
7-
import pandas as pd
6+
7+
from .._shared import _jumps_vs_distance
88

99
if TYPE_CHECKING:
1010
import matplotlib.figure
1111

12-
from gemdat import Jumps
12+
from gemdat.jumps import Jumps
1313

1414

1515
def jumps_vs_distance(
@@ -34,32 +34,7 @@ def jumps_vs_distance(
3434
fig : matplotlib.figure.Figure
3535
Output figure
3636
"""
37-
sites = jumps.sites
38-
trajectory = jumps.trajectory
39-
lattice = trajectory.get_lattice()
40-
41-
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
42-
43-
bin_max = (1 + pdist.max() // jump_res) * jump_res
44-
n_bins = int(bin_max / jump_res) + 1
45-
x = np.linspace(0, bin_max, n_bins)
46-
47-
bin_idx = np.digitize(pdist, bins=x)
48-
data = []
49-
for transitions_part in jumps.split(n_parts=n_parts):
50-
counts = np.zeros_like(x)
51-
for idx, n in zip(bin_idx.flatten(), transitions_part.matrix().flatten()):
52-
counts[idx] += n
53-
for idx in range(n_bins):
54-
if counts[idx] > 0:
55-
data.append((x[idx], counts[idx]))
56-
57-
df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
58-
59-
grouped = df.groupby(['Displacement'])
60-
mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
61-
std = grouped.std().reset_index().rename(columns={'count': 'std'})
62-
df = mean.merge(std, how='inner')
37+
df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)
6338

6439
fig, ax = plt.subplots()
6540

src/gemdat/plots/plotly/_jumps_vs_distance.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from typing import TYPE_CHECKING
44

5-
import numpy as np
6-
import pandas as pd
75
import plotly.express as px
86
import plotly.graph_objects as go
97

8+
from .._shared import _jumps_vs_distance
9+
1010
if TYPE_CHECKING:
1111
from gemdat import Jumps
1212

@@ -33,32 +33,7 @@ def jumps_vs_distance(
3333
fig : plotly.graph_objects.Figure
3434
Output figure
3535
"""
36-
sites = jumps.sites
37-
trajectory = jumps.trajectory
38-
lattice = trajectory.get_lattice()
39-
40-
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
41-
42-
bin_max = (1 + pdist.max() // jump_res) * jump_res
43-
n_bins = int(bin_max / jump_res) + 1
44-
x = np.linspace(0, bin_max, n_bins)
45-
46-
bin_idx = np.digitize(pdist, bins=x)
47-
data = []
48-
for transitions_part in jumps.split(n_parts=n_parts):
49-
counts = np.zeros_like(x)
50-
for idx, n in zip(bin_idx.flatten(), transitions_part.matrix().flatten()):
51-
counts[idx] += n
52-
for idx in range(n_bins):
53-
if counts[idx] > 0:
54-
data.append((x[idx], counts[idx]))
55-
56-
df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
57-
58-
grouped = df.groupby(['Displacement'])
59-
mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
60-
std = grouped.std().reset_index().rename(columns={'count': 'std'})
61-
df = mean.merge(std, how='inner')
36+
df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)
6237

6338
if n_parts == 1:
6439
fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
Loading

0 commit comments

Comments
 (0)