Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions src/pyslice/multislice/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def setup(
cleanup_temp_files: bool = False,
slice_axis: int = 2,
cache_levels: list = ["exitwaves"], # options include: exitwaves, slices, potentials (this replaces store_all_slices)
cache_layer_indices: Optional[List[int]] = None, # NEW: subset of slice indices to store; None = store all layers
max_kx = np.inf,
max_ky = np.inf,
use_memmap = False,
Expand All @@ -151,6 +152,11 @@ def setup(
save_path: Optional path to save wave function data
cleanup_temp_files: Whether to delete temp files after loading
store_all_slices: If True, store wavefunction at each slice for 3D visualization
cache_layer_indices: Optional list of slice-layer indices (0-based) to record when
cache_levels includes "slices". If None (default), all nz layers are stored.
Specifying a small subset (e.g. the 6 depths needed for EELS thickness series)
can reduce disk usage by >98% without affecting propagation accuracy.
Example: cache_layer_indices=[44, 88, 176, 264, 352, 440]
"""

self.trajectory = trajectory
Expand All @@ -166,6 +172,7 @@ def setup(
self.cleanup_temp_files = cleanup_temp_files
self.slice_axis = slice_axis
self.cache_levels = cache_levels
self.cache_layer_indices = cache_layer_indices # NEW: store for use in run()
self.max_kx = max_kx
self.max_ky = max_ky
self.use_memmap = use_memmap # bool: frame_data (p,x,y,l,1) and wavefunction_data (p,t,x,y,l) will be memmapped instead of held in RAM
Expand Down Expand Up @@ -293,6 +300,41 @@ def run(self) -> WFData:
self.output_dir = Path("psi_data/" + ("torch" if TORCH_AVAILABLE else "numpy") + "_"+cache_key)
self.output_dir.mkdir(parents=True, exist_ok=True)

# ── Resolve which layers to store ──
# NEW: if cache_layer_indices is set, only those layers are FFT'd and
# written to disk; the propagation itself still runs through all nz
# slices (physically required). cache_layer_indices=None keeps the
# original behaviour of storing every layer.
if "slices" in self.cache_levels and self.cache_layer_indices is not None:
# Validate and clip indices to [0, nz-1]
_requested = sorted(set(int(i) for i in self.cache_layer_indices))
_dropped = [i for i in _requested if not (0 <= i < self.nz)]
_active_layers = [i for i in _requested if 0 <= i < self.nz]
if _dropped:
logger.warning(
f"cache_layer_indices: dropped out-of-range indices {_dropped} "
f"(nz={self.nz})"
)
if not _active_layers:
raise ValueError(
"cache_layer_indices produced no valid layer indices after "
f"clipping to [0, {self.nz-1}]."
)
logger.info(
f"Selective layer storage: recording {len(_active_layers)}/{self.nz} "
f"layers -> {_active_layers}"
)
print(
f"[MultisliceCalculator] cache_layer_indices: storing "
f"{len(_active_layers)}/{self.nz} layers: {_active_layers}",
flush=True
)
else:
# Default: store every layer (original behaviour)
_active_layers = list(range(self.nz)) if "slices" in self.cache_levels else [0]

self._active_layers = _active_layers # expose for inspection / post-processing


# if probes are over vacuum (e.g. nanoparticles), we don't need to propagate them?
self.probe_indices = xp.arange(len(self.probe_positions))
Expand All @@ -315,7 +357,8 @@ def run(self) -> WFData:
nc,npt,nx,ny = self.base_probe._array.shape
self.n_probes = nc*len(self.probe_positions)
# Storage: [probe, frame, x, y, layer] - matches WFData expected format
self.n_layers = self.nz if "slices" in self.cache_levels else 1
# CHANGED: n_layers is now len(_active_layers) instead of always self.nz
self.n_layers = len(_active_layers)
if self.store_full:
fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes
#if self.base_probe.cropping:
Expand Down Expand Up @@ -406,6 +449,7 @@ def run(self) -> WFData:

#print("create frame_data") ; start = time.time()
# frame_data is always: p,x,y,l,1 (self.wavefunction_data expects p,t,x,y,l, since we loop time. recall Propagate gave l,p,x,y)
# CHANGED: last dim is self.n_layers = len(_active_layers), not nz
if self.store_full or self.prism:
fd_nx = self.nx ; fd_ny = self.ny ; fd_npt = self.n_probes
#if self.base_probe.cropping:
Expand Down Expand Up @@ -454,16 +498,22 @@ def run(self) -> WFData:
exit_waves_single = expand_dims(exit_waves_single,0) if len(exit_waves_single.shape)==3 else exit_waves_single
# FFT and load into frame_data
kwarg = {"dim":(-2,-1)} if TORCH_AVAILABLE else {"axes":(-2,-1)}
for layer_idx in range(self.n_layers):
exit_waves_k = xp.fft.fft2(exit_waves_single[layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y

# CHANGED: iterate over (out_idx, real_layer_idx) pairs instead of
# range(self.n_layers). When cache_layer_indices=None, _active_layers
# is list(range(nz)) so out_idx == real_layer_idx and behaviour is
# identical to the original code.
for out_idx, real_layer_idx in enumerate(_active_layers):
exit_waves_k = xp.fft.fft2(exit_waves_single[real_layer_idx,:,:,:], **kwarg) # l,p,x,y --> p,x,y
diffraction_patterns = xp.fft.fftshift(exit_waves_k, **kwarg)
#if not self.prism:
diffraction_patterns = diffraction_patterns[:,self.keep_kxs_indices,:][:,:,self.keep_kys_indices]*self.kth**2
if self.use_memmap:
diffraction_patterns = to_cpu(diffraction_patterns)
selected = to_cpu(selected)
if self.store_full or self.prism:
frame_data[selected,:,:,layer_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices
# CHANGED: write to compact slot out_idx, not real_layer_idx
frame_data[selected,:,:,out_idx,0] = diffraction_patterns # load p,x,y --> p,x,y,l,1 indices
if self.ADF and not self.prism:
#print(self.ADF._wf_array[0,:,:,0,0,0,0])
intensities = einsum('pxy,xy->p',absolute(diffraction_patterns[:,:,:])**2,self.ADFmask)
Expand Down Expand Up @@ -550,7 +600,11 @@ def run(self) -> WFData:
#kxs = xp.fft.fftshift(xp.fft.fftfreq(self.nx, self.sampling)) # k-space in 1/Å MOVING TO INIT SO WE CAN CROP ON-THE-FLY
#kys = xp.fft.fftshift(xp.fft.fftfreq(self.ny, self.sampling)) # k-space in 1/Å
time_array = np.arange(self.n_frames) * self.trajectory.timestep # Time array in ps
layer_array = np.arange(self.nz) if "slices" in self.cache_levels else np.array([0]) # Layer indices

# CHANGED: layer_array now reflects the actual stored layer indices.
# When cache_layer_indices=None, _active_layers == list(range(nz)) so
# layer_array == np.arange(nz), identical to the original behaviour.
layer_array = np.array(_active_layers) if "slices" in self.cache_levels else np.array([0]) # Layer indices

# Package results
array = zeros((self.n_probes,1,1,1,1),dtype=self.complex_dtype)
Expand Down Expand Up @@ -667,4 +721,3 @@ def plot(self,w,filename=None): # TODO MAYBE "RUN" SHOULD RETURN A TACAW OBJECT
else:
plt.show()