5
5
from math import ceil
6
6
from typing import TYPE_CHECKING , Callable
7
7
8
+ import networkx as nx
8
9
import numpy as np
9
10
import pandas as pd
10
11
from pymatgen .core .units import FloatWithUnit
13
14
from ._plot_backend import plot_backend
14
15
from .caching import weak_lru_cache
15
16
from .collective import Collective
16
- from .simulation_metrics import SimulationMetrics
17
+ from .metrics import TrajectoryMetrics
17
18
from .transitions import Transitions , _calculate_transitions_matrix
18
19
19
20
if TYPE_CHECKING :
@@ -223,7 +224,7 @@ def collective(self, max_dist: float = 1) -> Collective:
223
224
sites = self .transitions .sites
224
225
225
226
time_step = trajectory .time_step
226
- attempt_freq , _ = SimulationMetrics (trajectory ).attempt_frequency ()
227
+ attempt_freq , _ = TrajectoryMetrics (trajectory ).attempt_frequency ()
227
228
228
229
max_steps = ceil (1.0 / (attempt_freq * time_step ))
229
230
@@ -237,7 +238,7 @@ def collective(self, max_dist: float = 1) -> Collective:
237
238
238
239
@weak_lru_cache ()
239
240
def activation_energies (self , n_parts : int = 10 ) -> pd .DataFrame :
240
- """Calculate activation energies for jumps (UNITS?) .
241
+ """Calculate activation energies for jumps in eV .
241
242
242
243
Parameters
243
244
----------
@@ -251,7 +252,7 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
251
252
between site pairs.
252
253
"""
253
254
trajectory = self .trajectory
254
- attempt_freq , _ = SimulationMetrics (trajectory ).attempt_frequency ()
255
+ attempt_freq , _ = TrajectoryMetrics (trajectory ).attempt_frequency ()
255
256
256
257
dct = {}
257
258
@@ -260,13 +261,13 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
260
261
atom_locations_parts = [
261
262
part .atom_locations () for part in self .transitions .split (n_parts )
262
263
]
263
- jumps_counter_parts = [part .jumps_counter () for part in self .split (n_parts )]
264
+ counter_parts = [part .counter () for part in self .split (n_parts )]
264
265
n_floating = self .n_floating
265
266
266
267
for site_pair in self .site_pairs :
267
268
site_start , site_stop = site_pair
268
269
269
- n_jumps = np .array ([part [site_pair ] for part in jumps_counter_parts ])
270
+ n_jumps = np .array ([part [site_pair ] for part in counter_parts ])
270
271
271
272
part_time = trajectory .total_time / n_parts
272
273
@@ -292,22 +293,106 @@ def activation_energies(self, n_parts: int = 10) -> pd.DataFrame:
292
293
293
294
return df
294
295
295
- def jumps_counter (self ) -> Counter :
296
- """Calculate number of jumps between sites.
296
+ @weak_lru_cache ()
297
+ def counter (self ) -> Counter [tuple [str , str ]]:
298
+ """Count number of jumps between sites.
297
299
298
300
Returns
299
301
-------
300
- jumps : dict[tuple[str, str], int]
301
- Dictionary with number of jumps per sites combination
302
+ counter : Counter[tuple[str, str]]
303
+ Dictionary with site pairs as keys and corresponding
304
+ number of jumps as dictionary values
302
305
"""
303
306
labels = self .sites .labels
304
- jumps = Counter (
305
- [
306
- (labels [i ], labels [j ])
307
- for _ , (i , j ) in self .data [['start site' , 'destination site' ]].iterrows ()
308
- ]
309
- )
310
- return jumps
307
+ counter : Counter [tuple [str , str ]] = Counter ()
308
+ for (i , j ), val in self ._counter ().items ():
309
+ counter [labels [i ], labels [j ]] += val
310
+ return counter
311
+
312
+ @weak_lru_cache ()
313
+ def _counter (self ) -> Counter [tuple [int , int ]]:
314
+ """Count number of jumps between sites. Keys are site indices.
315
+
316
+ Returns
317
+ -------
318
+ counter : Counter[tuple[int, int]]
319
+ Dictionary with site pairs as keys and corresponding
320
+ number of jumps as dictionary values
321
+ """
322
+ counter = Counter (zip (self .data ['start site' ], self .data ['destination site' ]))
323
+ return counter
324
+
325
+ def activation_energy_between_sites (self , start : str , stop : str ) -> float :
326
+ """Returns activation energy between two sites.
327
+
328
+ Uses `Jumps.to_graph()` in the background. For a large number of operations,
329
+ it is more efficient to query the graph directly.
330
+
331
+ Parameters
332
+ ----------
333
+ start : str
334
+ Label of the start site
335
+ stop : str
336
+ Label of the stop site
337
+
338
+ Returns
339
+ -------
340
+ e_act : float
341
+ Activation energy in eV
342
+ """
343
+ G = self .to_graph ()
344
+ edge_data = G .get_edge_data (start , stop )
345
+ if not edge_data :
346
+ raise IndexError (f'No jumps between ({ start } ) and ({ stop } )' )
347
+ return edge_data ['e_act' ]
348
+
349
+ @weak_lru_cache ()
350
+ def to_graph (
351
+ self , min_e_act : float | None = None , max_e_act : float | None = None
352
+ ) -> nx .DiGraph :
353
+ """Create a graph from jumps data.
354
+
355
+ The edges are weighted by the activation energy. The nodes are indices that
356
+ correspond to `Jumps.sites`.
357
+
358
+ Parameters
359
+ ----------
360
+ min_e_act : float
361
+ Reject edges with activation energy below this threshold
362
+ max_e_act : float
363
+ Reject edges with activation energy above this threshold
364
+
365
+ Returns
366
+ -------
367
+ G : nx.DiGraph
368
+ A networkx DiGraph object.
369
+ """
370
+ min_e_act = min_e_act if min_e_act else float ('-inf' )
371
+ max_e_act = max_e_act if max_e_act else float ('inf' )
372
+
373
+ atom_percentage = [site .species .num_atoms for site in self .transitions .occupancy ()]
374
+
375
+ attempt_freq , _ = self .trajectory .metrics ().attempt_frequency ()
376
+ temperature = self .trajectory .metadata ['temperature' ]
377
+ kBT = Boltzmann * temperature
378
+
379
+ G = nx .DiGraph ()
380
+
381
+ for i , site in enumerate (self .sites ):
382
+ G .add_node (i , label = site .label )
383
+
384
+ for (start , stop ), n_jumps in self ._counter ().items ():
385
+ time_perc = atom_percentage [start ] * self .trajectory .total_time
386
+
387
+ eff_rate = n_jumps / time_perc
388
+
389
+ e_act = - np .log (eff_rate / attempt_freq ) * kBT
390
+ e_act /= elementary_charge
391
+
392
+ if min_e_act <= e_act <= max_e_act :
393
+ G .add_edge (start , stop , e_act = e_act )
394
+
395
+ return G
311
396
312
397
def split (self , n_parts : int ) -> list [Jumps ]:
313
398
"""Split the jumps into parts.
@@ -336,12 +421,12 @@ def rates(self, n_parts: int = 10) -> pd.DataFrame:
336
421
"""
337
422
dct = {}
338
423
339
- parts = [part .jumps_counter () for part in self .split (n_parts )]
424
+ parts = [part .counter () for part in self .split (n_parts )]
425
+ part_time = self .trajectory .total_time / n_parts
340
426
341
427
for site_pair in self .site_pairs :
342
428
n_jumps = [part [site_pair ] for part in parts ]
343
429
344
- part_time = self .trajectory .total_time / n_parts
345
430
denom = self .n_floating * part_time
346
431
347
432
jump_freq_mean = np .mean (n_jumps ) / denom
0 commit comments