Skip to content

Commit

Permalink
Add validators to base_model2
Browse files Browse the repository at this point in the history
  • Loading branch information
Stefan Kuethe committed Oct 17, 2024
1 parent 231a2ca commit 2015a19
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 18 deletions.
38 changes: 36 additions & 2 deletions R/experimental.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ model_config <- function(allow_extra = FALSE, ...) {
}

# ---
base_model2 <- function(fields = list(), ..., .model_config = model_config()) {
base_model2 <- function(fields = list(), ...,
.model_config = model_config(),
.validators_before = list(),
.validators_after = list()) {
fields <- utils::modifyList(fields, list(...), keep.null = TRUE)
fields <- purrr::map(fields, ~ {
if (inherits(.x, "function")) {
Expand All @@ -31,12 +34,15 @@ base_model2 <- function(fields = list(), ..., .model_config = model_config()) {
obj <- as.list(environment())
}

obj <- validate_fields(obj, .validators_before)

for (name in names(fields)) {
check_type_fn <- rlang::as_function(fields[[name]]$fn)
obj_value <- obj[[name]]
if (isFALSE(check_type_fn(obj_value))) {
cli::cli_abort(
c(
x = "Type check failed.",
i = "{name} = {rlang::quo_text(obj_value)}",
i = "typeof({name}): {typeof(obj_value)}",
i = "length({name}): {length(obj_value)}",
Expand All @@ -47,6 +53,8 @@ base_model2 <- function(fields = list(), ..., .model_config = model_config()) {
}
}

obj <- validate_fields(obj, .validators_after)

if (is.environment(obj)) {
return(invisible(obj))
}
Expand All @@ -58,5 +66,31 @@ base_model2 <- function(fields = list(), ..., .model_config = model_config()) {
return(obj)
}))

return(set_attributes(model_fn, fields = fields))
return(
set_attributes(
model_fn,
fields = fields,
class = c(class(model_fn), "base_model")
)
)
}

# ---
check_args <- function(...) {
fields <- list(...)
if (length(fields) == 0) {
fn <- rlang::caller_fn()
fmls <- rlang::fn_fmls(fn)
fields <- purrr::map(as.list(fmls), eval)
}

e <- rlang::caller_env()
for (name in names(e)) {
value <- e[[name]]
if (is.list(value)) {
e[[name]] <- value$default
}
}

base_model2(fields)(.x = e)
}
6 changes: 6 additions & 0 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ set_attributes <- function(x, ...) {
return(x)
}

# ---
add_class <- function(x, cls = "base_model") {
class(x) <- c(class(x), cls)
return(x)
}

# ---
is_not_null <- function(x) {
isFALSE(is.null(x))
Expand Down
46 changes: 30 additions & 16 deletions examples/base-model2.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
devtools::load_all()

check_args <- function(...) {
if (length(list(...)) > 0) {
return(base_model2(...)(.x = rlang::caller_env()))
}

fn <- rlang::caller_fn()
fmls <- rlang::fn_fmls(fn)
fields <- purrr::map(as.list(fmls), eval)
base_model2(fields)(.x = rlang::caller_env())
}
#check_args <- function(...) {
# fields <- list(...)
# if (length(fields) == 0) {
# fn <- rlang::caller_fn()
# fmls <- rlang::fn_fmls(fn)
# fields <- purrr::map(as.list(fmls), eval)
# }

# e <- rlang::caller_env()
# for (name in names(e)) {
# value <- e[[name]]
# if (is.list(value)) {
# e[[name]] <- value$default
# }
# }

# base_model2(fields)(.x = e)
#}

f <- function(a, b) {
f <- function(a, b = 80L) {
check_args(a = is.integer, b = is.integer)
a + b
}

f(2L, 4)
f(2L)

f2 <- function(aa = is.numeric, bb = is.integer) {
f2 <- function(aa = is.numeric, bb = model_field(is.integer, 10L)) {
check_args()
aa + bb
}

f2(10, 20)
f2(5)

# ---
my_model2 <- base_model2(cyl = is.double, mpg = is.integer)
my_model2(.x = mtcars)
my_model2 <- base_model2(
cyl = is.double,
mpg = is.integer,
.validators_before = list(
mpg = as.integer
)
)
my_model2(.x = tibble::as_tibble(mtcars))

0 comments on commit 2015a19

Please sign in to comment.