Skip to content

Commit

Permalink
Refactor sympy and export functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Sep 17, 2023
1 parent f89d890 commit b2d7f41
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 104 deletions.
14 changes: 7 additions & 7 deletions pysr/export_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _print_Float(self, expr):
return super()._print_Float(reduced_float)


def to_latex(expr, prec=3, full_prec=True, **settings):
def sympy2latex(expr, prec=3, full_prec=True, **settings):
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
Expand Down Expand Up @@ -56,7 +56,7 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
return top_latex_table, bottom_latex_table


def generate_single_table(
def sympy2latextable(
equations: pd.DataFrame,
indices: List[int] = None,
precision: int = 3,
Expand All @@ -74,16 +74,16 @@ def generate_single_table(
indices = range(len(equations))

for i in indices:
latex_equation = to_latex(
latex_equation = sympy2latex(
equations.iloc[i]["sympy_format"],
prec=precision,
)
complexity = str(equations.iloc[i]["complexity"])
loss = to_latex(
loss = sympy2latex(
sympy.Float(equations.iloc[i]["loss"]),
prec=precision,
)
score = to_latex(
score = sympy2latex(
sympy.Float(equations.iloc[i]["score"]),
prec=precision,
)
Expand Down Expand Up @@ -124,7 +124,7 @@ def generate_single_table(
return "\n".join([latex_top, *latex_table_content, latex_bottom])


def generate_multiple_tables(
def sympy2multilatextable(
equations: List[pd.DataFrame],
indices: List[List[int]] = None,
precision: int = 3,
Expand All @@ -135,7 +135,7 @@ def generate_multiple_tables(
# TODO: Let user specify custom output variable

latex_tables = [
generate_single_table(
sympy2latextable(
equations[i],
(None if not indices else indices[i]),
precision=precision,
Expand Down
9 changes: 6 additions & 3 deletions pysr/export_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from sympy import lambdify


def sympy2numpy(eqn, sympy_symbols, *, selection=None):
return CallableEquation(eqn, sympy_symbols, selection=selection)


class CallableEquation:
"""Simple wrapper for numpy lambda functions built with sympy"""

def __init__(self, sympy_symbols, eqn, selection=None, variable_names=None):
def __init__(self, eqn, sympy_symbols, selection=None):
self._sympy = eqn
self._sympy_symbols = sympy_symbols
self._selection = selection
self._variable_names = variable_names

def __repr__(self):
return f"PySRFunction(X=>{self._sympy})"
Expand All @@ -23,7 +26,7 @@ def __call__(self, X):
if isinstance(X, pd.DataFrame):
# Lambda function takes as argument:
return self._lambda(
**{k: X[k].values for k in self._variable_names}
**{k: X[k].values for k in map(str, self._sympy_symbols)}
) * np.ones(expected_shape)
if self._selection is not None:
if X.shape[1] != len(self._selection):
Expand Down
72 changes: 72 additions & 0 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Define utilities to export to sympy"""
from typing import Callable, Dict, List, Optional

import sympy
from sympy import sympify

sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
"sqrt": lambda x: sympy.sqrt(x),
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
"square": lambda x: x**2,
"cube": lambda x: x**3,
"plus": lambda x, y: x + y,
"sub": lambda x, y: x - y,
"neg": lambda x: -x,
"pow": lambda x, y: x**y,
"pow_abs": lambda x, y: abs(x) ** y,
"cos": sympy.cos,
"sin": sympy.sin,
"tan": sympy.tan,
"cosh": sympy.cosh,
"sinh": sympy.sinh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"acos": sympy.acos,
"asin": sympy.asin,
"atan": sympy.atan,
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
"erfc": sympy.erfc,
"log": lambda x: sympy.log(x),
"log10": lambda x: sympy.log(x, 10),
"log2": lambda x: sympy.log(x, 2),
"log1p": lambda x: sympy.log(x + 1),
"log_abs": lambda x: sympy.log(abs(x)),
"log10_abs": lambda x: sympy.log(abs(x), 10),
"log2_abs": lambda x: sympy.log(abs(x), 2),
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
"floor": sympy.floor,
"ceil": sympy.ceiling,
"sign": sympy.sign,
"gamma": sympy.gamma,
}


def create_sympy_symbols(
feature_names_in: Optional[List[str]] = None,
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
) -> sympy.Expr:
local_sympy_mappings = {
**(extra_sympy_mappings if extra_sympy_mappings else {}),
**sympy_mappings,
}

return sympify(equation, locals=local_sympy_mappings)


def assert_valid_sympy_symbol(var_name: str) -> None:
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
raise ValueError(f"Variable name {var_name} is already a function name.")
110 changes: 26 additions & 84 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import numpy as np
import pandas as pd
import sympy
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.utils import check_array, check_consistent_length, check_random_state
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
from sympy import sympify

from .deprecated import make_deprecated_kwargs_for_pysr_regressor
from .export_latex import generate_multiple_tables, generate_single_table, to_latex
from .export_numpy import CallableEquation
from .export_jax import sympy2jax
from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
from .export_numpy import sympy2numpy
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
from .export_torch import sympy2torch
from .julia_helpers import (
_escape_filename,
_load_backend,
Expand All @@ -37,51 +38,6 @@

already_ran = False

sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
"sqrt": lambda x: sympy.sqrt(x),
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
"square": lambda x: x**2,
"cube": lambda x: x**3,
"plus": lambda x, y: x + y,
"sub": lambda x, y: x - y,
"neg": lambda x: -x,
"pow": lambda x, y: x**y,
"pow_abs": lambda x, y: abs(x) ** y,
"cos": sympy.cos,
"sin": sympy.sin,
"tan": sympy.tan,
"cosh": sympy.cosh,
"sinh": sympy.sinh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"acos": sympy.acos,
"asin": sympy.asin,
"atan": sympy.atan,
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
"erfc": sympy.erfc,
"log": lambda x: sympy.log(x),
"log10": lambda x: sympy.log(x, 10),
"log2": lambda x: sympy.log(x, 2),
"log1p": lambda x: sympy.log(x + 1),
"log_abs": lambda x: sympy.log(abs(x)),
"log10_abs": lambda x: sympy.log(abs(x), 10),
"log2_abs": lambda x: sympy.log(abs(x), 2),
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
"floor": sympy.floor,
"ceil": sympy.ceiling,
"sign": sympy.sign,
"gamma": sympy.gamma,
}


def pysr(X, y, weights=None, **kwargs): # pragma: no cover
warnings.warn(
Expand Down Expand Up @@ -188,17 +144,14 @@ def _check_assertions(
assert len(variable_names) == X.shape[1]
# Check none of the variable names are function names:
for var_name in variable_names:
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
raise ValueError(
f"Variable name {var_name} is already a function name."
)
# Check if alphanumeric only:
if not re.match(r"^[₀₁₂₃₄₅₆₇₈₉a-zA-Z0-9_]+$", var_name):
raise ValueError(
f"Invalid variable name {var_name}. "
"Only alphanumeric characters, numbers, "
"and underscores are allowed."
)
assert_valid_sympy_symbol(var_name)
if X_units is not None and len(X_units) != X.shape[1]:
raise ValueError(
"The number of units in `X_units` must equal the number of features in `X`."
Expand Down Expand Up @@ -2116,10 +2069,10 @@ def latex(self, index=None, precision=3):
if self.nout_ > 1:
output = []
for s in sympy_representation:
latex = to_latex(s, prec=precision)
latex = sympy2latex(s, prec=precision)
output.append(latex)
return output
return to_latex(sympy_representation, prec=precision)
return sympy2latex(sympy_representation, prec=precision)

def jax(self, index=None):
"""
Expand Down Expand Up @@ -2282,53 +2235,41 @@ def get_hof(self):
jax_format = []
if self.output_torch_format:
torch_format = []
local_sympy_mappings = {
**(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
**sympy_mappings,
}

sympy_symbols = [
sympy.Symbol(variable) for variable in self.feature_names_in_
]

for _, eqn_row in output.iterrows():
eqn = sympify(eqn_row["equation"], locals=local_sympy_mappings)
eqn = pysr2sympy(
eqn_row["equation"],
extra_sympy_mappings=self.extra_sympy_mappings,
)
sympy_format.append(eqn)

# Numpy:
# NumPy:
sympy_symbols = create_sympy_symbols(self.feature_names_in_)
lambda_format.append(
CallableEquation(
sympy_symbols, eqn, self.selection_mask_, self.feature_names_in_
sympy2numpy(
eqn,
sympy_symbols,
selection=self.selection_mask_,
)
)

# JAX:
if self.output_jax_format:
from .export_jax import sympy2jax

func, params = sympy2jax(
eqn,
sympy_symbols,
selection=self.selection_mask_,
extra_jax_mappings=(
self.extra_jax_mappings if self.extra_jax_mappings else {}
),
extra_jax_mappings=self.extra_jax_mappings,
)
jax_format.append({"callable": func, "parameters": params})

# Torch:
if self.output_torch_format:
from .export_torch import sympy2torch

module = sympy2torch(
eqn,
sympy_symbols,
selection=self.selection_mask_,
extra_torch_mappings=(
self.extra_torch_mappings
if self.extra_torch_mappings
else {}
),
extra_torch_mappings=self.extra_torch_mappings,
)
torch_format.append(module)

Expand Down Expand Up @@ -2410,17 +2351,18 @@ def latex_table(
assert isinstance(indices[0], list)
assert len(indices) == self.nout_

generator_fnc = generate_multiple_tables
table_string = sympy2multilatextable(
self.equations_, indices=indices, precision=precision, columns=columns
)
else:
if indices is not None:
assert isinstance(indices, list)
assert isinstance(indices[0], int)

generator_fnc = generate_single_table
table_string = sympy2latextable(
self.equations_, indices=indices, precision=precision, columns=columns
)

table_string = generator_fnc(
self.equations_, indices=indices, precision=precision, columns=columns
)
preamble_string = [
r"\usepackage{breqn}",
r"\usepackage{booktabs}",
Expand Down
19 changes: 9 additions & 10 deletions pysr/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import numpy as np
import pandas as pd
import sympy
from sklearn import model_selection
from sklearn.utils.estimator_checks import check_estimator

from .. import PySRRegressor, julia_helpers
from ..export_latex import to_latex
from ..export_latex import sympy2latex
from ..sr import (
_check_assertions,
_csv_filename_to_pkl_filename,
Expand Down Expand Up @@ -884,23 +883,23 @@ def test_multi_output(self):
def test_latex_float_precision(self):
"""Test that we can print latex expressions with custom precision"""
expr = sympy.Float(4583.4485748, dps=50)
self.assertEqual(to_latex(expr, prec=6), r"4583.45")
self.assertEqual(to_latex(expr, prec=5), r"4583.4")
self.assertEqual(to_latex(expr, prec=4), r"4583.")
self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
self.assertEqual(sympy2latex(expr, prec=6), r"4583.45")
self.assertEqual(sympy2latex(expr, prec=5), r"4583.4")
self.assertEqual(sympy2latex(expr, prec=4), r"4583.")
self.assertEqual(sympy2latex(expr, prec=3), r"4.58 \cdot 10^{3}")
self.assertEqual(sympy2latex(expr, prec=2), r"4.6 \cdot 10^{3}")

# Multiple numbers:
x = sympy.Symbol("x")
expr = x * 3232.324857384 - 1.4857485e-10
self.assertEqual(
to_latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
sympy2latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
)
self.assertEqual(
to_latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
sympy2latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
)
self.assertEqual(
to_latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
sympy2latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
)

def test_latex_break_long_equation(self):
Expand Down

0 comments on commit b2d7f41

Please sign in to comment.