diff --git a/R/fit_mnhmm.R b/R/fit_mnhmm.R index 9af6a40..a123af6 100644 --- a/R/fit_mnhmm.R +++ b/R/fit_mnhmm.R @@ -215,11 +215,12 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs return_codes <- unlist(lapply(out, "[[", "status")) successful <- which(return_codes > 0) - stopifnot_( - length(successful) > 0, - c("All optimizations terminated due to error.", - "Error of first restart: ", error_msg(return_codes[1])) - ) + if (length(successful) == 0) { + warning_( + c("All optimizations terminated due to error.", + "Error of first restart: ", error_msg(return_codes[1])) + ) + } optimum <- successful[which.max(logliks[successful])] init <- out[[optimum]]$solution if (save_all_solutions) { @@ -235,10 +236,11 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method, opts = control ) end_time <- proc.time() - stopifnot_( - out$status >= 0, - paste("Optimization terminated due to error:", error_msg(out$status)) - ) + if (out$status < 0) { + warning_( + paste("Optimization terminated due to error:", error_msg(out$status)) + ) + } pars <- out$solution model$etas$pi <- create_eta_pi_mnhmm( pars[seq_len(n_i)], S, K_pi, D diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index 2e64b49..7953025 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -167,11 +167,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs return_codes <- unlist(lapply(out, "[[", "status")) successful <- which(return_codes > 0) - stopifnot_( - length(successful) > 0, - c("All optimizations terminated due to error.", - "Error of first restart: ", error_msg(return_codes[1])) - ) + if (length(successful) == 0) { + warning_( + c("All optimizations terminated due to error.", + "Error of first restart: ", error_msg(return_codes[1])) + ) + } optimum <- successful[which.max(logliks[successful])] init <- out[[optimum]]$solution if (save_all_solutions) { @@ -188,10 +189,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun opts = control ) end_time <- proc.time() - stopifnot_( - out$status >= 0, - paste("Optimization terminated due to error:", error_msg(out$status)) - ) + if (out$status < 0) { + warning_( + paste("Optimization terminated due to error:", error_msg(out$status)) + ) + } pars <- out$solution model$etas$pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_pi) @@ -255,11 +257,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun }, future.seed = TRUE) return_codes <- unlist(lapply(out, "[[", "return_code")) - stopifnot_( - any(return_codes == 0), - c("All optimizations terminated due to error.", - "Error of first restart: ", error_msg(return_codes[1])) - ) + if (all(return_codes < 0)) { + warning_( + c("All optimizations terminated due to error.", + "Error of first restart: ", error_msg(return_codes[1])) + ) + } logliks <- unlist(lapply(out, "[[", "penalized_logLik")) * n_obs optimum <- out[[which.max(logliks)]] init <- stats::setNames( @@ -297,11 +300,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun control_mstep$print_level, lambda, pseudocount) } end_time <- proc.time() - stopifnot_( - out$return_code == 0, - paste("Optimization terminated due to error:", error_msg(out$return_code)) - ) - + if (out$return_code < 0) { + warning_( + paste("Optimization terminated due to error:", error_msg(out$return_code)) + ) + } model$etas$pi[] <- out$eta_pi model$gammas$pi <- eta_to_gamma_mat(model$etas$pi) model$etas$A <- out$eta_A diff --git a/R/utilities.R b/R/utilities.R index 5ca26ad..1c99c51 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -144,12 +144,10 @@ create_emissionArray <- function(model) { #' Convert error message to text #' @noRd error_msg <- function(error) { - gamma <- dplyr::case_when( - error %in% c(-(1:4)) ~ "", - error %in% c(1, -(101:104)) ~ "gamma_pi", - error %in% c(2, -(201:204)) ~ "gamma_A", - error == 3 | error %in% -301:-304 ~ "gamma_B" - ) + + if (error %in% c(1, -(101:104))) gamma <- "gamma_pi" + if (error %in% c(2, -(201:204))) gamma <- "gamma_A" + if (error %in% c(3, -(301:304))) gamma <- "gamma_B" nonfinite_msg <- paste0( "Error: Some of the values in ", gamma, " are nonfinite, likely due to ", @@ -162,22 +160,29 @@ error_msg <- function(error) { } e <- seq(0, 300, by = 100) - msg <- dplyr::case_when( - error %in% 1:3 ~ nonfinite_msg, - error %in% (-1 - e) ~ paste0( - mstep, "NLOPT_FAILURE: Generic failure code." - ), - error %in% (-2 - e) ~ paste0( - mstep, "NLOPT_INVALID_ARGS: Invalid arguments (e.g., lower bounds are ", + if (error %in% 1:3) { + msg <- nonfinite_msg + } + if (error %in% (-1 - e)) { + msg <- paste0(mstep, "NLOPT_FAILURE: Generic failure code.") + } + if (error %in% (-2 - e)) { + msg <- paste0( + mstep, + "NLOPT_INVALID_ARGS: Invalid arguments (e.g., lower bounds are ", "bigger than upper bounds, an unknown algorithm was specified)." - ), - error %in% (-3 - e) ~ paste0( + ) + } + if (error %in% (-3 - e)) { + msg <- paste0( mstep, "NLOPT_OUT_OF_MEMORY: Ran out of memory." - ), - error %in% (-4 - e) ~ paste0( - mstep, + ) + } + if (error %in% (-4 - e)) { + msg <- paste0( + mstep, "NLOPT_ROUNDOFF_LIMITED: Halted because roundoff errors limited progress." ) - ) + } msg } \ No newline at end of file diff --git a/src/nhmm_EM.cpp b/src/nhmm_EM.cpp index 7e972ec..a67107a 100644 --- a/src/nhmm_EM.cpp +++ b/src/nhmm_EM.cpp @@ -12,7 +12,7 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) { mstep_iter++; double value = 0; - + eta_pi = arma::mat(x.memptr(), S - 1, K_pi); gamma_pi = sum_to_zero(eta_pi, Qs); arma::mat tmpgrad(S, K_pi, arma::fill::zeros); @@ -602,7 +602,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } model.mstep_A( @@ -611,7 +621,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } model.mstep_B( @@ -620,7 +640,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } // Update model @@ -795,7 +825,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } model.mstep_A( @@ -804,7 +844,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } model.mstep_B( @@ -813,7 +863,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel( if (model.mstep_error_code != 0) { return Rcpp::List::create( Rcpp::Named("return_code") = model.mstep_error_code, - Rcpp::Named("penalized_logLik") = arma::datum::nan + Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi), + Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A), + Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B), + Rcpp::Named("penalized_logLik") = arma::datum::nan, + Rcpp::Named("penalty_term") = arma::datum::nan, + Rcpp::Named("logLik") = arma::datum::nan, + Rcpp::Named("iterations") = iter, + Rcpp::Named("relative_f_change") = relative_change, + Rcpp::Named("absolute_f_change") = absolute_change, + Rcpp::Named("absolute_x_change") = absolute_x_change, + Rcpp::Named("relative_x_change") = relative_x_change ); } // Update model