Hello, and welcome to this lesson on creating a scatter plot to visualize the relationship between the carat and price of diamonds using the popular diamonds
dataset in Seaborn. Scatter plots are a fundamental part of exploratory data analysis (EDA) because they allow us to identify potential correlations and patterns between two continuous variables. By the end of this lesson, you will understand how to create and customize scatter plots for better data insights.
A scatter plot is a type of data visualization that displays values for two variables in a two-dimensional space. Each point on the scatter plot represents an observation, making it an excellent tool for visualizing relationships and identifying trends, clusters, or outliers in the data.
In our case, we will use a scatter plot to explore the relationship between the carat (size) of a diamond and its price. This will help us understand how these two variables are correlated and whether carat size significantly influences diamond pricing.
Once the dataset is loaded, we can proceed to create a scatter plot. We'll use the scatterplot
function from the Seaborn
library.
Here is the basic code to create a scatter plot:
Python1import seaborn as sns 2import matplotlib.pyplot as plt 3 4# Load the diamonds dataset 5diamonds = sns.load_dataset('diamonds') 6 7# Filter out any entries with missing values 8diamonds = diamonds.dropna() 9 10# Basic Scatter plot of carat vs price 11plt.figure(figsize=(10,6)) 12sns.scatterplot(x='carat', y='price', data=diamonds) 13plt.title('Scatter Plot of Carat vs. Price') 14plt.xlabel('Carat') 15plt.ylabel('Price') 16plt.show()
The output of the above code will be:
This plot visually indicates a positive correlation between the carat size of a diamond and its price, illustrating that, generally, larger diamonds tend to be more expensive.
To make the scatter plot more effective, we can enhance it by changing the size of the figure, adding transparency to the points, and customizing colors.
Let's update our plot for better visualization:
Python1import seaborn as sns 2import matplotlib.pyplot as plt 3 4# Load the diamonds dataset 5diamonds = sns.load_dataset('diamonds') 6 7# Filter out any entries with missing values 8diamonds = diamonds.dropna() 9 10# Enhanced Scatter plot with added transparency 11plt.figure(figsize=(10,6)) 12sns.scatterplot(x='carat', y='price', data=diamonds, alpha=0.6) 13plt.title('Scatter Plot of Carat vs. Price') 14plt.xlabel('Carat') 15plt.ylabel('Price') 16plt.show()
The output of the above code will be:
In this enhanced version:
alpha=0.6
adds transparency to the points, making overlapping data points easier to see, which provides a clearer visualization of the concentration and distribution of data points.
When creating scatter plots in Seaborn, there are several parameters you can adjust to customize your visualization. Here are some commonly used parameters:
- hue: Grouping variable that will produce points with different colors.
- size: Grouping variable that will produce points with different sizes.
- style: Grouping variable that will produce points with different markers.
- markers: Marker style to use for the different levels of the
style
variable. - edgecolor: Color of the point borders.
- linewidth: Width of the point borders.
- legend: How to draw the legend, if any (
'auto'
,'brief'
,'full'
, orFalse
). - ax: Matplotlib Axes to draw the plot onto, otherwise draws onto the current Axes.
By tweaking these parameters, you can tailor the scatter plot to better fit your specific data visualization needs.
One example implementation is the following code:
Python1import seaborn as sns 2import matplotlib.pyplot as plt 3 4# Load the diamonds dataset 5diamonds = sns.load_dataset('diamonds') 6 7# Filter out any entries with missing values 8diamonds = diamonds.dropna() 9 10# Create matplotlib Axes for the plot 11fig, ax = plt.subplots(figsize=(10,6)) 12 13# Scatter plot with various customizations 14sns.scatterplot( 15 x='carat', y='price', data=diamonds, 16 hue='cut', # Points colored by 'cut' 17 size='depth', # Points sized by 'depth' 18 style='clarity', # Points styled by 'clarity' 19 markers={'VS2': 'o', 'SI1': 's', 'SI2': 'D', 'VS1': 'P', 'VVS1': 'X', 'IF': '*', 'I1': 'v', 'VVS2': 'p'}, # Custom markers for all clarity levels 20 edgecolor='w', # White border around points 21 linewidth=0.5, # Width of the point borders 22 legend='brief', # Display a brief legend 23 ax=ax # Draw onto specified Axes 24) 25plt.title('Customized Scatter Plot of Carat vs. Price') 26plt.xlabel('Carat') 27plt.ylabel('Price') 28plt.show()
With our scatter plot created, it's time to interpret it. Here are some aspects to focus on:
- Correlation: Check whether there is a positive or negative correlation between
carat
andprice
. - Clusters: Look for any natural grouping of data points.
- Outliers: Identify any points that stand out from the general pattern, which might represent unusual diamonds.
By examining these aspects, we can gain valuable insights into how diamond characteristics affect their pricing.
In summary, we covered:
- The concept and importance of scatter plots.
- Creating a scatter plot to visualize the relationship between
carat
andprice
. - Enhancing the scatter plot for better understanding and readability.
- Interpreting the resulting scatter plot to derive meaningful insights.