Skip to content

Commit 2f39aa5

Browse files
committed
Add Hypersphere TestProblem to optimization fixtures
1 parent 386fcf7 commit 2f39aa5

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/optimization_problem_fixtures.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
__all__ = [
1717
'Quadratic',
18+
'Hypersphere',
1819
'Rosenbrock',
1920
'LinearConstraintsSooTestProblem',
2021
'LinearConstraintsSooTestProblem2',
@@ -125,6 +126,50 @@ def test_if_solved(self, optimization_results: OptimizationResults,
125126
np.testing.assert_allclose(x, x_true, **test_kwargs)
126127

127128

129+
class Hypersphere(TestProblem):
130+
def __init__(self, *args, n_var=1, **kwargs):
131+
super().__init__('quadratic', *args, **kwargs)
132+
133+
for i in range(n_var):
134+
self.add_variable(f'x_{i}', lb=-1, ub=1)
135+
136+
def objective_factory(i):
137+
def objective(x):
138+
return x[i]
139+
140+
return objective
141+
142+
for i in range(n_var):
143+
objective = objective_factory(i)
144+
145+
self.add_objective(objective, name=f"f_{i}")
146+
147+
self.add_nonlinear_constraint(self.hypersphere_constraint_violation)
148+
149+
@staticmethod
150+
def hypersphere_constraint_violation(point, center=None, radius=1):
151+
"""
152+
Calculate the constraint violation for a point relative to an n-dimensional hypersphere.
153+
154+
Args:
155+
- center (list/tuple): The coordinates of the center of the hypersphere.
156+
- radius (float): The radius of the hypersphere.
157+
- point (list/tuple): The coordinates of the point to check.
158+
159+
Returns:
160+
- float: Negative value if the point is inside the hypersphere, positive if outside,
161+
where the value represents the squared distance from the surface of the hypersphere.
162+
"""
163+
if center is None:
164+
center = np.zeros((len(point),))
165+
166+
if len(center) != len(point):
167+
raise ValueError("Center and point must have the same dimensions")
168+
169+
squared_distance = sum((c - p) ** 2 for c, p in zip(center, point))
170+
return squared_distance - radius ** 2
171+
172+
128173
class Rosenbrock(TestProblem):
129174
def __init__(self, *args, n_var=2, **kwargs):
130175
super().__init__('rosenbrock', *args, **kwargs)

0 commit comments

Comments
 (0)