PyTorch Lightning: How To Keep Your PyTorch Project Clean

As an engineer in perception and computer vision, I tend to work with PyTorch a lot.  It’s a very flexible framework and there are already many useful components built into it that prevent you from having to build it yourself.  However, if you have already built a few projects with it, you might have noticed there are always some parts in common in each project.

This is where PyTorch Lightning comes in.  PyTorch Lightning is a framework built on top of PyTorch that provides some structure to the components you will commonly need as well as some implementations of features that exist in most deep learning projects.  A good example is your model.  Each time you set up a model, you will need a loss function, an optimizer and maybe a learning rate scheduler.  You may also need some custom layers to add to your model.  PyTorch Lightning wraps all those components into a single class called a Lightning Module that has a spot for all of those components.

It also provides things like a trainer component that runs a training loop for you, allows you to set which compute to use, for example a GPU or TPU, and takes some callbacks for things like logging and creating copies of your model at various training steps.  It would be impossible to cover all the features they provide in one article, but take my word for it that using PyTorch Lightning can make your life much easier when developing deep learning models.

An Overview of A Training Script

To get started, let’s take a look at what a PyTorch Lightning training script would look like.  You might notice that I am using Hydra for my configurations.  If you are unfamiliar with it, don’t worry, I will cover it in a future post.

Python
import hydra
from lightning.pytorch import Trainer

from fasion_mnist_dm import FashionMNISTDataModule
from resnet_lm import ResnetLightningModule


@hydra.main(version_base=None, config_path="config", config_name="config")
def train(config):

    datamodule = FashionMNISTDataModule(config.datamodule)

    model = ResnetLightningModule(config.model)

    trainer = Trainer()
    trainer.fit(model, datamodule)


if __name__ == "__main__":
    train()

As you can see, there aren’t many components here.  It just contains the following components:

  1. Datamodule: Contains the logic for setting up datasets as well as train, validation and test data loaders.  Additionally, if you wanted to do things like split a dataset between train validation, you could also do that here.
  2. Model: In PyTorch Lightning it is called a LightningModule, but in addition to containing the code for the model itself, it also sets up the optimizer, loss function and metrics, and it provides functions before and after training steps and epochs that allow you to define behavior inside the training loop.
  3. Trainer:  This is the one component that you don’t need to write yourself.   It takes a model and datamodule as input and it runs the model in a training loop for you using the train and validation data loaders from the datamodule. Additionally, it can also run the trained model on the test data loader.

Setting Up A DataModule

As an example, I’ll make a datamodule using the FashionMNIST dataset from PyTorch.  The important point here is that you can use whatever dataset you want, including a custom dataset.

Python
import lightning as L

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


class FashionMNISTDataModule(L.LightningDataModule):

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.prepare_data_per_node = True

    def prepare_data(self):
        # download
        FashionMNIST("data", train=True, download=True)
        FashionMNIST("data", train=False, download=True)

    def setup(self, stage):

        self.train_dataset = FashionMNIST("data", train=True, transform=ToTensor())

        self.val_dataset = FashionMNIST("data", train=False, transform=ToTensor())

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

    def val_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

Now that we have an idea of what it does, let’s take a look at the internals of a datamodule.  It’s pretty basic.  Other than the standard init function, there’s a setup function and a function for each data loader: train, val and test.  In my example, I have left out test to keep things simple, but in general it’s better to have a test dataset than not.

Setting Up a LightningModule (Model)

Now that we have a data loader, I will set up a LightningModule.   Since this is just a simple classification model, I’ll use some of the standard stuff: a Resnet18, Adam optimizer, a cross entropy loss function and accuracy for a metric.

Python
import torch
import torch.nn as nn
import lightning as L
from torchvision.models import resnet18

from torchmetrics import Accuracy


class ResnetLightningModule(L.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.model = resnet18(weights='IMAGENET1K_V1')

        # Change the first layer to use only 1 channel for FashionMNIST
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # Change the fully connected layer to output 10 features, one for each MNIST class
        self.model.fc = nn.Linear(in_features=512, out_features=config.num_classes, bias=True)

        self.loss_function = nn.CrossEntropyLoss()

        self.metric = Accuracy(task="multiclass", num_classes=config.num_classes)

        self.training_outputs = []
        self.validation_outputs = []

        self.training_ground_truths = []
        self.validation_ground_truths = []
        

    def forward(self, inputs):
        return self.model(inputs)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("train_loss", loss)

        self.training_outputs.append(output)
        self.training_ground_truths.append(target)
        
        return loss
    
    def on_train_epoch_end(self):
        all_preds = torch.cat(self.training_outputs, dim=0)
        all_targets = torch.cat(self.training_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("train_accuracy", accuracy)

        self.training_outputs.clear()

    
    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("val_loss", loss)

        self.validation_outputs.append(output)
        self.validation_ground_truths.append(target)
        
        return loss
    
    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.validation_outputs, dim=0)
        all_targets = torch.cat(self.validation_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("val_accuracy", accuracy)

        self.validation_outputs.clear()
        self.validation_ground_truths.clear()

As you can see, I set up my model in the init function and defined how it runs in the forward function, similarly to how you would in a regular PyTorch module. I also set up my loss function and metric in the init.   In the configure_optimizers function, I set up my Adam optimizer.

After that, I defined how the training step and validation step should look.  (I did not add a test step this time because we have no test data loader, but feel free to add one). In addition to that, I also added methods for the end of epochs. These are also PyTorch Lightning methods and you don’t have to add them, but I added them because they make it easy to evaluate metrics at the end of each epoch.

Source Code

As always, I have provided the source code and a docker environment so you can try it out on your own.  (I will be updating the code so be sure you checkout the “using_lightning” tag). The readme has detailed instructions on how to run everything.

Conclusion

This concludes all the primary components of a PyTorch Lightning project. However, I have barely scratched the surface of all the functionality here, so make sure to take a look at the documentation and try out some of the other additional features.

Although it is still possible to make great projects using PyTorch alone, PyTorch Lightning provides a lot of great features that can make your deep learning code both cleaner and more consistent from project to project.  It also contains a number of other useful features that I will plan to write more about in the future.  If this article was useful, please feel free to let me know in the comments, or if you have questions I’ll do my best to answer them!

Comments

Leave a Reply

en_USEN