library(tidyverse)
library(tidymodels)
library(palmerpenguins)
library(ranger)
Machine Learning using tidymodels
About the activity
Access the Quarto document here.
Download the raw file.
Open it in RStudio.
We will work our way through this quarto document together during class. The activity will using 2 classification models to predict the species of penguin based on the penguin biometric data.
Load the Packages
Explore the Data
glimpse(penguins)
Rows: 344
Columns: 8
$ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel…
$ island <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgerse…
$ bill_length_mm <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39.2, 34.1, …
$ bill_depth_mm <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19.6, 18.1, …
$ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193, 190, 186…
$ body_mass_g <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 4675, 3475, …
$ sex <fct> male, female, female, NA, female, male, female, male…
$ year <int> 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007…
|> count(species) penguins
# A tibble: 3 × 2
species n
<fct> <int>
1 Adelie 152
2 Chinstrap 68
3 Gentoo 124
Prep the Data
# set a seed in order to make the analysis reproducible.
set.seed(462)
# split the data into training and testing sets. We will train the model on the training set and then test how well it worked on the testing data.
# split the data 70% for training and 30% for testing. The bulk of the data is usually used for training the models.
<- initial_split(penguins, prop=0.7, strata = species)
split_data <- training(split_data)
data_training <- testing(split_data)
data_testing
# lets check it did the split correctly, if a different seed was used a the splits would be slightly different.
|>
data_training group_by(species) |>
summarise( count = n(),
percent = n()/nrow(data_training) * 100)
# A tibble: 3 × 3
species count percent
<fct> <int> <dbl>
1 Adelie 106 44.4
2 Chinstrap 47 19.7
3 Gentoo 86 36.0
|>
data_testing group_by(species) |>
summarise( count = n(),
percent = n()/nrow(data_testing) * 100)
# A tibble: 3 × 3
species count percent
<fct> <int> <dbl>
1 Adelie 46 43.8
2 Chinstrap 21 20
3 Gentoo 38 36.2
# The recipe sets up what data we are going to use and how it to be treated before doing the modeling.
<-
penguin_recipe recipe( species ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g, data = penguins) %>%
step_normalize(all_predictors())
penguin_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs
Number of variables by role
outcome: 1
predictor: 4
── Operations
• Centering and scaling for: all_predictors()
# The prep step pulls in all the variables from the recipe based on the dataset we give it.
<- prep(penguin_recipe, data_training)
data_prep data_prep
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs
Number of variables by role
outcome: 1
predictor: 4
── Training information
Training data contained 239 data points and no incomplete rows.
── Operations
• Centering and scaling for: bill_length_mm bill_depth_mm, ... | Trained
# the bake steps preforms the prep steps and in this case normalizes all the data.
<- bake(data_prep, new_data = NULL)
data_bake data_bake
# A tibble: 239 × 5
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species
<dbl> <dbl> <dbl> <dbl> <fct>
1 -0.808 0.120 -1.10 -0.501 Adelie
2 -0.660 0.426 -0.447 -1.21 Adelie
3 -0.845 1.75 -0.812 -0.695 Adelie
4 -0.918 0.324 -1.47 -0.727 Adelie
5 -0.863 1.24 -0.447 0.625 Adelie
6 -1.80 0.477 -0.593 -0.920 Adelie
7 -0.346 1.55 -0.812 0.0780 Adelie
8 -1.12 -0.0333 -1.10 -1.15 Adelie
9 -1.12 0.0687 -1.54 -0.630 Adelie
10 -0.974 2.06 -0.739 -0.501 Adelie
# ℹ 229 more rows
Define the models
Random Forest uses the command rand_forest()
which takes the following arguments. We will use the defaults for some values.
- mode options are “unknown”, “regression”, “classification”, or “censored regression”
- engine options are “ranger”, “randomForest”, or “spark”
- mtry the number of predictors that will be randomly sampled at each split when creating the tree model.
- trees the number of trees to build.
- min_n the minimum number of data points in a node to stop splitting
# MODEL 1 Random Forest
<-
rf_model
# specify model
rand_forest() |>
# mode as classification not continuous
set_mode("classification") |>
# engine/package that underlies the model (ranger is default)
set_engine("ranger") |>
# we only have 4 predictors so mtry can't be more than 4
set_args(mtry = 4, trees = 200)
# Put everything together
<-
rf_wflow workflow() |>
add_recipe(penguin_recipe) |>
add_model(rf_model)
# train the model
<- fit(rf_wflow, data_training) rf_fit
Logistic Regression uses the command multinom_reg()
which takes the following arguments. We will use the defaults for some values.
- mode only “classification” is available
- engine options are “nnet”, “brulee”, “glmnet”, “h2o”, “keras”, “spark”
- penalty only used in keras models
- mixture only used in keras models
# MODEL 2 Logistic Regression
<-
lr_model
# specify that the model is a multinom_reg
multinom_reg() |>
# mode as classification not continuous
set_mode("classification") |>
# select the engine/package that underlies the model (nnet is default)
set_engine("nnet")
# Put everything together
<-
lr_wflow workflow() |>
add_recipe(penguin_recipe) |>
add_model(lr_model)
# train the model
<- fit(lr_wflow, data_training) lr_fit
Compare the performance of the two models
# predict the species of the testing data we held back for each model
<- predict(rf_fit, data_testing)
rf.predict <- predict(lr_fit, data_testing)
lr.predict
# create a table comparing the predicted species from the true species
<- rf.predict %>%
rf.outcome transmute(pred = .pred_class,
truth = data_testing$species)
# confusion matrix
|> conf_mat(pred, truth) rf.outcome
Truth
Prediction Adelie Chinstrap Gentoo
Adelie 46 0 0
Chinstrap 1 20 0
Gentoo 1 0 37
# accuracy
|> accuracy(pred, truth) -> rf.acc
rf.outcome
# specificity
|> spec(pred, truth) -> rf.spec
rf.outcome
# sensitivity
|> sens(pred, truth) -> rf.sens
rf.outcome
# precision
|> precision(pred, truth) -> rf.prec
rf.outcome
<- c(rf.acc$.estimate, rf.spec$.estimate, rf.sens$.estimate, rf.prec$.estimate)
rf.eval names(rf.eval) <- c("accuracy", "specificity", "sensitivity", "precision")
# create a table comparing the predicted species from the true species
<- lr.predict %>%
lr.outcome transmute(pred = .pred_class,
truth = data_testing$species)
# confusion matrix
|> conf_mat(pred, truth) lr.outcome
Truth
Prediction Adelie Chinstrap Gentoo
Adelie 44 1 0
Chinstrap 0 21 0
Gentoo 0 0 37
# accuracy
|> accuracy(pred, truth) -> lr.acc
lr.outcome
# specificity
|> spec(pred, truth) -> lr.spec
lr.outcome
# sensitivity
|> sens(pred, truth) -> lr.sens
lr.outcome
# precision
|> precision(pred, truth) -> lr.prec
lr.outcome
= c(lr.acc$.estimate, lr.spec$.estimate, lr.sens$.estimate, lr.prec$.estimate)
lr.eval names(lr.eval) <- c("accuracy", "specificity", "sensitivity", "precision")
rbind(rf.eval, lr.eval)
accuracy specificity sensitivity precision
rf.eval 0.9809524 0.9911765 0.9861111 0.9753551
lr.eval 0.9902913 0.9943503 0.9848485 0.9925926
sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: aarch64-apple-darwin20
Running under: macOS Sonoma 14.5
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: America/Chicago
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] ranger_0.17.0 palmerpenguins_0.1.1 yardstick_1.3.2
[4] workflowsets_1.1.1 workflows_1.2.0 tune_1.3.0
[7] rsample_1.3.0 recipes_1.3.1 parsnip_1.3.2
[10] modeldata_1.4.0 infer_1.0.8 dials_1.4.0
[13] scales_1.3.0 broom_1.0.7 tidymodels_1.3.0
[16] lubridate_1.9.4 forcats_1.0.0 stringr_1.5.1
[19] dplyr_1.1.4 purrr_1.0.4 readr_2.1.5
[22] tidyr_1.3.1 tibble_3.2.1 ggplot2_3.5.2
[25] tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] tidyselect_1.2.1 timeDate_4041.110 fastmap_1.2.0
[4] digest_0.6.37 rpart_4.1.24 timechange_0.3.0
[7] lifecycle_1.0.4 survival_3.8-3 magrittr_2.0.3
[10] compiler_4.4.1 rlang_1.1.5 tools_4.4.1
[13] utf8_1.2.4 yaml_2.3.10 data.table_1.17.0
[16] knitr_1.50 htmlwidgets_1.6.4 DiceDesign_1.10
[19] withr_3.0.2 nnet_7.3-20 grid_4.4.1
[22] sparsevctrs_0.3.4 colorspace_2.1-1 future_1.49.0
[25] globals_0.18.0 iterators_1.0.14 MASS_7.3-64
[28] cli_3.6.4 rmarkdown_2.29 generics_0.1.3
[31] rstudioapi_0.17.1 future.apply_1.11.3 tzdb_0.4.0
[34] splines_4.4.1 parallel_4.4.1 vctrs_0.6.5
[37] hardhat_1.4.1 Matrix_1.7-2 jsonlite_1.9.1
[40] hms_1.1.3 listenv_0.9.1 foreach_1.5.2
[43] gower_1.0.2 glue_1.8.0 parallelly_1.45.0
[46] codetools_0.2-20 stringi_1.8.4 gtable_0.3.6
[49] munsell_0.5.1 GPfit_1.0-9 pillar_1.10.1
[52] furrr_0.3.1 htmltools_0.5.8.1 ipred_0.9-15
[55] lava_1.8.1 R6_2.6.1 lhs_1.2.0
[58] evaluate_1.0.3 lattice_0.22-6 backports_1.5.0
[61] class_7.3-23 Rcpp_1.0.14 prodlim_2024.06.25
[64] xfun_0.51 pkgconfig_2.0.3