Skip to content

Commit 9a6c196

Browse files
authored
Splits resampleParams into two parameters and bumps version to 0.0.0.dev2 (#75)
* split resample params in source code * split resample params in rat.cpp * split resample params in tests * made min angle consistent with MATLAB * update submodule * Update controls.py * bumped version
1 parent fe0a45e commit 9a6c196

File tree

8 files changed

+114
-128
lines changed

8 files changed

+114
-128
lines changed

RATapi/controls.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
Field,
77
ValidationError,
88
ValidatorFunctionWrapHandler,
9-
field_validator,
109
model_serializer,
1110
model_validator,
1211
)
1312

1413
from RATapi.utils.custom_errors import custom_pydantic_validation_error
1514
from RATapi.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies
1615

17-
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"]
16+
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleMinAngle", "resampleNPoints", "display"]
1817
update_fields = ["updateFreq", "updatePlotFreq"]
1918
fields = {
2019
"calculate": common_fields,
@@ -41,7 +40,8 @@ class Controls(BaseModel, validate_assignment=True, extra="forbid"):
4140
procedure: Procedures = Procedures.Calculate
4241
parallel: Parallel = Parallel.Single
4342
calcSldDuringFit: bool = False
44-
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
43+
resampleMinAngle: float = Field(0.9, le=1, gt=0)
44+
resampleNPoints: int = Field(50, gt=0)
4545
display: Display = Display.Iter
4646
# Simplex
4747
xTolerance: float = Field(1.0e-6, gt=0.0)
@@ -117,16 +117,6 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle
117117

118118
return validated_self
119119

120-
@field_validator("resampleParams")
121-
@classmethod
122-
def check_resample_params(cls, values: list[float]) -> list[float]:
123-
"""Make sure each of the two values of resampleParams satisfy their conditions."""
124-
if not 0 < values[0] < 1:
125-
raise ValueError("resampleParams[0] must be between 0 and 1")
126-
if values[1] < 0:
127-
raise ValueError("resampleParams[1] must be greater than or equal to 0")
128-
return values
129-
130120
@model_serializer
131121
def serialize(self):
132122
"""Filter fields so only those applying to the chosen procedure are serialized."""

RATapi/examples/absorption/absorption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def absorption():
150150
)
151151

152152
# Now make a controls block and run the code
153-
controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0])
153+
controls = RAT.Controls(parallel="contrasts", resampleNPoints=150)
154154
problem, results = RAT.run(problem, controls)
155155

156156
return problem, results

RATapi/inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ def make_controls(input_controls: RATapi.Controls, checks: Checks) -> Control:
436436
controls.procedure = input_controls.procedure
437437
controls.parallel = input_controls.parallel
438438
controls.calcSldDuringFit = input_controls.calcSldDuringFit
439-
controls.resampleParams = input_controls.resampleParams
439+
controls.resampleMinAngle = input_controls.resampleMinAngle
440+
controls.resampleNPoints = input_controls.resampleNPoints
440441
controls.display = input_controls.display
441442
# Simplex
442443
controls.xTolerance = input_controls.xTolerance

