In today's lesson, we will focus on model checkpointing using PyTorch
. Model checkpointing is a vital technique in machine learning that allows us to save the state of a model during training, ensuring that the best-performing models are preserved. By the end of this lesson, you will understand how to implement model checkpointing, allowing you to save your model whenever it achieves the best performance on a validation set.
Model checkpointing involves saving the state of a neural network model at various points during the training process. This is crucial for several reasons:
- Prevent Loss of Progress: In case of unexpected interruptions (e.g., power failure, hardware malfunction), checkpointing helps in resuming training from the last saved state.
- Save Best Performing Models: By saving the model whenever it achieves a new best performance on a validation set, we ensure that we retain the best version of our model.
First, let’s set up our environment. We will import the necessary libraries, load and split our dataset, and finally build our model:
Python1import torch 2import torch.nn as nn 3import torch.optim as optim 4from sklearn.datasets import load_wine 5from sklearn.model_selection import train_test_split 6 7# Load dataset 8wine = load_wine() 9X = torch.tensor(wine.data, dtype=torch.float32) 10y = torch.tensor(wine.target, dtype=torch.long) 11 12# Split dataset into training and validation sets 13X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2) 14 15# Define the model using nn.Sequential 16model = nn.Sequential( 17 nn.Linear(13, 10), 18 nn.ReLU(), 19 nn.Linear(10, 10), 20 nn.ReLU(), 21 nn.Linear(10, 3) 22) 23 24# Define criterion and optimizer 25criterion = nn.CrossEntropyLoss() 26optimizer = optim.Adam(model.parameters(), lr=0.001)
Before diving into the training loop, let's first set up the initial parameters for checkpointing. This will ensure we can effectively track the model's performance and save the best version. Specifically, we'll need to establish:
best_loss
to keep track of the best validation loss. We initializebest_loss
tofloat('inf')
to ensure the first validation loss will trigger a model save.checkpoint_path
where the model will be saved.
Python1best_loss = float('inf') 2checkpoint_path = "best_model.pth"
Now, we will implement the training loop with validation and model checkpointing:
Python1num_epochs = 100 2for epoch in range(num_epochs): 3 model.train() 4 optimizer.zero_grad() 5 outputs = model(X_train) 6 loss = criterion(outputs, y_train) 7 loss.backward() 8 optimizer.step() 9 10 # Validate the model 11 model.eval() 12 with torch.no_grad(): 13 val_outputs = model(X_valid) 14 val_loss = criterion(val_outputs, y_valid).item() 15 16 # Save the model if the validation loss has decreased 17 if val_loss < best_loss: 18 best_loss = val_loss 19 torch.save(model, checkpoint_path) 20 print(f"Model saved at epoch {epoch} with validation loss {val_loss:.4f}!")
In this training loop:
- The model is trained on the training set.
- The model's performance is validated on the validation set.
- If the validation loss improves, the model is saved using
torch.save()
. This ensures that only the best performing model is saved.
Congratulations! You've learned about the concept and importance of model checkpointing, as well as how to implement checkpointing in a PyTorch model. Remember, implementing effective checkpointing can significantly boost your productivity and model performance in real-world machine learning tasks. Keep practicing and happy coding!