Skip to content

Commit

Permalink
fix get_parameter_types, fix categories in as.data.table, fix extende…
Browse files Browse the repository at this point in the history
…d tests, run-extended
  • Loading branch information
santikka committed May 3, 2024
1 parent a988911 commit 249f3e9
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 192 deletions.
7 changes: 4 additions & 3 deletions R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,16 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
)
}
categories <- c(
NA_character_,
ulapply(
x$stan$responses,
unlist(x$stan$responses),
function(y) {
channel <- get_channel(x, y)
if (is_cumulative(channel$family)) {
seq_len(channel$S - 1L)
} else {
} else if (is_categorical(channel$family)) {
channel$categories[-1L]
} else {
NA_character_
}
}
)
Expand Down
3 changes: 2 additions & 1 deletion R/coef.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#' deltas <- coef(gaussian_example_fit, type = "delta")
#'
coef.dynamitefit <- function(object,
types = NULL, parameters = NULL,
types = c("alpha", "beta", "delta"),
parameters = NULL,
responses = NULL, times = NULL, groups = NULL,
summary = TRUE, probs = c(0.05, 0.95), ...) {
stopifnot_(
Expand Down
2 changes: 1 addition & 1 deletion R/getters.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ get_parameter_types <- function(x, ...) {
#' @rdname get_parameter_types
#' @export
get_parameter_types.dynamitefit <- function(x, ...) {
d <- as.data.table(x)
d <- as.data.table(x, types = all_types)
unique(d$type)
}

Expand Down
1 change: 0 additions & 1 deletion R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ loo.dynamitefit <- function(x, separate_channels = FALSE, thin = 1L, ...) {
)
loo::loo(ll, r_eff = reff)
}

if (separate_channels) {
ll <- split(
x = data.table::melt(
Expand Down
2 changes: 1 addition & 1 deletion R/summary.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Summary for a Dynamite Model Fit
#'
#' The `summary` method provides statistics of the posterior samples of the
#' The `summary()` method provides statistics of the posterior samples of the
#' model; this is an alias of [dynamite::as.data.frame.dynamitefit()] with
#' `summary = TRUE`.
#'
Expand Down
2 changes: 1 addition & 1 deletion man/coef.dynamitefit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/dynamite.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-extended.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ test_that("multivariate gaussian fit and predict work", {
iter = 2000,
refresh = 0
)
expect_error(sumr <- summary(fit, type = "corr"), NA)
expect_error(sumr <- summary(fit, types = "corr"), NA)
expect_equal(sumr$mean, cov2cor(S)[2, 1], tolerance = 0.1)
expect_error(sumr <- summary(fit, type = "sigma"), NA)
expect_error(sumr <- summary(fit, types = "sigma"), NA)
expect_equal(sumr$mean, c(0.5, sqrt(diag(S))), tolerance = 0.1)
expect_error(sumr <- summary(fit, type = "beta"), NA)
expect_error(sumr <- summary(fit, types = "beta"), NA)
expect_equal(sumr$mean, c(0.5, 0.7, -0.2, 0.4), tolerance = 0.1)
expect_error(predict(fit, n_draws = 5), NA)
})
Expand Down
11 changes: 7 additions & 4 deletions tests/testthat/test-lfactor.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ test_that("nonidentifiable lfactor specification gives warning", {
correlated = TRUE,
noncentered_psi = TRUE
) + splines(30),
data = d, time = "time", group = "id", debug = list(no_compile = TRUE)),
data = d,
time = "time",
group = "id",
debug = list(no_compile = TRUE)),
NA
)
expect_warning(
Expand Down Expand Up @@ -167,22 +170,22 @@ test_that("latent factor related parameters can be got", {
skip_if_not(run_extended_tests)
expect_equal(
get_parameter_types(latent_factor_example_fit),
c("alpha", "sigma", "lambda", "sigma_lambda", "psi", "tau_psi", "omega_psi")
c("alpha", "lambda", "omega_psi", "psi", "sigma", "sigma_lambda", "tau_psi")
)
})

test_that("lambdas can be plotted", {
skip_if_not(run_extended_tests)
expect_error(
plot_lambdas(latent_factor_example_fit),
plot(latent_factor_example_fit, types = "lambda"),
NA
)
})

test_that("psis can be plotted", {
skip_if_not(run_extended_tests)
expect_error(
plot_psis(latent_factor_example_fit),
plot(latent_factor_example_fit, types = "psi"),
NA
)
})
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-output.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ test_that("default plot works", {

test_that("trace type plot works", {
expect_error(
plot(gaussian_example_fit, plot_type = "trace", type = "beta"),
plot(gaussian_example_fit, plot_type = "trace", types = "beta"),
NA
)
})
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ test_that("fitted works", {
idx <- categorical_example_fit$permutation[1L]
iter <- idx %% n
chain <- 1 + idx %/% n
xzy <- categorical_example_fit$data |> dplyr::filter(id == 5 & time == 20)
xzy <- categorical_example_fit$data |>
dplyr::filter(id == 5 & time == 20)
manual <- as_draws(categorical_example_fit) |>
dplyr::filter(.iteration == iter & .chain == chain) |>
dplyr::summarise(
Expand Down
Loading

0 comments on commit 249f3e9

Please sign in to comment.