Skip to content

Commit

Permalink
threading outside of estimation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Oct 22, 2024
1 parent 8183b6c commit 65be624
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 58 deletions.
48 changes: 26 additions & 22 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ permute_clusters <- function(model, pcp_mle) {
}
#' Bootstrap Sampling of NHMM Coefficients
#'
#' It is possible to parallelize the bootstrap runs using the `future` package,
#' e.g., by calling `future::plan(multisession, workers = 2)` before
#' `bootstrap_coefs()`. See [future::plan()] for details.
#'
#' @param model An `nhmm` or `mnhmm` object.
#' @param B number of bootstrap samples.
#' @param method Either `"nonparametric"` or `"parametric"`, to define whether
Expand Down Expand Up @@ -92,17 +96,17 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
gamma_pi <- replicate(B, gammas_mle$pi, simplify = FALSE)
gamma_A <- replicate(B, gammas_mle$A, simplify = FALSE)
gamma_B <- replicate(B, gammas_mle$B, simplify = FALSE)

if (verbose) pb <- utils::txtProgressBar(min = 0, max = 100, style = 3)
if (method == "nonparametric") {
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, threads = 1, ...)
fit$gammas <- permute_states(fit$gammas, gammas_mle)
gamma_pi[[i]] <- fit$gammas$pi
gamma_A[[i]] <- fit$gammas$A
gamma_B[[i]] <- fit$gammas$B
if (verbose) utils::setTxtProgressBar(pb, 100 * i/B)
}
out <- future.apply::future_lapply(
seq_len(B), function(i) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, ...)
if (verbose) utils::setTxtProgressBar(pb, 100 * i/B)
permute_states(fit$gammas, gammas_mle)
}
)
} else {
N <- model$n_sequences
T_ <- model$sequence_lengths
Expand All @@ -114,19 +118,19 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
d <- model$data
time <- model$time_variable
id <- model$id_variable
for (i in seq_len(B)) {
mod <- simulate_nhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
data = d, time, id, init)$model
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, threads = 1, ...)
fit$gammas <- permute_states(fit$gammas, gammas_mle)
gamma_pi[[i]] <- fit$gammas$pi
gamma_A[[i]] <- fit$gammas$A
gamma_B[[i]] <- fit$gammas$B
if (verbose) utils::setTxtProgressBar(pb, 100 * i/B)
}
out <- future.apply::future_lapply(
seq_len(B), function(i) {
mod <- simulate_nhmm(
N, T_, M, S, formula_pi, formula_A, formula_B,
data = d, time, id, init)$model
fit <- fit_nhmm(mod, init, init_sd = 0, restarts = 0, ...)
if (verbose) utils::setTxtProgressBar(pb, 100 * i/B)
fit$gammas <- permute_states(fit$gammas, gammas_mle)
}
)
}
if (verbose) close(pb)
browser()
model$boot <- list(gamma_pi = gamma_pi, gamma_A = gamma_A, gamma_B = gamma_B)
model
}
Expand All @@ -152,7 +156,7 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
if (method == "nonparametric") {
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, threads = 1, ...)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
out <- permute_states(
Expand Down Expand Up @@ -185,7 +189,7 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
mod <- simulate_mnhmm(
N, T_, M, S, D, formula_pi, formula_A, formula_B, formula_omega,
data = d, time, id, init)$model
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, threads = 1, ...)
fit <- fit_mnhmm(mod, init, init_sd = 0, restarts = 0, ...)
fit <- permute_clusters(fit, pcp_mle)
for (j in seq_len(D)) {
out <- permute_states(
Expand Down
4 changes: 2 additions & 2 deletions R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ estimate_mnhmm <- function(
transition_formula = ~1, emission_formula = ~1, cluster_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, threads = 1L, store_data = TRUE, ...) {
restarts = 0L, store_data = TRUE, ...) {

call <- match.call()
model <- build_mnhmm(
Expand All @@ -59,7 +59,7 @@ estimate_mnhmm <- function(
if (store_data) {
model$data <- data
}
out <- fit_mnhmm(model, inits, init_sd, restarts, threads, ...)
out <- fit_mnhmm(model, inits, init_sd, restarts, ...)

attr(out, "call") <- call
out
Expand Down
11 changes: 7 additions & 4 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
#' parameters is less than `1e-8`. The covariate data is standardardized before
#' optimization.
#'
#' With multiple runs of optimization (by using the `restarts` argument), it is
#' possible to parallelize these runs using the `future` package, e.g., by
#' calling `future::plan(multisession, workers = 2)` before `estimate_nhmm()`.
#' See [future::plan()] for details.
#'
#' @param observations Either the name of the response variable in `data`, or
#' an `stslist` object (see [TraMineR::seqdef()]) containing the
#' sequences. In case of multichannel data, `observations` should be a vector
Expand Down Expand Up @@ -45,8 +50,6 @@
#' of the regression coefficients to zero, use `init_sd = 0`.
#' @param restarts Number of times to run optimization using random starting
#' values (in addition to the final run). Default is 0.
#' @param threads Number of parallel threads for optimization with restarts.
#' Default is 1.
#' @param store_data If `TRUE` (default), original data frame passed as `data`
#' is stored to the model object. For large datasets, this can be set to
#' `FALSE`, in which case you might need to pass the data separately to some
Expand Down Expand Up @@ -78,7 +81,7 @@ estimate_nhmm <- function(
observations, n_states, initial_formula = ~1,
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 0L, threads = 1L,
inits = "random", init_sd = 2, restarts = 0L,
store_data = TRUE, ...) {

call <- match.call()
Expand All @@ -95,7 +98,7 @@ estimate_nhmm <- function(
if (store_data) {
model$data <- data
}
out <- fit_nhmm(model, inits, init_sd, restarts, threads, ...)
out <- fit_nhmm(model, inits, init_sd, restarts, ...)
attr(out, "call") <- call
out
}
11 changes: 1 addition & 10 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#' Estimate a Mixture Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts, threads,
fit_mnhmm <- function(model, inits, init_sd, restarts,
save_all_solutions = FALSE, ...) {
stopifnot_(
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
)
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
Expand Down Expand Up @@ -209,11 +205,6 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads,
all_solutions <- NULL
start_time <- proc.time()
if (restarts > 0L) {
if (threads > 1L) {
future::plan(future::multisession, workers = threads)
} else {
future::plan(future::sequential)
}
dots$control_restart$algorithm <- dots$algorithm
if (is.null(dots$control_restart$maxeval))
dots$control_restart$maxeval <- dots$maxeval
Expand Down
12 changes: 2 additions & 10 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
#' Estimate a Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_nhmm <- function(model, inits, init_sd, restarts, threads, save_all_solutions = FALSE, ...) {
stopifnot_(
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
)
fit_nhmm <- function(model, inits, init_sd, restarts, save_all_solutions = FALSE, ...) {

stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
Expand Down Expand Up @@ -154,11 +151,6 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, save_all_solution
all_solutions <- NULL
start_time <- proc.time()
if (restarts > 0L) {
if (threads > 1L) {
future::plan(future::multisession, workers = threads)
} else {
future::plan(future::sequential)
}
dots$control_restart$algorithm <- dots$algorithm
if (is.null(dots$control_restart$maxeval))
dots$control_restart$maxeval <- dots$maxeval
Expand Down
2 changes: 1 addition & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ is_list_of_lists <- function(x) {
#' @noRd
p_to_eta <-function(x) {
Q <- create_Q(length(x))
x <- pmin(pmax(x, 0.001), 0.999)
x <- pmin(pmax(x, 1e-6), 1-1e-6)
x <- x / sum(x)
log_x <- log(x)
t(Q) %*% (log_x - mean(log_x))
Expand Down
4 changes: 3 additions & 1 deletion man/bootstrap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions man/estimate_mnhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions man/estimate_nhmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 65be624

Please sign in to comment.