So far, we have explored the Flights
dataset from Seaborn, analyzed trends, and visualized these trends using various charts such as line plots and heat maps. Today, we focus on leveraging this historical data to predict future trends using a technique known as Linear Regression
.
Linear regression
is a powerful tool used to predict an outcome (dependent variable) based on one or more predictor (independent) variables, forming a linear relationship. For instance, we might wish to forecast future passenger numbers based on past trends using our air travel dataset.
Why is this important? Anticipating future trends is key to strategic decision-making. Airlines, for example, use these predictions for numerous operational and strategic decisions, such as scheduling flights, capacity planning, resource allocation, and strategic expansions. Let's see how we can make such predictions!
Linear regression assumes a linear relationship between the dependent and predictor variable(s), which can be represented as: . Here, is the dependent variable we want to predict, is our predictor variable, is the y-intercept, and is the slope of the line. In the context of our Flights dataset, can represent the number of passengers, and can represent time (years or months).
Python1import seaborn as sns 2import pandas as pd 3 4flights_data = sns.load_dataset("flights") 5flights_data['year'] = pd.to_datetime(flights_data['year'], format='%Y') 6flights_pivot = pd.pivot_table(data=flights_data, values='passengers', index='year', aggfunc='sum').reset_index() 7 8import matplotlib.pyplot as plt 9from sklearn.linear_model import LinearRegression 10import numpy as np 11 12# Extracting the year from the date and converting it into the appropriate format 13flights_pivot['year'] = flights_pivot['year'].dt.year 14 15X = np.array(flights_pivot['year']).reshape(-1,1) 16y = flights_pivot['passengers'] 17 18reg = LinearRegression().fit(X, y) 19 20plt.scatter(X, y, color = "m", marker = "o", s = 30) 21 22Y_pred = reg.predict(X) 23plt.plot(X, Y_pred, color = "g") 24 25plt.xlabel('Year') 26plt.ylabel('Passengers') 27plt.title('Linear Regression: Passengers Over Time') 28plt.show()
In the above code example, we load the Flights
dataset and pivot it to get the total passenger count for each year. Next, we create our Linear Regression
model and fit it to the data (years and passenger counts). We use a scatter plot to visualize the data points, while the line plot indicates our fitted regression line.
The purple dots represent the actual number of passengers for each year from 1949 to 1960, plotted against the year. The green line represents the line of best fit generated by linear regression. This line aims to minimize the total distance between itself and each point (which signifies the error or residual). It represents our model's best guess for the passenger count given a particular year.
A natural question arises: how good is our model's fit? This can be answered using R-squared
, a statistical measure representing the proportion of the variance for a dependent variable explained by an independent variable(s) in a regression model. It provides a measure of how well the model will likely predict future outcomes.
Python1print ("R-squared statistic: ", reg.score(X, y))
The closer the R-squared
statistic is to 1.0, the better the regression equation accounts for the variation in the dependent variable. An R-squared
score of 1 indicates that the regression predictions perfectly fit the data. An R-squared
score of one usually provides excellent predictions from the model.
Now our model is created and validated, we can use it to predict future passenger traffic.
Python1# Forecast for the next 10 years 2new_years = np.array(range(1961, 1971)).reshape(-1,1) 3new_passenger_numbers = reg.predict(new_years) 4print("Year, Predicted Passengers") 5for year, passengers in zip(new_years, new_passenger_numbers): 6 print(f"{year[0]}, {int(passengers)}")
Here, we're predicting passenger counts for the years 1961 to 1971. The LinearRegression.predict
function is used to perform the predictions. It generates the dependent variable's predicted values based on the linear regression model.
Each line of the output is as follows:
Markdown11961, 5853 21962, 6236 31963, 6619 41964, 7002 51965, 7386 61966, 7769 71967, 8152 81968, 8535 91969, 8918 101970, 9301
This proves our linear regression model, suggesting that the number of passengers will increase yearly from 1961 through 1970.
Congratulations! In this lesson, you have learned how to apply Linear Regression
to predict future trends in the airline industry. In particular, we've learned how to:
- Understand the fundamental concept of linear regression.
- Use Python's
sklearn
library to create and fit a linear regression model. - Assess the quality of our model's fit using the
R-squared
statistic. - Apply the linear regression model to make future trend forecasts.
These skills will support your work in many fields, as forecasting is vital to decision-making processes. They will be extremely useful when you derive conclusions about trends based on historical data and make predictions about future data points.
Next, you'll take your newly acquired skills for a spin in a few hands-on exercises. Solving linear regression problems will consolidate your understanding of this powerful forecasting tool. By understanding the historical data and predicting future trends, you'll make data-driven decisions quickly!