cpp/rat.cpp

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,8 @@ struct Control {
509509
real_T propScale {};
510510
real_T nsTolerance {};
511511
boolean_T calcSldDuringFit {};
512-
py::array_t<real_T> resampleParams;
512+
real_T resampleMinAngle {};
513+
real_T resampleNPoints {};
513514
real_T updateFreq {};
514515
real_T updatePlotFreq {};
515516
real_T nSamples {};
@@ -914,8 +915,8 @@ RAT::struct2_T createStruct2T(const Control& control)
914915
stringToRatArray(control.procedure, control_struct.procedure.data, control_struct.procedure.size);
915916
stringToRatArray(control.display, control_struct.display.data, control_struct.display.size);
916917
control_struct.xTolerance = control.xTolerance;
917-
control_struct.resampleParams[0] = control.resampleParams.at(0);
918-
control_struct.resampleParams[1] = control.resampleParams.at(1);
918+
control_struct.resampleMinAngle = control.resampleMinAngle;
919+
control_struct.resampleNPoints = control.resampleNPoints;
919920
stringToRatArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size);
920921
control_struct.adaptPCR = control.adaptPCR;
921922
control_struct.checks = createStruct3(control.checks);
@@ -1616,7 +1617,8 @@ PYBIND11_MODULE(rat_core, m) {
16161617
.def_readwrite("propScale", &Control::propScale)
16171618
.def_readwrite("nsTolerance", &Control::nsTolerance)
16181619
.def_readwrite("calcSldDuringFit", &Control::calcSldDuringFit)
1619-
.def_readwrite("resampleParams", &Control::resampleParams)
1620+
.def_readwrite("resampleMinAngle", &Control::resampleMinAngle)
1621+
.def_readwrite("resampleNPoints", &Control::resampleNPoints)
16201622
.def_readwrite("updateFreq", &Control::updateFreq)
16211623
.def_readwrite("updatePlotFreq", &Control::updatePlotFreq)
16221624
.def_readwrite("nSamples", &Control::nSamples)
@@ -1633,14 +1635,14 @@ PYBIND11_MODULE(rat_core, m) {
16331635
return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance,
16341636
ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability,
16351637
ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale,
1636-
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq,
1637-
ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR,
1638-
ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift,
1639-
ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
1638+
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleMinAngle, ctrl.resampleNPoints,
1639+
ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma,
1640+
ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam,
1641+
ctrl.checks.fitQzshift, ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
16401642
ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio);
16411643
},
16421644
[](py::tuple t) { // __setstate__
1643-
if (t.size() != 36)
1645+
if (t.size() != 37)
16441646
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");
16451647

16461648
/* Create a new C++ instance */
@@ -1664,25 +1666,26 @@ PYBIND11_MODULE(rat_core, m) {
16641666
ctrl.propScale = t[15].cast<real_T>();
16651667
ctrl.nsTolerance = t[16].cast<real_T>();
16661668
ctrl.calcSldDuringFit = t[17].cast<boolean_T>();
1667-
ctrl.resampleParams = t[18].cast<py::array_t<real_T>>();
1668-
ctrl.updateFreq = t[19].cast<real_T>();
1669-
ctrl.updatePlotFreq = t[20].cast<real_T>();
1670-
ctrl.nSamples = t[21].cast<real_T>();
1671-
ctrl.nChains = t[22].cast<real_T>();
1672-
ctrl.jumpProbability = t[23].cast<real_T>();
1673-
ctrl.pUnitGamma = t[24].cast<real_T>();
1674-
ctrl.boundHandling = t[25].cast<std::string>();
1675-
ctrl.adaptPCR = t[26].cast<boolean_T>();
1676-
ctrl.IPCFilePath = t[27].cast<std::string>();
1669+
ctrl.resampleMinAngle = t[18].cast<real_T>();
1670+
ctrl.resampleNPoints = t[19].cast<real_T>();
1671+
ctrl.updateFreq = t[20].cast<real_T>();
1672+
ctrl.updatePlotFreq = t[21].cast<real_T>();
1673+
ctrl.nSamples = t[22].cast<real_T>();
1674+
ctrl.nChains = t[23].cast<real_T>();
1675+
ctrl.jumpProbability = t[24].cast<real_T>();
1676+
ctrl.pUnitGamma = t[25].cast<real_T>();
1677+
ctrl.boundHandling = t[26].cast<std::string>();
1678+
ctrl.adaptPCR = t[27].cast<boolean_T>();
1679+
ctrl.IPCFilePath = t[28].cast<std::string>();
16771680

1678-
ctrl.checks.fitParam = t[28].cast<py::array_t<real_T>>();
1679-
ctrl.checks.fitBackgroundParam = t[29].cast<py::array_t<real_T>>();
1680-
ctrl.checks.fitQzshift = t[30].cast<py::array_t<real_T>>();
1681-
ctrl.checks.fitScalefactor = t[31].cast<py::array_t<real_T>>();
1682-
ctrl.checks.fitBulkIn = t[32].cast<py::array_t<real_T>>();
1683-
ctrl.checks.fitBulkOut = t[33].cast<py::array_t<real_T>>();
1684-
ctrl.checks.fitResolutionParam = t[34].cast<py::array_t<real_T>>();
1685-
ctrl.checks.fitDomainRatio = t[35].cast<py::array_t<real_T>>();
1681+
ctrl.checks.fitParam = t[29].cast<py::array_t<real_T>>();
1682+
ctrl.checks.fitBackgroundParam = t[30].cast<py::array_t<real_T>>();
1683+
ctrl.checks.fitQzshift = t[31].cast<py::array_t<real_T>>();
1684+
ctrl.checks.fitScalefactor = t[32].cast<py::array_t<real_T>>();
1685+
ctrl.checks.fitBulkIn = t[33].cast<py::array_t<real_T>>();
1686+
ctrl.checks.fitBulkOut = t[34].cast<py::array_t<real_T>>();
1687+
ctrl.checks.fitResolutionParam = t[35].cast<py::array_t<real_T>>();
1688+
ctrl.checks.fitDomainRatio = t[36].cast<py::array_t<real_T>>();
16861689

16871690
return ctrl;
16881691
}));

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from setuptools.command.build_clib import build_clib
99
from setuptools.command.build_ext import build_ext
1010

11-
__version__ = "0.0.0.dev1"
11+
__version__ = "0.0.0.dev2"
1212
PACKAGE_NAME = "RATapi"
1313

1414
with open("README.md") as f:

0 commit comments

Comments
 (0)