diff --git a/R/average_marginal_prediction.R b/R/average_marginal_prediction.R index b01c3e83..bf5fb4d3 100644 --- a/R/average_marginal_prediction.R +++ b/R/average_marginal_prediction.R @@ -19,6 +19,10 @@ average_marginal_prediction <- function( # avoid warnings of NSEs cluster <- state <- estimate <- state_from <- state_to <- time_var <- channel <- observation <- NULL + stopifnot_( + attr(model, "intercept_only") == FALSE, + "Model does not contain any covariates." + ) stopifnot_( inherits(model, "nhmm") || inherits(model, "mnhmm"), "Argument {.arg model} must be a {.cls nhmm} or {.cls mnhmm} object." diff --git a/R/build_mmm.R b/R/build_mmm.R index ea8ade67..a00c83d6 100644 --- a/R/build_mmm.R +++ b/R/build_mmm.R @@ -119,8 +119,8 @@ build_mmm <- function(observations, n_clusters, transition_probs, initial_probs, "{.arg transition_probs} is not a {.cls list}." ) stopifnot_( - is.list(.check_initial_probs), - "{.arg .check_initial_probs} is not a {.cls list}." + is.list(initial_probs), + "{.arg initial_probs} is not a {.cls list}." ) n_clusters <- length(transition_probs) for (i in seq_len(n_clusters)) { diff --git a/R/build_mnhmm.R b/R/build_mnhmm.R index 3fb0dd31..e8dca4d0 100644 --- a/R/build_mnhmm.R +++ b/R/build_mnhmm.R @@ -29,6 +29,7 @@ build_mnhmm <- function( class = "mnhmm", nobs = attr(out$observations, "nobs"), df = out$extras$n_pars, - type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type) + type = paste0(out$extras$multichannel, "mnhmm_", out$extras$model_type), + intercept_only = intercept_only ) } diff --git a/R/build_nhmm.R b/R/build_nhmm.R index 198eae74..84b9dcb7 100644 --- a/R/build_nhmm.R +++ b/R/build_nhmm.R @@ -16,6 +16,7 @@ build_nhmm <- function( class = "nhmm", nobs = attr(out$observations, "nobs"), df = out$extras$n_pars, - type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type) + type = paste0(out$extras$multichannel, "nhmm_", out$extras$model_type), + intercept_only = intercept_only ) } diff --git a/R/create_base_nhmm.R b/R/create_base_nhmm.R index 91a75963..de413102 100644 --- a/R/create_base_nhmm.R +++ b/R/create_base_nhmm.R @@ -55,8 +55,11 @@ create_base_nhmm <- function(observations, data, time, id, n_states, "Can't find response variable {.var {y}} in {.arg data}." ) x <- suppressMessages( - seqdef(matrix(data[[y]], n_sequences, length_of_sequences), - id = sort(ids) + seqdef(matrix( + data[[y]], + n_sequences, + length_of_sequences, byrow = TRUE), + id = ids ) ) colnames(x) <- sort(times) @@ -97,8 +100,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states, list( model = list( observations = observations, - time_variable = time, - id_variable = id, + time_variable = if (is.null(time)) "time" else time, + id_variable = if (is.null(id)) "id" else id, X_initial = pi$X, X_transition = A$X, X_emission = B$X, X_cluster = if(mixture) theta$X else NULL, initial_formula = pi$formula, @@ -126,7 +129,8 @@ create_base_nhmm <- function(observations, data, time, id, n_states, multichannel = ifelse(n_channels > 1, "multichannel_", ""), model_type = paste0( pi$type, A$type, B$type, if (mixture) theta$type else "" - ) + ), + intercept_only = icp_only_i && icp_only_s && icp_only_o && icp_only_d ) ) } \ No newline at end of file diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 7772b089..28b82375 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -120,16 +120,15 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) { model$coefficients <- out$par[ c("beta_i_raw", "beta_s_raw", "beta_o_raw", "theta_raw") ] - model$stan_model <- model_code + model$stan_model <- model_code@model_code model$estimation_results <- list( hessian = out$hessian, penalized_loglik = out$value, - loglik = out$par["log_lik"], - penalty = out$par["prior"], + loglik = out$par[["log_lik"]], + penalty = out$par[["prior"]], return_code = out$return_code, plogliks_of_restarts = if(restarts > 1L) logliks else NULL, - return_codes_of_restarts = if(restarts > 1L) return_codes else NULL, - stan_model = stanmodels[[attr(model, "type")]] + return_codes_of_restarts = if(restarts > 1L) return_codes else NULL ) model } diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index 15dd2882..48a916c7 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -108,12 +108,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) { )[c("par", "value", "return_code", "hessian")] model$coefficients <- out$par[c("beta_i_raw", "beta_s_raw", "beta_o_raw")] - model$stan_model <- model_code + model$stan_model <- model_code@model_code model$estimation_results <- list( hessian = out$hessian, penalized_loglik = out$value, - loglik = out$par["log_lik"], - penalty = out$par["prior"], + loglik = out$par[["log_lik"]], + penalty = out$par[["prior"]], return_code = out$return_code, plogliks_of_restarts = if(restarts > 1L) logliks else NULL, return_codes_of_restarts = if(restarts > 1L) return_codes else NULL diff --git a/R/get_coefs.R b/R/get_coefs.R index 273c859e..66e0bf46 100644 --- a/R/get_coefs.R +++ b/R/get_coefs.R @@ -35,14 +35,14 @@ coef.nhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) { if (object$n_channels == 1) { beta_o <- data.frame( state = object$state_names, - symbol = rep(object$symbol_names[-1], each = S), + observation = rep(object$symbol_names[-1], each = S), parameter = rep(object$coef_names_emission, each = S * (M - 1)), estimate = beta_o_raw ) } else { beta_o <- data.frame( state = object$state_names, - symbol = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S), + observation = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S), parameter = unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$coef_names_emission, each = S * (M[i] - 1)) })), @@ -113,7 +113,7 @@ coef.mnhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) { if (object$n_channels == 1) { beta_o <- data.frame( state = rep(object$state_names, each = D), - symbol = rep(object$symbol_names[-1], each = S * D), + observations = rep(object$symbol_names[-1], each = S * D), parameter = rep(object$coef_names_emission, each = S * (M - 1) * D), estimate = beta_o_raw, cluster = object$cluster_names @@ -121,7 +121,7 @@ coef.mnhmm <- function(object, nsim = 0, probs = c(0.025, 0.5, 0.975), ...) { } else { beta_o <- data.frame( state = rep(object$state_names, each = D), - symbol = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S * D), + observations = rep(unlist(lapply(object$symbol_names, "[", -1)), each = S * D), parameter = rep(unlist(lapply(seq_len(object$n_channels), function(i) { rep(object$coef_names_emission, each = S * (M[i] - 1)) })), each = D), diff --git a/R/get_probs.R b/R/get_probs.R index 34c41bbc..ff06f150 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -31,6 +31,8 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0, S <- model$n_states M <- model$n_symbols C <- model$n_channels + N <- model$n_sequences + T <- model$length_of_sequences beta_i_raw <- stan_to_cpp_initial( model$coefficients$beta_i_raw ) @@ -47,32 +49,51 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0, X_emission <- aperm(model$X_emission, c(3, 1, 2)) initial_probs <- get_pi(beta_i_raw, X_initial, 0) transition_probs <- get_A(beta_s_raw, X_transition, 0) - emission_probs <- if (model$n_channels == 1) { - get_B(beta_o_raw, X_emission, 0) + emission_probs <- if (C == 1) { + get_B(beta_o_raw, X_emission, 0, 0) } else { get_multichannel_B(beta_o_raw, X_emission, S, C, M, 0, 0) + } + if (C == 1) { + ids <- rownames(model$observations) + times <- colnames(model$observations) + symbol_names <- list(model$symbol_names) + } else { + ids <- rownames(model$observations[[1]]) + times <- colnames(model$observations[[1]]) + symbol_names <- model$symbol_names } - ids <- rownames(model$observations) - times <- colnames(model$observations) initial_probs <- data.frame( id = rep(ids, each = S), state = model$state_names, estimate = c(initial_probs) ) + colnames(initial_probs)[1] <- model$id_variable transition_probs <- data.frame( - id = rep(ids, each = S^2), - time = rep(times, each = S^2 * length(ids)), + id = rep(ids, each = S^2 * T), + time = rep(times, each = S^2), state_from = model$state_names, state_to = rep(model$state_names, each = S), estimate = unlist(transition_probs) ) - emission_probs <- data.frame( - id = rep(ids, each = S * M), - time = rep(times, each = S * M * length(ids)), - state = model$state_names, - observation = rep(model$symbol_names, each = S), - estimate = unlist(emission_probs) + colnames(transition_probs)[1] <- model$id_variable + colnames(transition_probs)[2] <- model$time_variable + emission_probs <- do.call( + "rbind", + lapply(seq_len(C), function(i) { + data.frame( + id = rep(ids, each = S * M[i] * T), + time = rep(times, each = S * M[i]), + state = model$state_names, + channel = model$channel_names[i], + observation = rep(symbol_names[[i]], each = S), + estimate = unlist(emission_probs[((i - 1) * N + 1):(i * N)]) + ) + }) ) + colnames(emission_probs)[1] <- model$id_variable + colnames(emission_probs)[2] <- model$time_variable + if (nsim > 0) { out <- sample_parameters(model, nsim, probs) for(i in seq_along(probs)) { @@ -84,10 +105,10 @@ get_probs.nhmm <- function(model, newdata = NULL, nsim = 0, for(i in seq_along(probs)) { emission_probs[paste0("q", 100 * probs[i])] <- out$quantiles_B[, i] } - for(i in seq_along(probs)) { - cluster_probs[paste0("q", 100 * probs[i])] <- out$quantiles_omega[, i] - } } + rownames(initial_probs) <- NULL + rownames(transition_probs) <- NULL + rownames(emission_probs) <- NULL list( initial_probs = initial_probs, transition_probs = remove_voids(model, transition_probs), @@ -150,7 +171,7 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0, initial_probs[[d]] <- get_pi(beta_i_raw, X_initial, 0) transition_probs[[d]] <- get_A(beta_s_raw, X_transition, 0) emission_probs[[d]] <- if (C == 1) { - get_B(beta_o_raw, X_emission, 0) + get_B(beta_o_raw, X_emission, 0, 0) } else { get_multichannel_B(beta_o_raw, X_emission, S, C, M, 0, 0) } @@ -182,16 +203,12 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0, colnames(transition_probs)[2] <- model$id_variable colnames(transition_probs)[3] <- model$time_variable emission_probs <- data.frame( - cluster = rep(model$cluster_names, each = S * sum(M) * T * N), + cluster = rep(model$cluster_names, each = S * sum(M) * T * N), id = unlist(lapply(seq_len(C), function(i) rep(ids, each = S * M[i] * T))), time = unlist(lapply(seq_len(C), function(i) rep(times, each = S * M[i]))), state = model$state_names, - channel = unlist(lapply(seq_len(C), function(i) { - rep(model$channel_names[i], each = S * M[i]* T * N) - })), - observation = unlist(lapply(seq_len(C), function(i) { - rep(symbol_names[[i]], each = S) - })), + channel = rep(model$channel_names, S * M * T * N), + observation = rep(unlist(symbol_names), each = S), estimate = unlist(emission_probs) ) colnames(emission_probs)[2] <- model$id_variable @@ -218,6 +235,10 @@ get_probs.mnhmm <- function(model, newdata = NULL, nsim = 0, cluster_probs[paste0("q", 100 * probs[i])] <- out$quantiles_omega[, i] } } + rownames(initial_probs) <- NULL + rownames(transition_probs) <- NULL + rownames(emission_probs) <- NULL + rownames(cluster_probs) <- NULL list( initial_probs = initial_probs, transition_probs = remove_voids(model, transition_probs), diff --git a/R/predict.R b/R/predict.R index b4db6765..51b13765 100644 --- a/R/predict.R +++ b/R/predict.R @@ -59,7 +59,7 @@ predict.nhmm <- function( out$pi <- get_pi(beta_i_raw, X_initial, 0) out$A <- get_A(beta_s_raw, X_transition, 0) out$B <- if (object$n_channels == 1) { - get_B(beta_o_raw, X_emission, 0) + get_B(beta_o_raw, X_emission, 0, 0) } else { get_multichannel_B( beta_o_raw, @@ -131,7 +131,7 @@ predict.mnhmm <- function( out$pi <- get_pi(beta_i_raw, X_initial, 0) out$A <- get_A(beta_s_raw, X_transition, 0) out$B <- if (object$n_channels == 1) { - get_B(beta_o_raw, X_emission, 0) + get_B(beta_o_raw, X_emission, 0, 0) } else { get_multichannel_B( beta_o_raw, diff --git a/R/sample_parameters.R b/R/sample_parameters.R index ce8ce328..4d5c5aba 100644 --- a/R/sample_parameters.R +++ b/R/sample_parameters.R @@ -35,6 +35,7 @@ sample_parameters <- function(model, nsim, probs, return_samples = FALSE) { p_i <- length(beta_i_raw) p_s <- length(beta_s_raw) p_o <- length(beta_o_raw) + D <- model$n_clusters if (mixture) { theta_raw <- model$coefficients$theta_raw pars <- c(pars, theta_raw) @@ -44,26 +45,38 @@ sample_parameters <- function(model, nsim, probs, return_samples = FALSE) { samples_pi <- apply( x[seq_len(p_i), ], 2, function(z) { z <- array(z, dim = dim(beta_i_raw)) - get_pi(z, X_initial) + get_pi(z, X_initial, 0) } ) samples_A <- apply( x[p_i + seq_len(p_s), ], 2, function(z) { z <- array(z, dim = dim(beta_s_raw)) - unlist(get_A(aperm(z, c(2, 3, 1)), X_transition)) - } - ) - samples_B <- apply( - x[p_i + p_s + seq_len(p_o), ], 2, function(z) { - z <- array(z, dim = dim(beta_o_raw)) - unlist(get_B(aperm(z, c(2, 3, 1)), X_emission)) + unlist(get_A(stan_to_cpp_transition(z, D), X_transition, 0)) } ) + if (model$n_channels == 1) { + samples_B <- apply( + x[p_i + p_s + seq_len(p_o), ], 2, function(z) { + z <- array(z, dim = dim(beta_o_raw)) + unlist(get_B(stan_to_cpp_emission(z, D, FALSE), X_emission, 0, 0)) + } + ) + } else { + samples_B <- apply( + x[p_i + p_s + seq_len(p_o), ], 2, function(z) { + z <- array(z, dim = dim(beta_o_raw)) + unlist(get_multichannel_B( + stan_to_cpp_emission(z, D, TRUE), X_emission, + model$n_states, model$n_channels, model$n_symbols, 0, 0 + )) + } + ) + } if (mixture) { samples_omega <- apply( x[p_i + p_s + p_o + seq_len(p_d), ], 2, function(z) { z <- array(z, dim = dim(theta_raw)) - unlist(get_omega(z, X_cluster)) + unlist(get_omega(z, X_cluster, 0)) } ) } diff --git a/tests/testthat/test-build_lcm.R b/tests/testthat/test-build_lcm.R index bff91618..1f97cea9 100644 --- a/tests/testthat/test-build_lcm.R +++ b/tests/testthat/test-build_lcm.R @@ -20,6 +20,17 @@ test_that("build_lcm returns object of class 'mhmm'", { emission_probs = cbind(1, matrix(0, 2, s - 1))), NA ) + expect_warning( + model <- build_lcm( + list(obs, obs), n_clusters = k, + channel_names = 1:2, + cluster_names = letters[1:(k + 1)]), + "The length of `cluster_names` does not match the number of clusters. Names were not used." + ) + expect_equal( + cluster_names(model), + paste("Class", seq_len(k)) + ) }) test_that("build_lcm errors with incorrect dims", { expect_error( diff --git a/tests/testthat/test-build_mm.R b/tests/testthat/test-build_mm.R index 74323014..27d6d18c 100644 --- a/tests/testthat/test-build_mm.R +++ b/tests/testthat/test-build_mm.R @@ -1,19 +1,30 @@ -# create test data -set.seed(123) -s <- 4 -obs <- suppressMessages( - seqdef(matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10)) -) - test_that("build_mm returns object of class 'hmm'", { - expect_error( + set.seed(123) + s <- 4 + obs_matrix <- matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10) + obs_matrix[1:3, 10] <- NA + obs_matrix[5, 5] <- NA + obs_matrix[4, 0] <- "z" + obs_matrix[5, 10] <- "z" + obs <- suppressMessages(seqdef(obs_matrix)) + expect_message( model <- build_mm(obs), - NA + "Sequences contain missing values, initial and transition probabilities estimated via EM." ) expect_s3_class( model, "hmm" ) + set.seed(123) + s <- 4 + obs_matrix <- matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10) + obs_matrix[4, 10] <- "z" + obs_matrix[5, 10] <- "z" + obs <- suppressMessages(seqdef(obs_matrix)) + expect_warning( + model <- build_mm(obs), + "There are no observed transitions from some of the symbols." + ) }) test_that("build_mm errors with incorrect observations", { expect_error( @@ -35,6 +46,11 @@ test_that("build_mm errors with incorrect observations", { }) test_that("build_mm returns the correct number of states", { + set.seed(123) + s <- 4 + obs <- suppressMessages( + seqdef(matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10)) + ) expect_error( model <- build_mm(obs), NA @@ -54,6 +70,11 @@ test_that("build_mm returns the correct number of states", { }) test_that("build_mm returns the correct probabilities", { + set.seed(123) + s <- 4 + obs <- suppressMessages( + seqdef(matrix(sample(letters[1:s], 50, replace = TRUE), ncol = 10)) + ) model <- build_mm(obs) expect_equal( model$initial_probs, diff --git a/tests/testthat/test-build_mmm.R b/tests/testthat/test-build_mmm.R index ddd8c349..0c911d74 100644 --- a/tests/testthat/test-build_mmm.R +++ b/tests/testthat/test-build_mmm.R @@ -11,10 +11,35 @@ test_that("build_mmm returns object of class 'mhmm'", { model <- build_mmm(obs, n_clusters = k), NA ) + expect_error( + model <- build_mmm( + obs, + transition_probs = list(diag(2), diag(2)), initial_probs = list(1:0, 1:0) + ), + NA + ) expect_s3_class( model, "mhmm" ) + +}) + +test_that("build_mmm errors when neither clusters or probs are given", { + expect_error( + build_mmm(obs), + "Provide either `n_clusters` or both `initial_probs` and `transition_probs`." + ) +}) +test_that("build_mmm errors with incorrect argument types", { + expect_error( + build_mmm(obs, transition = 1, initial = "a"), + "`transition_probs` is not a ." + ) + expect_error( + build_mmm(obs, transition = list(diag(2), diag(2)), initial = "a"), + "`initial_probs` is not a ." + ) }) test_that("build_mmm errors with incorrect observations", { expect_error( diff --git a/tests/testthat/test-build_mnhmm.R b/tests/testthat/test-build_mnhmm.R index 32554428..8dd28918 100644 --- a/tests/testthat/test-build_mnhmm.R +++ b/tests/testthat/test-build_mnhmm.R @@ -108,8 +108,32 @@ test_that("estimate_mnhmm errors with incorrect observations", { }) test_that("build_mnhmm works with vector of characters as observations", { expect_error( - estimate_mnhmm("y", s, d, data = data, time = "time", id = "id", iter = 0, - verbose = FALSE), + model <- estimate_mnhmm("y", s, d, data = data, time = "time", id = "id", iter = 0, + verbose = FALSE), NA ) + expect_error( + cluster_names(model) <- seq_len(d), + NA + ) + expect_equal( + cluster_names(model), + seq_len(d) + ) }) + +test_that("build_mnhmm works with missing observations", { + data <- data[data$time != 5, ] + data$y[50:55] <- NA + expect_error( + model <- estimate_mnhmm( + "y", s, d, data = data, time = "time", id = "id", iter = 0, + verbose = FALSE), + NA + ) + expect_equal( + which(model$observations == "*"), + c(41L, 42L, 43L, 44L, 45L, 46L, 47L, 48L, 49L, 50L, 60L, 61L, + 62L, 63L, 64L, 65L) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-get_probs.R b/tests/testthat/test-get_probs.R new file mode 100644 index 00000000..ca6cea71 --- /dev/null +++ b/tests/testthat/test-get_probs.R @@ -0,0 +1,63 @@ +test_that("'get_probs' works for multichannel 'nhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_nhmm( + hmm_biofam$observations, n_states = 5, + inits = hmm_biofam[ + c("initial_probs", "transition_probs", "emission_probs") + ], verbose = FALSE + ), + NA + ) + expect_error( + p <- get_probs(fit, nsim = 10), + NA + ) +}) +test_that("'get_probs' works for single-channel 'nhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_nhmm( + hmm_biofam$observations[[1]], n_states = 3, + verbose = FALSE, iter = 1 + ), + NA + ) + expect_error( + p <- get_probs(fit), + NA + ) +}) + +test_that("'get_probs' works for multichannel 'mnhmm'", { + data("hmm_biofam") + set.seed(1) + expect_error( + fit <- estimate_mnhmm( + hmm_biofam$observations, n_states = 3, n_clusters = 2, + verbose = FALSE, iter = 1 + ), + NA + ) + expect_error( + p <- get_probs(fit), + NA + ) +}) + +test_that("'get_probs' works for single-channel 'mnhmm'", { + set.seed(1) + expect_error( + fit <- estimate_mnhmm( + hmm_biofam$observations[[1]], n_states = 4, n_clusters = 2, + verbose = FALSE, iter = 1 + ), + NA + ) + expect_error( + p <- get_probs(fit), + NA + ) +})