Topic 11 Logistic Regression & Evaluation
Learning Goals
- Use a logistic regression model to make hard (class) and soft (probability) predictions
- Interpret non-intercept coefficients from logistic regression models in the data context
- Calculate (by hand from confusion matrices) and contextually interpret overall accuracy, sensitivity, and specificity
- Construct and interpret plots of predicted probabilities across classes
- Explain how a ROC curve is constructed and the rationale behind AUC as an evaluation metric
- Appropriately use and interpret the no-information rate (largest proportion of the observed classes) to evaluate accuracy metrics
Slides from today are available here.
Logistic regression in tidymodels
To build logistic regression models in tidymodels
, first load the package and set the seed for the random number generator to ensure reproducible results:
library(dplyr)
library(ggplot2)
library(tidymodels)
library(probably) #install.packages('probably')
tidymodels_prefer()
set.seed(___) # Pick your favorite number to fill in the parentheses
Then adapt the following code to fit a logistic regression model:
# Make sure you set reference level (the outcome you are NOT interested in)
<- data %>%
data mutate(outcome = relevel(outcome, ref='failure')) #set reference level
<- vfold_cv(data, v = 10)
data_cv10
# Logistic Regression Model Spec
<- logistic_reg() %>%
logistic_spec set_engine('glm') %>%
set_mode('classification')
# Recipe
<- recipe(outcome ~ ., data = data) %>%
logistic_rec step_normalize(all_numeric_predictors()) %>%
step_dummy(all_nominal_predictors())
# Workflow (Recipe + Model)
<- workflow() %>%
log_wf add_recipe(logistic_rec) %>%
add_model(logistic_spec)
# Fit Model to Training Data
<- fit(log_wf, data = data) log_fit
Examining the logistic model
# Print out Coefficients
%>% tidy()
log_fit
# Get Exponentiated coefficients + CI
%>% tidy() %>%
log_fit mutate(OR.conf.low = exp(estimate - 1.96*std.error), OR.conf.high = exp(estimate + 1.96*std.error)) %>% # do this first
mutate(OR = exp(estimate))
Making predictions from the logistic model
# Make soft (probability) predictions
predict(log_fit, new_data = ___, type = "prob")
# Make hard (class) predictions (using a default 0.5 probability threshold)
predict(log_fit, new_data = ___, type = "class")
Evaluating the logistic model on training data
# Soft predictions
<- data %>%
logistic_output bind_cols(predict(log_fit, new_data = data, type = 'prob'))
# Hard predictions (you pick threshold)
<- logistic_output %>%
logistic_output mutate(.pred_class = make_two_class_pred(.pred_REFLEVEL, levels(outcome), threshold = .8)) #Try changing threshold (.5, 0, 1, .2, .8)
# Visualize Soft Predictions
%>%
logistic_output ggplot(aes(x = outcome, y = .pred_OUTCOME)) +
geom_boxplot() +
geom_hline(yintercept = 0.5, color='red') + # try changing threshold
labs(y = 'Predicted Probability of Outcome', x = 'Observed Outcome') +
theme_classic()
Calculate evaluation metrics of the logistic model on training data
# Confusion Matrix
%>%
logistic_output conf_mat(truth = outcome, estimate = .pred_class)
<- metric_set(sens, yardstick::spec, accuracy) # these metrics are based on hard predictions
log_metrics
#sens: sensitivity = chance of correctly predicting second level, given second level (Yes)
#spec: specificity = chance of correctly predicting first level, given first level (No)
#accuracy: accuracy = chance of correctly predicting outcome
%>%
logistic_output log_metrics(estimate = .pred_class, truth = outcome, event_level = "second") # set second level of outcome as "success"
ROC Curve: evaluating logistic model using soft predictions
For each threshold level, you can calculate sensitivity and specificity. An ROC (receiver operating characteristics) curve plots (1-specificity) on the x axis against the sensitivity. Ideally, we want sensitivity = 1 (no errors for one outcome group) and (1-specificity) = 0 (no errors for the other outcome group). This would visually look like a vertical line at x=0 and horizontal line at y = 1 (a “sharp corner”). The better model will be as close to this shape as possible.
The y=x line is there and would be equivalent to guessing. A poor model will be close to this line.
Typically the ROC curve shape is somewhere between a “sharp corner” and the diagonal y=x line. For different thresholds, as sensitivity increases, the specificity decreases or (1-specificity increases).
<- logistic_output %>%
logistic_roc roc_curve(outcome, .pred_OUTCOME, event_level = "second") # set second level of outcome as "success"
autoplot(logistic_roc) + theme_classic()
Calculate evaluation metrics of the logistic model using CV
# CV Fit Model
<- fit_resamples(
log_cv_fit
log_wf, resamples = data_cv10,
metrics = metric_set(sens, yardstick::spec, accuracy, roc_auc),
control = control_resamples(save_pred = TRUE, event_level = 'second')) # you need predictions for ROC calculations
collect_metrics(log_cv_fit) #default threshold is 0.5
Exercises
You can download a template RMarkdown file to start from here.
Context
We’ll be working with a spam dataset that contains information on different features of emails and whether or not the email was spam. The variables are as follows:
spam
: Eitherspam
ornot spam
word_freq_WORD
: percentage of words in the e-mail that matchWORD
(0-100)char_freq_CHAR
: percentage of characters in the e-mail that matchCHAR
(e.g., exclamation points, dollar signs)capital_run_length_average
: average length of uninterrupted sequences of capital letterscapital_run_length_longest
: length of longest uninterrupted sequence of capital letterscapital_run_length_total
: sum of length of uninterrupted sequences of capital letters
Our goal will be to use email features to predict whether or not an email is spam - essentially, to build a spam filter!
library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
library(probably) #install.packages('probably')
tidymodels_prefer()
<- read_csv("https://www.dropbox.com/s/leurr6a30f4l32a/spambase.csv?dl=1") spam
Exercise 1: Visualization warmup
Let’s take a look at the frequency of the word “George” (the email recipient’s name is George) (word_freq_george
) and the frequency of exclamation points (char_freq_exclam
).
Create appropriate visualizations to assess the predictive ability of these two predictors.
# If you want to adjust the axis limits, you can add the following to your plot:
# + coord_cartesian(ylim = c(0,1))
# + coord_cartesian(xlim = c(0,1))
ggplot(spam, aes(x = ___, y = ___)) +
geom_???()
Exercise 2: Implementing logistic regression in tidymodels
Our goal is to fit a logistic regression model with word_freq_george
and char_freq_exclam
as predictors.
Write down the corresponding logistic regression model formula using mathematical notation.
Use
tidymodels
to fit this logistic regression model to the training data.
# Make sure you set reference level (to the outcome you are NOT interested in)
<- spam %>%
spam mutate(spam = relevel(factor(spam), ref='not spam')) #set reference level
# Logistic Regression Model Spec
<- logistic_reg() %>%
logistic_spec set_engine('glm') %>%
set_mode('classification')
# Recipe
<- recipe(spam ~ _____, data = spam)
logistic_rec
# Workflow (Recipe + Model)
<- workflow() %>%
log_wf add_recipe(logistic_rec) %>%
add_model(logistic_spec)
# Fit Model
<- fit(log_wf, data = ___) log_fit
Exercise 3: Interpreting the model
Take a look at the log-scale coefficients with
tidy(log_fit)
. Do the signs of the coefficients for the 2 predictors agree with your visual inspection from Exercise 1?Display the exponentiated coefficients, and provide contextual interpretations for them (not the intercept).
Exercise 4: Making predictions
Consider a new email where the frequency of “George” is 0.25% and the frequency of exclamation points is 1%.
Use the model summary to make both a soft (probability) and hard (class) prediction for this test case by hand. Use a default probability threshold of 0.5. (You can use math expressions to use R as a calculator. The
exp()
function exponentiates a number.)Check your work from part a by using
predict()
.
predict(___, new_data = data.frame(word_freq_george = 0.25, char_freq_exclam = 1), type = ___)
Exercise 5: Evaluate the model
- Visualize the soft predictions and comment on a potential threshold for doing hard predictions.
# Soft predictions
<- spam %>%
logistic_output bind_cols(predict(log_fit, new_data = spam, type = 'prob'))
head(logistic_output)
- Choose a threshold and calculate training metrics based on hard predictions (sensitivity, specificity, accuracy). Write a sentence interpreting these values in context.
# Hard predictions (you pick threshold)
<- logistic_output %>%
logistic_output mutate(.pred_class = make_two_class_pred(`.pred_not spam`, levels(spam), threshold = 1 - ___))
# Confusion Matrix
%>%
logistic_output conf_mat(truth = spam, estimate = .pred_class)
# Calculate them by hand first and then confirm below
#sens: sensitivity = chance of correctly predicting second level, given second level (Spam)
#spec: specificity = chance of correctly predicting first level, given first level (Not Spam)
#accuracy: accuracy = chance of correctly predicting outcome
<- metric_set(sens, yardstick::spec, accuracy)
log_metrics %>%
logistic_output log_metrics(estimate = .pred_class, truth = spam, event_level = "second")
- Now visualize the soft predictions using all possible thresholds by visualizing the ROC curve and calculating the area under the roc curve. Discuss what this tells you about how well the model classifies spam emails.
<- logistic_output %>%
logistic_roc roc_curve(spam, .pred_spam, event_level = "second")
autoplot(logistic_roc) + theme_classic()
%>%
logistic_output roc_auc(spam, .pred_spam, event_level = "second")
- Now, let’s use CV to evaluate the model on test data. Comment on the difference in CV metrics to the training metrics.
set.seed(123)
<- vfold_cv(spam, v = 10)
data_cv10
# CV Fit Model
<- fit_resamples(
log_cv_fit
log_wf, resamples = data_cv10,
metrics = metric_set(sens, yardstick::spec, accuracy, roc_auc),
control = control_resamples(save_pred = TRUE, event_level = 'second')) # you need predictions for ROC calculations
collect_metrics(log_cv_fit) #default threshold is 0.5