Skip to content

Commit dac2357

Browse files
authored
Fix open issues (#355)
* Fix crash in plotly function * Update pre-commit hooks * Fix latp cif file (349) * Add std to mpl plot: jumps vs distance (345) * Refactor plotting code * Fix Trajectory.radial_distribution_between_species() error (344) * Fix formatting
1 parent 00dbc19 commit dac2357

16 files changed

+80
-65
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@ repos:
1616
- id: debug-statements
1717
- id: double-quote-string-fixer
1818
- repo: https://github.com/stefsmeets/nbcheckorder/
19-
rev: v0.2.0
19+
rev: v0.3.0
2020
hooks:
2121
- id: nbcheckorder
2222
- repo: https://github.com/myint/docformatter
2323
rev: 06907d0
2424
hooks:
2525
- id: docformatter
2626
- repo: https://github.com/astral-sh/ruff-pre-commit
27-
rev: v0.6.9
27+
rev: v0.11.0
2828
hooks:
2929
- id: ruff
3030
args: [--fix]
3131
types_or: [python, pyi, jupyter]
3232
- id: ruff-format
3333
types_or: [python, pyi, jupyter]
3434
- repo: https://github.com/pre-commit/mirrors-mypy
35-
rev: v1.11.2
35+
rev: v1.15.0
3636
hooks:
3737
- id: mypy
3838
additional_dependencies: [matplotlib, MDAnalysis, numpy, pymatgen, rich, scikit-image, scipy]

src/gemdat/data/latp.cif

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,5 @@ loop_
5757
_atom_site_fract_y
5858
_atom_site_fract_z
5959
_atom_site_occupancy
60-
Li 6b 6 0.00 0.00 0.00
61-
Li 18e 18 0.66 0.00 0.25
60+
Li 6b 6 0.00 0.00 0.00 1.00
61+
Li 18e 18 0.66 0.00 0.25 1.00

src/gemdat/plots/_shared.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from scipy.stats import skewnorm
1212

1313
if TYPE_CHECKING:
14+
from gemdat.jumps import Jumps
1415
from gemdat.orientations import Orientations
1516
from gemdat.trajectory import Trajectory
1617

@@ -127,7 +128,7 @@ def dataframe(self):
127128
def _get_vibrational_amplitudes_hist(
128129
*, trajectories: list[Trajectory], bins: int
129130
) -> VibrationalAmplitudeHist:
130-
"""Calculate vabrational amplitudes histogram.
131+
"""Calculate vibrational amplitudes histogram.
131132
132133
Helper for `vibrational_amplitudes`.
133134
"""
@@ -150,3 +151,38 @@ def _get_vibrational_amplitudes_hist(
150151
std = np.std(data, axis=0)
151152

152153
return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std)
154+
155+
156+
def _jumps_vs_distance(jumps: Jumps, *, resolution: float, n_parts: int) -> pd.DataFrame:
157+
"""Calculate jumps vs distance histogram.
158+
159+
Helper for `jumps_vs_distance`.
160+
"""
161+
sites = jumps.sites
162+
trajectory = jumps.trajectory
163+
lattice = trajectory.get_lattice()
164+
165+
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
166+
167+
bin_max = (1 + pdist.max() // resolution) * resolution
168+
n_bins = int(bin_max / resolution) + 1
169+
x = np.linspace(0, bin_max, n_bins)
170+
171+
bin_idx = np.digitize(pdist, bins=x)
172+
data = []
173+
for transitions_part in jumps.split(n_parts=n_parts):
174+
counts = np.zeros_like(x)
175+
for idx, n in zip(bin_idx.flatten(), transitions_part.matrix().flatten()):
176+
counts[idx] += n
177+
for idx in range(n_bins):
178+
if counts[idx] > 0:
179+
data.append((x[idx], counts[idx]))
180+
181+
df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
182+
183+
grouped = df.groupby(['Displacement'])
184+
mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
185+
std = grouped.std().reset_index().rename(columns={'count': 'std'})
186+
df = mean.merge(std, how='inner')
187+
188+
return df

src/gemdat/plots/matplotlib/_energy_along_path.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def energy_along_path(
8080
for idx, path in enumerate(other_paths):
8181
if path.energy is None:
8282
raise ValueError('Pathway does not contain energy data')
83-
ax.plot(range(len(path.energy)), path.energy, label=f'Alternative {idx+1}')
83+
ax.plot(range(len(path.energy)), path.energy, label=f'Alternative {idx + 1}')
8484

8585
ax.legend(fontsize=8)
8686

src/gemdat/plots/matplotlib/_jumps_vs_distance.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33
from typing import TYPE_CHECKING
44

55
import matplotlib.pyplot as plt
6-
import numpy as np
6+
7+
from .._shared import _jumps_vs_distance
78

89
if TYPE_CHECKING:
910
import matplotlib.figure
1011

11-
from gemdat import Jumps
12+
from gemdat.jumps import Jumps
1213

1314

1415
def jumps_vs_distance(
1516
*,
1617
jumps: Jumps,
1718
jump_res: float = 0.1,
19+
n_parts: int = 1,
1820
) -> matplotlib.figure.Figure:
1921
"""Plot jumps vs. distance histogram.
2022
@@ -24,32 +26,27 @@ def jumps_vs_distance(
2426
Input data
2527
jump_res : float, optional
2628
Resolution of the bins in Angstrom
29+
n_parts : int
30+
Number of parts for error analysis
2731
2832
Returns
2933
-------
3034
fig : matplotlib.figure.Figure
3135
Output figure
3236
"""
33-
sites = jumps.sites
34-
35-
trajectory = jumps.trajectory
36-
lattice = trajectory.get_lattice()
37-
38-
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
39-
40-
bin_max = (1 + pdist.max() // jump_res) * jump_res
41-
n_bins = int(bin_max / jump_res) + 1
42-
x = np.linspace(0, bin_max, n_bins)
43-
counts = np.zeros_like(x)
44-
45-
bin_idx = np.digitize(pdist, bins=x)
46-
for idx, n in zip(bin_idx.flatten(), jumps.matrix().flatten()):
47-
counts[idx] += n
37+
df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)
4838

4939
fig, ax = plt.subplots()
5040

51-
ax.bar(x, counts, width=(jump_res * 0.8))
41+
if n_parts == 1:
42+
ax.bar('Displacement', 'mean', data=df, width=(jump_res * 0.8))
43+
else:
44+
ax.bar('Displacement', 'mean', yerr='std', data=df, width=(jump_res * 0.8))
5245

53-
ax.set(title='Jumps vs. Distance', xlabel='Distance (Å)', ylabel='Number of jumps')
46+
ax.set(
47+
title='Jumps vs. Distance',
48+
xlabel='Distance (Å)',
49+
ylabel='Number of jumps',
50+
)
5451

5552
return fig

src/gemdat/plots/plotly/_displacement_histogram.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def displacement_histogram(trajectory: Trajectory, n_parts: int = 1) -> go.Figur
8787

8888
fig.update_layout(
8989
title=(
90-
'Displacement per element after ' f'{int(interval[1]-interval[0])} timesteps'
90+
f'Displacement per element after {int(interval[1] - interval[0])} timesteps'
9191
),
9292
xaxis_title='Displacement (Å)',
9393
yaxis_title='Nr. of atoms',

src/gemdat/plots/plotly/_energy_along_path.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def energy_along_path(
8888
go.Scatter(
8989
x=np.arange(len(path.energy)),
9090
y=path.energy,
91-
name=f'Alternative {idx+1}',
91+
name=f'Alternative {idx + 1}',
9292
mode='lines',
9393
line={'width': 1},
9494
)

src/gemdat/plots/plotly/_jumps_vs_distance.py

+9-29
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22

33
from typing import TYPE_CHECKING
44

5-
import numpy as np
6-
import pandas as pd
75
import plotly.express as px
86
import plotly.graph_objects as go
97

8+
from .._shared import _jumps_vs_distance
9+
1010
if TYPE_CHECKING:
1111
from gemdat import Jumps
1212

1313

14-
def jumps_vs_distance(*, jumps: Jumps, jump_res: float = 0.1, n_parts: int = 1) -> go.Figure:
14+
def jumps_vs_distance(
15+
*,
16+
jumps: Jumps,
17+
jump_res: float = 0.1,
18+
n_parts: int = 1,
19+
) -> go.Figure:
1520
"""Plot jumps vs. distance histogram.
1621
1722
Parameters
@@ -28,32 +33,7 @@ def jumps_vs_distance(*, jumps: Jumps, jump_res: float = 0.1, n_parts: int = 1)
2833
fig : plotly.graph_objects.Figure
2934
Output figure
3035
"""
31-
sites = jumps.sites
32-
trajectory = jumps.trajectory
33-
lattice = trajectory.get_lattice()
34-
35-
pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
36-
37-
bin_max = (1 + pdist.max() // jump_res) * jump_res
38-
n_bins = int(bin_max / jump_res) + 1
39-
x = np.linspace(0, bin_max, n_bins)
40-
41-
bin_idx = np.digitize(pdist, bins=x)
42-
data = []
43-
for transitions_part in jumps.split(n_parts=n_parts):
44-
counts = np.zeros_like(x)
45-
for idx, n in zip(bin_idx.flatten(), transitions_part.matrix().flatten()):
46-
counts[idx] += n
47-
for idx in range(n_bins):
48-
if counts[idx] > 0:
49-
data.append((x[idx], counts[idx]))
50-
51-
df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
52-
53-
grouped = df.groupby(['Displacement'])
54-
mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
55-
std = grouped.std().reset_index().rename(columns={'count': 'std'})
56-
df = mean.merge(std, how='inner')
36+
df = _jumps_vs_distance(jumps=jumps, resolution=jump_res, n_parts=n_parts)
5737

5838
if n_parts == 1:
5939
fig = px.bar(df, x='Displacement', y='mean', barmode='stack')

src/gemdat/plots/plotly/_plot3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def plot_paths(
212212
size = 6
213213
color = 'teal'
214214
else:
215-
name = f'Alternative {idx+1}'
215+
name = f'Alternative {idx + 1}'
216216
size = 5
217217
color = None
218218

src/gemdat/plots/plotly/_rectilinear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def rectilinear(
5858
z=hist,
5959
colorbar={
6060
'title': 'Areal probability',
61-
'titleside': 'right',
61+
'title_side': 'right',
6262
},
6363
)
6464
)

src/gemdat/shape.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def to_str(x):
129129
'Spacegroup',
130130
f' {symbol} ({self.spacegroup.int_number})',
131131
'Lattice',
132-
f" abc : {' '.join(to_str(val) for val in self.lattice.abc)}",
133-
f" angles: {' '.join(to_str(val) for val in self.lattice.angles)}",
132+
f' abc : {" ".join(to_str(val) for val in self.lattice.abc)}',
133+
f' angles: {" ".join(to_str(val) for val in self.lattice.angles)}',
134134
f'Unique sites ({len(self.sites)})',
135135
]
136136
for site in self.sites:

src/gemdat/trajectory.py

+1
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def metrics(self) -> TrajectoryMetrics:
749749

750750
return TrajectoryMetrics(trajectory=self)
751751

752+
@plot_backend
752753
def radial_distribution_between_species(self, *, module, **kwargs) -> RDFData:
753754
"""See [gemdat.rdf.radial_distribution_between_species][] for more
754755
info."""

src/gemdat/transitions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def _compute_site_radius(
465465

466466
raise ValueError(
467467
'Crystallographic sites are too close together '
468-
f'(expected: >{site_radius*2:.4f}, '
468+
f'(expected: >{site_radius * 2:.4f}, '
469469
f'got: {min_dist:.4f} for {msg}'
470470
)
471471

src/gemdat/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def warn_lattice_not_close(a: Lattice, b: Lattice):
202202
"""Raises a userwarning if lattices are not close."""
203203
if not is_lattice_similar(a, b):
204204
warnings.warn(
205-
'Lattices are not similar.' f'a: {a.parameters}, b: {b.parameters}',
205+
f'Lattices are not similar.a: {a.parameters}, b: {b.parameters}',
206206
UserWarning,
207207
)
208208

tests/helpers/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def assert_figures_similar(fig, *, name: str, ext: str = 'png', rms: float = 0.0
5555
err[key] = Path(err[key]).relative_to('.')
5656
raise AssertionError(
5757
(
58-
'images not close (RMS {rms:.3f}):'
59-
'\n\t{actual}\n\t{expected}\n\t{diff}'.format(**err)
58+
'images not close (RMS {rms:.3f}):\n\t{actual}\n\t{expected}\n\t{diff}'.format(
59+
**err
60+
)
6061
)
6162
)
Loading

0 commit comments

Comments
 (0)