gee model #154
Closed
spsanderson
started this conversation in
Ideas
gee model
#154
Replies: 3 comments 2 replies
-
I have made what I think can be a generic update to this that will add a class to the model_spec output #157 |
Beta Was this translation helpful? Give feedback.
2 replies
-
I think I am starting to get a framework of sorts going. Here is what I ended up with tonight: New internal_make_spec_tbl: internal_make_spec_tblv2 <- function(.model_tbl){
# Tidyeval ----
model_tbl <- .model_tbl
# Checks ----
if (!inherits(model_tbl, "tidyaml_base_tbl")){
rlang::abort(
message = "The model tibble must come from the make base tbl function.",
use_cli_format = TRUE
)
}
# Manipulation
model_factor_tbl <- model_tbl |>
dplyr::mutate(.model_id = dplyr::row_number() |>
forcats::as_factor()) |>
dplyr::select(.model_id, dplyr::everything())
# Make a group split object list
models_list <- model_factor_tbl |>
dplyr::group_split(.model_id)
# Make the Workflow Object using purrr imap
model_spec <- models_list |>
purrr::imap(
.f = function(obj, id){
# Pull the model column and then pluck the model
pe <- obj |> dplyr::pull(2) |> purrr::pluck(1)
pm <- obj |> dplyr::pull(3) |> purrr::pluck(1)
pf <- obj |> dplyr::pull(4) |> purrr::pluck(1)
ret <- match.fun(pf)(mode = pm, engine = pe)
# Add parsnip engine and fns as class
# class(ret) <- c(
# class(ret),
# paste0(base::tolower(pe), "_", base::tolower(pf))
# )
# Return the result
attributes(ret)$.tidyaml_mod_class <- paste0(base::tolower(pe), "_", base::tolower(pf))
return(ret)
}
)
# Return
# Make sure to return as a tibble
model_spec_ret <- model_factor_tbl |>
dplyr::mutate(model_spec = model_spec) |>
dplyr::mutate(.model_id = as.integer(.model_id))
return(model_spec_ret)
} New internal_make_wflw_gee_lin_reg specific to gee linear reg: internal_make_wflw_gee_lin_reg <- function(.model_tbl, .rec_obj){
# Tidyeval ----
model_tbl <- .model_tbl
rec_obj <- .rec_obj
mod_atb <- attributes(model_tbl$model_spec[[1]])
# Checks ----
if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
rlang::abort(
message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
use_cli_format = TRUE
)
}
if (!mod_atb$.tidyaml_mod_class == "gee_linear_reg"){
rlang::abort(
message = "The model class is not 'gee_linear_reg'.",
use_cli_format = TRUE
)
}
# Manipulation
model_factor_tbl <- model_tbl |>
dplyr::mutate(.model_id = forcats::as_factor(.model_id)) |>
dplyr::mutate(rec_obj = list(rec_obj))
# Make a group split object list
models_list <- model_factor_tbl |>
dplyr::group_split(.model_id)
# Make the Workflow Object using purrr imap
wflw_list <- models_list |>
purrr::imap(
.f = function(obj, id){
# Pull the model column and then pluck the model
mod <- obj |> dplyr::pull(5) |> purrr::pluck(1)
# PUll the recipe column and then pluck the recipe
rec_obj <- obj |> dplyr::pull(6) |> purrr::pluck(1)
# Make New formula
# Make a formula
my_formula <- formula(recipes::prep(rec_obj))
predictor_vars <- rec_obj$var_info |>
dplyr::filter(role == "predictor") |>
dplyr::pull(variable)
var_to_replace <- rec_obj$var_info |>
dplyr::filter(role == "predictor") |>
dplyr::slice(1) |>
dplyr::pull(variable)
outcome_var <- rec_obj$var_info |>
dplyr::filter(role == "outcome") |>
dplyr::pull(variable)
new_terms <- paste0("id_var(", var_to_replace, ")")
new_terms1 <- paste(new_terms, collapse = "+")
new_formula <- do.call(
"substitute",
list(
my_formula,
stats::setNames(
list(
str2lang(new_terms1)
),
var_to_replace
)
)
)
new_formula <- stats::as.formula(new_formula)
# Create a safe add_model function
safe_add_model <- purrr::safely(
workflows::add_model,
otherwise = NULL,
quiet = TRUE
)
# Return the workflow object with recipe and model
ret <- workflows::workflow() |>
workflows::add_variables(
outcomes = outcome_var,
predictors = predictor_vars
) |>
safe_add_model(mod, formula = new_formula)
# Pluck the result
res <- ret |> purrr::pluck("result")
if (!is.null(ret$error)) message(stringr::str_glue("{ret$error}"))
# Return the result
return(res)
}
)
# Return
return(wflw_list)
} A new function to make the workflow based upon the model attributes: full_internal_make_wflw <- function(.model_tbl, .rec_obj){
# Tidyeval ----
model_tbl <- .model_tbl
rec_obj <- .rec_obj
model_tbl_class <- class(model_tbl)
# Checks ----
if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
rlang::abort(
message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
use_cli_format = TRUE
)
}
# Manipulation
model_factor_tbl <- model_tbl |>
dplyr::mutate(.model_id = forcats::as_factor(.model_id))
#dplyr::mutate(rec_obj = list(rec_obj))
# Make a group split object list
models_list <- model_factor_tbl |>
dplyr::group_split(.model_id)
# Make the Workflow Object using purrr imap
wflw_list <- models_list |>
purrr::imap(
.f = function(obj, id){
# Pull the model column and then pluck the model
mod <- obj |> dplyr::pull(5) |> purrr::pluck(1)
# PUll the recipe column and then pluck the recipe
#rec_obj <- obj |> dplyr::pull(6) |> purrr::pluck(1)
# Switch Statement
# First get attributes of the model
mod_attr <- attributes(mod)$.tidyaml_mod_class
class(obj) <- c("tidyaml_mod_spec_tbl", class(obj))
# Switch on the class of the model
if (mod_attr == "gee_linear_reg"){
ret <- internal_make_wflw_gee_lin_reg(obj, rec_obj)
}
if (!mod_attr == "gee_linear_reg"){
ret <- internal_make_wflw(obj, rec_obj)
}
# Return Result
return(ret)
}
)
# Return
return(wflw_list[[1]])
} Examples: library(tidyAML)
library(dplyr)
library(recipes)
library(tidyverse)
> mod_tbl <- make_regression_base_tbl()
> mod_tbl <- mod_tbl |>
+ filter(
+ .parsnip_engine %in% c("lm", "glm", "gee") &
+ .parsnip_fns == "linear_reg"
+ )
>
> class(mod_tbl) <- c("tidyaml_mod_spec_tbl", class(mod_tbl))
>
> mod_spec_tbl <- internal_make_spec_tblv2(mod_tbl)
> internal_make_wflw(mod_spec_tbl, rec_obj)
[[1]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: lm
## We see that this model is wrong as there is no preprocessor
[[2]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: gee
[[3]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: glm
>
> gee_mod_spec_tbl <- mod_spec_tbl |>
+ dplyr::slice(2)
> internal_make_wflw(gee_mod_spec_tbl, rec_obj)
## We see that the current function does not create the appropriate preprocessor
[[1]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: gee
> internal_make_wflw_gee_lin_reg(gee_mod_spec_tbl, rec_obj)
## We see that there is now an appropriate preprocessor
[[1]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Variables
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
Outcomes: outcome_var
Predictors: predictor_vars
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: gee
## The new function properly creates the workflow on the single object
> full_internal_make_wflw(gee_mod_spec_tbl, rec_obj)
[[1]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Variables
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
Outcomes: outcome_var
Predictors: predictor_vars
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: gee
## The new function creates the workflows correctly on all objects
> full_internal_make_wflw(mod_spec_tbl, rec_obj)
[[1]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Computational engine: lm
>
> mod_wflw_tbl <- mod_spec_tbl |>
+ mutate(wflw = full_internal_make_wflw(mod_spec_tbl, rec_obj))
> mod_wflw_tbl
# A tibble: 3 × 6
.model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw
<int> <chr> <chr> <chr> <list> <list>
1 1 lm regression linear_reg <spec[+]> <workflow>
2 2 gee regression linear_reg <spec[+]> <workflow>
3 3 glm regression linear_reg <spec[+]> <workflow>
>
> internal_make_fitted_wflw(mod_wflw_tbl, splits_obj) |>
+ purrr::map(broom::glance)
[[1]]
# A tibble: 1 × 12
r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC deviance df.residual nobs
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <int>
1 0.902 0.826 2.54 11.9 0.0000531 10 -49.1 122. 136. 84.1 13 24
[[2]]
# A tibble: 1 × 12
r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC deviance df.residual nobs
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <int>
1 0.902 0.826 2.54 11.9 0.0000531 10 -49.1 122. 136. 84.1 13 24
[[3]]
# A tibble: 1 × 12
r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC deviance df.residual nobs
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int> <int>
1 0.902 0.826 2.54 11.9 0.0000531 10 -49.1 122. 136. 84.1 13 24 |
Beta Was this translation helpful? Give feedback.
0 replies
-
this is now fixed! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I got the following to work for the
gee
model, so I can turn it into a function, but this requires a bit of a paradigm shift in how the fast_regression/classification functions work:It's important to note that this is 'fast' so we just pick the first predictor variable.
Also maybe it is only necessary to make functions for those that fail the traditional workflow? thoughts?
Beta Was this translation helpful? Give feedback.
All reactions