diff --git a/pysr/export_jax.py b/pysr/export_jax.py index e1730ca4..1a03b454 100644 --- a/pysr/export_jax.py +++ b/pysr/export_jax.py @@ -69,7 +69,7 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None): _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func] except KeyError: raise KeyError( - f"Function {expr.func} was not found in JAX function mappings." + f"Function {expr.func} was not found in JAX function mappings. " "Please add it to extra_jax_mappings in the format, e.g., " "{sympy.sqrt: 'jnp.sqrt'}." ) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 7fcb67e8..0b9113d7 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -33,6 +33,70 @@ def _initialize_torch(): torch = _torch + # Allows PyTorch to map Piecewise functions: + def expr_cond_pair(expr, cond): + if isinstance(cond, torch.Tensor) and not isinstance(expr, torch.Tensor): + expr = torch.tensor(expr, dtype=cond.dtype, device=cond.device) + elif isinstance(expr, torch.Tensor) and not isinstance(cond, torch.Tensor): + cond = torch.tensor(cond, dtype=expr.dtype, device=expr.device) + else: + return expr, cond + + # First, make sure expr and cond are same size: + if expr.shape != cond.shape: + if len(expr.shape) == 0: + expr = expr.expand(cond.shape) + elif len(cond.shape) == 0: + cond = cond.expand(expr.shape) + else: + raise ValueError( + "expr and cond must have same shape, or one must be a scalar." + ) + return expr, cond + + def if_then_else(*conds): + a, b, c = conds + return torch.where( + a, torch.where(b, True, False), torch.where(c, True, False) + ) + + def piecewise(*expr_conds): + output = None + already_used = None + for expr, cond in expr_conds: + if not isinstance(cond, torch.Tensor) and not isinstance( + expr, torch.Tensor + ): + # When we just have scalars, have to do this a bit more complicated + # due to the fact that we need to evaluate on the correct device. + if output is None: + already_used = cond + output = expr if cond else 0.0 + else: + if not isinstance(output, torch.Tensor): + output += expr if cond and not already_used else 0.0 + already_used = already_used or cond + else: + expr = torch.tensor( + expr, dtype=output.dtype, device=output.device + ).expand(output.shape) + output += torch.where( + cond & ~already_used, expr, torch.zeros_like(expr) + ) + already_used = already_used | cond + else: + if output is None: + already_used = cond + output = torch.where(cond, expr, torch.zeros_like(expr)) + else: + output += torch.where( + cond.bool() & ~already_used, expr, torch.zeros_like(expr) + ) + already_used = already_used | cond.bool() + return output + + # TODO: Add test that makes sure tensors are on the same device + _global_func_lookup = { sympy.Mul: _reduce(torch.mul), sympy.Add: _reduce(torch.add), @@ -81,6 +145,12 @@ def _initialize_torch(): sympy.Heaviside: torch.heaviside, sympy.core.numbers.Half: (lambda: 0.5), sympy.core.numbers.One: (lambda: 1.0), + sympy.logic.boolalg.Boolean: lambda x: x, + sympy.logic.boolalg.BooleanTrue: (lambda: True), + sympy.logic.boolalg.BooleanFalse: (lambda: False), + sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair, + sympy.Piecewise: piecewise, + sympy.logic.boolalg.ITE: if_then_else, } class _Node(torch.nn.Module): @@ -125,7 +195,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._torch_func = _func_lookup[expr.func] except KeyError: raise KeyError( - f"Function {expr.func} was not found in Torch function mappings." + f"Function {expr.func} was not found in Torch function mappings. " "Please add it to extra_torch_mappings in the format, e.g., " "{sympy.sqrt: torch.sqrt}." ) @@ -153,7 +223,13 @@ def forward(self, memodict): arg_ = arg(memodict) memodict[arg] = arg_ args.append(arg_) - return self._torch_func(*args) + try: + return self._torch_func(*args) + except Exception as err: + # Add information about the current node to the error: + raise type(err)( + f"Error occurred in node {self._sympy_func} with args {args}" + ) class _SingleSymPyModule(torch.nn.Module): """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""