Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add division model k-calibration #276

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 201 additions & 1 deletion sed/calibrator/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -107,7 +108,10 @@ def __init__(
self.correction: Dict[Any, Any] = {"applied": False}
self.adjust_params: Dict[Any, Any] = {"applied": False}
self.calibration: Dict[Any, Any] = {}

self.division_model_params: Dict[str, Any] = self._config["momentum"].get(
"division_model_params",
{},
)
self.x_column = self._config["dataframe"]["x_column"]
self.y_column = self._config["dataframe"]["y_column"]
self.corrected_x_column = self._config["dataframe"]["corrected_x_column"]
Expand Down Expand Up @@ -1837,6 +1841,122 @@ def gather_calibration_metadata(self, calibration: dict = None) -> dict:

return metadata

def calibrate_k_division_model(
self,
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
warp_params: Union[Dict[str, Any], Sequence[float]] = None,
x_column: str = None,
y_column: str = None,
kx_column: str = None,
ky_column: str = None,
) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
"""Use the division model to calibrate the momentum axis.

This function returns the distorted coordinates given the undistorted ones
a little complicated by the fact that gamma needs to go to (0,0).
it uses a radial distortion model called division model
(https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction)
commonly used to correct for lens artifacts.

The radial distortion parameters k0, k1, k2 are defined as follows:
.. math::
K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4)
where rpx is the distance from the center of distortion in pixels and rk is the
distance from the center of distortion in k space.

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the
distotion correction to.
warp_params (Sequence[float], optional): Parameters of the division model.
Either a dictionary containing the parameters or a sequence of the
parameters in the order ['center','k0','k1','k2','gamma'].
center and gamma are both 2D vectors, k0, k1 and k2 are scalars.
Defaults to config["momentum"]["division_model_params"].
x_column (str, optional): Label of the source 'X' column.
Defaults to config["momentum"]["x_column"].
y_column (str, optional): Label of the source 'Y' column.
Defaults to config["momentum"]["y_column"].
kx_column (str, optional): Label of the destination 'X' column after
momentum calibration. Defaults to config["momentum"]["kx_column"].
ky_column (str, optional): Label of the destination 'Y' column after
momentum calibration. Defaults to config["momentum"]["ky_column"].

Returns:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe with added columns
metadata (dict): momentum calibration metadata dictionary.
"""
if x_column is None:
x_column = self.x_column
if y_column is None:
y_column = self.y_column
if kx_column is None:
kx_column = self.kx_column
if ky_column is None:
ky_column = self.ky_column

if warp_params is None:
warp_params = self.division_model_params

if isinstance(warp_params, Sequence):
if len(warp_params) != 7:
raise ValueError(
f"Warp parameters must be a sequence of 7 floats! (center, k0, k1, k2, gamma)\n"
f"Got {len(warp_params)} instead",
)
warp_params = {
"center": np.asarray(warp_params[0:2]),
"k0": warp_params[2],
"k1": warp_params[3],
"k2": warp_params[4],
"gamma": np.asarray(warp_params[5:7]),
}
elif isinstance(warp_params, dict):
if not all(key in warp_params for key in ["center", "k0", "k1", "k2", "gamma"]):
raise ValueError(
f"Warp parameters must be a dictionary containing the keys "
"'center', 'k0', 'k1', 'k2', 'gamma'!\n"
f"Got {warp_params.keys()} instead",
)
if len(warp_params["center"]) != 2:
raise ValueError(
f"Warp parameter 'center' must be a 2D vector!\n"
f"Got {warp_params['center']} instead",
)
if len(warp_params["gamma"]) != 2:
raise ValueError(
f"Warp parameter 'gamma' must be a 2D vector!\n"
f"Got {warp_params['gamma']} instead",
)
if not all(
isinstance(value, (int, float, np.integer, np.floating))
for value in [warp_params[k] for k in ["k0", "k1", "k2"]]
):
raise ValueError(
f"Warp parameters 'k0', 'k1' and 'k2' must be floats!\n"
f"Got {warp_params['k0']}, {warp_params['k1']} and {warp_params['k2']} instead",
)
else:
raise TypeError("Warp parameters must be a dictionary or a sequence of floats!")

df = calibrate_k_division_model(
df,
x_column=x_column,
y_column=y_column,
kx_column=kx_column,
ky_column=ky_column,
**warp_params,
)

metadata = {
"applied": True,
"warp_params": warp_params,
"x_column": x_column,
"y_column": y_column,
"kx_column": kx_column,
"ky_column": ky_column,
}
return df, metadata


def cm2palette(cmap_name: str) -> list:
"""Convert certain matplotlib colormap (cm) to bokeh palette.
Expand Down Expand Up @@ -2091,3 +2211,83 @@ def load_dfield(file: str) -> Tuple[np.ndarray, np.ndarray]:
pass

return rdeform_field, cdeform_field


def calibrate_k_division_model(
df: Union[pd.DataFrame, dask.dataframe.DataFrame],
center: Tuple[float, float] = None,
k0: float = None,
k1: float = None,
k2: float = None,
rot: float = None,
gamma: Tuple[float, float] = None,
x_column: str = None,
y_column: str = None,
kx_column: str = None,
ky_column: str = None,
) -> dask.dataframe.DataFrame:
"""K calibration based on the division model

This function returns the distorted coordinates given the undistorted ones
a little complicated by the fact that gamma needs to go to (0,0).
it uses a radial distortion model called division model
(https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction)
commonly used to correct for lens artifacts.

The radial distortion parameters k0, k1, k2 are defined as follows:
.. math::
K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4)
where rpx is the distance from the center of distortion in pixels and rk is the
distance from the center of distortion in k space.

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the
distotion correction to.
center (Tuple[float, float]): center of distortion in px
k0 (float): radial distortion parameter
k1 (float): radial distortion parameter
k2 (float): radial distortion parameter
rot (float): rotation in rad
gamma (Tuple[float, float]): normal emission (Gamma) in px
x_column (str): Name of the column containing the x steps.
y_column (str): Name of the column containing the y steps.
kx_column (str, optional): Name of the target calibrated x column.
If None, defaults to x_column.
ky_column (str, optional): Name of the target calibrated x column.
If None, defaults to y_column.

