Skip to content

Commit

Permalink
errors to warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 10, 2024
1 parent a21bc03 commit 5068577
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 54 deletions.
20 changes: 11 additions & 9 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,12 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes > 0)
stopifnot_(
length(successful) > 0,
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
if (length(successful) == 0) {
warning_(
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
}
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
if (save_all_solutions) {
Expand All @@ -235,10 +236,11 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
opts = control
)
end_time <- proc.time()
stopifnot_(
out$status >= 0,
paste("Optimization terminated due to error:", error_msg(out$status))
)
if (out$status < 0) {
warning_(
paste("Optimization terminated due to error:", error_msg(out$status))
)
}
pars <- out$solution
model$etas$pi <- create_eta_pi_mnhmm(
pars[seq_len(n_i)], S, K_pi, D
Expand Down
41 changes: 22 additions & 19 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
logliks <- -unlist(lapply(out, "[[", "objective")) * n_obs
return_codes <- unlist(lapply(out, "[[", "status"))
successful <- which(return_codes > 0)
stopifnot_(
length(successful) > 0,
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
if (length(successful) == 0) {
warning_(
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
}
optimum <- successful[which.max(logliks[successful])]
init <- out[[optimum]]$solution
if (save_all_solutions) {
Expand All @@ -188,10 +189,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
opts = control
)
end_time <- proc.time()
stopifnot_(
out$status >= 0,
paste("Optimization terminated due to error:", error_msg(out$status))
)
if (out$status < 0) {
warning_(
paste("Optimization terminated due to error:", error_msg(out$status))
)
}

pars <- out$solution
model$etas$pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_pi)
Expand Down Expand Up @@ -255,11 +257,12 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
},
future.seed = TRUE)
return_codes <- unlist(lapply(out, "[[", "return_code"))
stopifnot_(
any(return_codes == 0),
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
if (all(return_codes < 0)) {
warning_(
c("All optimizations terminated due to error.",
"Error of first restart: ", error_msg(return_codes[1]))
)
}
logliks <- unlist(lapply(out, "[[", "penalized_logLik")) * n_obs
optimum <- out[[which.max(logliks)]]
init <- stats::setNames(
Expand Down Expand Up @@ -297,11 +300,11 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, pseudocoun
control_mstep$print_level, lambda, pseudocount)
}
end_time <- proc.time()
stopifnot_(
out$return_code == 0,
paste("Optimization terminated due to error:", error_msg(out$return_code))
)

if (out$return_code < 0) {
warning_(
paste("Optimization terminated due to error:", error_msg(out$return_code))
)
}
model$etas$pi[] <- out$eta_pi
model$gammas$pi <- eta_to_gamma_mat(model$etas$pi)
model$etas$A <- out$eta_A
Expand Down
43 changes: 24 additions & 19 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,10 @@ create_emissionArray <- function(model) {
#' Convert error message to text
#' @noRd
error_msg <- function(error) {
gamma <- dplyr::case_when(
error %in% c(-(1:4)) ~ "",
error %in% c(1, -(101:104)) ~ "gamma_pi",
error %in% c(2, -(201:204)) ~ "gamma_A",
error == 3 | error %in% -301:-304 ~ "gamma_B"
)

if (error %in% c(1, -(101:104))) gamma <- "gamma_pi"
if (error %in% c(2, -(201:204))) gamma <- "gamma_A"
if (error %in% c(3, -(301:304))) gamma <- "gamma_B"

nonfinite_msg <- paste0(
"Error: Some of the values in ", gamma, " are nonfinite, likely due to ",
Expand All @@ -162,22 +160,29 @@ error_msg <- function(error) {
}

e <- seq(0, 300, by = 100)
msg <- dplyr::case_when(
error %in% 1:3 ~ nonfinite_msg,
error %in% (-1 - e) ~ paste0(
mstep, "NLOPT_FAILURE: Generic failure code."
),
error %in% (-2 - e) ~ paste0(
mstep, "NLOPT_INVALID_ARGS: Invalid arguments (e.g., lower bounds are ",
if (error %in% 1:3) {
msg <- nonfinite_msg
}
if (error %in% (-1 - e)) {
msg <- paste0(mstep, "NLOPT_FAILURE: Generic failure code.")
}
if (error %in% (-2 - e)) {
msg <- paste0(
mstep,
"NLOPT_INVALID_ARGS: Invalid arguments (e.g., lower bounds are ",
"bigger than upper bounds, an unknown algorithm was specified)."
),
error %in% (-3 - e) ~ paste0(
)
}
if (error %in% (-3 - e)) {
msg <- paste0(
mstep, "NLOPT_OUT_OF_MEMORY: Ran out of memory."
),
error %in% (-4 - e) ~ paste0(
mstep,
)
}
if (error %in% (-4 - e)) {
msg <- paste0(
mstep,
"NLOPT_ROUNDOFF_LIMITED: Halted because roundoff errors limited progress."
)
)
}
msg
}
74 changes: 67 additions & 7 deletions src/nhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
mstep_iter++;

double value = 0;

eta_pi = arma::mat(x.memptr(), S - 1, K_pi);
gamma_pi = sum_to_zero(eta_pi, Qs);
arma::mat tmpgrad(S, K_pi, arma::fill::zeros);
Expand Down Expand Up @@ -602,7 +602,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
model.mstep_A(
Expand All @@ -611,7 +621,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
model.mstep_B(
Expand All @@ -620,7 +640,17 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
// Update model
Expand Down Expand Up @@ -795,7 +825,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
model.mstep_A(
Expand All @@ -804,7 +844,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
model.mstep_B(
Expand All @@ -813,7 +863,17 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
if (model.mstep_error_code != 0) {
return Rcpp::List::create(
Rcpp::Named("return_code") = model.mstep_error_code,
Rcpp::Named("penalized_logLik") = arma::datum::nan
Rcpp::Named("eta_pi") = Rcpp::wrap(model.eta_pi),
Rcpp::Named("eta_A") = Rcpp::wrap(model.eta_A),
Rcpp::Named("eta_B") = Rcpp::wrap(model.eta_B),
Rcpp::Named("penalized_logLik") = arma::datum::nan,
Rcpp::Named("penalty_term") = arma::datum::nan,
Rcpp::Named("logLik") = arma::datum::nan,
Rcpp::Named("iterations") = iter,
Rcpp::Named("relative_f_change") = relative_change,
Rcpp::Named("absolute_f_change") = absolute_change,
Rcpp::Named("absolute_x_change") = absolute_x_change,
Rcpp::Named("relative_x_change") = relative_x_change
);
}
// Update model
Expand Down

0 comments on commit 5068577

Please sign in to comment.