Skip to content

Commit

Permalink
Merge pull request #358 from OpenCOMPES/interpolation_performance_fix
Browse files Browse the repository at this point in the history
Improve interpolation performance
  • Loading branch information
rettigl authored Mar 18, 2024
2 parents 8c9ebff + ef499f6 commit 8c0e91b
Showing 1 changed file with 69 additions and 40 deletions.
109 changes: 69 additions & 40 deletions sed/calibrator/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import itertools as it
from copy import deepcopy
from datetime import datetime
from multiprocessing import Pool
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -18,6 +19,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import scipy.ndimage as ndi
import xarray as xr
from bokeh.colors import RGB
Expand All @@ -27,11 +29,13 @@
from matplotlib import cm
from numpy.linalg import norm
from scipy.interpolate import griddata
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import map_coordinates
from symmetrize import pointops as po
from symmetrize import sym
from symmetrize import tps

N_CPU = psutil.cpu_count()


class MomentumCorrector:
"""
Expand Down Expand Up @@ -68,6 +72,10 @@ def __init__(

self._config = config

self.num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
if self.num_cores >= N_CPU:
self.num_cores = N_CPU - 1

self.image: np.ndarray = None
self.img_ndim: int = None
self.slice: np.ndarray = None
Expand Down Expand Up @@ -1218,6 +1226,7 @@ def calc_inverse_dfield(self):
self.cdeform_field,
self.bin_ranges,
self.detector_ranges,
self.num_cores,
)

return self.inverse_dfield
Expand Down Expand Up @@ -1707,6 +1716,7 @@ def apply_corrections(
self.cdeform_field,
self.bin_ranges,
self.detector_ranges,
self.num_cores,
)
self.dfield_updated = False

Expand Down Expand Up @@ -2027,34 +2037,12 @@ def apply_dfield(
x = df[x_column]
y = df[y_column]

r_axis = np.linspace(
detector_ranges[0][0],
dfield[0].shape[0],
detector_ranges[0][1],
endpoint=False,
)

c_axis = np.linspace(
detector_ranges[1][0],
dfield[0].shape[1],
detector_ranges[1][1],
endpoint=False,
)

interp_x = RegularGridInterpolator(
(r_axis, c_axis),
dfield[0],
bounds_error=False,
)
interp_y = RegularGridInterpolator(
(r_axis, c_axis),
dfield[1],
bounds_error=False,
)
r_axis_steps = (detector_ranges[0][1] - detector_ranges[0][0]) / dfield[0].shape[0]
c_axis_steps = (detector_ranges[1][1] - detector_ranges[1][0]) / dfield[0].shape[1]

df[new_x_column], df[new_y_column] = (
interp_x((x, y)),
interp_y((x, y)),
map_coordinates(dfield[0], (x, y), order=1) * r_axis_steps,
map_coordinates(dfield[1], (x, y), order=1) * c_axis_steps,
)
return df

Expand All @@ -2064,6 +2052,7 @@ def generate_inverse_dfield(
cdeform_field: np.ndarray,
bin_ranges: List[Tuple],
detector_ranges: List[Tuple],
num_cores: int,
) -> np.ndarray:
"""Generate inverse deformation field using inperpolation with griddata.
Assuming the binning range of the input ``rdeform_field`` and ``cdeform_field``
Expand All @@ -2074,6 +2063,7 @@ def generate_inverse_dfield(
cdeform_field (np.ndarray): Column-wise deformation field.
bin_ranges (List[Tuple]): Detector ranges of the binned coordinates.
detector_ranges (List[Tuple]): Ranges of detector coordinates to interpolate to.
num_cores (int): number of cores to use for parallelization.
Returns:
np.ndarray: The calculated inverse deformation field (row/column)
Expand Down Expand Up @@ -2106,7 +2096,49 @@ def generate_inverse_dfield(
rc_position = [] # row/column position in c/rdeform_field
r_dest = [] # destination pixel row position
c_dest = [] # destination pixel column position
for i in np.arange(cdeform_field.shape[0]):
compute_i0 = [(cdeform_field.shape[0] * i) // num_cores for i in np.arange(0, num_cores)]
compute_i1 = [(cdeform_field.shape[0] * i) // num_cores for i in np.arange(1, num_cores + 1)]
data = [
(rdeform_field, cdeform_field, bin_ranges, bin_step, i0, i1)
for (i0, i1) in zip(compute_i0, compute_i1)
]
with Pool(num_cores) as p:
ret = p.map(generate_lists, data)

for pos, rd, cd in ret:
rc_position += pos
r_dest += rd
c_dest += cd

with Pool(2) as p:
ret = p.map(
griddata_,
[
(np.asarray(rc_position), np.asarray(r_dest), (r_mesh, c_mesh)),
(np.asarray(rc_position), np.asarray(c_dest), (r_mesh, c_mesh)),
],
)

inverse_dfield = np.asarray([ret[0], ret[1]])

return inverse_dfield


def generate_lists(args):
"""Function for paralellizing code with multiprocessing.Pool.map
Args:
args: argument tuple containing (rdeform_field, cdeform_field, bin_ranges, bin_step, i0, i1)
Returns:
return tuple of lists (rc_position, r_dest, c_dest)
"""
(rdeform_field, cdeform_field, bin_ranges, bin_step, i0, i1) = args
rc_position = [] # row/column position in c/rdeform_field
r_dest = [] # destination pixel row position
c_dest = [] # destination pixel column position

for i in np.arange(i0, i1):
for j in np.arange(cdeform_field.shape[1]):
if not np.isnan(rdeform_field[i, j]) and not np.isnan(
cdeform_field[i, j],
Expand All @@ -2123,22 +2155,19 @@ def generate_inverse_dfield(
c_dest.append(
bin_step[1] * j + bin_ranges[1][0],
)
return (rc_position, r_dest, c_dest)

inv_rdeform_field = griddata(
np.asarray(rc_position),
r_dest,
(r_mesh, c_mesh),
)

inv_cdeform_field = griddata(
np.asarray(rc_position),
c_dest,
(r_mesh, c_mesh),
)
def griddata_(args):
"""Wrapper for griddata to use with multiprocessing.Pool.map
inverse_dfield = np.asarray([inv_rdeform_field, inv_cdeform_field])
Args:
args: argument tuple to griddata
return inverse_dfield
Returns:
return value of griddata
"""
return griddata(*args)


def load_dfield(file: str) -> Tuple[np.ndarray, np.ndarray]:
Expand Down

0 comments on commit 8c0e91b

Please sign in to comment.