Skip to content

Commit

Permalink
fix initial values based on probs
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Oct 5, 2024
1 parent 440f3d8 commit 741b8bc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
4 changes: 2 additions & 2 deletions R/create_initial_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ create_eta_omega_inits <- function(x, D, K, init_sd = 0) {
#' @noRd
create_inits_vector <- function(x, n, K, sd = 0, D = 1) {
cbind(
inv_softmax(x)[-1],
p_to_eta(x), # intercepts
matrix(rnorm((n - 1) * (K - 1), sd = sd), n - 1, K - 1)
)
}
create_inits_matrix <- function(x, n, m, K, sd = 0) {
z <- array(0, c(m - 1, K, n))
for (i in seq_len(n)) {
z[, , i] <- t(create_inits_vector(x[i, ], m, K, sd))
z[, , i] <- create_inits_vector(x[i, ], m, K, sd)
}
z
}
Expand Down
24 changes: 12 additions & 12 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
Qd <- t(create_Q(D))
if (need_grad) {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
}
# if (any(!is.finite(exp(pars)))) {
# return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
# }
eta_pi <- create_eta_pi_mnhmm(pars[seq_len(n_i)], S, K_i, D)
eta_A <- create_eta_A_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D)
eta_B <- create_eta_B_mnhmm(
Expand All @@ -112,9 +112,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
} else {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(Inf)
}
# if (any(!is.finite(exp(pars)))) {
# return(Inf)
# }
eta_pi <- create_eta_pi_mnhmm(pars[seq_len(n_i)], S, K_i, D)
eta_A <- create_eta_A_mnhmm(pars[n_i + seq_len(n_s)], S, K_s, D)
eta_B <- create_eta_B_mnhmm(
Expand All @@ -137,9 +137,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
Qd <- t(create_Q(D))
if (need_grad) {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
}
# if (any(!is.finite(exp(pars)))) {
# return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
# }
eta_pi <- create_eta_pi_mnhmm(pars[seq_len(n_i)], S, K_i, D)
eta_A <- create_eta_A_mnhmm(
pars[n_i + seq_len(n_s)],
Expand All @@ -164,9 +164,9 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
} else {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(Inf)
}
# if (any(!is.finite(exp(pars)))) {
# return(Inf)
# }
eta_pi <- create_eta_pi_mnhmm(pars[seq_len(n_i)], S, K_i, D)
eta_A <- create_eta_A_mnhmm(
pars[n_i + seq_len(n_s)], S, K_s, D
Expand Down
24 changes: 12 additions & 12 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
Qm <- t(create_Q(M))
if (need_grad) {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
}
# if (any(!is.finite(exp(pars)))) {
# return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
# }
eta_pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_i)
eta_A <- create_eta_A_nhmm(pars[n_i + seq_len(n_s)], S, K_s)
eta_B <- create_eta_B_nhmm(
Expand All @@ -94,9 +94,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
} else {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(Inf)
}
# if (any(!is.finite(exp(pars)))) {
# return(Inf)
# }
eta_pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_i)
eta_A <- create_eta_A_nhmm(pars[n_i + seq_len(n_s)], S, K_s)
eta_B <- create_eta_B_nhmm(
Expand All @@ -114,9 +114,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
Qm <- lapply(M, function(m) t(create_Q(m)))
if (need_grad) {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
}
# if (any(!is.finite(exp(pars)))) {
# return(list(objective = Inf, gradient = rep(-Inf, length(pars))))
# }
eta_pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_i)
eta_A <- create_eta_A_nhmm(pars[n_i + seq_len(n_s)], S, K_s)
eta_B <- create_eta_multichannel_B_nhmm(
Expand All @@ -131,9 +131,9 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, penalty, ...) {
}
} else {
objectivef <- function(pars) {
if (any(!is.finite(exp(pars)))) {
return(Inf)
}
# if (any(!is.finite(exp(pars)))) {
# return(Inf)
# }
eta_pi <- create_eta_pi_nhmm(pars[seq_len(n_i)], S, K_i)
eta_A <- create_eta_A_nhmm(pars[n_i + seq_len(n_s)], S, K_s)
eta_B <- create_eta_multichannel_B_nhmm(
Expand Down
9 changes: 5 additions & 4 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ is_list_of_lists <- function(x) {
}
}
}
#' Regularized Inverse Softmax Function
#' (Regularized) Inverse of softmax(Q*eta)
#'
#' @noRd
inv_softmax <- function(x) {
p_to_eta <-function(x) {
Q <- create_Q(length(x))
x <- pmin(pmax(x, 0.001), 0.999)
x <- x / sum(x)
log(x) - log(x[1])
log_x <- log(x)
t(Q) %*% (log_x - mean(log_x))
}

#' Stop Function Execution Unless Condition Is True
#'
#' Function copied from the `dynamite` package.
Expand Down

0 comments on commit 741b8bc

Please sign in to comment.