Skip to content

Commit

Permalink
Update domain_reduction.py (#454)
Browse files Browse the repository at this point in the history
* Update domain_reduction.py

Adds ability to pass a dictionary as the minimum_window argument in SequentialDomainReductionTransformer. Deals with issue where TargetSpace sorts the pBounds dictionary, causing the order of the list used for minimum_window to possibly not match the order of the stored boundaries.

* Update test_seq_domain_red.py

Added test_minimum_window_dict_ordering to test whether dictionary input for SequentialDomainReductionTransformer is ordered to match pBounds.
  • Loading branch information
lm314 authored Jan 22, 2024
1 parent 844927b commit 129caac
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
10 changes: 7 additions & 3 deletions bayes_opt/domain_reduction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, List
from typing import Optional, Union, List, Dict

import numpy as np
from .target_space import TargetSpace
Expand Down Expand Up @@ -29,12 +29,16 @@ def __init__(
gamma_osc: float = 0.7,
gamma_pan: float = 1.0,
eta: float = 0.9,
minimum_window: Optional[Union[List[float], float]] = 0.0
minimum_window: Optional[Union[List[float], float, Dict[str, float]]] = 0.0
) -> None:
self.gamma_osc = gamma_osc
self.gamma_pan = gamma_pan
self.eta = eta
self.minimum_window_value = minimum_window
if isinstance(minimum_window, dict):
self.minimum_window_value = [item[1] for item in sorted(minimum_window.items(), key=lambda x: x[0])]
else:
self.minimum_window_value = minimum_window


def initialize(self, target_space: TargetSpace) -> None:
"""Initialize all of the parameters.
Expand Down
17 changes: 16 additions & 1 deletion tests/test_seq_domain_red.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,24 @@ def verify_bounds_in_range(new_bounds, global_bounds):
trimmed_bounds = bounds_transformer._trim(new_bounds, global_bounds)
assert verify_bounds_in_range(trimmed_bounds, global_bounds)


def test_minimum_window_dict_ordering():
"""Tests if dictionary input for minimum_window is reordered the same as pbounds"""
window_ranges = {'y': 1, 'x': 3,'w': 1e5}
bounds_transformer = SequentialDomainReductionTransformer(minimum_window=window_ranges)
pbounds = {'y': (-1, 1),'w':(-1e6,1e6), 'x': (-10, 10)}

_ = BayesianOptimization(
f=None,
pbounds=pbounds,
verbose=0,
random_state=1,
bounds_transformer=bounds_transformer
)

if __name__ == '__main__':
r"""
CommandLine:
python tests/test_seq_domain_red.py
"""
pytest.main([__file__])
pytest.main([__file__])

0 comments on commit 129caac

Please sign in to comment.