Skip to content

Commit

Permalink
Added possibility to have scalar optimization objects
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Aug 4, 2023
1 parent 54cf002 commit 3f3979a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 59 deletions.
127 changes: 70 additions & 57 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def _generate_opti_object(
if value.ndim < 2:
value = np.expand_dims(value, axis=1)

if isinstance(value, float):
value = value * np.ones((1, 1))

if storage_type is Variable.StorageTypeValue:
return self._solver.variable(*value.shape)

Expand Down Expand Up @@ -280,65 +283,14 @@ def _set_initial_guess_internal(
)

if isinstance(corresponding_value, list):
if not isinstance(guess, list):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ " is supposed to be a list. "
+ "Received "
+ str(type(guess))
+ " instead."
)

if len(corresponding_value) == len(guess):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ " is a list of the wrong size. Expected: "
+ str(len(corresponding_value))
+ ". Guess: "
+ str(len(guess))
)

for i in range(len(corresponding_value)):
if not isinstance(guess[i], np.ndarray):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ "["
+ str(i)
+ "] is not an numpy array."
)

input_shape = (
guess[i].shape
if len(guess[i].shape) > 1
else (guess[i].shape[0], 1)
)

if corresponding_value[i].shape != input_shape:
raise ValueError(
"The dimension of the guess for the field "
+ base_name
+ field.name
+ "["
+ str(i)
+ "] does not match with the corresponding"
+ " optimization variable"
)

self._set_opti_guess(
storage_type=field.metadata[
OptimizationObject.StorageTypeField
],
variable=corresponding_value[i],
value=guess[i],
)
self._set_list_object_guess_internal(
base_name, corresponding_value, field, guess
)
continue

if isinstance(guess, float):
guess = guess * np.ones((1, 1))

if not isinstance(guess, np.ndarray):
raise ValueError(
"The guess for the field "
Expand Down Expand Up @@ -387,6 +339,67 @@ def _set_initial_guess_internal(
)
continue

def _set_list_object_guess_internal(
self,
base_name: str,
corresponding_value: list,
field: dataclasses.Field,
guess: list,
) -> None:
if not isinstance(guess, list):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ " is supposed to be a list. "
+ "Received "
+ str(type(guess))
+ " instead."
)
if len(corresponding_value) == len(guess):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ " is a list of the wrong size. Expected: "
+ str(len(corresponding_value))
+ ". Guess: "
+ str(len(guess))
)
for i in range(len(corresponding_value)):
value = guess[i]
if isinstance(value, float):
value = value * np.ones((1, 1))

if not isinstance(value, np.ndarray):
raise ValueError(
"The guess for the field "
+ base_name
+ field.name
+ "["
+ str(i)
+ "] is supposed to be an array (or even a float if scalar)."
)

input_shape = value.shape if len(value.shape) > 1 else (value.shape[0], 1)

if corresponding_value[i].shape != input_shape:
raise ValueError(
"The dimension of the guess for the field "
+ base_name
+ field.name
+ "["
+ str(i)
+ "] does not match with the corresponding"
+ " optimization variable"
)

self._set_opti_guess(
storage_type=field.metadata[OptimizationObject.StorageTypeField],
variable=corresponding_value[i],
value=value,
)

def generate_optimization_objects(
self, input_structure: TOptimizationObject | list[TOptimizationObject], **kwargs
) -> TOptimizationObject | list[TOptimizationObject]:
Expand Down
2 changes: 1 addition & 1 deletion src/hippopt/base/optimization_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

TOptimizationObject = TypeVar("TOptimizationObject", bound="OptimizationObject")
StorageType = cs.MX | np.ndarray | list[cs.MX] | list[np.ndarray]
StorageType = cs.MX | np.ndarray | float | list[cs.MX] | list[np.ndarray] | list[float]


class TimeExpansion(Enum):
Expand Down
2 changes: 1 addition & 1 deletion test/test_multiple_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class MassFallingTestVariables(OptimizationObject):
foo: StorageType = default_storage_field(Variable)

def __post_init__(self):
self.g = -9.81 * np.ones(1)
self.g = -9.81
self.masses = []
for _ in range(3):
self.masses.append(MassFallingState())
Expand Down
6 changes: 6 additions & 0 deletions test/test_opti_generate_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
class CustomVariable(OptimizationObject):
variable: StorageType = default_storage_field(cls=Variable)
parameter: StorageType = default_storage_field(cls=Parameter)
scalar: StorageType = default_storage_field(cls=Variable)

def __post_init__(self):
self.variable = np.ones(shape=3)
self.parameter = np.ones(shape=3)
self.scalar = 1.0


@dataclasses.dataclass
Expand All @@ -44,6 +46,8 @@ def test_generate_objects():
assert opti_var.aggregated.variable.shape == (3, 1)
assert isinstance(opti_var.other_parameter, cs.MX)
assert opti_var.other_parameter.shape == (3, 1)
assert isinstance(opti_var.aggregated.scalar, cs.MX)
assert opti_var.aggregated.scalar.shape == (1, 1)
assert opti_var.other == "untouched"
assert solver.get_optimization_objects() is opti_var

Expand All @@ -60,6 +64,8 @@ def test_generate_objects_list():
assert opti_var.aggregated.parameter.shape == (3, 1)
assert isinstance(opti_var.aggregated.variable, cs.MX)
assert opti_var.aggregated.variable.shape == (3, 1)
assert isinstance(opti_var.aggregated.scalar, cs.MX)
assert opti_var.aggregated.scalar.shape == (1, 1)
assert isinstance(opti_var.other_parameter, cs.MX)
assert opti_var.other_parameter.shape == (3, 1)
assert opti_var.other == "untouched"
Expand Down

0 comments on commit 3f3979a

Please sign in to comment.