Skip to content

Commit

Permalink
Allow setting input prefix name when converting to function
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Apr 17, 2024
1 parent 4d15127 commit 9bbca1e
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 16 deletions.
11 changes: 7 additions & 4 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,16 @@ def get_free_parameters_names(self) -> list[str]:
return self._free_parameters

def to_function(
self, name: str = "opti_function", options: dict = None
self,
input_name_prefix: str,
function_name: str = "opti_function",
options: dict = None,
) -> cs.Function:
self._cost = self._cost if self._cost is not None else cs.MX(0)
self._solver.minimize(self._cost)

# Prepend guess to the variable names
guess_names = ["guess." + name for name in self._objects_dict]
# Prepend input_name_prefix to the variable names
guess_names = [input_name_prefix + name for name in self._objects_dict]
all_variables_values = list(self._objects_dict.values())

# Workaround for https://github.com/casadi/casadi/issues/3655
Expand All @@ -622,7 +625,7 @@ def to_function(

options = {} if options is None else options
return self._solver.to_function(
name,
function_name,
all_variables_values,
output_variables,
guess_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ def compute_state(
pf_parametric_link_length_multipliers
)
pf_function = pf_input.to_function(
name="pose_finder", options={"error_on_fail": True}
input_name_prefix="pf_in.",
function_name="pose_finder",
options={"error_on_fail": True},
)
pf_guess_dict = pf_guess.to_dict(prefix="guess.")
pf_guess_dict = pf_guess.to_dict(prefix="pf_in.")
output_pf_dict = pf_function(**pf_guess_dict)

for k in output_pf_dict:
Expand Down Expand Up @@ -549,10 +551,12 @@ def get_references(
)

planner_function = planner.to_function(
name="kinodynamic_walking", options={"error_on_fail": True}
input_name_prefix="in.",
function_name="kinodynamic_walking",
options={"error_on_fail": True},
)

initial_guess_dict = planner_guess.to_dict(prefix="guess.")
initial_guess_dict = planner_guess.to_dict(prefix="in.")
initial_guess_dict_pruned = {}
for key in initial_guess_dict:
if isinstance(initial_guess_dict[key], np.ndarray) or isinstance(
Expand Down
24 changes: 19 additions & 5 deletions src/hippopt/turnkey_planners/humanoid_kinodynamic/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,11 +1071,16 @@ def solve(self) -> hp.Output[Variables]:
return output

def to_function(
self, name: str = "opti_function", options: dict = None
self,
input_name_prefix: str,
function_name: str = "opti_function",
options: dict = None,
) -> cs.Function:

inner_function = self.optimization_solver.to_function(
name=name + "_internal", options=options
input_name_prefix=input_name_prefix,
function_name=function_name + "_internal",
options=options,
)
variable_names = inner_function.name_in()
variables_list = []
Expand All @@ -1087,9 +1092,13 @@ def to_function(
optimization_structure = copy.deepcopy(
self.optimization_solver.get_optimization_objects()
)
optimization_structure.from_dict(input_dict=variables_sym, prefix="guess.")
optimization_structure.from_dict(
input_dict=variables_sym, prefix=input_name_prefix
)
mass_regularized_vars = self._apply_mass_regularization(optimization_structure)
output_dict = inner_function(**mass_regularized_vars.to_dict(prefix="guess."))
output_dict = inner_function(
**mass_regularized_vars.to_dict(prefix=input_name_prefix)
)
optimization_structure = copy.deepcopy(
self.optimization_solver.get_optimization_objects()
)
Expand All @@ -1104,7 +1113,12 @@ def to_function(
output_names.append(n)

return cs.Function(
name, variables_list, output_values, variable_names, output_names, options
function_name,
variables_list,
output_values,
variable_names,
output_names,
options,
)

def get_adam_model(self) -> adam.model.Model:
Expand Down
11 changes: 9 additions & 2 deletions src/hippopt/turnkey_planners/humanoid_pose_finder/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,16 @@ def solve(self) -> hp.Output[Variables]:
return self.op.problem.solve()

def to_function(
self, name: str = "opti_function", options: dict = None
self,
input_name_prefix: str,
function_name: str = "opti_function",
options: dict = None,
) -> cs.Function:
return self.optimization_solver.to_function(name=name, options=options)
return self.optimization_solver.to_function(
input_name_prefix=input_name_prefix,
function_name=function_name,
options=options,
)

def get_adam_model(self) -> adam.model.Model:
if self.parametric_model:
Expand Down
2 changes: 1 addition & 1 deletion test/test_optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_opti_to_function():
),
)

opti_function = opti_solver.to_function()
opti_function = opti_solver.to_function(input_name_prefix="guess.")
output_dict = opti_function(**initial_guess.to_dict(prefix="guess."))
output = MyTestVarAndPar()
output.from_dict(output_dict)
Expand Down

0 comments on commit 9bbca1e

Please sign in to comment.