Skip to content

Commit

Permalink
fix predict for multinomial, more tests, threaded cumulative, run-ext…
Browse files Browse the repository at this point in the history
…ended
  • Loading branch information
santikka committed May 7, 2024
1 parent cc033f8 commit af47542
Show file tree
Hide file tree
Showing 11 changed files with 647 additions and 71 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* A new argument `plot_type` has been added to control what type of plot will be drawn by the `plot()` method. The default value `"default"` draws the posterior means and posterior intervals of all parameters. The old functionality of drawing posterior densities and traceplots is provided by the option `"trace"`.
* The `plot()` method has gained the argument `n_params` to limit the amount of parameters drawn at once (per parameter type).
* Both time-varying and time-invariant parameters can now be plotted simultaneously.
* Fixed an issue with `predict()` and `fitted()` for multinomial responses.

# dynamite 1.5.0

Expand Down
1 change: 1 addition & 0 deletions R/as_data_frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#'
#' * `alpha`\cr Intercept terms (time-invariant or time-varying).
#' * `beta`\cr Time-invariant regression coefficients.
#' * `cutpoints`\cr Cutpoints for ordinal regression.
#' * `delta`\cr Time-varying regression coefficients.
#' * `nu`\cr Group-level random effects.
#' * `lambda`\cr Factor loadings.
Expand Down
49 changes: 27 additions & 22 deletions R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,22 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
)
)
}
categories <- c(
ulapply(
unlist(x$stan$responses),
function(y) {
channel <- get_channel(x, y)
if (is_cumulative(channel$family)) {
seq_len(channel$S - 1L)
} else if (is_categorical(channel$family)) {
channel$categories[-1L]
} else {
NA_character_
categories <- unique(
c(
NA_character_,
ulapply(
unlist(x$stan$responses),
function(y) {
channel <- get_channel(x, y)
if (is_cumulative(channel$family)) {
seq_len(channel$S - 1L)
} else if (is_categorical(channel$family)) {
channel$categories[-1L]
} else {
NA_character_
}
}
}
)
)
)
tmp <- data.table::as.data.table(
Expand Down Expand Up @@ -368,18 +371,20 @@ as_data_table_default <- function(type, draws, response, ...) {
)
}

#' Shrinkage feature removed at least for now.
#'
#' @describeIn as_data_table_default Data Table for a "xi" Parameter
#' @noRd
as_data_table_xi <- function(x, draws, n_draws, ...) {
D <- x$stan$model_vars$D
data.table::data.table(
parameter = rep(
paste0("xi_d", seq_len(D - 1L)),
each = n_draws
),
value = c(draws)
)
}
# as_data_table_xi <- function(x, draws, n_draws, ...) {
# D <- x$stan$model_vars$D
# data.table::data.table(
# parameter = rep(
# paste0("xi_d", seq_len(D - 1L)),
# each = n_draws
# ),
# value = c(draws)
# )
# }

#' @describeIn as_data_table_default Data Table for a "corr_nu" Parameter
#' @noRd
Expand Down
7 changes: 7 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,13 @@ predict_ <- function(object, simulated, storage, observed,
)
idx_summ <- which(summaries[[time_var]] == u_time[1L]) + (fixed - 1L)
}
data.table::setcolorder(
simulated,
neworder = c(
group_var, time_var, ".draw",
setdiff(names(simulated), c(group_var, time_var, ".draw"))
)
)
eval_envs <- prepare_eval_envs(
object,
simulated,
Expand Down
20 changes: 11 additions & 9 deletions R/predict_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,15 @@ prepare_eval_env_multivariate <- function(e, resp, resp_levels, cvars,
e$d <- d
e$resp <- resp
e$L <- samples[[paste(c("L", resp), collapse = "_")]]
e$link_cols <- ifelse_(
is_multivariate(e$family) && !is_multinomial(e$family),
paste0(resp, "_link"),
paste0(resp, "_link_", resp_levels)
)
e$mean_cols <- paste0(resp, "_mean_", resp_levels)
e$fitted_cols <- paste0(resp, "_fitted_", resp_levels)
if (is_multivariate(e$family)) {
e$link_cols <- paste0(resp, "_link")
e$mean_cols <- paste0(resp, "_mean")
e$fitted_cols <- paste0(resp, "_fitted")
} else {
e$link_cols <- paste0(resp, "_link_", resp_levels)
e$mean_cols <- paste0(resp, "_mean_", resp_levels)
e$fitted_cols <- paste0(resp, "_fitted_", resp_levels)
}
e$sigma <- matrix(0.0, e$n_draws, d)
has_fixed <- logical(d)
has_varying <- logical(d)
Expand Down Expand Up @@ -1061,8 +1063,8 @@ predict_expr$fitted$categorical <- "
"

predict_expr$fitted$cumulative <- "
prob <- cbind(1, invlink(xbeta - cuts{idx_cuts})) -
cbind(invlink(xbeta - cuts{idx_cuts}), 0)
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
for (s in 1:d) {{
data.table::set(
x = out,
Expand Down
11 changes: 6 additions & 5 deletions R/stanblocks.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,12 @@ create_transformed_parameters_lines <- function(idt, backend, cvars, cgvars) {
create_model <- function(idt, backend, cg, cvars, cgvars, mvars, threading) {
spline_def <- mvars$spline_def
spline_text <- ""
if (!is.null(spline_def) && spline_def$shrinkage) {
xi_prior <- mvars$common_priors
xi_prior <- xi_prior[xi_prior$parameter == "xi", "prior"]
spline_text <- paste_rows("xi ~ {xi_prior};", .indent = idt(1))
}
# Shringake feature removed for now
# if (!is.null(spline_def) && spline_def$shrinkage) {
# xi_prior <- mvars$common_priors
# xi_prior <- xi_prior[xi_prior$parameter == "xi", "prior"]
# spline_text <- paste_rows("xi ~ {xi_prior};", .indent = idt(1))
# }
random_text <- ""
if (mvars$random_def$M > 0L) {
if (mvars$random_def$correlated) {
Expand Down
48 changes: 34 additions & 14 deletions R/stanblocks_families.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,16 @@ intercept_lines <- function(y, obs, family, has_varying, has_fixed, has_random,
backend, ydim = y, ...) {

intercept_alpha <- ifelse_(
has_fixed_intercept,
glue::glue("alpha_{y}"),
is_cumulative(family),
"0",
ifelse_(
has_varying_intercept,
glue::glue("alpha_{y}[t]"),
"0"
has_fixed_intercept,
glue::glue("alpha_{y}"),
ifelse_(
has_varying_intercept,
glue::glue("alpha_{y}[t]"),
"0"
)
)
)
offset <- ifelse_(
Expand Down Expand Up @@ -227,12 +231,15 @@ loglik_lines_default <- function(y, idt, obs, family, has_missing,
)
glm <- attr(intercept, "glm")
scalar_intercept <- !has_offset && !has_random && !has_random_intercept &&
!has_lfactor && (glm || !has_X)
!has_lfactor && (glm || !has_X) && !is_cumulative(family)
n_obs <- ifelse_(
nchar(obs),
paste0("n_obs_", y, "[t]"),
"N"
)
if (is_cumulative(family) && intercept == "0") {
intercept <- glue::glue("rep_vector(0, {n_obs})")
}
intercept_line <- ifelse_(
scalar_intercept,
"real intercept_{y} = {intercept};",
Expand All @@ -254,9 +261,12 @@ loglik_lines_default <- function(y, idt, obs, family, has_missing,
"data int N"
),
paste("data", y_type, "y_{y}"),
onlyif(has_fixed_intercept, "real alpha_{y}"),
onlyif(
has_varying_intercept,
has_fixed_intercept && !is_cumulative(family),
"real alpha_{y}"
),
onlyif(
has_varying_intercept && !is_cumulative(family),
stan_array_arg(backend, "real", "alpha_{y}", 0L)
),
onlyif(
Expand Down Expand Up @@ -618,7 +628,7 @@ loglik_lines_categorical <- function(y, idt, obs, family, has_missing,

loglik_lines_cumulative <- function(y, obs, idt, default, family,
has_fixed_intercept,
has_varying_intercept, ...) {
has_varying_intercept, backend, ...) {
u <- default$u
is_logit <- identical("logit", family$link)
link <- ifelse_(is_logit, "logistic", "probit")
Expand All @@ -643,7 +653,11 @@ loglik_lines_cumulative <- function(y, obs, idt, default, family,
default$fun_args,
onlyif(
has_fixed_intercept,
glue::glue("ordered[S_{y} - 1] cutpoints_{y}")
glue::glue("vector cutpoints_{y}")
),
onlyif(
has_varying_intercept,
glue::glue("array[] vector cutpoints_{y}")
)
)
}
Expand Down Expand Up @@ -2023,9 +2037,12 @@ prior_lines <- function(y, idt, noncentered, shrinkage,
model_lines_default <- function(y, obs, idt, threading, default, family, ...) {
likelihood <- ifelse_(
threading,
glue::glue(
"target += reduce_sum({family$name}_loglik_{y}_lpmf, {default$seq1T}, ",
"grainsize, {default$fun_call_args});"
paste_rows(
paste0(
"target += reduce_sum({family$name}_loglik_{y}_lpmf, {default$seq1T}, ",
"grainsize, {default$fun_call_args});"
),
.indent = idt(1)
),
do.call(
paste0("loglik_lines_", family$name),
Expand Down Expand Up @@ -2144,7 +2161,10 @@ model_lines_categorical <- function(y, idt, obs, family, priors,
model_lines_cumulative <- function(y, obs, idt, priors,
threading, default, ...) {
if (threading) {
default$fun_call_args <- cs(default$fun_call_args, glue::glue("alpha_{y}"))
default$fun_call_args <- cs(
default$fun_call_args,
glue::glue("cutpoints_{y}")
)
}
paste_rows(
priors,
Expand Down
Loading

0 comments on commit af47542

Please sign in to comment.