Welcome to this lesson on "Visualizing Clusters with Matplotlib using an Iris Dataset". Previously, we introduced unsupervised learning, with a focus on clustering and the K-means clustering algorithm. In this unit, we will visualize the clusters resulting from the K-means algorithm using Python's Matplotlib library. The goal of this lesson is to illustrate the implementation of the K-means clustering algorithm and to demonstrate how to visualize the results using Matplotlib, utilizing a simple Iris dataset as an example.
In this lesson, we will focus on the scatter()
function, which allows us to create scatter plots, and the show()
function, which will enable us to display the plot.
We begin by loading the Iris dataset. It is a classic and widely used dataset in pattern recognition, consisting of 150 samples from three species of Iris flowers (Iris setosa, Iris virginica, and Iris versicolor), with four features measured for each sample: the length and the width of the sepals and petals.
Python1# Required Libraries 2from sklearn.datasets import load_iris 3 4# Load the Iris dataset 5iris = load_iris() 6data = iris.data
Let's see how to apply KMeans from sklearn and visualize results using Matplotlib. Horizontal and vertical indicators alone may not provide all the clarity we need in a plot. That's where Python's Matplotlib shines — its assortment of plot customization capabilities. Let's delve into customizing the colormap and other parameters for your plots for improving visualization and readability.
Before jumping into plotting, let's cluster our data first:
Python1from sklearn.cluster import KMeans 2import matplotlib.pyplot as plt 3 4# Perform K-means clustering and take cluster centers 5kmeans_model = KMeans(n_clusters=2, random_state=1, n_init=10).fit(data) 6labels = kmeans_model.labels_ 7clusters_sklearn = kmeans_model.cluster_centers_
A colormap is like an artist's palette, essentially mapping values to colors on a plot. Matplotlib offers a variety of default colormaps. The scatter()
function uses a colormap to define the colors of markers.
The colormap can be specified using the parameter cmap
of the scatter()
method.
The c
parameter passed to the scatter()
method is a list of the same length as data
that specifies the color of each point.
Let's customize our plot:
Python1# Function to plot final clusters 2def plot_clusters_sklearn(data, labels, clusters): 3 # Plot data points, choosing first two features for visualization 4 plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', label='Data points') 5 # Plot cluster centers, also focusing on the same two features 6 plt.scatter(clusters[:, 0], clusters[:, 1], s=200, color='red', marker='X', label='Centers') 7 plt.title('Visualizing Clusters with Matplotlib using Iris Dataset') 8 plt.xlabel('Sepal Length (cm)') 9 plt.ylabel('Sepal Width (cm)') 10 plt.legend() 11 plt.grid(True) 12 plt.show() 13 14# Visualize the clusters 15plot_clusters_sklearn(data, labels, clusters_sklearn)
With Sklearn's implementation, we instantiate the KMeans model with the desired number of clusters, then fit it to our data. The labels of the clusters for each data point are obtained via the labels_
attribute of the model, while the cluster_centers_
attribute gives us the centroid of these clusters. These are then visualized through the scatter
function in Matplotlib just as we did before. Check out the plot presented below:
Here's a consolidated explanation of the parameters used across both plt.scatter
calls for plotting data points and cluster centers in the context of visualizing KMeans clustering on the Iris dataset:
data[:, 0]
and clusters[:, 0]
— These parameters specify the x coordinates for the scatter plot. data[:, 0]
selects the first feature (e.g., sepal length) of the Iris dataset for all data points. clusters[:, 0]
selects the x coordinates of the cluster centers based on the same feature.
data[:, 1]
and clusters[:, 1]
— These parameters specify the y coordinates for the scatter plot. data[:, 1]
selects the second feature (e.g., sepal width) of the Iris dataset for all data points. clusters[:, 1]
selects the y coordinates of the cluster centers based on the same feature.
c=labels
— This parameter is used in the first plt.scatter
call to determine the color of each data point based on its cluster label. The labels
array contains the cluster assignments for each point, allowing the plot to visually group data points by their respective clusters.
cmap='viridis'
— Specifies the colormap used to map the cluster labels (c=labels
) to colors. The 'viridis' colormap is a gradient from yellow to dark blue, which is perceptually uniform, making it easier to distinguish between different clusters in the visualization.
s=200
— Sets the size of the markers (in points^2). Used in the second plt.scatter
call to make the cluster centers significantly larger than the data points, enhancing their visibility. The size can be adjusted based on visual preference.
color='red'
— Determines the color of the markers. In the second plt.scatter
call, this parameter colors the cluster centers red to distinguish them clearly from the data points.
marker='X'
— Controls the shape of the marker. The 'X' shape is used for the cluster centers to differentiate them from the data points visually.
label='Data points'
and label='Cluster Centers'
— These parameters label the different elements in the plot for the legend. The first label='Data points'
is used to label the scatter plot of the Iris dataset points, and the second label='Cluster Centers'
labels the plot of the cluster centers. Including labels helps in creating a plot legend that distinguishes between the plotted data points and cluster centers, making the plot more informative and easier to understand.
Congratulations! In this lesson, we have successfully implemented K-means clustering and visualized clusters using Matplotlib with an Iris dataset. Practice exercises following this lesson will enable you to utilize these newly acquired skills. Continue practicing and enriching your machine learning knowledge with Python. Keep up the good work!