Skip to content

Commit

Permalink
predict for cumulative
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed May 3, 2024
1 parent 7191774 commit 4f7d4f5
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 21 deletions.
16 changes: 10 additions & 6 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ initialize_predict <- function(object, newdata, type, eval_type, funs, impute,
)
)
simulated <- newdata[, .SD, .SDcols = c(resp_draw, group_var, time_var)]
obs_names <- setdiff(names(newdata), resp_draw)
mode <- "full"
if (length(funs) > 0L) {
mode <- "summary"
Expand All @@ -354,7 +355,7 @@ initialize_predict <- function(object, newdata, type, eval_type, funs, impute,
object = object,
simulated = simulated,
storage = simulated,
observed = newdata[, .SD, .SDcols = setdiff(names(newdata), resp_draw)],
observed = newdata[, .SD, .SDcols = obs_names],
mode = mode,
type = type,
eval_type = eval_type,
Expand Down Expand Up @@ -396,6 +397,8 @@ predict_ <- function(object, simulated, storage, observed,
attr(object$dformulas$lag_det, "rank_order")
)
ro_ls <- seq_along(lhs_ls)
resp_store <- grep("_store", names(observed), value = TRUE)
obs_merge <- setdiff(names(observed), c(names(simulated), resp_store))
n_group <- n_unique(observed[[group_var]])
time <- observed[[time_var]]
draw_time <- rep(time, each = n_draws)
Expand All @@ -416,8 +419,8 @@ predict_ <- function(object, simulated, storage, observed,
env = list(n_new = n_new, n_draws = n_draws)
]
simulated[,
(".draw") := rep(seq.int(1L, n_draws), n_new),
env = list(n_new = n_new, n_draws = n_draws)
(".draw") := rep(seq.int(1L, n_draws), n_new),
env = list(n_new = n_new, n_draws = n_draws)
]
idx <- which(draw_time == u_time[1L]) + (fixed - 1L) * n_draws
n_sim <- n_draws
Expand Down Expand Up @@ -488,7 +491,10 @@ predict_ <- function(object, simulated, storage, observed,
for (j in model_topology) {
cg_idx <- which(channel_groups == j)
k <- cg_idx[1L]
sub <- cbind_datatable(simulated[idx, ], observed[idx_obs, ])
sub <- cbind_datatable(
simulated[idx, ],
observed[idx_obs, .SD, .SDcols = obs_merge]
)
if (is_deterministic(families[[k]])) {
assign_deterministic_predict(
simulated,
Expand Down Expand Up @@ -553,8 +559,6 @@ predict_ <- function(object, simulated, storage, observed,
if (identical(mode, "full")) {
lhs_lag <- c(lhs_ld, lhs_ls)
if (length(lhs_lag) > 0L) {
# This if might not be needed in next version of data.table
#simulated[, c(lhs_ld, lhs_ls) := NULL]
simulated[, c(lhs_lag) := NULL]
}
data.table::setkeyv(simulated, cols = c(".draw", group_var, time_var))
Expand Down
116 changes: 101 additions & 15 deletions R/predict_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ parse_newdata <- function(dformulas, newdata, data, type, eval_type,
}
for (i in seq_along(resp_stoch)) {
y <- resp_stoch[i]
if (type %in% c("mean", "link", "fitted")) {
family <- dformulas$stoch[[i]]$family
if (type %in% c("mean", "fitted")) {
# create a separate column for each level of
# a categorical response variables
pred_col <- ifelse_(
is_categorical(dformulas$stoch[[i]]$family),
is_categorical(family) || (is_cumulative(family) && type != "link"),
glue::glue("{y}_{type}_{categories[[y]]}"),
glue::glue("{y}_{type}")
)
Expand Down Expand Up @@ -370,20 +371,25 @@ clear_nonfixed <- function(newdata, newdata_null, resp_stoch, eval_type,
env = list(fixed = fixed, group = group_var)
]$V1
} else {
first_obs <- function(x) {
out <- base::which(
base::apply(!base::is.na(x), 1L, base::any)
)
ifelse_(base::length(out) == 0, 1L, out[1L])
}
clear_idx <- newdata[,
.I[
base::seq.int(
fixed + base::which(
base::apply(!base::is.na(.SD), 1L, base::any)
)[1L],
fixed + first_obs(.SD),
.N
)
],
.SDcols = resp_stoch,
by = group,
env = list(
fixed = fixed,
group = group_var
group = group_var,
first_obs = first_obs
)
]$V1
}
Expand Down Expand Up @@ -520,7 +526,8 @@ prepare_eval_envs <- function(object, simulated, observed,
new_ids <- unique(observed[[group_var]])
extra_levels <- unique(new_ids[!new_ids %in% orig_ids])
has_lfactor <- attr(object$dformulas$stoch, "lfactor")$P > 0
stopifnot_(identical(length(extra_levels), 0L) || !has_lfactor,
stopifnot_(
identical(length(extra_levels), 0L) || !has_lfactor,
c(
"Grouping variable {.var {group_var}} contains unknown levels:",
`x` = "Level{?s} {.val {as.character(extra_levels)}}
Expand Down Expand Up @@ -603,9 +610,18 @@ prepare_eval_envs <- function(object, simulated, observed,
j <- cg_idx[1L]
k <- k + 1L
resp <- object$dformulas$all[[j]]$response
resp_levels <- ifelse_(
is_cumulative(family),
attr(
attr(object$stan$responses, "resp_class")[[resp]],
"levels"
),
resp
)
prepare_eval_env_univariate(
e = e,
resp = resp,
resp_levels = resp_levels,
cvars = channel_vars[[k]],
samples = samples,
nu_samples = nu_samples,
Expand All @@ -623,11 +639,12 @@ prepare_eval_envs <- function(object, simulated, observed,
#' Prepare a Evaluation Environment for a Univariate Channel
#'
#' @noRd
prepare_eval_env_univariate <- function(e, resp, cvars, samples, nu_samples,
has_random_effects,
prepare_eval_env_univariate <- function(e, resp, resp_levels, cvars,
samples, nu_samples, has_random_effects,
idx, type, eval_type) {
alpha <- paste0("alpha_", resp)
beta <- paste0("beta_", resp)
cutpoints <- paste0("cutpoints_", resp)
delta <- paste0("delta_", resp)
phi <- paste0("phi_", resp)
sigma <- paste0("sigma_", resp)
Expand All @@ -654,11 +671,32 @@ prepare_eval_env_univariate <- function(e, resp, cvars, samples, nu_samples,
nus <- make.unique(rep(paste0("nu_", resp), e$K_random))
e$nu <- nu_samples[, , nus, drop = FALSE]
}
if (cvars$has_fixed_intercept) {
e$alpha <- array(samples[[alpha]][idx], c(e$n_draws, 1L))
}
if (cvars$has_varying_intercept) {
e$alpha <- samples[[alpha]][idx, , drop = FALSE]
if (is_cumulative(e$family)) {
e$d <- cvars$S
e$mean_cols <- paste0(resp, "_mean_", resp_levels)
e$fitted_cols <- paste0(resp, "_fitted_", resp_levels)
e$invlink <- ifelse_(
identical(e$family$link, "logit"),
plogis,
pnorm
)
if (cvars$has_fixed_intercept) {
e$cutpoints <- samples[[cutpoints]][idx, , drop = FALSE]
e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , drop = FALSE]
e$alpha <- matrix(0.0, e$n_draws, 1L)
}
if (cvars$has_varying_intercept) {
e$cutpoints <- samples[[cutpoints]][idx, , , drop = FALSE]
e$cutpoints <- e$cutpoints[rep_len(e$n_draws, e$k), , , drop = FALSE]
e$alpha <- matrix(0.0, e$n_draws, dim(e$cutpoints)[2L])
}
} else {
if (cvars$has_fixed_intercept) {
e$alpha <- array(samples[[alpha]][idx], c(e$n_draws, 1L))
}
if (cvars$has_varying_intercept) {
e$alpha <- samples[[alpha]][idx, , drop = FALSE]
}
}
if (cvars$has_lfactor) {
e$lambda <- samples[[lambda]][idx, , drop = FALSE]
Expand Down Expand Up @@ -713,7 +751,7 @@ generate_sim_call_univariate <- function(resp, family, type, eval_type,
has_varying_intercept,
has_random_intercept,
has_offset, has_lfactor) {

idx_cuts <- ifelse_(has_varying_intercept, "[, time, ]", "")
out <- paste0(
"{\n",
"idx_draw <- seq.int(1L, n_draws) - n_draws\n",
Expand Down Expand Up @@ -1022,6 +1060,19 @@ predict_expr$fitted$categorical <- "
}}
"

predict_expr$fitted$cumulative <- "
prob <- cbind(1, invlink(xbeta - cuts{idx_cuts})) -
cbind(invlink(xbeta - cuts{idx_cuts}), 0)
for (s in 1:d) {{
data.table::set(
x = out,
i = idx,
j = fitted_cols[s],
value = prob[, s]
)
}}
"

predict_expr$fitted$multinomial <- "
mval <- exp(xbeta - log_sum_exp_rows(xbeta, k, d))
for (s in 1:d) {{
Expand Down Expand Up @@ -1119,6 +1170,17 @@ predict_expr$predicted$categorical <- "
)
"

predict_expr$predicted$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
data.table::set(
x = out,
i = idx_data,
j = '{resp}',
value = max.col(log(prob) - log(-log(runif(k * d))))[idx_out]
)
"

predict_expr$predicted$multinomial <- "
pred <- matrix(0L, k, d)
n <- max(trials)
Expand Down Expand Up @@ -1257,6 +1319,19 @@ predict_expr$mean$categorical <- "
}}
"

predict_expr$mean$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
for (s in 1:d) {{
data.table::set(
x = out,
i = idx_data,
j = mean_cols[s],
value = prob[idx_out, s]
)
}}
"

predict_expr$mean$multinomial <- "
mval <- exp(xbeta - log_sum_exp_rows(xbeta, k, d))
for (s in 1:d) {{
Expand Down Expand Up @@ -1384,6 +1459,17 @@ predict_expr$loglik$categorical <- "
)
"

predict_expr$loglik$cumulative <- "
prob <- cbind(1, invlink(xbeta - cutpoints{idx_cuts})) -
cbind(invlink(xbeta - cutpoints{idx_cuts}), 0)
data.table::set(
x = out,
i = idx,
j = '{resp}_loglik',
value = log(prob[cbind(seq_along(y), y)])
)
"

predict_expr$loglik$multinomial <- "
data.table::set(
x = out,
Expand Down

0 comments on commit 4f7d4f5

Please sign in to comment.