Skip to content

Commit 40ecd00

Browse files
committed
test/func: modifying refinement to allow lists and dicts and adding a test of refinement for squeeze and funcy
1 parent a450cd7 commit 40ecd00

File tree

5 files changed

+103
-9
lines changed

5 files changed

+103
-9
lines changed

src/diffpy/morph/morph_api.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,17 @@
3939
morph_helpers.TransformXtalRDFtoPDF,
4040
],
4141
qdamp=morphs.MorphResolutionDamping,
42+
squeeze=morphs.MorphSqueeze,
43+
parameters=morphs.MorphFuncy,
4244
)
4345
_default_config = dict(
44-
scale=None, stretch=None, smear=None, baselineslope=None, qdamp=None
46+
scale=None,
47+
stretch=None,
48+
smear=None,
49+
baselineslope=None,
50+
qdamp=None,
51+
squeeze=None,
52+
parameters=None,
4553
)
4654

4755

@@ -197,6 +205,14 @@ def morph(
197205
if k == "smear":
198206
[chain.append(el()) for el in morph_cls]
199207
refpars.append("baselineslope")
208+
elif k == "parameters":
209+
morph_inst = morph_cls()
210+
morph_inst.function = rv_cfg.get("function", None)
211+
if morph_inst.function is None:
212+
raise ValueError(
213+
"Must provide a 'function' when using 'parameters'"
214+
)
215+
chain.append(morph_inst)
200216
else:
201217
chain.append(morph_cls())
202218
refpars.append(k)

src/diffpy/morph/morphs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919

2020
from diffpy.morph.morphs.morph import Morph # noqa: F401
2121
from diffpy.morph.morphs.morphchain import MorphChain # noqa: F401
22+
from diffpy.morph.morphs.morphfuncy import MorphFuncy
2223
from diffpy.morph.morphs.morphishape import MorphISphere, MorphISpheroid
2324
from diffpy.morph.morphs.morphresolution import MorphResolutionDamping
2425
from diffpy.morph.morphs.morphrgrid import MorphRGrid
2526
from diffpy.morph.morphs.morphscale import MorphScale
2627
from diffpy.morph.morphs.morphshape import MorphSphere, MorphSpheroid
2728
from diffpy.morph.morphs.morphshift import MorphShift
2829
from diffpy.morph.morphs.morphsmear import MorphSmear
30+
from diffpy.morph.morphs.morphsqueeze import MorphSqueeze
2931
from diffpy.morph.morphs.morphstretch import MorphStretch
3032

3133
# List of morphs
@@ -40,6 +42,8 @@
4042
MorphISpheroid,
4143
MorphResolutionDamping,
4244
MorphShift,
45+
MorphSqueeze,
46+
MorphFuncy,
4347
]
4448

4549
# End of file

src/diffpy/morph/morphs/morphfuncy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class MorphFuncy(Morph):
1111
yinlabel = LABEL_GR
1212
xoutlabel = LABEL_RA
1313
youtlabel = LABEL_GR
14+
parnames = ["parameters"]
1415

1516
def morph(self, x_morph, y_morph, x_target, y_target):
1617
"""General morph function that applies a user-supplied function to the

src/diffpy/morph/refine.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
class Refiner(object):
2727
"""Class for refining a Morph or MorphChain.
2828
29-
This is provided to allow for custom residuals and refinement algorithms.
30-
3129
Attributes
3230
----------
3331
chain
@@ -51,12 +49,27 @@ def __init__(self, chain, x_morph, y_morph, x_target, y_target):
5149
self.y_target = y_target
5250
self.pars = []
5351
self.residual = self._residual
52+
self.flat_to_grouped = {}
5453
return
5554

5655
def _update_chain(self, pvals):
5756
"""Update the parameters in the chain."""
58-
pairs = zip(self.pars, pvals)
59-
self.chain.config.update(pairs)
57+
updated = {}
58+
for idx, val in enumerate(pvals):
59+
p, subkey = self.flat_to_grouped[idx]
60+
if subkey is None:
61+
updated[p] = val
62+
else:
63+
if p not in updated:
64+
updated[p] = {} if isinstance(subkey, str) else []
65+
if isinstance(updated[p], dict):
66+
updated[p][subkey] = val
67+
else:
68+
while len(updated[p]) <= subkey:
69+
updated[p].append(0.0)
70+
updated[p][subkey] = val
71+
72+
self.chain.config.update(updated)
6073
return
6174

