Lesson 1
Saving Progress with Model Checkpointing in PyTorch
Introduction to Model Checkpointing

In today's lesson, we will focus on model checkpointing using PyTorch. Model checkpointing is a vital technique in machine learning that allows us to save the state of a model during training, ensuring that the best-performing models are preserved. By the end of this lesson, you will understand how to implement model checkpointing, allowing you to save your model whenever it achieves the best performance on a validation set.

What is Model Checkpointing?

Model checkpointing involves saving the state of a neural network model at various points during the training process. This is crucial for several reasons:

  • Prevent Loss of Progress: In case of unexpected interruptions (e.g., power failure, hardware malfunction), checkpointing helps in resuming training from the last saved state.
  • Save Best Performing Models: By saving the model whenever it achieves a new best performance on a validation set, we ensure that we retain the best version of our model.
Setting Up the Environment

First, let’s set up our environment. We will import the necessary libraries, load and split our dataset, and finally build our model:

Python
1import torch 2import torch.nn as nn 3import torch.optim as optim 4from sklearn.datasets import load_wine 5from sklearn.model_selection import train_test_split 6 7# Load dataset 8wine = load_wine() 9X = torch.tensor(wine.data, dtype=torch.float32) 10y = torch.tensor(wine.target, dtype=torch.long) 11 12# Split dataset into training and validation sets 13X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2) 14 15# Define the model using nn.Sequential 16model = nn.Sequential( 17 nn.Linear(13, 10), 18 nn.ReLU(), 19 nn.Linear(10, 10), 20 nn.ReLU(), 21 nn.Linear(10, 3) 22) 23 24# Define criterion and optimizer 25criterion = nn.CrossEntropyLoss() 26optimizer = optim.Adam(model.parameters(), lr=0.001)
Initialize Checkpoint Parameters

Before diving into the training loop, let's first set up the initial parameters for checkpointing. This will ensure we can effectively track the model's performance and save the best version. Specifically, we'll need to establish:

  • best_loss to keep track of the best validation loss. We initialize best_loss to float('inf') to ensure the first validation loss will trigger a model save.
  • checkpoint_path where the model will be saved.
Python
1best_loss = float('inf') 2checkpoint_path = "best_model.pth"
Training Loop with Checkpointing

Now, we will implement the training loop with validation and model checkpointing:

Python
1num_epochs = 100 2for epoch in range(num_epochs): 3 model.train() 4 optimizer.zero_grad() 5 outputs = model(X_train) 6 loss = criterion(outputs, y_train) 7 loss.backward() 8 optimizer.step() 9 10 # Validate the model 11 model.eval() 12 with torch.no_grad(): 13 val_outputs = model(X_valid) 14 val_loss = criterion(val_outputs, y_valid).item() 15 16 # Save the model if the validation loss has decreased 17 if val_loss < best_loss: 18 best_loss = val_loss 19 torch.save(model, checkpoint_path) 20 print(f"Model saved at epoch {epoch} with validation loss {val_loss:.4f}!")

In this training loop:

  • The model is trained on the training set.
  • The model's performance is validated on the validation set.
  • If the validation loss improves, the model is saved using torch.save(). This ensures that only the best performing model is saved.
Lesson Summary

Congratulations! You've learned about the concept and importance of model checkpointing, as well as how to implement checkpointing in a PyTorch model. Remember, implementing effective checkpointing can significantly boost your productivity and model performance in real-world machine learning tasks. Keep practicing and happy coding!

Enjoy this lesson? Now it's time to practice with Cosmo!
Practice is how you turn knowledge into actual skills.