Returns:
df (dask.dataframe.DataFrame): Dataframe with added columns
"""
if kx_column is None:
kx_column = x_column
if ky_column is None:
ky_column = y_column

def convert_to_kx(x):
"""Converts the x steps to kx."""
x_diff = x[x_column] - center[0]
y_diff = x[y_column] - center[1]
dist = np.sqrt(x_diff**2 + y_diff**2)
den = k0 + k1 * dist**2 + k2 * dist**4
angle = np.arctan2(y_diff, x_diff) - rot
warp_diff = np.sqrt((gamma[0] - center[0]) ** 2 + (gamma[1] - center[1]) ** 2)
warp_den = k0 + k1 * (gamma[0] - center[0]) ** 2 + k2 * (gamma[1] - center[1]) ** 2
warp_angle = np.arctan2(gamma[1] - center[1], gamma[0] - center[0]) - rot
return (dist / den) * np.cos(angle) - (warp_diff / warp_den) * np.cos(warp_angle)

def convert_to_ky(x):
x_diff = x[x_column] - center[0]
y_diff = x[y_column] - center[1]
dist = np.sqrt(x_diff**2 + y_diff**2)
den = k0 + k1 * dist**2 + k2 * dist**4
angle = np.arctan2(y_diff, x_diff) - rot
warp_diff = np.sqrt((gamma[0] - center[0]) ** 2 + (gamma[1] - center[1]) ** 2)
warp_den = k0 + k1 * (gamma[0] - center[0]) ** 2 + k2 * (gamma[1] - center[1]) ** 2
warp_angle = np.arctan2(gamma[1] - center[1], gamma[0] - center[0]) - rot
return (dist / den) * np.sin(angle) - (warp_diff / warp_den) * np.sin(warp_angle)

df[kx_column] = df.map_partitions(convert_to_kx, meta=(kx_column, np.float64))
df[ky_column] = df.map_partitions(convert_to_ky, meta=(ky_column, np.float64))

return df
82 changes: 81 additions & 1 deletion sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,86 @@ def save_delay_calibration(
}
save_config(config, filename, overwrite)

def calibrate_k_division_model(
self,
warp_params: Sequence[float] = None,
**kwargs,
) -> None:
"""Use the division model to calibrate the momentum axis.

This function returns the distorted coordinates given the undistorted ones
a little complicated by the fact that gamma needs to go to (0,0).
it uses a radial distortion model called division model
(https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction)
commonly used to correct for lens artifacts.

The radial distortion parameters k0, k1, k2 are defined as follows:
.. math::
K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4)
where rpx is the distance from the center of distortion in pixels and rk is the
distance from the center of distortion in k space.

Args:
df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the
distotion correction to.
warp_params (Sequence[float], optional): Parameters of the division model.
Either a dictionary containing the parameters or a sequence of the
parameters in the order ['center','k0','k1','k2','gamma'].
center and gamma are both 2D vectors, k0, k1 and k2 are scalars.
Center is the center of distortion in pixels, gamma is the center of
the image in k space. k0, k1 and k2 are the radial distortion parameters.
Defaults to config["momentum"]["division_model_params"].
kwargs: Keyword arguments passed to ``calibrate_k_division_model``:
x_column (str, optional): Label of the source 'X' column.
Defaults to config["momentum"]["x_column"].
y_column (str, optional): Label of the source 'Y' column.
Defaults to config["momentum"]["y_column"].
kx_column (str, optional): Label of the destination 'X' column after
momentum calibration. Defaults to config["momentum"]["kx_column"].
ky_column (str, optional): Label of the destination 'Y' column after
momentum calibration. Defaults to config["momentum"]["ky_column"].
"""
self._dataframe, metadata = self.mc.calibrate_k_division_model(
df=self._dataframe,
warp_params=warp_params,
**kwargs,
)
self._attributes.add(
metadata,
"k_division_model",
duplicate_policy="raise",
)

def save_k_division_model(
self,
filename: str = None,
overwrite: bool = False,
) -> None:
"""save the generated k division model parameters to the folder config file.



Args:
filename (str, optional): Filename of the config dictionary to save to.
Defaults to "sed_config.yaml" in the current folder.
overwrite (bool, optional): Option to overwrite the present dictionary.
Defaults to False.
"""
if filename is None:
filename = "sed_config.yaml"
params = {}
try:
for key in ["center", "k0", "k1", "k2", "gamma"]:
params[key] = self.mc.division_model_params[key]
except KeyError as exc:
raise KeyError(
"k division model parameters not found, need to generate parameters first!",
) from exc

config: Dict[str, Any] = {"momentum": {"k_division_model": params}}
save_config(config, filename, overwrite)
print(f"Saved k division model parameters to {filename}")

def add_delay_offset(
self,
constant: float = None,
Expand All @@ -1529,7 +1609,6 @@ def add_delay_offset(
of dask.dataframe.Series. For example "mean". In this case the function is applied
to the column to generate a single value for the whole dataset. If None, the shift
is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.

Returns:
None
"""
Expand Down Expand Up @@ -1618,6 +1697,7 @@ def save_workflow_params(
self.save_energy_offset,
self.save_delay_calibration,
self.save_delay_offsets,
self.save_k_division_model,
]:
try:
method(filename, overwrite)
Expand Down