Topic 13 Trees

Learning Goals

  • Clearly describe the recursive binary splitting algorithm for tree building for both regression and classification
  • Compute the weighted average Gini index to measure the quality of a classification tree split
  • Compute the sum of squared residuals to measure the quality of a regression tree split
  • Explain how recursive binary splitting is a greedy algorithm
  • Explain how different tree parameters relate to the bias-variance tradeoff


Slides from today are available here.




Trees in tidymodels

To build tree models in tidymodels, first load the package and set the seed for the random number generator to ensure reproducible results:

library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
tidymodels_prefer()

set.seed(___) # Pick your favorite number to fill in the parentheses

To fit a classification tree, we can adapt the following:

ct_spec <- decision_tree() %>%
  set_engine(engine = 'rpart') %>%
  set_args(cost_complexity = NULL,  #default is 0.01 (used for pruning a tree)
           min_n = NULL, #min number of observations to try split: default is 20 [I think the documentation has a typo and says 2]  (used to stop early)
           tree_depth = NULL) %>% #max depth, number of branches/splits to get to any final group: default is 30 (used to stop early)
  set_mode('classification') # change this for regression tree


data_rec <- recipe(___ ~ ___, data = ______)

data_wf <- workflow() %>%
  add_model(ct_spec) %>%
  add_recipe(data_rec)

fit_mod <- data_wf %>% # or use tune_grid() to tune any of the parameters above
  fit(data = _____)


Visualizing and interpreting the “best” tree

# Plot the tree (make sure to load the rpart.plot package first)
fit_mod %>%
  extract_fit_engine() %>%
  rpart.plot()

# Get variable importance metrics 
# Sum of the goodness of split measures (impurity reduction) for each split for which it was the primary variable.
fit_mod %>%
  extract_fit_engine() %>%
  pluck('variable.importance')




Exercises

You can download a template RMarkdown file to start from here.

Context

Before proceeding, install the rpart and rpart.plot packages (for building and plotting decision trees) by entering install.packages(c("rpart", "rpart.plot")) in the Console.

Our goal will be to classify types of urban land cover in small subregions within a high resolution aerial image of a land region. Data from the UCI Machine Learning Repository include the observed type of land cover (determined by human eye) and “spectral, size, shape, and texture information” computed from the image. See this page for the data codebook.


Source: https://ncap.org.uk/sites/default/files/EK_land_use_0.jpg
library(dplyr)
library(readr)
library(ggplot2)
library(rpart.plot)
library(tidymodels)
tidymodels_prefer()


# Read in the data
land <- read_csv("https://www.macalester.edu/~ajohns24/data/land_cover.csv")

# There are 9 land types, but we'll focus on 3 of them
land <- land %>% 
    filter(class %in% c("asphalt", "grass", "tree"))

Exercise 1: Core theme: parametric/nonparametric

  1. What does it mean for a method to be nonparametric? In general, when might we prefer nonparametric to parametric methods? When might we not?

  2. Where do you think trees fall on the parametric/nonparametric spectrum?

Exercise 2: Core theme: Tuning parameters and the BVT

The key feature governing complexity of a tree model is the number of splits used in the tree. How is the number of splits related to model complexity, bias, and variance?

In practice, the number of splits is controlled indirectly through the following tuning parameters. For each, discuss how low/high parameter settings would affect the number of tree splits.

  • min_n: the minimum number of observations that must exist in a node in order for a split to be attempted.

  • cost_complexity: complexity parameter. Any split that does not increase node purity by cost_complexity is not attempted.

  • depth: Set the maximum depth of any node of the final tree. The depth of a node is the number of branches that need to be followed to get to a given node from the root node. (The root node has depth 0.)

Exercise 3: Building & tuning trees in tidymodels

  1. Fit a tree model for the class outcome (land type), and allow all possible predictors to be considered (~ . in the model formula).
  • Use 10-fold CV.
  • Choose a final model whose test overall accuracy is within one standard error of the overall best metric.
  • The Gini index impurity measure (default for rpart) can be a minimum of zero and has an upper bound of 1.
  • Try a sequence of 10 cost_complexity values from \(10^{-5}\) to 0.1.
# Make sure you understand what each line of code is doing
set.seed(123) # don't change this

data_fold <- vfold_cv(land, v = 10)

ct_spec_tune <- decision_tree() %>%
  set_engine(engine = 'rpart') %>%
  set_args(cost_complexity = tune(),  
           min_n = 2, 
           tree_depth = NULL) %>% 
  set_mode('classification') 

data_rec <- recipe(class ~ ., data = land)

data_wf_tune <- workflow() %>%
  add_model(ct_spec_tune) %>%
  add_recipe(data_rec)

param_grid <- grid_regular(cost_complexity(range = c(-5, 1)), levels = 10) 

