Welcome back! In the last lesson, we explored how to make predictions and evaluate your model's performance using a confusion matrix. Now, let's move forward to an essential technique in machine learning: cross-validation. This method helps ensure your model's performance is robust and reliable.
In this lesson, you will learn how to:
- Define cross-validation and understand its purpose.
- Implement k-fold cross-validation using the
trainControl
function from thecaret
package in R. - Train a machine learning model using cross-validation.
Cross-validation is a vital part of the model-building process. It helps you ensure your model generalizes well to unseen data, reducing the risk of overfitting and improving the model's reliability.
Cross-validation splits your data into several subsets, or "folds," and trains the model multiple times using different folds as training and validation sets. This technique gives you a more comprehensive evaluation of the model's performance, rather than relying on a single training/test split. It provides a more accurate estimate of how your model will perform on new, unseen data.
Let's now look at an example to see how cross-validation can be implemented in practice.
First, let's load the iris
dataset and set a seed for reproducibility.
R1# Load iris dataset 2data(iris) 3 4# For reproducibility 5set.seed(123)
In this step, we will define the control parameters for cross-validation using the trainControl
function from the caret
package.
R1# Define the control using a k-fold cross validation 2train_control <- trainControl(method = "cv", number = 10)
Here, we specify that we want to use 10-fold cross-validation (method = "cv"
, number = 10
). This means the data will be split into 10 subsets or folds. The model will be trained on 9 folds and tested on the remaining one, and this process will be repeated 10 times with each fold serving once as the test set.
Next, we will train a Support Vector Machine (SVM) model using cross-validation. We'll use the train
function, which is also part of the caret
package, and pass in the control parameters we defined earlier.
R1# Train the SVM model using cross-validation 2cv_model <- train(Species ~ ., data = iris, method = "svmLinear", trControl = train_control)
In this code:
Species ~ .
indicates that we are predicting theSpecies
column using all other columns in theiris
dataset.method = "svmLinear"
specifies that we are using a linear SVM model.trControl = train_control
passes the control parameters for cross-validation.
Finally, we can display the results of the cross-validation to see how well the model performed.
R1# Display cross-validation results 2print(cv_model)
The print
function will output the performance metrics of the model, including accuracy and other evaluation measures resulting from the cross-validation process.
Here’s a rough idea of what you might see:
1Support Vector Machines with Linear Kernel 2 3150 samples 4 4 predictor 5 3 classes: 'setosa', 'versicolor', 'virginica' 6 7No pre-processing 8Resampling: Cross-Validated (10 fold) 9Summary of sample sizes: 135, 135, 135, 135, 135, 135, ... 10Resampling results: 11 12 Accuracy Kappa 13 0.96 0.94 14 15Tuning parameter 'C' was held constant at a value of 1
Note that these scores are calculated on the training set.
By following these steps, you can ensure that your machine learning model is more reliable and less prone to overfitting. This practice will give you confidence in your model's ability to generalize well to new, unseen data.
Ready to make your model even more reliable and robust? Let's dive into the practice section and learn how to implement cross-validation effectively.