Skip to content

Commit

Permalink
Further refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemahoney218 committed Oct 17, 2023
1 parent d9290fc commit 433976e
Showing 1 changed file with 59 additions and 61 deletions.
120 changes: 59 additions & 61 deletions R/multi_scale.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,47 @@ ww_multi_scale.SpatRaster <- function(
rlang::check_installed("terra")
rlang::check_installed("exactextractr")

data <- prep_multi_scale_raster(data, truth, estimate)
metrics <- handle_metrics(metrics)
grid_list <- handle_grids(data, grids, autoexpand_grid, sf::st_crs(data), ...)

grid_list$grids <- lapply(
grid_list$grids,
spatraster_extract,
data,
aggregation_function,
progress
)

.notes <- raster_method_notes(grid_list)

raster_method_summary(grid_list, .notes, metrics, na_rm)
}

spatraster_extract <- function(grid, data, aggregation_function, progress) {
grid <- sf::st_as_sf(grid)
sf::st_geometry(grid) <- "geometry"
exactextract_names <- c(".truth", ".estimate")
if (!rlang::is_function(aggregation_function) && aggregation_function != "count") {
exactextract_names <- c(exactextract_names, ".truth_count", ".estimate_count")
aggregation_function <- c(aggregation_function, "count")
}
grid_df <- exactextractr::exact_extract(
data,
grid,
fun = aggregation_function,
progress = progress
)
names(grid_df) <- exactextract_names

if (length(exactextract_names) == 4L) {
exactextract_names <- exactextract_names[c(1, 3, 2, 4)]
}

cbind(grid, grid_df)[c(exactextract_names, "geometry")]
}

prep_multi_scale_raster <- function(data, truth, estimate) {
data <- tryCatch(
terra::subset(data, c(truth, estimate)),
error = function(e) {
Expand All @@ -188,52 +229,7 @@ ww_multi_scale.SpatRaster <- function(
))
}
names(data) <- c("truth", "estimate")

metrics <- handle_metrics(metrics)
grid_list <- handle_grids(data, grids, autoexpand_grid, sf::st_crs(data), ...)

grid_list$grids <- purrr::map(
grid_list$grids,
function(grid) {
grid <- sf::st_as_sf(grid)
sf::st_geometry(grid) <- "geometry"
if (rlang::is_function(aggregation_function) || aggregation_function == "count") {
grid <- cbind(
grid,
stats::setNames(
exactextractr::exact_extract(
data,
grid,
fun = aggregation_function,
progress = progress
),
c(".truth", ".estimate")
)
)
grid[c(".truth", ".estimate", "geometry")]
} else {
grid_df <- exactextractr::exact_extract(
data,
grid,
fun = c(aggregation_function, "count"),
progress = progress
)
names(grid_df) <- c(".truth", ".estimate", ".truth_count", ".estimate_count")

cbind(grid, grid_df)[c(
".truth",
".truth_count",
".estimate",
".estimate_count",
"geometry"
)]
}
}
)

.notes <- raster_method_notes(grid_list)

raster_method_summary(grid_list, .notes, metrics, na_rm)
data
}

ww_multi_scale_raster_args <- function(
Expand Down Expand Up @@ -265,7 +261,7 @@ ww_multi_scale_raster_args <- function(
metrics <- handle_metrics(metrics)
grid_list <- handle_grids(truth, grids, autoexpand_grid, sf::st_crs(data), ...)

grid_list$grids <- purrr::map(
grid_list$grids <- lapply(
grid_list$grids,
function(grid) {
grid <- sf::st_as_sf(grid)
Expand Down Expand Up @@ -318,7 +314,7 @@ ww_multi_scale_raster_args <- function(
}

raster_method_notes <- function(grid_list) {
purrr::map(
lapply(
seq_along(grid_list$grids),
function(idx) {
tibble::tibble(
Expand All @@ -330,21 +326,21 @@ raster_method_notes <- function(grid_list) {
}

raster_method_summary <- function(grid_list, .notes, metrics, na_rm) {
purrr::pmap_dfr(
list(
grid = grid_list$grids,
grid_arg = grid_list$grid_arg_idx,
.notes = .notes
),
out <- mapply(
function(grid, grid_arg, .notes) {
out <- metrics(grid, .truth, .estimate, na_rm = na_rm)
out[attr(out, "sf_column")] <- NULL
out$.grid_args <- list(grid_list$grid_args[grid_arg, ])
out$.grid <- list(grid)
out$.notes <- list(.notes)
out
}
},
grid = grid_list$grids,
grid_arg = grid_list$grid_arg_idx,
.notes = .notes,
SIMPLIFY = FALSE
)
do.call(dplyr::bind_rows, out)
}

#' @exportS3Method
Expand Down Expand Up @@ -383,9 +379,7 @@ ww_multi_scale.sf <- function(
grid_list <- handle_grids(data, grids, autoexpand_grid, data_crs, ...)

data$.grid_idx <- seq_len(nrow(data))
out <- purrr::map2_dfr(
grid_list$grids,
grid_list$grid_arg_idx,
out <- mapply(
function(grid, grid_args_idx) {
grid_args <- grid_list[["grid_args"]][grid_args_idx, ]

Expand Down Expand Up @@ -423,10 +417,14 @@ ww_multi_scale.sf <- function(
out$.grid <- list(.grid)
out$.notes <- list(notes_tibble)
out
}
},
grid = grid_list$grids,
grid_args_idx = grid_list$grid_arg_idx,
SIMPLIFY = FALSE
)
out <- dplyr::bind_rows(out)

if (any(purrr::map_lgl(out[[".notes"]], function(x) nrow(x) > 0))) {
if (any(vapply(out[[".notes"]], function(x) nrow(x) > 0, logical(1)))) {
rlang::warn(
c(
"Some observations were not within any grid cell, and as such were not used in any assessments.",
Expand Down Expand Up @@ -490,7 +488,7 @@ handle_grids <- function(data, grids, autoexpand_grid, data_crs, ...) {
grid_args <- tibble::tibble()
grid_arg_idx <- 0
if (!is.na(data_crs)) {
grids <- purrr::map(grids, sf::st_transform, sf::st_crs(data))
grids <- lapply(grids, sf::st_transform, sf::st_crs(data))
}
}

Expand Down

0 comments on commit 433976e

Please sign in to comment.