Skip to content

Commit

Permalink
faster simulation of nhmms, document bootstrap, add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Oct 7, 2024
1 parent 44eb3b3 commit 91bf701
Show file tree
Hide file tree
Showing 8 changed files with 419 additions and 134 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Imports:
numDeriv,
patchwork,
Rcpp (>= 0.12.0),
RcppHungarian,
rlang,
stats,
TraMineR (>= 2.2-7),
Expand Down
16 changes: 16 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,22 @@ objectivex <- function(transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols,
.Call(`_seqHMM_objectivex`, transition, emission, init, obs, ANZ, BNZ, INZ, nSymbols, coef, X, numberOfStates, threads)
}

simulate_nhmm_singlechannel <- function(eta_pi, X_i, eta_A, X_s, eta_B, X_o) {
.Call(`_seqHMM_simulate_nhmm_singlechannel`, eta_pi, X_i, eta_A, X_s, eta_B, X_o)
}

simulate_nhmm_multichannel <- function(eta_pi, X_i, eta_A, X_s, eta_B, X_o, M) {
.Call(`_seqHMM_simulate_nhmm_multichannel`, eta_pi, X_i, eta_A, X_s, eta_B, X_o, M)
}

simulate_mnhmm_singlechannel <- function(eta_pi, X_i, eta_A, X_s, eta_B, X_o, eta_omega, X_d) {
.Call(`_seqHMM_simulate_mnhmm_singlechannel`, eta_pi, X_i, eta_A, X_s, eta_B, X_o, eta_omega, X_d)
}

simulate_mnhmm_multichannel <- function(eta_pi, X_i, eta_A, X_s, eta_B, X_o, eta_omega, X_d, M) {
.Call(`_seqHMM_simulate_mnhmm_multichannel`, eta_pi, X_i, eta_A, X_s, eta_B, X_o, eta_omega, X_d, M)
}

softmax <- function(x) {
.Call(`_seqHMM_softmax`, x)
}
Expand Down
26 changes: 22 additions & 4 deletions R/bootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ permute_states <- function(gammas_boot, gammas_mle) {
}
gammas_boot
}
#' Bootstrap Sampling of NHMM Coefficients
#'
#' @param model An `nhmm` or `mnhmm` object.
#' @param B number of bootstrap samples.
#' @param method Either `"nonparametric"` or `"parametric"`, to define whether
#' nonparametric or parametric bootstrap should be used. The former samples
#' sequences with replacement, whereas the latter simulates new datasets based
#' on the model.
#' @param A penalty term for model estimation. By default, same penalty is used
#' as was in model estimation by `estimate_nhmm` or `estimate_mnhmm`.
#' @param verbose Should the progress bar be displayed? Default is `FALSE`.
#' @rdname bootstrap
#' @export
bootstrap_coefs.nhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric"),
Expand All @@ -59,12 +71,13 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
gammas_mle <- model$gammas

