Skip to content

Commit

Permalink
Added TOFReflectionManager. Added TOFLeastSquaresResidualWithRmsdCutf…
Browse files Browse the repository at this point in the history
…f. Added wavelength columns to refinement output.
  • Loading branch information
toastisme committed Mar 27, 2024
1 parent 241baec commit 01e97a5
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/dials/algorithms/indexing/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def _xyzcal_mm_to_px(self, experiments, reflections):
if expt.scan is not None:
if expt.scan.has_property("time_of_flight"):
tof = expt.scan.get_property("time_of_flight")
frames = [i + 1 for i in range(len(tof))]
frames = list(range(len(tof)))
tof_to_frame = tof_helpers.tof_to_frame_interpolator(tof, frames)
z_px = flex.double(tof_to_frame(z))
else:
Expand Down
4 changes: 2 additions & 2 deletions src/dials/algorithms/indexing/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def score_by_volume(self, reverse=False):
def score_by_rmsd_xy(self, reverse=False):
# smaller rmsds = better
rmsd_x, rmsd_y, rmsd_z = flex.vec3_double(
s.rmsds for s in self.all_solutions
s.rmsds[:3] for s in self.all_solutions
).parts()
rmsd_xy = flex.sqrt(flex.pow2(rmsd_x) + flex.pow2(rmsd_y))
score = flex.log(rmsd_xy) / math.log(2)
Expand Down Expand Up @@ -275,7 +275,7 @@ def __str__(self):
perm = flex.sort_permutation(combined_scores)

rmsd_x, rmsd_y, rmsd_z = flex.vec3_double(
s.rmsds for s in self.all_solutions
s.rmsds[:3] for s in self.all_solutions
).parts()
rmsd_xy = flex.sqrt(flex.pow2(rmsd_x) + flex.pow2(rmsd_y))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _post_predict_one_experiment(self, experiment, reflections):

# Add frame to xyzcal.px
expt_tof = experiment.scan.get_property("time_of_flight") # (usec)
frames = [i + 1 for i in range(len(expt_tof))]
frames = list(range(len(expt_tof)))
tof_to_frame = tof_helpers.tof_to_frame_interpolator(expt_tof, frames)
reflection_frames = flex.double(tof_to_frame(tof_cal))
px, py, pz = reflections["xyzcal.px"].parts()
Expand Down
9 changes: 6 additions & 3 deletions src/dials/algorithms/refinement/refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ def print_step_table(self):
rmsd_multipliers.append(1.0)
elif units == "A":
header.append(name + "\n(A)")
rmsd_multipliers.append(1.0)
elif units == "rad": # convert radians to degrees for reporting
header.append(name + "\n(deg)")
rmsd_multipliers.append(RAD2DEG)
Expand Down Expand Up @@ -913,6 +914,8 @@ def calc_exp_rmsd_table(self):
# will convert other angles in radians to degrees (e.g. for
# RMSD_DeltaPsi and RMSD_2theta)
header.append(name + "\n(deg)")
elif name == "RMSD_wavelength" and units == "A":
header.append(name + "\n(A)")
else: # skip other/unknown RMSDs
pass

Expand Down Expand Up @@ -952,7 +955,7 @@ def calc_exp_rmsd_table(self):
elif name == "RMSD_Phi" and units == "rad":
rmsds.append(rmsd * images_per_rad)
elif name == "RMSD_wavelength" and units == "A":
header.append(name + "\n(A)")
rmsds.append(rmsd)
elif units == "rad":
rmsds.append(rmsd * RAD2DEG)
rows.append([str(iexp), str(num)] + [f"{r:.5g}" for r in rmsds])
Expand Down Expand Up @@ -1003,7 +1006,7 @@ def print_panel_rmsd_table(self):
name == "RMSD_DeltaPsi" and units == "rad"
): # convert radians to degrees for reporting of stills
header.append(name + "\n(deg)")
elif name == "RMSD_wavelength" and units == "A":
elif name == "RMSD_wavelength" and units == "frame":
header.append(name + "\n(frame)")
else: # skip RMSDs that cannot be expressed in image/scan space
pass
Expand Down Expand Up @@ -1031,7 +1034,7 @@ def print_panel_rmsd_table(self):
rmsds.append(rmsd * images_per_rad)
elif name == "RMSD_DeltaPsi" and units == "rad":
rmsds.append(rmsd * RAD2DEG)
elif name == "RMSD_wavelength" and units == "A":
elif name == "RMSD_wavelength" and units == "frame":
rmsds.append(rmsd)
rows.append([str(ipanel), str(num)] + [f"{r:.5g}" for r in rmsds])

Expand Down
97 changes: 96 additions & 1 deletion src/dials/algorithms/refinement/reflection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random

import libtbx
from dxtbx.model import tof_helpers
from dxtbx.model.experiment_list import ExperimentList
from libtbx.phil import parse
from scitbx import matrix
Expand Down Expand Up @@ -371,7 +372,19 @@ def laue_manager(
params: libtbx.phil.scope_extract,
) -> LaueReflectionManager:

