From bcbb0a9c29c264946f5f1bbeab64b917192483b4 Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Tue, 22 Oct 2024 11:12:10 -0400 Subject: [PATCH] #253 Add .export_vars to parallel_start() --- NEWS.md | 8 ++++++++ R/utils-control-par.R | 31 +++++++++++++++++++++++++++---- man/parallel_start.Rd | 11 ++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 7a70dd2f..d4a7c98b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# modeltime 1.3.0.9000 + +Parallel Computation: +- `parallel_start()`: New parameters `.export_vars` and `.packages` allows passing environment variables and packages to the parallel workers. + +Fixes: +- Adam (`adam_reg()`): Fixes #254 + # modeltime 1.3.0 #### Overview diff --git a/R/utils-control-par.R b/R/utils-control-par.R index 26d84bd4..47ac7562 100644 --- a/R/utils-control-par.R +++ b/R/utils-control-par.R @@ -8,6 +8,8 @@ #' #' - "parallel" - Uses the `parallel` and `doParallel` packages #' - "spark" - Uses the `sparklyr` package +#' @param .export_vars Environment variables that can be sent to the workers +#' @param .packages Packages that can be sent to the workers #' #' #' @details @@ -44,7 +46,8 @@ #' @export #' @rdname parallel_start -parallel_start <- function(..., .method = c("parallel", "spark")) { +parallel_start <- function(..., .method = c("parallel", "spark"), + .export_vars = NULL, .packages = NULL) { meth <- tolower(.method[1]) @@ -53,11 +56,31 @@ parallel_start <- function(..., .method = c("parallel", "spark")) { } if (meth == "parallel") { + # Step 1: Create the cluster cl <- parallel::makeCluster(...) + + # Step 2: Register the cluster doParallel::registerDoParallel(cl) - invisible( - parallel::clusterCall(cl, function(x) .libPaths(x), .libPaths()) - ) + + # Step 3: Export variables (if provided) + if (!is.null(.export_vars)) { + parallel::clusterExport(cl, varlist = .export_vars) + } + + # Step 4: Load .packages (if provided) + if (!is.null(.packages)) { + parallel::clusterCall(cl, function(pkgs) { + lapply(pkgs, function(pkg) { + if (!requireNamespace(pkg, quietly = TRUE)) { + stop(paste("Package", pkg, "is not installed.")) + } + library(pkg, character.only = TRUE) + }) + }, .packages) + } + + # Step 5: Set the library paths for each worker + invisible(parallel::clusterCall(cl, function(x) .libPaths(x), .libPaths())) } if (meth == "spark") { diff --git a/man/parallel_start.Rd b/man/parallel_start.Rd index 68032016..45880284 100644 --- a/man/parallel_start.Rd +++ b/man/parallel_start.Rd @@ -5,7 +5,12 @@ \alias{parallel_stop} \title{Start parallel clusters using \code{parallel} package} \usage{ -parallel_start(..., .method = c("parallel", "spark")) +parallel_start( + ..., + .method = c("parallel", "spark"), + .export_vars = NULL, + .packages = NULL +) parallel_stop() } @@ -17,6 +22,10 @@ parallel_stop() \item "parallel" - Uses the \code{parallel} and \code{doParallel} packages \item "spark" - Uses the \code{sparklyr} package }} + +\item{.export_vars}{Environment variables that can be sent to the workers} + +\item{.packages}{Packages that can be sent to the workers} } \description{ Start parallel clusters using \code{parallel} package