coefs <- matrix(NA, length(unlist(gammas_mle)), B)
pb <- utils::txtProgressBar(min = 0, max = 100, style = 3)
if (method == "nonparametric") {
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, ...)
coefs[, i] <- unlist(permute_states(fit$gammas, gammas_mle))
if(verbose) print(paste0("Bootstrap replication ", i, " complete."))
if (verbose) setTxtProgressBar(pb, i)
}
} else {
N <- model$n_sequences
Expand All @@ -83,11 +96,14 @@ bootstrap_coefs.nhmm <- function(model, B = 1000,
data = d, time, id, init)$model
fit <- fit_nhmm(mod, init, 0, 0, 1, penalty, ...)
coefs[, i] <- unlist(permute_states(fit$gammas, gammas_mle))
print(paste0("Bootstrap replication ", i, " complete."))
if (verbose) setTxtProgressBar(pb, i)
}
}
close(pb)
return(coefs)
}
#' @inheritParams bootstrap_coefs.nhmm
#' @rdname bootstrap
#' @export
bootstrap_coefs.mnhmm <- function(model, B = 1000,
method = c("nonparametric", "parametric"),
Expand All @@ -101,13 +117,14 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
if (missing(penalty)) {
penalty <- model$estimation_results$penalty
}
pb <- utils::txtProgressBar(min = 0, max = 100, style = 3)
if (method == "nonparametric") {
coefs <- matrix(NA, length(unlist(init)), B)
for (i in seq_len(B)) {
mod <- bootstrap_model(model)
fit <- fit_mnhmm(mod, init, 0, 0, 1, penalty, FALSE)
coefs[, i] <- unlist(fit$coefficients)
print(paste0("Bootstrap replication ", i, " complete."))
if (verbose) setTxtProgressBar(pb, i)
}
} else {
coefs <- matrix(NA, length(unlist(init)), B)
Expand All @@ -129,8 +146,9 @@ bootstrap_coefs.mnhmm <- function(model, B = 1000,
data = d, time, id, init)$model
fit <- fit_mnhmm(mod, init, 0, 0, 1, penalty, FALSE)
coefs[, i] <- unlist(fit$coefficients)
print(paste0("Bootstrap replication ", i, " complete."))
if (verbose) setTxtProgressBar(pb, i)
}
}
close(pb)
return(coefs)
}
120 changes: 42 additions & 78 deletions R/simulate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,99 +77,63 @@ simulate_mnhmm <- function(
if (is.null(coefs$emission_probs)) coefs$emission_probs <- NULL
if (is.null(coefs$cluster_probs)) coefs$cluster_probs <- NULL
}
K_i <- nrow(model$X_initial)
K_s <- nrow(model$X_transition)
K_o <- nrow(model$X_emission)
K_d <- nrow(model$X_cluster)
model$etas <- create_initial_values(
coefs, n_states, n_symbols, init_sd, K_i, K_s, K_o, K_d, n_clusters
coefs, model$n_states, model$n_symbols, init_sd, nrow(model$X_initial),
nrow(model$X_transition), nrow(model$X_emission), nrow(model$X_cluster),
n_clusters
)
model$gammas$pi <- c(eta_to_gamma_mat_field(
model$etas$pi
))
model$gammas$A <- c(eta_to_gamma_cube_field(
model$etas$A
))
if (n_channels == 1L) {
model$gammas$B <- c(eta_to_gamma_cube_field(
model$etas$B
))
out <- simulate_mnhmm_singlechannel(
model$etas$pi, model$X_initial,
model$etas$A, model$X_transition,
model$etas$B, model$X_emission,
model$etas$omega, model$X_cluster
)
} else {
l <- lengths(model$etas$B)
gamma_B <- c(eta_to_gamma_cube_field(unlist(model$etas$B, recursive = FALSE)))
model$gammas$B <- split(gamma_B, rep(seq_along(l), l))
}
model$gammas$omega <- eta_to_gamma_mat(
model$etas$omega
)
probs <- get_probs(model)
states <- array(NA_character_, c(max(sequence_lengths), n_sequences))
obs <- array(NA_character_, c(max(sequence_lengths), n_channels, n_sequences))
ids <- unique(data[[id]])
times <- sort(unique(data[[time]]))
clusters <- character(n_sequences)
cluster_names <- model$cluster_names
state_names <- paste0(
rep(cluster_names, each = model$n_states), ": ", unlist(model$state_names)
)
for (i in seq_len(n_sequences)) {
p_cluster <- probs$cluster[
probs$cluster[[time]] == time[1] & probs$cluster[[id]] == ids[i],
"probability"
]
clusters[i] <- sample(model$cluster_names, 1, prob = p_cluster)
p_init <- probs$initial[
probs$initial[[time]] == time[1] & probs$initial[[id]] == ids[i] &
probs$initial$cluster == clusters[i],
"probability"
]
states[1, i] <- sample(state_names, 1, prob = p_init)
for (k in seq_len(n_channels)) {
p_emission <- probs$emission[
probs$emission[[time]] == time[1] & probs$emission[[id]] == ids[i] &
probs$emission$cluster == clusters[i] &
probs$emission$state == states[1, i] & probs$emission$channel == k,
"probability"
]
obs[1, k, i] <- sample(symbol_names[[k]], 1, prob = p_emission)
}
out <- simulate_mnhmm_multichannel(
model$etas$pi, model$X_initial,
model$etas$A, model$X_transition,
unlist(model$etas$B, recursive = FALSE), model$X_emission,
model$etas$omega, model$X_cluster,
model$n_symbols
)
}

