diff --git a/Dockerfile b/Dockerfile index 859e82b4..ed0b5f89 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ ADD ./pysr /pysr/pysr RUN pip3 install --no-cache-dir . # Install Julia pre-requisites: -RUN python3 -c 'import pysr' +RUN python3 -c 'import pysr; pysr.load_all_packages()' # metainformation LABEL org.opencontainers.image.authors = "Miles Cranmer" diff --git a/pysr/__init__.py b/pysr/__init__.py index fe204dae..b40ee840 100644 --- a/pysr/__init__.py +++ b/pysr/__init__.py @@ -7,6 +7,7 @@ from .deprecated import best, best_callable, best_row, best_tex, install, pysr from .export_jax import sympy2jax from .export_torch import sympy2torch +from .julia_extensions import load_all_packages from .sr import PySRRegressor # This file is created by setuptools_scm during the build process: @@ -19,6 +20,7 @@ "sympy2jax", "sympy2torch", "install", + "load_all_packages", "PySRRegressor", "best", "best_callable", diff --git a/pysr/julia_extensions.py b/pysr/julia_extensions.py index 5c537105..b71f8acd 100644 --- a/pysr/julia_extensions.py +++ b/pysr/julia_extensions.py @@ -22,6 +22,17 @@ def load_required_packages( load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e") +def load_all_packages(): + """Install and load all Julia extensions available to PySR.""" + load_required_packages( + turbo=True, bumper=True, enable_autodiff=True, cluster_manager="slurm" + ) + + +# TODO: Refactor this file so we can install all packages at once using `juliapkg`, +# ideally parameterizable via the regular Python extras API + + def isinstalled(uuid_s: str): return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s)) diff --git a/pysr/test/test.py b/pysr/test/test.py index 00a25444..c641e9f6 100644 --- a/pysr/test/test.py +++ b/pysr/test/test.py @@ -12,7 +12,7 @@ import sympy # type: ignore from sklearn.utils.estimator_checks import check_estimator -from pysr import PySRRegressor, install, jl +from pysr import PySRRegressor, install, jl, load_all_packages from pysr.export_latex import sympy2latex from pysr.feature_selection import _handle_feature_selection, run_feature_selection from pysr.julia_helpers import init_julia @@ -739,6 +739,11 @@ def test_param_groupings(self): # Check the sets are equal: self.assertSetEqual(set(params), set(regressor_params)) + def test_load_all_packages(self): + """Test we can load all packages at once.""" + load_all_packages() + self.assertTrue(jl.seval("ClusterManagers isa Module")) + class TestHelpMessages(unittest.TestCase): """Test user help messages."""