Skip to content

Commit

Permalink
fix and test get_probs, add default id and time variables
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 7, 2024
1 parent 1ff2404 commit c68247c
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 66 deletions.
4 changes: 4 additions & 0 deletions R/average_marginal_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ average_marginal_prediction <- function(
# avoid warnings of NSEs
cluster <- state <- estimate <- state_from <- state_to <- time_var <-
channel <- observation <- NULL
stopifnot_(
attr(model, "intercept_only") == FALSE,
"Model does not contain any covariates."
)
stopifnot_(
inherits(model, "nhmm") || inherits(model, "mnhmm"),
"Argument {.arg model} must be a {.cls nhmm} or {.cls mnhmm} object."
Expand Down
4 changes: 2 additions & 2 deletions R/build_mmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ build_mmm <- function(observations, n_clusters, transition_probs, initial_probs,
"{.arg transition_probs} is not a {.cls list}."
)
stopifnot_(
is.list(.check_initial_probs),
"{.arg .check_initial_probs} is not a {.cls list}."
is.list(initial_probs),
"{.arg initial_probs} is not a {.cls list}."
)
n_clusters <- length(transition_probs)
for (i in seq_len(n_clusters)) {
Expand Down
3 changes: 2 additions & 1 deletion R/build_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ build_mnhmm <- function(
class = "mnhmm",
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type)
type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type),
intercept_only = intercept_only
)
}
3 changes: 2 additions & 1 deletion R/build_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ build_nhmm <- function(
class = "nhmm",
nobs = attr(out$observations, "nobs"),
df = out$extras$n_pars,
type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type)
type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type),
intercept_only = intercept_only
)
}
14 changes: 9 additions & 5 deletions R/create_base_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
"Can't find response variable {.var {y}} in {.arg data}."
)
x <- suppressMessages(
seqdef(matrix(data[[y]], n_sequences, length_of_sequences),
id = sort(ids)
seqdef(matrix(
data[[y]],
n_sequences,
length_of_sequences, byrow = TRUE),
id = ids
)
)
colnames(x) <- sort(times)
Expand Down Expand Up @@ -97,8 +100,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
list(
model = list(
observations = observations,
time_variable = time,
id_variable = id,
time_variable = if (is.null(time)) "time" else time,
id_variable = if (is.null(id)) "id" else id,
X_initial = pi$X, X_transition = A$X, X_emission = B$X,
X_cluster = if(mixture) theta$X else NULL,
initial_formula = pi$formula,
Expand Down Expand Up @@ -126,7 +129,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
multichannel = ifelse(n_channels > 1, "multichannel_", ""),
model_type = paste0(
pi$type, A$type, B$type, if (mixture) theta$type else ""
)
),
intercept_only = icp_only_i && icp_only_s && icp_only_o && icp_only_d
)
)
}
9 changes: 4 additions & 5 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,15 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
model$coefficients <- out$par[
c("beta_i_raw", "beta_s_raw", "beta_o_raw", "theta_raw")
]
model$stan_model <- model_code
model$stan_model <- model_code@model_code
model$estimation_results <- list(
hessian = out$hessian,
penalized_loglik = out$value,
loglik = out$par["log_lik"],
penalty = out$par["prior"],
loglik = out$par[["log_lik"]],
penalty = out$par[["prior"]],
return_code = out$return_code,
plogliks_of_restarts = if(restarts > 1L) logliks else NULL,
return_codes_of_restarts = if(restarts > 1L) return_codes else NULL,
stan_model = stanmodels[[attr(model, "type")]]
return_codes_of_restarts = if(restarts > 1L) return_codes else NULL
)
model
}
6 changes: 3 additions & 3 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
)[c("par", "value", "return_code", "hessian")]

