-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add post previewing postprocessing (#708)
Co-authored-by: Max Kuhn <[email protected]> Co-authored-by: Emil Hvitfeldt <[email protected]> Co-authored-by: Hannah Frick <[email protected]>
- Loading branch information
1 parent
ef09e35
commit 6bc8ac5
Showing
6 changed files
with
566 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+115 KB
content/blog/postprocessing-preview/figs/predictios-better-boost-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
--- | ||
output: hugodown::hugo_document | ||
|
||
slug: postprocessing-preview | ||
title: Postprocessing is coming to tidymodels | ||
date: 2024-10-08 | ||
author: Simon Couch, Hannah Frick, and Max Kuhn | ||
description: > | ||
The tidymodels team has been hard at work on postprocessing, a set of | ||
features to adjust model predictions. The functionality includes a new | ||
package as well as changes across the framework. | ||
photo: | ||
url: https://unsplash.com/photos/G6Y_YM3gO44 | ||
author: Haley Owens | ||
|
||
categories: [roundup] | ||
tags: [tidymodels, postprocessing, workflows] | ||
--- | ||
|
||
We're bristling with elation to share about a set of upcoming features for postprocessing with tidymodels. Postprocessors refine predictions outputted from machine learning models to improve predictive performance or better satisfy distributional limitations. The developmental versions of many tidymodels core packages include changes to support postprocessors, and we're ready to share about our work and hear the community's thoughts on our progress so far. | ||
|
||
Postprocessing support with tidymodels hasn't yet made it to CRAN, but you can install the needed versions of tidymodels packages with the following code. | ||
|
||
```{r} | ||
#| label: install-packages | ||
#| eval: false | ||
pak::pak( | ||
paste0( | ||
"tidymodels/", | ||
c("tune", "workflows", "rsample", "tailor") | ||
) | ||
) | ||
``` | ||
|
||
Now, we load packages with those developmental versions installed. | ||
|
||
```{r} | ||
#| label: load-packages | ||
#| message: false | ||
library(tidymodels) | ||
library(probably) | ||
library(tailor) | ||
``` | ||
|
||
Existing tidymodels users might have spotted something funky already; who is this tailor character? | ||
|
||
## Meet tailor👋 | ||
|
||
The tailor package introduces tailor objects, which compose iterative adjustments to model predictions. tailor is to postprocessing as recipes is to preprocessing; applying your mental model of recipes to tailor should get you a good bit of the way there. | ||
|
||
<div style="width: 140%; max-width: 140%; overflow-x: auto;"> | ||
|
||
| Tool | Applied to\... | Initialize with\... | Composes\... | Train with\... | Predict with\... | | ||
|------------|------------|------------|------------|------------|------------| | ||
| recipes | Training data | `recipe()` | `step_*()`s | `prep()` | `bake()` | | ||
| tailor | Model predictions | `tailor()` | `adjust_*()`ments | `fit()` | `predict()` | | ||
|
||
</div> | ||
|
||
First, users can initialize a tailor object with `tailor()`. | ||
|
||
```{r} | ||
#| label: tailor | ||
tailor() | ||
``` | ||
|
||
Tailors compose "adjustments," analogous to steps from the recipes package. | ||
|
||
```{r} | ||
#| label: tailor-adjust-1 | ||
tailor() %>% | ||
adjust_probability_threshold(threshold = .7) | ||
``` | ||
|
||
As an example, we'll apply this tailor to the `two_class_example` data made available after loading tidymodels. | ||
|
||
```{r} | ||
#| label: two-class-example | ||
head(two_class_example) | ||
``` | ||
|
||
This data gives the true value of an outcome variable `truth` as well as predicted probabilities (`Class1` and `Class2`). The hard class predictions, in `predicted`, are `"Class1"` if the probability assigned to `"Class1"` is above .5, and `"Class2"` otherwise. | ||
|
||
The model predicts `"Class1"` more often than it does `"Class2"`. | ||
|
||
```{r} | ||
#| label: count-two-class-example | ||
two_class_example %>% count(predicted) | ||
``` | ||
|
||
If we wanted the model to predict `"Class2"` more often, we could increase the probability threshold assigned to `"Class1"` above which the hard class prediction will be `"Class1"`. In the tailor package, this adjustment is implemented in `adjust_probability_threshold()`, which can be situated in a tailor object. | ||
|
||
```{r} | ||
#| label: tlr | ||
tlr <- | ||
tailor() %>% | ||
adjust_probability_threshold(threshold = .7) | ||
tlr | ||
``` | ||
|
||
tailors must be fitted before they can predict on new data. For adjustments like `adjust_probability_threshold()`, there's no training that actually happens at the `fit()` step besides recording the name and type of relevant variables. For other adjustments, like numeric calibration with `adjust_numeric_calibration()`, parameters are actually estimated at the `fit()` stage and separate data should be used to train the postprocessor and evaluate its performance. More on this in [Tailors in context](#tailors-in-context). | ||
|
||
In this case, though, we can `fit()` on the whole dataset. The resulting object is still a tailor, but is now flagged as trained. | ||
|
||
```{r} | ||
#| label: tlr-trained | ||
tlr_trained <- fit( | ||
tlr, | ||
two_class_example, | ||
outcome = truth, | ||
estimate = predicted, | ||
probabilities = c(Class1, Class2) | ||
) | ||
tlr_trained | ||
``` | ||
|
||
When used with a model [workflow](https://workflows.tidymodels.org) via [`add_tailor()`](https://workflows.tidymodels.org/dev/reference/add_tailor.html), the arguments to `fit()` a tailor will be set automatically. Generally, as in recipes, we recommend that users add tailors to model workflows for training and prediction rather than using them standalone for greater ease of use and to prevent data leakage, but tailors are totally functional by themselves, too. | ||
|
||
Now, when passed new data, the trained tailor will determine the outputted class based on whether the probability assigned to the level `"Class1"` is above `.7`, resulting in more predictions of `"Class2"` than before. | ||
|
||
```{r} | ||
#| label: count-tlr-trained | ||
predict(tlr_trained, two_class_example) %>% count(predicted) | ||
``` | ||
|
||
Changing the probability threshold is one of many possible adjustments available in tailor. | ||
|
||
- For probabilities: [calibration](https://tailor.tidymodels.org/reference/adjust_probability_calibration.html) | ||
- For transformation of probabilities to hard class predictions: [thresholds](https://tailor.tidymodels.org/reference/adjust_probability_threshold.html), [equivocal zones](https://tailor.tidymodels.org/reference/adjust_equivocal_zone.html) | ||
- For numeric outcomes: [calibration](https://tailor.tidymodels.org/reference/adjust_numeric_calibration.html), [range](https://tailor.tidymodels.org/reference/adjust_numeric_range.html) | ||
|
||
Support for tailors is now plumbed through workflows (via [`add_tailor()`](https://workflows.tidymodels.org/dev/reference/add_tailor.html)) and tune, and rsample includes a set of infrastructural changes to prevent data leakage behind the scenes. That said, we haven't yet implemented support for tuning parameters in tailors, but we plan to implement that before this functionality heads to CRAN. | ||
|
||
## Tailors in context | ||
|
||
As an example, let's model a study of food delivery times in minutes (i.e., the time from the initial order to receiving the food) for a single restaurant. The `deliveries` data is available upon loading the tidymodels meta-package. | ||
|
||
```{r} | ||
#| label: deliveries | ||
data(deliveries) | ||
# split into training and testing sets | ||
set.seed(1) | ||
delivery_split <- initial_split(deliveries) | ||
delivery_train <- training(delivery_split) | ||
delivery_test <- testing(delivery_split) | ||
# resample the training set using 10-fold cross-validation | ||
set.seed(1) | ||
delivery_folds <- vfold_cv(delivery_train) | ||
# print out the training set | ||
delivery_train | ||
``` | ||
|
||
Let's deliberately define a regression model that has poor predicted values: a boosted tree with only three ensemble members. | ||
|
||
```{r} | ||
#| label: bad-boost | ||
delivery_wflow <- | ||
workflow() %>% | ||
add_formula(time_to_delivery ~ .) %>% | ||
add_model(boost_tree(mode = "regression", trees = 3)) | ||
``` | ||
|
||
Evaluating against resamples: | ||
|
||
```{r} | ||
#| label: resample-bad-boost | ||
set.seed(1) | ||
delivery_res <- | ||
fit_resamples( | ||
delivery_wflow, | ||
delivery_folds, | ||
control = control_resamples(save_pred = TRUE) | ||
) | ||
``` | ||
|
||
The $R^2$ looks quite strong! | ||
|
||
```{r} | ||
#| label: metrics-bad-boost | ||
collect_metrics(delivery_res) | ||
``` | ||
|
||
Let's take a closer look at the predictions, though. How well are they calibrated? We can use the `cal_plot_regression()` helper from the probably package to put together a quick diagnostic plot. | ||
|
||
```{r} | ||
#| label: predictions-bad-boost | ||
collect_predictions(delivery_res) %>% | ||
cal_plot_regression(truth = time_to_delivery, estimate = .pred) | ||
``` | ||
|
||
Ooof. | ||
|
||
In comes tailor! Numeric calibration can help address the correlated errors here. We can add a tailor to our existing workflow to "bump up" predictions towards their true value. | ||
|
||
```{r} | ||
#| label: better-boost | ||
delivery_wflow_improved <- | ||
delivery_wflow %>% | ||
add_tailor(tailor() %>% adjust_numeric_calibration()) | ||
``` | ||
|
||
The resampling code looks the same from here. | ||
|
||
```{r} | ||
#| label: resample-better-boost | ||
set.seed(1) | ||
delivery_res_improved <- | ||
fit_resamples( | ||
delivery_wflow_improved, | ||
delivery_folds, | ||
control = control_resamples(save_pred = TRUE) | ||
) | ||
``` | ||
|
||
Checking out the same plot reveals a much better fit! | ||
|
||
```{r} | ||
#| label: predictios-better-boost | ||
collect_predictions(delivery_res_improved) %>% | ||
cal_plot_regression(truth = time_to_delivery, estimate = .pred) | ||
``` | ||
|
||
There's actually some tricky data leakage prevention happening under the hood here. When you add tailors to workflow and fit them with tune, this is all taken care of for you. If you're interested in using tailors outside of that context, check out [this documentation section](https://workflows.tidymodels.org/dev/reference/add_tailor.html#data-usage) in `add_tailor()`. | ||
|
||
## What's to come | ||
|
||
We're excited about how this work is shaping up and would love to hear yall's thoughts on what we've brought together so far. Please do comment on our social media posts about this blog entry or leave issues on the [tailor GitHub repository](https://github.com/tidymodels/tailor) and let us know what you think! | ||
|
||
Before these changes head out to CRAN, we'll also be implementing tuning functionality for postprocessors. You'll be able to tag arguments like `adjust_probability_threshold(threshold)` or `adjust_probability_calibration(method)` with `tune()` to optimize across several values. Besides that, post-processing with tidymodels should "just work" on the developmental versions of our packages---let us know if you come across anything wonky. | ||
|
||
## Acknowledgements | ||
|
||
Postprocessing support has been a longstanding feature request across many of our repositories; we're grateful for the community discussions there for shaping this work. Additionally, we thank Ryan Tibshirani and Daniel McDonald for fruitful discussions on how we might scope these features. |
Oops, something went wrong.