tune_res <- tune_grid(
  data_wf_tune, 
  resamples = data_fold, 
  grid = param_grid, 
  metrics = metric_set(accuracy) #change this for regression trees
)
  1. Make a plot of test performance versus the cost_complexity tuning parameter. Does it look as you expected?
autoplot(tune_res) + theme_classic()
  1. Now choose the cost_complexity value that gives the simplest tree (high or low cost_complexity?) within 1 SE of the max accuracy. Pull out the CV accuracy for the chosen cost_complexity.
best_complexity <- select_by_one_std_err(tune_res, metric = 'accuracy', ______)
data_wf_final <- finalize_workflow(data_wf_tune, best_complexity)

land_final_fit <- fit(data_wf_final, data = land)


tune_res %>% 
  collect_metrics() %>%
  filter(cost_complexity == ________)

Exercise 4: Visualizing trees

Let’s visualize the difference between the trees learned under cost_complexity parameters. The code below fits a tree for a lower than optimal cost_complexity value (tree_mod_lowcp) and a higher than optimal cost_complexity (tree_mod_highcp). We then plot these trees (1st and 3rd) along with our best tree (2nd).

Look at page 3 of the rpart.plot package vignette (an example-heavy manual) to understand what the plot components mean.

tree_mod_lowcp <- fit(
    data_wf_tune %>%
        update_model(ct_spec_tune %>% set_args(cost_complexity = .00001)),
    data = land
)
tree_mod_highcp <- fit(
    data_wf_tune %>%
        update_model(ct_spec_tune %>% set_args(cost_complexity = .1)),
    data = land
)

# Plot all 3 trees in a row
par(mfrow = c(1,3))
tree_mod_lowcp %>% extract_fit_engine() %>% rpart.plot()
land_final_fit %>% extract_fit_engine() %>% rpart.plot()
tree_mod_highcp %>% extract_fit_engine() %>% rpart.plot()
  • Verify for a couple of splits the idea of increasing node purity/homogeneity in tree-building. (How is this idea reflected in the numbers in the plot output?) Discuss as a group.

  • Tuning classification trees (like with the cost_complexity parameter) is also referred to as “pruning”. Why does this make sense? NOTE: If “pruning” is a new word to you, first Google it.

Exercise 5: Predictions from trees

  1. Looking at the plot of the best fitted tree, manually make a soft (probability) and hard (class) prediction for the case shown below.
# Pick out training case 2 to make a prediction
test_case <- land[2,]
# Show only the needed predictors
test_case %>% select(NDVI, Bright_100, SD_NIR, GLCM2_100)
    
land_final_fit %>% extract_fit_engine() %>% rpart.plot()
  1. Verify your predictions with the predict() function. (Note: we introduced this code in Logistic Regression, but this type of code applies to any classification model fit in tidymodels).
# Soft (probability) prediction
predict(land_final_fit, new_data = test_case, type = "prob")

# Hard (class) prediction
predict(land_final_fit, new_data = test_case, type = "class")

Exercise 6: Variable importance in trees

We can obtain numerical variable importance measures from trees. These measure, roughly, “the total decrease in node impurities from splitting on the variable” (even if the variable isn’t ultimately used in the split).

What are the 3 most important predictors by this measure? Does this agree with you might have expected based on the plot of the best fitted tree? What might greedy behavior have to do with this?

land_final_fit %>%
  extract_fit_engine() %>%
  pluck('variable.importance')

Extra! Regression Trees!

As discussed in the video, trees can also be used for regression! Let’s work through a step of building a regression tree by hand.

For the two possible splits below, determine the better split for the tree by computing the sum of squared residuals as the measure of node impurity. (The numbers following Yes: and No: indicate the outcome value of the cases in the left (Yes) and right (No) regions.)

Split 1: x1 < 3
    - Yes: 1, 1, 2, 4
    - No: 2, 2, 4, 4

Split 2: x1 < 4
    - Yes: 1, 1, 2
    - No: 2, 2, 4, 4, 4

In case you want to explore building regression trees in R, try out the following exercises using the College data from the ISLR package. Our goal was to predict graduation rate (Grad.Rate) as a function of other predictors. You can look at the data codebook with ?College in the Console.

library(ISLR)

data(College)

# A little data cleaning
college_clean <- College %>% 
    mutate(school = rownames(College)) %>% 
    filter(Grad.Rate <= 100) # Remove one school with grad rate of 118%
rownames(college_clean) <- NULL # Remove school names as row names
  1. Adapt our general decision tree code for the regression setting by adapting the metric used to pick the final model. (Note how other parts stay the same!)

  2. Plot test performance as a function of cost_complexity, and comment on the shape of the plot.

  3. Plot the “best” tree. (See page 3 of the rpart.plot package vignette for a refresher on what the plot shows.) Do the sequence of splits and outcomes in the leaf nodes make sense?

  4. Look at the variable importance metrics from the best tree. Do the most important variables align with your intuition?