6275
def _residual(self, pvals):
@@ -118,20 +131,38 @@ def refine(self, *args, **kw):
118131
if not self.pars:
119132
return 0.0
120133

121-
initial = [config[p] for p in self.pars]
134+
# Build flat list of initial parameters and flat_to_grouped mapping
135+
initial = []
136+
self.flat_to_grouped = {}
137+
138+
for p in self.pars:
139+
val = config[p]
140+
if isinstance(val, dict):
141+
for k, v in val.items():
142+
initial.append(v)
143+
self.flat_to_grouped[len(initial) - 1] = (p, k)
144+
elif isinstance(val, list):
145+
for i, v in enumerate(val):
146+
initial.append(v)
147+
self.flat_to_grouped[len(initial) - 1] = (p, i)
148+
else:
149+
initial.append(val)
150+
self.flat_to_grouped[len(initial) - 1] = (p, None)
151+
152+
# Perform least squares refinement
122153
sol, cov_sol, infodict, emesg, ier = leastsq(
123154
self.residual, initial, full_output=1
124155
)
125156
fvec = infodict["fvec"]
157+
126158
if ier not in (1, 2, 3, 4):
127-
emesg
128159
raise ValueError(emesg)
129160

130-
# Place the fit parameters in config
161+
# Place the fit parameters back into config
131162
vals = sol
132163
if not hasattr(vals, "__iter__"):
133164
vals = [vals]
134-
self.chain.config.update(zip(self.pars, vals))
165+
self._update_chain(vals)
135166

136167
return dot(fvec, fvec)
137168

tests/test_morph_func.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,45 @@ def test_smear_with_morph_func():
101101
assert np.allclose(y0, y1, atol=1e-3) # numerical error -> 1e-4
102102
# verify morphed param
103103
assert np.allclose(smear, morphed_cfg["smear"], atol=1e-1)
104+
105+
106+
def test_squeeze_with_morph_func():
107+
squeeze_init = [0, -0.001, -0.0001, 0.0001]
108+
x_morph = np.linspace(0, 10, 101)
109+
y_morph = 2 * np.sin(
110+
x_morph + x_morph * (-0.01) - 0.0001 * x_morph**2 + 0.0002 * x_morph**3
111+
)
112+
expected_squeeze = [0, -0.01, -0.0001, 0.0002]
113+
expected_scale = 1 / 2
114+
x_target = x_morph.copy()
115+
y_target = np.sin(x_target)
116+
cfg = morph_default_config(scale=1.1, squeeze=squeeze_init) # off init
117+
morph_rv = morph(x_morph, y_morph, x_target, y_target, **cfg)
118+
morphed_cfg = morph_rv["morphed_config"]
119+
# verified they are morphable
120+
x1, y1, x0, y0 = morph_rv["morph_chain"].xyallout
121+
assert np.allclose(x0, x1)
122+
assert np.allclose(y0, y1, atol=1e-3) # numerical error -> 1e-4
123+
# verify morphed param
124+
assert np.allclose(expected_squeeze, morphed_cfg["squeeze"], atol=1e-4)
125+
assert np.allclose(expected_scale, morphed_cfg["scale"], atol=1e-4)
126+
127+
128+
def test_funcy_with_morph_func():
129+
def linear_function(x, y, scale, offset):
130+
return (scale * x) * y + offset
131+
132+
x_morph = np.linspace(0, 10, 101)
133+
y_morph = np.sin(x_morph)
134+
x_target = x_morph.copy()
135+
y_target = np.sin(x_target) * 2 * x_target + 0.4
136+
cfg = morph_default_config(parameters={"scale": 1.2, "offset": 0.1})
137+
cfg["function"] = linear_function
138+
morph_rv = morph(x_morph, y_morph, x_target, y_target, **cfg)
139+
morphed_cfg = morph_rv["morphed_config"]
140+
x1, y1, x0, y0 = morph_rv["morph_chain"].xyallout
141+
assert np.allclose(x0, x1)
142+
assert np.allclose(y0, y1, atol=1e-6)
143+
fitted_parameters = morphed_cfg["parameters"]
144+
assert np.allclose(fitted_parameters["scale"], 2, atol=1e-6)
145+
assert np.allclose(fitted_parameters["offset"], 0.4, atol=1e-6)

0 commit comments

Comments
 (0)