model$coefficients <- out$par[c("beta_i_raw", "beta_s_raw", "beta_o_raw")]
model$stan_model <- model_code
model$stan_model <- model_code@model_code
model$estimation_results <- list(
hessian = out$hessian,
penalized_loglik = out$value,
loglik = out$par["log_lik"],
penalty = out$par["prior"],
loglik = out$par[["log_lik"]],
penalty = out$par[["prior"]],
return_code = out$return_code,
plogliks_of_restarts = if(restarts > 1L) logliks else NULL,
return_codes_of_restarts = if(restarts > 1L) return_codes else NULL
Expand Down
8 changes: 4 additions & 4 deletions R/get_coefs.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ coef.nhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) {
if (object$n_channels == 1) {
beta_o <- data.frame(
state = object$state_names,
symbol = rep(object$symbol_names[-1], each = S),
observation = rep(object$symbol_names[-1], each = S),
parameter = rep(object$coef_names_emission, each = S * (M - 1)),
estimate = beta_o_raw
)
} else {
beta_o <- data.frame(
state = object$state_names,
symbol = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S),
observation = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S),
parameter = unlist(lapply(seq_len(object$n_channels), function(i) {
rep(object$coef_names_emission, each = S * (M[i] - 1))
})),
Expand Down Expand Up @@ -113,15 +113,15 @@ coef.mnhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) {
if (object$n_channels == 1) {
beta_o <- data.frame(
state = rep(object$state_names, each = D),
symbol = rep(object$symbol_names[-1], each = S * D),
observations = rep(object$symbol_names[-1], each = S * D),
parameter = rep(object$coef_names_emission, each = S * (M - 1) * D),
estimate = beta_o_raw,
cluster = object$cluster_names
)
} else {
beta_o <- data.frame(
state = rep(object$state_names, each = D),
symbol = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S * D),
observations = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S * D),
parameter = rep(unlist(lapply(seq_len(object$n_channels), function(i) {
rep(object$coef_names_emission, each = S * (M[i] - 1))
})), each = D),
Expand Down
67 changes: 44 additions & 23 deletions R/get_probs.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0,
S <- model$n_states
M <- model$n_symbols
C <- model$n_channels
N <- model$n_sequences
T <- model$length_of_sequences
beta_i_raw <- stan_to_cpp_initial(
model$coefficients$beta_i_raw
)
Expand All @@ -47,32 +49,51 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0,
X_emission <- aperm(model$X_emission, c(3, 1, 2))
initial_probs <- get_pi(beta_i_raw, X_initial, 0)
transition_probs <- get_A(beta_s_raw, X_transition, 0)
emission_probs <- if (model$n_channels == 1) {
get_B(beta_o_raw, X_emission, 0)
emission_probs <- if (C == 1) {
get_B(beta_o_raw, X_emission, 0, 0)
} else {
get_multichannel_B(beta_o_raw, X_emission, S, C, M, 0, 0)
}
if (C == 1) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
symbol_names <- list(model$symbol_names)
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
symbol_names <- model$symbol_names
}
ids <- rownames(model$observations)
times <- colnames(model$observations)
initial_probs <- data.frame(
id = rep(ids, each = S),
state = model$state_names,
estimate = c(initial_probs)
)
colnames(initial_probs)[1] <- model$id_variable
transition_probs <- data.frame(
id = rep(ids, each = S^2),
time = rep(times, each = S^2 * length(ids)),
id = rep(ids, each = S^2 * T),
time = rep(times, each = S^2),
state_from = model$state_names,
state_to = rep(model$state_names, each = S),
estimate = unlist(transition_probs)
)
emission_probs <- data.frame(
id = rep(ids, each = S * M),
time = rep(times, each = S * M * length(ids)),
state = model$state_names,
observation = rep(model$symbol_names, each = S),
estimate = unlist(emission_probs)
colnames(transition_probs)[1] <- model$id_variable
colnames(transition_probs)[2] <- model$time_variable
emission_probs <- do.call(
"rbind",
lapply(seq_len(C), function(i) {
data.frame(
id = rep(ids, each = S * M[i] * T),
time = rep(times, each = S * M[i]),
state = model$state_names,
channel = model$channel_names[i],
observation = rep(symbol_names[[i]], each = S),
estimate = unlist(emission_probs[((i - 1) * N + 1):(i * N)])
)
})
)
colnames(emission_probs)[1] <- model$id_variable
colnames(emission_probs)[2] <- model$time_variable

if (nsim > 0) {
out <- sample_parameters(model, nsim, probs)
for(i in seq_along(probs)) {
Expand All @@ -84,10 +105,10 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0,
for(i in seq_along(probs)) {
emission_probs[paste0("q", 100 * probs[i])] <- out$quantiles_B[, i]
}
for(i in seq_along(probs)) {
cluster_probs[paste0("q", 100 * probs[i])] <- out$quantiles_omega[, i]
}
}
rownames(initial_probs) <- NULL
rownames(transition_probs) <- NULL
rownames(emission_probs) <- NULL
list(
initial_probs = initial_probs,
transition_probs = remove_voids(model, transition_probs),
Expand Down Expand Up @@ -150,7 +171,7 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0,
initial_probs[[d]] <- get_pi(beta_i_raw, X_initial, 0)
transition_probs[[d]] <- get_A(beta_s_raw, X_transition, 0)
emission_probs[[d]] <- if (C == 1) {
get_B(beta_o_raw, X_emission, 0)
get_B(beta_o_raw, X_emission, 0, 0)
} else {
get_multichannel_B(beta_o_raw, X_emission, S, C, M, 0, 0)
}
Expand Down Expand Up @@ -182,16 +203,12 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0,
colnames(transition_probs)[2] <- model$id_variable
colnames(transition_probs)[3] <- model$time_variable
emission_probs <- data.frame(
cluster = rep(model$cluster_names, each = S * sum(M) * T * N),
cluster = rep(model$cluster_names, each = S * sum(M) * T * N),
id = unlist(lapply(seq_len(C), function(i) rep(ids, each = S * M[i] * T))),
time = unlist(lapply(seq_len(C), function(i) rep(times, each = S * M[i]))),
state = model$state_names,
channel = unlist(lapply(seq_len(C), function(i) {
rep(model$channel_names[i], each = S * M[i]* T * N)
})),
observation = unlist(lapply(seq_len(C), function(i) {
rep(symbol_names[[i]], each = S)
})),
channel = rep(model$channel_names, S * M * T * N),
observation = rep(unlist(symbol_names), each = S),
estimate = unlist(emission_probs)
)
colnames(emission_probs)[2] <- model$id_variable
Expand All @@ -218,6 +235,10 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0,
cluster_probs[paste0("q", 100 * probs[i])] <- out$quantiles_omega[, i]
}
}
rownames(initial_probs) <- NULL
rownames(transition_probs) <- NULL
rownames(emission_probs) <- NULL
rownames(cluster_probs) <- NULL
list(
initial_probs = initial_probs,
transition_probs = remove_voids(model, transition_probs),
Expand Down
4 changes: 2 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ predict.nhmm <- function(
out$pi <- get_pi(beta_i_raw, X_initial, 0)
out$A <- get_A(beta_s_raw, X_transition, 0)
out$B <- if (object$n_channels == 1) {
get_B(beta_o_raw, X_emission, 0)
get_B(beta_o_raw, X_emission, 0, 0)
} else {
get_multichannel_B(
beta_o_raw,
Expand Down Expand Up @@ -131,7 +131,7 @@ predict.mnhmm <- function(
out$pi <- get_pi(beta_i_raw, X_initial, 0)
out$A <- get_A(beta_s_raw, X_transition, 0)
out$B <- if (object$n_channels == 1) {
get_B(beta_o_raw, X_emission, 0)
get_B(beta_o_raw, X_emission, 0, 0)
} else {
get_multichannel_B(
beta_o_raw,
Expand Down
31 changes: 22 additions & 9 deletions R/sample_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ sample_parameters <- function(model, nsim, probs, return_samples = FALSE) {
p_i <- length(beta_i_raw)
p_s <- length(beta_s_raw)
p_o <- length(beta_o_raw)
D <- model$n_clusters
if (mixture) {
theta_raw <- model$coefficients$theta_raw
pars <- c(pars, theta_raw)
Expand All @@ -44,26 +45,38 @@ sample_parameters <- function(model, nsim, probs, return_samples = FALSE) {
samples_pi <- apply(
x[seq_len(p_i), ], 2, function(z) {
z <- array(z, dim = dim(beta_i_raw))
get_pi(z, X_initial)
get_pi(z, X_initial, 0)
}
)
samples_A <- apply(
x[p_i + seq_len(p_s), ], 2, function(z) {
z <- array(z, dim = dim(beta_s_raw))
unlist(get_A(aperm(z, c(2, 3, 1)), X_transition))
}
)
samples_B <- apply(
x[p_i + p_s + seq_len(p_o), ], 2, function(z) {
z <- array(z, dim = dim(beta_o_raw))
unlist(get_B(aperm(z, c(2, 3, 1)), X_emission))
unlist(get_A(stan_to_cpp_transition(z, D), X_transition, 0))
}
)
if (model$n_channels == 1) {
samples_B <- apply(
x[p_i + p_s + seq_len(p_o), ], 2, function(z) {
z <- array(z, dim = dim(beta_o_raw))
unlist(get_B(stan_to_cpp_emission(z, D, FALSE), X_emission, 0, 0))
}
)
} else {
samples_B <- apply(
x[p_i + p_s + seq_len(p_o), ], 2, function(z) {
z <- array(z, dim = dim(beta_o_raw))
unlist(get_multichannel_B(
stan_to_cpp_emission(z, D, TRUE), X_emission,
model$n_states, model$n_channels, model$n_symbols, 0, 0
))
}
)
}
if (mixture) {
samples_omega <- apply(
x[p_i + p_s + p_o + seq_len(p_d), ], 2, function(z) {
z <- array(z, dim = dim(theta_raw))
unlist(get_omega(z, X_cluster))
unlist(get_omega(z, X_cluster, 0))
}
)
}
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-build_lcm.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ test_that("build_lcm returns object of class 'mhmm'", {
emission_probs = cbind(1, matrix(0, 2, s - 1))),
NA
)
expect_warning(
model <- build_lcm(
list(obs, obs), n_clusters = k,
channel_names = 1:2,
cluster_names = letters[1:(k + 1)]),
"The length of `cluster_names` does not match the number of clusters. Names were not used."
)
expect_equal(
cluster_names(model),
paste("Class", seq_len(k))
)
})
test_that("build_lcm errors with incorrect dims", {
expect_error(
Expand Down
Loading

0 comments on commit c68247c

Please sign in to comment.