refman = LaueReflectionManager
all_tof_experiments = False
for expt in experiments:
if expt.scan is not None and expt.scan.has_property("time_of_flight"):
all_tof_experiments = True
elif all_tof_experiments:
raise ValueError(
"Cannot refine ToF and non-ToF experiments at the same time"
)

if all_tof_experiments:
refman = TOFReflectionManager
else:
refman = LaueReflectionManager

## Outlier detection
if params.outlier.algorithm in ("auto", libtbx.Auto):
Expand Down Expand Up @@ -1125,3 +1138,85 @@ def update_residuals(self):
self._reflections["wavelength_resid2"] = (
self._reflections["wavelength_resid"] ** 2
)


class TOFReflectionManager(LaueReflectionManager):
def __init__(
self,
reflections,
experiments,
nref_per_degree=None,
max_sample_size=None,
min_sample_size=0,
close_to_spindle_cutoff=0.02,
scan_margin=0.0,
outlier_detector=None,
weighting_strategy_override=None,
wavelength_weight=1e7,
):

super().__init__(
reflections=reflections,
experiments=experiments,
nref_per_degree=nref_per_degree,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
close_to_spindle_cutoff=close_to_spindle_cutoff,
scan_margin=scan_margin,
outlier_detector=outlier_detector,
weighting_strategy_override=weighting_strategy_override,
wavelength_weight=wavelength_weight,
)

tof_to_frame_interpolators = []
sample_to_source_distances = []
for expt in self._experiments:
tof = expt.scan.get_property("time_of_flight") # (usec)
frames = list(range(len(tof)))
tof_to_frame = tof_helpers.tof_to_frame_interpolator(tof, frames)
tof_to_frame_interpolators.append(tof_to_frame)
sample_to_source_distances.append(
expt.beam.get_sample_to_source_distance() * 10**-3 # (m)
)

self.tof_to_frame_interpolators = tof_to_frame_interpolators
self.sample_to_source_distances = sample_to_source_distances

def update_residuals(self):
x_obs, y_obs, _ = self._reflections["xyzobs.mm.value"].parts()
x_calc, y_calc, _ = self._reflections["xyzcal.mm"].parts()
wavelength_obs = self._reflections["wavelength"]
wavelength_cal = self._reflections["wavelength_cal"]
L2 = self._reflections["s1"].norms() * 10**-3
self._reflections["x_resid"] = x_calc - x_obs
self._reflections["y_resid"] = y_calc - y_obs
self._reflections["wavelength_resid"] = wavelength_cal - wavelength_obs
self._reflections["wavelength_resid2"] = (
self._reflections["wavelength_resid"] ** 2
)

frame_resid = flex.double(len(self._reflections))
frame_resid2 = flex.double(len(self._reflections))
for idx, expt in enumerate(self._experiments):
if "imageset_id" in self._reflections:
r_expt = self._reflections["imageset_id"] == idx
else:
r_expt = self._reflections["id"] == idx
L_expt = self.sample_to_source_distances[idx] + L2.select(r_expt)
tof_obs_expt = (
tof_helpers.tof_from_wavelength(L_expt, wavelength_obs.select(r_expt))
* 10**6
) # (usec)
tof_cal_expt = (
tof_helpers.tof_from_wavelength(L_expt, wavelength_cal.select(r_expt))
* 10**6
) # (usec)
tof_to_frame = self.tof_to_frame_interpolators[idx]
frame_resid_expt = flex.double(
tof_to_frame(tof_cal_expt) - tof_to_frame(tof_obs_expt)
)
frame_resid.set_selected(r_expt, frame_resid_expt)
frame_resid2.set_selected(r_expt, frame_resid_expt**2)

self._reflections["frame_resid"] = frame_resid
self._reflections["frame_resid2"] = frame_resid2
27 changes: 26 additions & 1 deletion src/dials/algorithms/refinement/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def from_parameters_and_experiments(

if all_tof_experiments:
from dials.algorithms.refinement.target import (
LaueLeastSquaresResidualWithRmsdCutoff as targ,
TOFLeastSquaresResidualWithRmsdCutoff as targ,
)

# Determine whether the target is in X, Y, Phi space or just X, Y to choose
Expand Down Expand Up @@ -854,3 +854,28 @@ def achieved(self):
):
return True
return False


class TOFLeastSquaresResidualWithRmsdCutoff(LaueLeastSquaresResidualWithRmsdCutoff):

_grad_names = ["dX_dp", "dY_dp", "dwavelength_dp"]
rmsd_names = ["RMSD_X", "RMSD_Y", "RMSD_wavelength", "RMSD_wavelength"]
rmsd_units = ["mm", "mm", "A", "frame"]

def _rmsds_core(self, reflections):

"""calculate unweighted RMSDs for the specified reflections"""

resid_x = flex.sum(reflections["x_resid2"])
resid_y = flex.sum(reflections["y_resid2"])
resid_wavelength = flex.sum(reflections["wavelength_resid2"])
resid_frame = flex.sum(reflections["frame_resid2"])
n = len(reflections)

rmsds = (
math.sqrt(resid_x / n),
math.sqrt(resid_y / n),
math.sqrt(abs(resid_wavelength) / n),
math.sqrt(abs(resid_frame) / n),
)
return rmsds

0 comments on commit 01e97a5

Please sign in to comment.