Hello and welcome! In this lesson, you'll learn how to compute and visualize a correlation matrix using the diamonds dataset. The goal is to understand how different features in the dataset relate to each other through correlations and visualize these relationships using a heatmap.
The diamonds dataset contains categorical features such as cut
, color
, and clarity
. Correlation matrices require numerical data, so we need to convert these categorical variables to numerical codes.
First, let's identify the categorical columns that need conversion, then we'll convert them using astype('category').cat.codes
. astype('category')
makes sure the feature is a categorical type, after which .cat.codes
can be applied to convert it to a unique integer code ranging from 0
to number_of_categories - 1
.
Python1import seaborn as sns 2import pandas as pd 3 4# Load the diamonds dataset 5diamonds = sns.load_dataset('diamonds') 6 7# Convert categorical variables for correlation calculation 8category_columns = ['cut', 'color', 'clarity'] 9for col in category_columns: 10 diamonds[col] = diamonds[col].astype('category').cat.codes 11 12print(diamonds.head())
By converting these columns, you enable the dataset to be used in correlation computations where all features need to be numerical:
Plain text1 carat cut color clarity depth table price x y z 20 0.23 0 1 6 61.5 55.0 326 3.95 3.98 2.43 31 0.21 1 1 5 59.8 61.0 326 3.89 3.84 2.31 42 0.23 3 1 3 56.9 65.0 327 4.05 4.07 2.31 53 0.29 1 5 4 62.4 58.0 334 4.20 4.23 2.63 64 0.31 3 6 6 63.3 58.0 335 4.34 4.35 2.75
Next, we'll compute the correlation matrix. A correlation matrix is a table that shows correlation coefficients between variables. Each cell in the table shows the correlation between two variables.
We'll use the corr()
method from pandas for this:
Python1import seaborn as sns 2import pandas as pd 3 4# Load the diamonds dataset 5diamonds = sns.load_dataset('diamonds') 6 7# Convert categorical variables for correlation calculation 8category_columns = ['cut', 'color', 'clarity'] 9for col in category_columns: 10 diamonds[col] = diamonds[col].astype('category').cat.codes 11 12# Compute the correlation matrix 13correlation_matrix = diamonds.corr() 14print(correlation_matrix)
The correlation matrix will give us an understanding of how each feature relates to every other feature in the dataset. The values will range from -1 to 1, where:
Understanding these values helps us see the strength and direction of the relationships between different features in the dataset.
Plain text1 carat cut color clarity depth table price \ 2carat 1.000000 0.134967 0.291437 0.352841 0.028224 0.181618 0.921591 3cut 0.134967 1.000000 0.020519 0.189175 0.218055 0.433405 0.053491 4color 0.291437 0.020519 1.000000 -0.025631 0.047279 0.026465 0.172511 5clarity 0.352841 0.189175 -0.025631 1.000000 0.067384 0.160327 0.146800 6depth 0.028224 0.218055 0.047279 0.067384 1.000000 -0.295779 -0.010647 7table 0.181618 0.433405 0.026465 0.160327 -0.295779 1.000000 0.127134 8price 0.921591 0.053491 0.172511 0.146800 -0.010647 0.127134 1.000000 9x 0.975094 0.125565 0.270287 0.371999 -0.025289 0.195344 0.884435 10y 0.951722 0.121462 0.263584 0.358420 -0.029341 0.183760 0.865421 11z 0.953387 0.149323 0.268227 0.366952 0.094924 0.150929 0.861249 12 13 x y z 14carat 0.975094 0.951722 0.953387 15cut 0.125565 0.121462 0.149323 16color 0.270287 0.263584 0.268227 17clarity 0.371999 0.358420 0.366952 18depth -0.025289 -0.029341 0.094924 19table 0.195344 0.183760 0.150929 20price 0.884435 0.865421 0.861249 21x 1.000000 0.974701 0.970772 22y 0.974701 1.000000 0.952006 23z 0.970772 0.952006 1.000000
Correlation matrices, while informative, can be hard to interpret when not visualized. A heatmap can make it easier to see patterns.
We'll use the sns.heatmap()
function from the Seaborn library to visualize our correlation matrix.
Python1import matplotlib.pyplot as plt 2import seaborn as sns 3import pandas as pd 4 5# Load the diamonds dataset 6diamonds = sns.load_dataset('diamonds') 7 8# Convert categorical variables for correlation calculation 9category_columns = ['cut', 'color', 'clarity'] 10for col in category_columns: 11 diamonds[col] = diamonds[col].astype('category').cat.codes 12 13# Compute the correlation matrix 14correlation_matrix = diamonds.corr() 15 16# Visualize the correlation matrix 17plt.figure(figsize=(10, 6)) 18sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5) 19plt.title('Heatmap of Correlations') 20plt.show()
This heatmap uses colors to highlight the strength of the correlations, making it easier to spot strong positive or negative relationships between features. The annot=True
parameter ensures that correlation values are displayed on each cell, while cmap='coolwarm'
provides a visually appealing color map.
In this lesson, we've covered how to convert categorical variables to numerical values, compute a correlation matrix, and visualize this correlation matrix using a heatmap. These are crucial skills for correlation analysis, enabling you to identify and interpret relationships between features in a dataset.
Next, you'll get to practice these tasks on your own, reinforcing your understanding and improving your data analysis skills. Keep practicing to master these essential data science skills!