Skip to content

Commit 34a0cc0

Browse files
authored
Fix crash in Transitions.occupancy if input sites are disordered (#342)
* Add work-around for disordered sites listing #339 * Add warning for disordered structures * Fix atom locations and add occupancy_by_site_type
1 parent 6f8375e commit 34a0cc0

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

src/gemdat/transitions.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import typing
77
from collections import defaultdict
88
from itertools import pairwise
9+
from warnings import warn
910

1011
import numpy as np
1112
import pandas as pd
@@ -67,6 +68,15 @@ def __init__(
6768
inner_states : np.ndarray
6869
Input states for inner sites
6970
"""
71+
if not (sites.is_ordered):
72+
warn(
73+
'Input `sites` are disordered! '
74+
'Although the code may work, it was written under the assumption '
75+
'that an ordered structure would be passed. '
76+
'See https://github.com/GEMDAT-repos/GEMDAT/issues/339 for more information.',
77+
stacklevel=2,
78+
)
79+
7080
self.sites = sites
7181
self.trajectory = trajectory
7282
self.diff_trajectory = diff_trajectory
@@ -252,7 +262,10 @@ def occupancy(self) -> Structure:
252262
counts = counts / len(states)
253263
occupancies = dict(zip(unq, counts))
254264

255-
species = [{site.specie.name: occupancies.get(i, 0)} for i, site in enumerate(sites)]
265+
species = [
266+
{site.species.elements[0].name: occupancies.get(i, 0)}
267+
for i, site in enumerate(sites)
268+
]
256269

257270
return Structure(
258271
lattice=sites.lattice,
@@ -262,27 +275,36 @@ def occupancy(self) -> Structure:
262275
labels=sites.labels,
263276
)
264277

265-
def atom_locations(self):
278+
def occupancy_by_site_type(self) -> dict[str, float]:
279+
"""Calculate average occupancy per a type of site.
280+
281+
Returns
282+
-------
283+
occupancy : dict[str, float]
284+
Return dict with average occupancy per site type
285+
"""
286+
compositions_by_label = defaultdict(list)
287+
288+
for site in self.occupancy():
289+
compositions_by_label[site.label].append(site.species.num_atoms)
290+
291+
return {k: sum(v) / len(v) for k, v in compositions_by_label.items()}
292+
293+
def atom_locations(self) -> dict[str, float]:
266294
"""Calculate fraction of time atoms spent at a type of site.
267295
268296
Returns
269297
-------
270298
dict[str, float]
271299
Return dict with the fraction of time atoms spent at a site
272300
"""
273-
multiplier = len(self.sites) / self.n_floating
274-
301+
n = self.n_floating
275302
compositions_by_label = defaultdict(list)
276303

277304
for site in self.occupancy():
278305
compositions_by_label[site.label].append(site.species.num_atoms)
279306

280-
ret = {}
281-
282-
for k, v in compositions_by_label.items():
283-
ret[k] = (sum(v) / len(v)) * multiplier
284-
285-
return ret
307+
return {k: sum(v) / n for k, v in compositions_by_label.items()}
286308

287309
def split(self, n_parts: int = 10) -> list[Transitions]:
288310
"""Split data into equal parts in time for statistics.

tests/integration/transitions_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def test_occupancy_parts(self, vasp_transitions):
133133
35.43733333333334,
134134
]
135135

136+
def test_occupancy_by_site_type(self, vasp_transitions):
137+
occ = vasp_transitions.occupancy_by_site_type()
138+
assert occ == {'48h': 0.3806277777777776}
139+
136140
def test_atom_locations(self, vasp_transitions):
137141
dct = vasp_transitions.atom_locations()
138142
assert dct == {'48h': 0.7612555555555552}

0 commit comments

Comments
 (0)