for (i in seq_len(n_sequences)) {
for (t in 2:sequence_lengths[i]) {
p_transition <- probs$transition[
probs$transition[[time]] == times[t] & probs$transition[[id]] == ids[i] &
probs$transition$cluster == clusters[i] &
probs$transition$state_from == states[t - 1, i], "probability"
]
states[t, i] <- sample(state_names, 1, prob = p_transition)
for (k in seq_len(n_channels)) {
p_emission <- probs$emission[
probs$emission[[time]] == time[t] & probs$emission[[id]] == ids[i] &
probs$emission$cluster == clusters[i] &
probs$emission$state == states[t, i] & probs$emission$channel == k,
"probability"
]
obs[t, k, i] <- sample(symbol_names[[k]], 1, prob = p_emission)
}
T_ <- model$length_of_sequences
for (i in seq_len(model$n_sequences)) {
Ti <- sequence_lengths[i]
if (Ti < T_) {
out$states[(Ti + 1):T_, i] <- NA
out$observations[(Ti + 1):T_, i] <- NA
}
}
state_names <- paste0(
rep(model$cluster_names, each = model$n_states),
": ", unlist(model$state_names)
)
symbol_names <- model$symbol_names
out$states[] <- state_names[c(out$states) + 1]
states <- suppressWarnings(suppressMessages(
seqdef(
matrix(
t(states),
t(out$states),
n_sequences, max(sequence_lengths)
),
alphabet = state_names
)
))
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
seqdef(t(obs[, i, ]), alphabet = symbol_names[[i]])
if (n_channels == 1) {
out$observations[] <- symbol_names[c(out$observations) + 1]
model$observations <- suppressWarnings(suppressMessages(
seqdef(t(out$observations), alphabet = symbol_names)
))
})
names(obs) <- model$channel_names
if (n_channels == 1) obs <- obs[[1]]
model$observations <- obs
} else {
model$observations <- lapply(seq_len(n_channels), function(i) {
out$observations[, , i] <- symbol_names[[i]][c(out$observations[, , i]) + 1]
suppressWarnings(suppressMessages(
seqdef(t(out$observations[, , i]), alphabet = symbol_names[[i]])
))
})
names(model$observations) <- model$channel_names
}
list(model = model, states = states)
}
90 changes: 38 additions & 52 deletions R/simulate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,71 +69,57 @@ simulate_nhmm <- function(
if (is.null(coefs$transition_probs)) coefs$transition_probs <- NULL
if (is.null(coefs$emission_probs)) coefs$emission_probs <- NULL
}
K_i <- nrow(model$X_initial)
K_s <- nrow(model$X_transition)
K_o <- nrow(model$X_emission)
model$gammas$pi <- eta_to_gamma_mat(model$etas$pi)
model$gammas$A <- eta_to_gamma_cube(model$etas$A)
model$etas <- create_initial_values(
coefs, model$n_states, model$n_symbols, init_sd, nrow(model$X_initial),
nrow(model$X_transition), nrow(model$X_emission)
)
if (n_channels == 1L) {
model$gammas$B <- eta_to_gamma_cube(model$etas$B)
out <- simulate_nhmm_singlechannel(
model$etas$pi, model$X_initial,
model$etas$A, model$X_transition,
model$etas$B, model$X_emission
)
} else {
model$gammas$B <- eta_to_gamma_cube_field(model$etas$B)
}
probs <- get_probs(model)
states <- array(NA_character_, c(max(sequence_lengths), n_sequences))
obs <- array(NA_character_, c(max(sequence_lengths), n_channels, n_sequences))
ids <- unique(data[[id]])
times <- sort(unique(data[[time]]))
state_names <- model$state_names
for (i in seq_len(n_sequences)) {
p_init <- probs$initial[
probs$initial[[time]] == time[1] & probs$initial[[id]] == ids[i],
"probability"
]
states[1, i] <- sample(state_names, 1, prob = p_init)
for (k in seq_len(n_channels)) {
p_emission <- probs$emission[
probs$emission[[time]] == time[1] & probs$emission[[id]] == ids[i] &
probs$emission$state == states[1, i] & probs$emission$channel == k,
"probability"
]
obs[1, k, i] <- sample(symbol_names[[k]], 1, prob = p_emission)
}
out <- simulate_nhmm_multichannel(
model$etas$pi, model$X_initial,
model$etas$A, model$X_transition,
model$etas$B, model$X_emission,
model$n_symbols
)
}

for (i in seq_len(n_sequences)) {
for (t in 2:sequence_lengths[i]) {
p_transition <- probs$transition[
probs$transition[[time]] == times[t] & probs$transition[[id]] == ids[i] &
probs$transition$state_from == states[t - 1, i], "probability"
]
states[t, i] <- sample(state_names, 1, prob = p_transition)
for (k in seq_len(n_channels)) {
p_emission <- probs$emission[
probs$emission[[time]] == time[t] & probs$emission[[id]] == ids[i] &
probs$emission$state == states[t, i] & probs$emission$channel == k,
"probability"
]
obs[t, k, i] <- sample(symbol_names[[k]], 1, prob = p_emission)
}
T_ <- model$length_of_sequences
for (i in seq_len(model$n_sequences)) {
Ti <- sequence_lengths[i]
if (Ti < T_) {
out$states[(Ti + 1):T_, i] <- NA
out$observations[(Ti + 1):T_, i] <- NA
}
}
state_names <- model$state_names
symbol_names <- model$symbol_names
out$states[] <- state_names[c(out$states) + 1]
states <- suppressWarnings(suppressMessages(
seqdef(
matrix(
t(states),
t(out$states),
n_sequences, max(sequence_lengths)
),
alphabet = state_names
)
))
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
seqdef(t(obs[, i, ]), alphabet = symbol_names[[i]])
if (n_channels == 1) {
out$observations[] <- symbol_names[c(out$observations) + 1]
model$observations <- suppressWarnings(suppressMessages(
seqdef(t(out$observations), alphabet = symbol_names)
))
})
names(obs) <- model$channel_names
if (n_channels == 1) obs <- obs[[1]]
model$observations <- obs
} else {
model$observations <- lapply(seq_len(n_channels), function(i) {
out$observations[, , i] <- symbol_names[[i]][c(out$observations[, , i]) + 1]
suppressWarnings(suppressMessages(
seqdef(t(out$observations[, , i]), alphabet = symbol_names[[i]])
))
})
names(model$observations) <- model$channel_names
}
list(model = model, states = states)
}
Loading

0 comments on commit 91bf701

Please sign in to comment.