Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export scaled controls when running forward models #9932

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,12 @@ def _init_batch_data(
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
prefix: str = "",
) -> dict[int, dict[str, Any]]:
def _add_controls(
controls_config: list[ControlConfig], values: NDArray[np.float64]
controls_config: list[ControlConfig],
values: NDArray[np.float64],
prefix: str = "",
) -> dict[str, Any]:
batch_data_item: dict[str, Any] = {}
value_list = values.tolist()
Expand All @@ -345,13 +348,28 @@ def _add_controls(
else:
variable_value = value_list.pop(0)
control_dict[variable.name] = variable_value
batch_data_item[control.name] = control_dict
batch_data_item[prefix + control.name] = control_dict
return batch_data_item

def _add_controls_with_rescaling(
controls_config: list[ControlConfig], values: NDArray[np.float64]
) -> dict[str, Any]:
batch_data_item = _add_controls(controls_config, values)
if self._opt_model_transforms.variables is not None:
rescaled_item = _add_controls(
controls_config,
self._opt_model_transforms.variables.backward(values),
prefix="rescaled-",
)
batch_data_item.update(rescaled_item)
return batch_data_item

active = evaluator_context.active
realizations = evaluator_context.realizations
return {
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
idx: _add_controls_with_rescaling(
self._everest_config.controls, control_values[idx, :]
)
for idx in range(control_values.shape[0])
if (
idx not in cached_results
Expand Down Expand Up @@ -393,7 +411,12 @@ def _check_suffix(
f"Key {key} has suffixes, a suffix must be specified"
)

if set(controls.keys()) != set(self._everest_config.control_names):
control_names = set(self._everest_config.control_names)
if self._opt_model_transforms.variables is not None:
control_names |= {
"rescaled-" + name for name in self._everest_config.control_names
}
if set(controls.keys()) != control_names:
err_msg = "Mismatch between initialized and provided control names."
raise KeyError(err_msg)

Expand Down
6 changes: 0 additions & 6 deletions src/everest/config/control_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ def ropt_perturbation_type(self) -> PerturbationType:
def ropt_control_type(self) -> VariableType:
return VariableType[self.control_type.upper()]

@property
def has_auto_scale(self) -> bool:
return self.auto_scale or any(
variable.auto_scale for variable in self.variables
)

@model_validator(mode="after")
def validate_variables(self) -> Self:
if self.variables is None:
Expand Down
20 changes: 1 addition & 19 deletions src/everest/config/control_variable_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ropt.enums import VariableType

from .sampler_config import SamplerConfig
from .validation_utils import no_dots_in_string, valid_range
from .validation_utils import no_dots_in_string


class _ControlVariable(BaseModel):
Expand All @@ -34,24 +34,6 @@ class _ControlVariable(BaseModel):
initial value.
""",
)
auto_scale: bool | None = Field(
default=None,
description="""
Can be set to true to re-scale variable from the range
defined by [min, max] to the range defined by scaled_range (default [0, 1])
""",
)
scaled_range: Annotated[tuple[float, float] | None, AfterValidator(valid_range)] = (
Field(
default=None,
description="""
Can be used to set the range of the variable values
after scaling (default = [0, 1]).

This option has no effect if auto_scale is not set.
""",
)
)
min: float | None = Field(
default=None,
description="""
Expand Down
2 changes: 0 additions & 2 deletions src/everest/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def _add_variable(
for key in [
"control_type",
"enabled",
"auto_scale",
"scaled_range",
"min",
"max",
"perturbation_magnitude",
Expand Down
6 changes: 6 additions & 0 deletions src/everest/simulator/everest_to_ert.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,5 +521,11 @@ def _get_variables(
input_keys=_get_variables(control.variables),
output_file=control.name + ".json",
)
if control.auto_scale:
ens_config.parameter_configs["rescaled-" + control.name] = ExtParamConfig(
name="rescaled-" + control.name,
input_keys=_get_variables(control.variables),
output_file="rescaled-" + control.name + ".json",
)

return ert_config
17 changes: 10 additions & 7 deletions test-data/everest/math_func/jobs/distance3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import sys
from pathlib import Path


def compute_distance_squared(p, q):
Expand All @@ -24,23 +25,25 @@ def main(argv):
arg_parser.add_argument("--target-file", type=str)
arg_parser.add_argument("--target", nargs=3, type=float)
arg_parser.add_argument("--out", type=str)
arg_parser.add_argument("--scaling", nargs=4, type=float)
arg_parser.add_argument("--realization", type=float)
options, _ = arg_parser.parse_known_args(args=argv)

point = options.point if options.point else read_point(options.point_file)
point = (
options.point
if options.point
else read_point(
"rescaled-" + options.point_file
if Path("rescaled-" + options.point_file).exists()
else options.point_file
)
)
if len(point) != 3:
raise RuntimeError("Failed parsing point")

target = options.target if options.target else read_point(options.target_file)
if len(target) != 3:
raise RuntimeError("Failed parsing target")

if options.scaling is not None:
min_range, max_range, target_min, target_max = options.scaling
point = [(p - target_min) / (target_max - target_min) for p in point]
point = [p * (max_range - min_range) + min_range for p in point]

value = compute_distance_squared(point, target)
# If any realizations with an index > 0 are passed we make those incorrect
# by taking the negative value. This used by test_cvar.py.
Expand Down
1 change: 0 additions & 1 deletion tests/everest/test_math_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_math_func_auto_scaled_controls(
{"weights": {"point.x": 1.0, "point.y": 1.0}, "upper_bound": 0.5}
],
}
config_dict["forward_model"][0] += " --scaling -1 1 0.3 0.7"
config = EverestConfig.model_validate(config_dict)

# Act
Expand Down
16 changes: 0 additions & 16 deletions tests/everest/test_ropt_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,6 @@ def test_everest2ropt_controls_auto_scale():
assert np.allclose(ropt_config.variables.upper_bounds, 0.7)


def test_everest2ropt_variables_auto_scale():
config = EverestConfig.load_file(os.path.join(_CONFIG_DIR, _CONFIG_FILE))
controls = config.controls
controls[0].variables[1].auto_scale = True
controls[0].variables[1].scaled_range = [0.3, 0.7]
ropt_config = everest2ropt(
config, transforms=get_opt_model_transforms(config.controls)
)
assert ropt_config.variables.lower_bounds[0] == 0.0
assert ropt_config.variables.upper_bounds[0] == 0.1
assert ropt_config.variables.lower_bounds[1] == 0.3
assert ropt_config.variables.upper_bounds[1] == 0.7
assert np.allclose(ropt_config.variables.lower_bounds[2:], 0.0)
assert np.allclose(ropt_config.variables.upper_bounds[2:], 0.1)


def test_everest2ropt_controls_input_constraint():
config = EverestConfig.load_file(
os.path.join(_CONFIG_DIR, "config_input_constraints.yml")
Expand Down