Hello again, and welcome to another crucial part of our machine learning journey with the mtcars
dataset. In the previous lesson, you learned how to visualize the results of your logistic regression model and identify the importance of different features. This time, we're moving ahead to a critical step in creating robust machine learning models: generalization and validation through cross-validation.
In this lesson, we will focus on:
caret
package in R to validate our logistic regression model.Cross-validation is a technique used to evaluate how well your model generalizes to unseen data by splitting the dataset into multiple subsets or "folds". The model is trained on a portion of the data and validated on the remaining part, rotating through the folds to get a comprehensive performance metric. This helps ensure that your model isn’t just fitting noise in your training data but can perform well on independent datasets. It reduces the risk of overfitting and provides a more reliable estimate of model performance.
Here’s why mastering cross-validation is essential:
Understanding and implementing cross-validation is a significant step toward becoming a proficient data scientist who can build reliable and effective models.
To give you a feel for what you will be doing, here’s a snippet of code to perform cross-validation on your logistic regression model:
R1# Load the mtcars dataset 2data(mtcars) 3 4# Set seed for reproducibility 5set.seed(123) 6 7# Convert categorical columns to factors 8mtcars$am <- as.factor(mtcars$am) 9mtcars$cyl <- as.factor(mtcars$cyl) 10mtcars$vs <- as.factor(mtcars$vs) 11mtcars$gear <- as.factor(mtcars$gear) 12mtcars$carb <- as.factor(mtcars$carb) 13 14# Splitting data into training and testing sets 15trainIndex <- createDataPartition(mtcars$am, p = 0.7, list = FALSE, times = 1) 16trainData <- mtcars[trainIndex,] 17testData <- mtcars[-trainIndex,] 18 19# Feature scaling (excluding factor columns) 20numericColumns <- sapply(trainData, is.numeric) 21preProcValues <- preProcess(trainData[, numericColumns], method = c("center", "scale")) 22trainData[, numericColumns] <- predict(preProcValues, trainData[, numericColumns]) 23testData[, numericColumns] <- predict(preProcValues, testData[, numericColumns]) 24 25# Train a logistic regression model with cross-validation, and display warnings 26withCallingHandlers({ 27 cross_val <- train(am ~ mpg + hp + wt, data = trainData, method = "glm", family="binomial", trControl = trainControl(method="cv", number=10)) 28}, warning = function(w) { 29 message("Warning: ", conditionMessage(w)) 30 invokeRestart("muffleWarning") 31}) 32 33# Display cross-validation results 34print(cross_val)
Here is the breakdown of parameters used with the cross-validation training:
am ~ mpg + hp + wt
: This formula specifies the response variable am
and includes mpg
, hp
, and wt
variables in the dataset as predictors.data = trainData
: This specifies that the training will be performed on the trainData
dataset.method = "glm"
: Indicates the use of a generalized linear model (logistic regression).family = "binomial"
: Specifies a binomial distribution, required for logistic regression.trControl = trainControl(method = "cv", number = 10)
: This creates a training control object that specifies cross-validation (cv
) with 10 folds as the resampling method.Output:
1Generalized Linear Model 2 324 samples 4 3 predictor 5 2 classes: '0', '1' 6 7No pre-processing 8Resampling: Cross-Validated (10 fold) 9Summary of sample sizes: 22, 22, 21, 22, 21, 22, ... 10Resampling results: 11 12 Accuracy Kappa 13 0.8 0.64
Excited to jump in? In the following practice section, you’ll get hands-on experience with cross-validation, solidifying your understanding and making your models more dependable. Let's get started!