Welcome back! In the previous lesson, we learned how to train a simple machine learning model using the Linear SVM
method with the caret
package in R. Now that you have a trained model, it's time to make predictions and evaluate how well your model performs.
In this lesson, you will learn how to:
Making predictions and evaluating your model are crucial steps in the machine learning workflow. They help you assess whether your model performs well on unseen data and identify areas for improvement.
In this lesson, we continue from where we left off in the last lesson. Before making predictions, remember that the initial steps include loading the dataset, splitting the data into training and testing sets, and training your model. Here’s a brief reminder of those steps:
R1# Load iris dataset 2data(iris) 3 4# For reproducibility 5set.seed(123) 6 7# Splitting data into train and test sets 8trainIndex <- createDataPartition(iris$Species, p = 0.7, list = FALSE, times = 1) 9irisTrain <- iris[trainIndex,] 10irisTest <- iris[-trainIndex,] 11 12# Training a Linear SVM model 13model <- train(Species ~ ., data = irisTrain, method = "svmLinear")
Once you have your trained model, the next step is to make predictions on your test data. In R, you can use the predict
function from the caret
package to generate these predictions. Here's how you can do it:
R1# Making predictions 2predictions <- predict(model, irisTest)
In this code snippet:
model
is your trained machine learning model.irisTest
is your test dataset, which contains the same features as your training data, including the target labels. The predict
function will ignore the labels and use only the features for making predictions.After making predictions, you need to evaluate how well your model performed. One of the common ways to do this is by using a confusion matrix. But what exactly is a confusion matrix?
A confusion matrix is a table used to evaluate the performance of a classification model. It compares the actual target values with the predicted values and provides a detailed breakdown of your model's performance. The matrix includes the following terms:
The rows in a confusion matrix represent the actual classes, and the columns represent the predicted classes. This makes it easy to see where your model is making correct and incorrect classifications.
Here's how you can create and print a confusion matrix in R:
R1confusion <- confusionMatrix(predictions, irisTest$Species) 2print(confusion)
In this code snippet:
confusionMatrix
is a function from the caret
package that takes the predictions and the actual target values as arguments and returns a confusion matrix.irisTest$Species
is the actual target values from your test dataset.Here’s a rough idea of what you might see:
1Confusion Matrix and Statistics 2 3 Reference 4Prediction setosa versicolor virginica 5 setosa 15 0 0 6 versicolor 0 15 1 7 virginica 0 0 14 8 9Overall Statistics 10 11 Accuracy : 0.9778 12 95% CI : (0.8823, 0.9994) 13 No Information Rate : 0.3333 14 P-Value [Acc > NIR] : < 2.2e-16 15 16 Kappa : 0.9667 17 18 Mcnemar's Test P-Value : NA 19 20Statistics by Class: 21 22 Class: setosa Class: versicolor Class: virginica 23Sensitivity 1.0000 1.0000 0.9333 24Specificity 1.0000 0.9667 1.0000 25Pos Pred Value 1.0000 0.9375 1.0000 26Neg Pred Value 1.0000 1.0000 0.9677 27Prevalence 0.3333 0.3333 0.3333 28Detection Rate 0.3333 0.3333 0.3111 29Detection Prevalence 0.3333 0.3556 0.3111 30Balanced Accuracy 1.0000 0.9833 0.9667
Accuracy is a common metric used to evaluate the performance of a classification model. It is calculated as the number of correct predictions divided by the total number of predictions. Mathematically, it's represented as:
In the example confusion matrix above:
To calculate the accuracy manually from the confusion matrix:
Using the caret
package, you don't need to calculate it manually as it will be provided under the "Overall Statistics" section when you print the confusion matrix.
Evaluating your model is essential because it tells you how good your predictions are. A model that performs well on training data but poorly on test data is not useful. By utilizing tools like the confusion matrix, you can quantify your model's performance and gain insights into its accuracy and errors.
Precision is the ratio of correctly predicted positive observations to the total predicted positives. It answers the question, "Of all the instances classified as positive, how many are actually positive?"
Recall, or sensitivity, is the ratio of correctly predicted positive observations to all observations in the actual class. It answers the question, "Of all the actual positive instances, how many were correctly classified as positive?"
Using the caret
package, you can calculate these metrics from the confusion matrix object. Here's how you can do it:
R1# Evaluating the model 2confusion <- confusionMatrix(predictions, irisTest$Species) 3print(confusion) 4 5# Extracting the relevant metrics 6accuracy <- confusion$overall['Accuracy'] 7# Extracting metrics for Setosa (class 1) 8precision_setosa <- confusion$byClass[1, 'Pos Pred Value'] 9recall_setosa <- confusion$byClass[1, 'Sensitivity'] 10# Extracting metrics for Versicolor (class 2) 11precision_versicolor <- confusion$byClass[2, 'Pos Pred Value'] 12recall_versicolor <- confusion$byClass[2, 'Sensitivity'] 13# Extracting metrics for Virginica (class 3) 14precision_virginica <- confusion$byClass[3, 'Pos Pred Value'] 15recall_virginica <- confusion$byClass[3, 'Sensitivity'] 16 17# Printing the extracted metrics 18print(paste("Accuracy:", accuracy)) 19print(paste("Precision for Setosa:", precision_setosa)) 20print(paste("Recall for Setosa:", recall_setosa)) 21print(paste("Precision for Versicolor:", precision_versicolor)) 22print(paste("Recall for Versicolor:", recall_versicolor)) 23print(paste("Precision for Virginica:", precision_virginica)) 24print(paste("Recall for Virginica:", recall_virginica))
In this code:
confusion$overall['Accuracy']
gives you the accuracy of the model.confusion$byClass[i, 'Pos Pred Value']
gives you the precision for class (i).confusion$byClass[i, 'Sensitivity']
gives you the recall for class (i).Note: The paste
function in R concatenates strings together, allowing you to create a single string from multiple elements. In this example, paste
is used to combine the metric names (e.g., "Accuracy:") with their respective values for easier readability when printing.
Here’s a rough idea of what you might see:
1[1] "Accuracy: 0.977777777777778" 2[1] "Precision for Setosa: 1" 3[1] "Recall for Setosa: 1" 4[1] "Precision for Versicolor: 0.9375" 5[1] "Recall for Versicolor: 1" 6[1] "Precision for Virginica: 1" 7[1] "Recall for Virginica: 0.933333333333333"
These metrics provide a comprehensive view of your model's performance:
Excited to see how your hard work pays off? Let's dive into the practice section and put your model to the test.