From 616ec5c432271fa81db73d05e056677884bf7af9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 16 Jun 2024 14:02:09 +0100 Subject: [PATCH] test: list-like variable complexity --- pysr/test/test.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pysr/test/test.py b/pysr/test/test.py index f394b861..55fc5cb2 100644 --- a/pysr/test/test.py +++ b/pysr/test/test.py @@ -172,6 +172,26 @@ def test_multioutput_custom_operator_quiet_custom_complexity(self): self.assertLessEqual(mse1, 1e-4) self.assertLessEqual(mse2, 1e-4) + def test_custom_variable_complexity(self): + y = self.X[:, [0, 1]] ** 2 + model = PySRRegressor( + binary_operators=["*", "+"], + verbosity=0, + **self.default_test_kwargs, + early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7", + ) + model.fit( + self.X, + y, + complexity_of_variables=[2, 3] + [100 for _ in range(self.X.shape[1] - 2)], + ) + equations = model.equations_ + self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4) + self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4) + + self.assertEqual(model.get_best()[0]["complexity"], 5) + self.assertEqual(model.get_best()[1]["complexity"], 7) + def test_multioutput_weighted_with_callable_temp_equation(self): X = self.X.copy() y = X[:, [0, 1]] ** 2 @@ -1053,8 +1073,14 @@ def test_unit_checks(self): """This just checks the number of units passed""" use_custom_variable_names = False variable_names = None + complexity_of_variables = 1 weights = None - args = (use_custom_variable_names, variable_names, weights) + args = ( + use_custom_variable_names, + variable_names, + complexity_of_variables, + weights, + ) valid_units = [ (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"), (np.ones((10, 1)), np.ones(10), ["m/s"], None),