Skip to content

Commit 493612e

Browse files
Add method to create a graph from jumps (#325)
* Add method to create a graph from jumps * Rename SimulationMetrics -> TrajectoryMetrics * Add method to get single activation energy * Add min/max energy filters * Do not depend on labels for graph * Attach label to node * Rename test module * Fix tests (there is one orphan node) * Update src/gemdat/jumps.py Co-authored-by: SCiarella <[email protected]> * Update src/gemdat/jumps.py Co-authored-by: SCiarella <[email protected]> * Update src/gemdat/jumps.py Co-authored-by: SCiarella <[email protected]> * Fix line lengths --------- Co-authored-by: SCiarella <[email protected]>
1 parent 28427d5 commit 493612e

16 files changed

+164
-56
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ jumps.jump_diffusivity(dimensions=3)
7575
To calculate different metrics, such as tracer diffusivity:
7676

7777
```python
78-
from gemdat import SimulationMetrics
78+
from gemdat import TrajectoryMetrics
7979

80-
metrics = SimulationMetrics(diff_trajectory)
80+
metrics = TrajectoryMetrics(diff_trajectory)
8181

8282
metrics.tracer_diffusivity(dimensions=3)
8383
metrics.haven_ratio(dimensions=3)

docs/api/gemdat.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
- [gemdat.read_cif][gemdat.io.read_cif]
22
- [gemdat.load_known_material][gemdat.io.load_known_material]
3-
- [gemdat.SimulationMetrics][gemdat.simulation_metrics.SimulationMetrics]
3+
- [gemdat.TrajectoryMetrics][gemdat.metrics.TrajectoryMetrics]
44
- [gemdat.Transitions][gemdat.transitions.Transitions]
55
- [gemdat.Jumps][gemdat.jumps.Jumps]
66
- [gemdat.Trajectory][gemdat.trajectory.Trajectory]

docs/api/gemdat_simulation_metrics.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
::: gemdat.simulation_metrics
1+
::: gemdat.metrics
22
options:
33
show_root_heading: false
44
show_root_toc_entry: false

mkdocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ nav:
2525
- gemdat.io: api/gemdat_io.md
2626
- gemdat.plots: api/gemdat_plots.md
2727
- gemdat.rdf: api/gemdat_rdf.md
28-
- gemdat.simulation_metrics: api/gemdat_simulation_metrics.md
28+
- gemdat.metrics: api/gemdat_metrics.md
2929
- gemdat.trajectory: api/gemdat_trajectory.md
3030
- gemdat.transitions: api/gemdat_transitions.md
3131
- gemdat.jumps: api/gemdat_jumps.md

src/gemdat/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from .io import load_known_material, read_cif
44
from .jumps import Jumps
5+
from .metrics import TrajectoryMetrics
56
from .orientations import Orientations
67
from .rdf import radial_distribution
78
from .shape import ShapeAnalyzer
8-
from .simulation_metrics import SimulationMetrics
99
from .trajectory import Trajectory
1010
from .transitions import Transitions
1111
from .volume import Volume, trajectory_to_volume
@@ -18,7 +18,7 @@
1818
'radial_distribution',
1919
'read_cif',
2020
'ShapeAnalyzer',
21-
'SimulationMetrics',
21+
'TrajectoryMetrics',
2222
'Trajectory',
2323
'trajectory_to_volume',
2424
'Transitions',

src/gemdat/jumps.py

+104-19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from math import ceil
66
from typing import TYPE_CHECKING, Callable
77

8+
import networkx as nx
89
import numpy as np
910
import pandas as pd
1011
from pymatgen.core.units import FloatWithUnit
@@ -13,7 +14,7 @@
1314
from ._plot_backend import plot_backend
1415
from .caching import weak_lru_cache
1516
from .collective import Collective
16-
from .simulation_metrics import SimulationMetrics
17+
from .metrics import TrajectoryMetrics
1718
from .transitions import Transitions, _calculate_transitions_matrix
1819

1920
if TYPE_CHECKING:
@@ -223,7 +224,7 @@ def collective(self, max_dist: float = 1) -> Collective:
223224
sites = self.transitions.sites
224225

225226
time_step = trajectory.time_step
226-
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
227+
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()
227228

228229
max_steps = ceil(1.0 / (attempt_freq * time_step))
229230

@@ -237,7 +238,7 @@ def collective(self, max_dist: float = 1) -> Collective:
237238

238239
@weak_lru_cache()
239240
def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
240-
"""Calculate activation energies for jumps (UNITS?).
241+
"""Calculate activation energies for jumps in eV.
241242
242243
Parameters
243244
----------
@@ -251,7 +252,7 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
251252
between site pairs.
252253
"""
253254
trajectory = self.trajectory
254-
attempt_freq, _ = SimulationMetrics(trajectory).attempt_frequency()
255+
attempt_freq, _ = TrajectoryMetrics(trajectory).attempt_frequency()
255256

256257
dct = {}
257258

@@ -260,13 +261,13 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
260261
atom_locations_parts = [
261262
part.atom_locations() for part in self.transitions.split(n_parts)
262263
]
263-
jumps_counter_parts = [part.jumps_counter() for part in self.split(n_parts)]
264+
counter_parts = [part.counter() for part in self.split(n_parts)]
264265
n_floating = self.n_floating
265266

266267
for site_pair in self.site_pairs:
267268
site_start, site_stop = site_pair
268269

269-
n_jumps = np.array([part[site_pair] for part in jumps_counter_parts])
270+
n_jumps = np.array([part[site_pair] for part in counter_parts])
270271

271272
part_time = trajectory.total_time / n_parts
272273

@@ -292,22 +293,106 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
292293

293294
return df
294295

295-
def jumps_counter(self) -> Counter:
296-
"""Calculate number of jumps between sites.
296+
@weak_lru_cache()
297+
def counter(self) -> Counter[tuple[str, str]]:
298+
"""Count number of jumps between sites.
297299
298300
Returns
299301
-------
300-
jumps : dict[tuple[str, str], int]
301-
Dictionary with number of jumps per sites combination
302+
counter : Counter[tuple[str, str]]
303+
Dictionary with site pairs as keys and corresponding
304+
number of jumps as dictionary values
302305
"""
303306
labels = self.sites.labels
304-
jumps = Counter(
305-
[
306-
(labels[i], labels[j])
307-
for _, (i, j) in self.data[['start site', 'destination site']].iterrows()
308-
]
309-
)
310-
return jumps
307+
counter: Counter[tuple[str, str]] = Counter()
308+
for (i, j), val in self._counter().items():
309+
counter[labels[i], labels[j]] += val
310+
return counter
311+
312+
@weak_lru_cache()
313+
def _counter(self) -> Counter[tuple[int, int]]:
314+
"""Count number of jumps between sites. Keys are site indices.
315+
316+
Returns
317+
-------
318+
counter : Counter[tuple[int, int]]
319+
Dictionary with site pairs as keys and corresponding
320+
number of jumps as dictionary values
321+
"""
322+
counter = Counter(zip(self.data['start site'], self.data['destination site']))
323+
return counter
324+
325+
def activation_energy_between_sites(self, start: str, stop: str) -> float:
326+
"""Returns activation energy between two sites.
327+
328+
Uses `Jumps.to_graph()` in the background. For a large number of operations,
329+
it is more efficient to query the graph directly.
330+
331+
Parameters
332+
----------
333+
start : str
334+
Label of the start site
335+
stop : str
336+
Label of the stop site
337+
338+
Returns
339+
-------
340+
e_act : float
341+
Activation energy in eV
342+
"""
343+
G = self.to_graph()
344+
edge_data = G.get_edge_data(start, stop)
345+
if not edge_data:
346+
raise IndexError(f'No jumps between ({start}) and ({stop})')
347+
return edge_data['e_act']
348+
349+
@weak_lru_cache()
350+
def to_graph(
351+
self, min_e_act: float | None = None, max_e_act: float | None = None
352+
) -> nx.DiGraph:
353+
"""Create a graph from jumps data.
354+
355+
The edges are weighted by the activation energy. The nodes are indices that
356+
correspond to `Jumps.sites`.
357+
358+
Parameters
359+
----------
360+
min_e_act : float
361+
Reject edges with activation energy below this threshold
362+
max_e_act : float
363+
Reject edges with activation energy above this threshold
364+
365+
Returns
366+
-------
367+
G : nx.DiGraph
368+
A networkx DiGraph object.
369+
"""
370+
min_e_act = min_e_act if min_e_act else float('-inf')
371+
max_e_act = max_e_act if max_e_act else float('inf')
372+
373+
atom_percentage = [site.species.num_atoms for site in self.transitions.occupancy()]
374+
375+
attempt_freq, _ = self.trajectory.metrics().attempt_frequency()
376+
temperature = self.trajectory.metadata['temperature']
377+
kBT = Boltzmann * temperature
378+
379+
G = nx.DiGraph()
380+
381+
for i, site in enumerate(self.sites):
382+
G.add_node(i, label=site.label)
383+
384+
for (start, stop), n_jumps in self._counter().items():
385+
time_perc = atom_percentage[start] * self.trajectory.total_time
386+
387+
eff_rate = n_jumps / time_perc
388+
389+
e_act = -np.log(eff_rate / attempt_freq) * kBT
390+
e_act /= elementary_charge
391+
392+
if min_e_act <= e_act <= max_e_act:
393+
G.add_edge(start, stop, e_act=e_act)
394+
395+
return G
311396

312397
def split(self, n_parts: int) -> list[Jumps]:
313398
"""Split the jumps into parts.
@@ -336,12 +421,12 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:
336421
"""
337422
dct = {}
338423

339-
parts = [part.jumps_counter() for part in self.split(n_parts)]
424+
parts = [part.counter() for part in self.split(n_parts)]
425+
part_time = self.trajectory.total_time / n_parts
340426

341427
for site_pair in self.site_pairs:
342428
n_jumps = [part[site_pair] for part in parts]
343429

344-
part_time = self.trajectory.total_time / n_parts
345430
denom = self.n_floating * part_time
346431

347432
jump_freq_mean = np.mean(n_jumps) / denom

src/gemdat/simulation_metrics.py src/gemdat/metrics.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from trajectory import Trajectory
1818

1919

20-
class SimulationMetrics:
20+
class TrajectoryMetrics:
2121
"""Class for calculating different metrics and properties from a molecular
2222
dynamics simulation."""
2323

@@ -115,7 +115,7 @@ def tracer_diffusivity_center_of_mass(
115115
"""
116116
center_of_mass = self.trajectory.center_of_mass()
117117

118-
metrics = SimulationMetrics(center_of_mass)
118+
metrics = TrajectoryMetrics(center_of_mass)
119119

120120
return metrics.tracer_diffusivity(dimensions=dimensions)
121121

@@ -230,7 +230,7 @@ def amplitudes(self) -> np.ndarray:
230230
return np.asarray(amplitudes)
231231

232232

233-
class SimulationMetricsStd:
233+
class TrajectoryMetricsStd:
234234
"""Class for calculating different metrics and properties from a molecular
235235
dynamics simulation.
236236
@@ -246,7 +246,7 @@ def __init__(self, trajectories: list[Trajectory]):
246246
trajectories: list[Trajectory]
247247
Input trajectories
248248
"""
249-
self.metrics = [SimulationMetrics(trajectory) for trajectory in trajectories]
249+
self.metrics = [TrajectoryMetrics(trajectory) for trajectory in trajectories]
250250

251251
def speed(self) -> tuple[np.ndarray, np.ndarray]:
252252
"""Calculate mean speed and standard deviations.

src/gemdat/plots/matplotlib/_frequency_vs_occurence.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import matplotlib.pyplot as plt
66
import numpy as np
77

8-
from gemdat.simulation_metrics import SimulationMetrics
9-
108
if TYPE_CHECKING:
119
from gemdat.trajectory import Trajectory
1210

@@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
2422
fig : matplotlib.figure.Figure
2523
Output figure
2624
"""
27-
metrics = SimulationMetrics(trajectory)
25+
metrics = trajectory.metrics()
2826
speed = metrics.speed()
2927

3028
length = speed.shape[1]

src/gemdat/plots/matplotlib/_vibrational_amplitudes.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import numpy as np
77
from scipy import stats
88

9-
from gemdat.simulation_metrics import SimulationMetrics
10-
119
if TYPE_CHECKING:
1210
from gemdat.trajectory import Trajectory
1311

@@ -25,7 +23,7 @@ def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
2523
fig : matplotlib.figure.Figure
2624
Output figure
2725
"""
28-
metrics = SimulationMetrics(trajectory)
26+
metrics = trajectory.metrics()
2927

3028
fig, ax = plt.subplots()
3129
ax.hist(metrics.amplitudes(), bins=100, density=True)

src/gemdat/plots/plotly/_frequency_vs_occurence.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import numpy as np
66
import plotly.graph_objects as go
77

8-
from gemdat.simulation_metrics import SimulationMetrics
9-
108
if TYPE_CHECKING:
119
from gemdat.trajectory import Trajectory
1210

@@ -24,7 +22,7 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
2422
fig : plotly.graph_objects.Figure.Figure
2523
Output figure
2624
"""
27-
metrics = SimulationMetrics(trajectory)
25+
metrics = trajectory.metrics()
2826
speed = metrics.speed()
2927

3028
length = speed.shape[1]

src/gemdat/plots/plotly/_vibrational_amplitudes.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import plotly.graph_objects as go
99
from scipy import stats
1010

11-
from gemdat.simulation_metrics import SimulationMetrics
12-
1311
if TYPE_CHECKING:
1412
from gemdat.trajectory import Trajectory
1513

@@ -33,8 +31,8 @@ def vibrational_amplitudes(
3331
"""
3432

3533
trajectories = trajectory.split(n_parts)
36-
single_metrics = SimulationMetrics(trajectory)
37-
metrics = [SimulationMetrics(trajectory).amplitudes() for trajectory in trajectories]
34+
single_metrics = trajectory.metrics()
35+
metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories]
3836

3937
max_amp = max(max(metric) for metric in metrics)
4038
min_amp = min(min(metric) for metric in metrics)

src/gemdat/trajectory.py

+7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
if TYPE_CHECKING:
2222
from pymatgen.core import Structure
2323

24+
from .metrics import TrajectoryMetrics
2425
from .transitions import Transitions
2526
from .volume import Volume
2627

@@ -613,6 +614,12 @@ def transitions_between_sites(
613614
site_inner_fraction=site_inner_fraction,
614615
)
615616

617+
def metrics(self) -> TrajectoryMetrics:
618+
"""See [gemdat.TrajectoryMetrics][] for more info."""
619+
from .metrics import TrajectoryMetrics
620+
621+
return TrajectoryMetrics(trajectory=self)
622+
616623
@plot_backend
617624
def plot_displacement_per_atom(self, *, module, **kwargs):
618625
"""See [gemdat.plots.displacement_per_atom][] for more info."""

src/gemdat/transitions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pymatgen.core import Structure
1414

1515
from .caching import weak_lru_cache
16-
from .simulation_metrics import SimulationMetrics
16+
from .metrics import TrajectoryMetrics
1717
from .utils import bfill, ffill, integer_remap
1818

1919
if typing.TYPE_CHECKING:
@@ -108,7 +108,7 @@ def from_trajectory(
108108
diff_trajectory = trajectory.filter(floating_specie)
109109

110110
if site_radius is None:
111-
vibration_amplitude = SimulationMetrics(diff_trajectory).vibration_amplitude()
111+
vibration_amplitude = TrajectoryMetrics(diff_trajectory).vibration_amplitude()
112112

113113
site_radius = _compute_site_radius(
114114
trajectory=trajectory,

0 commit comments

Comments
 (0)