PyTorch Lighting: Una Manera Fácil Para Organizar Tu Proyecto PyTorch

Como ingeniero de percepción y visión por computadora, trabajo frecuentemente con PyTorch. Es un framework bastante flexible y tiene muchos componentes utiles ya hecho adentro que no hay que hacer por su mismo. Sin embargo, si ya has hecho unos proyectos con PyTorch, ya entendés que hay partes comunes entre todos.

Por eso podemos usar PyTorch Lightning. PyTorch Lightning es un framework hecho encima de PyTorch que provee un poco de estructura a los componentes que se usa en frecuentemente y además tiene unas implementaciones de cosas que existen en la mayoridad de proyectos de aprendizaje profundo. Un ejemplo bueno es el modelo. Cada vez que vos hacés un modelo, necesitás una loss function, un optimizer y quizás un scheduler para el learning rate. También es posible que vas a necesitar unas capas medidas para tu modelo. PyTorch Lightning pone todos esos compoentes en un clase que se llama Lightning Module y tiene un lugar especifico para todos.

It also provides things like a trainer component that runs a training loop for you, allows you to set which compute to use, for example a GPU or TPU, and takes some callbacks for things like logging and creating copies of your model at various training steps.  It would be impossible to cover all the features they provide in one article, but take my word for it that using PyTorch Lightning can make your life much easier when developing deep learning models.

An Overview of A Training Script

To get started, let’s take a look at what a PyTorch Lightning training script would look like.  You might notice that I am using Hydra for my configurations.  If you are unfamiliar with it, don’t worry, I will cover it in a future post.

Python
import hydra
from lightning.pytorch import Trainer

from fasion_mnist_dm import FashionMNISTDataModule
from resnet_lm import ResnetLightningModule


@hydra.main(version_base=None, config_path="config", config_name="config")
def train(config):

    datamodule = FashionMNISTDataModule(config.datamodule)

    model = ResnetLightningModule(config.model)

    trainer = Trainer()
    trainer.fit(model, datamodule)


if __name__ == "__main__":
    train()

As you can see, there aren’t many components here.  It just contains the following components:

  1. Datamodule: Contains the logic for setting up datasets as well as train, validation and test data loaders.  Additionally, if you wanted to do things like split a dataset between train validation, you could also do that here.
  2. Model: In PyTorch Lightning it is called a LightningModule, but in addition to containing the code for the model itself, it also sets up the optimizer, loss function and metrics, and it provides functions before and after training steps and epochs that allow you to define behavior inside the training loop.
  3. Trainer:  This is the one component that you don’t need to write yourself.   It takes a model and datamodule as input and it runs the model in a training loop for you using the train and validation data loaders from the datamodule. Additionally, it can also run the trained model on the test data loader.

Setting Up A DataModule

As an example, I’ll make a datamodule using the FashionMNIST dataset from PyTorch.  The important point here is that you can use whatever dataset you want, including a custom dataset.

Python
import lightning as L

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


class FashionMNISTDataModule(L.LightningDataModule):

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.prepare_data_per_node = True

    def prepare_data(self):
        # download
        FashionMNIST("data", train=True, download=True)
        FashionMNIST("data", train=False, download=True)

    def setup(self, stage):

        self.train_dataset = FashionMNIST("data", train=True, transform=ToTensor())

        self.val_dataset = FashionMNIST("data", train=False, transform=ToTensor())

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

    def val_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

Now that we have an idea of what it does, let’s take a look at the internals of a datamodule.  It’s pretty basic.  Other than the standard init function, there’s a setup function and a function for each data loader: train, val and test.  In my example, I have left out test to keep things simple, but in general it’s better to have a test dataset than not.

Setting Up a LightningModule (Model)

Now that we have a data loader, I will set up a LightningModule.   Since this is just a simple classification model, I’ll use some of the standard stuff: a Resnet18, Adam optimizer, a cross entropy loss function and accuracy for a metric.

Python
import torch
import torch.nn as nn
import lightning as L
from torchvision.models import resnet18

from torchmetrics import Accuracy


class ResnetLightningModule(L.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.model = resnet18(weights='IMAGENET1K_V1')

        # Change the first layer to use only 1 channel for FashionMNIST
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # Change the fully connected layer to output 10 features, one for each MNIST class
        self.model.fc = nn.Linear(in_features=512, out_features=config.num_classes, bias=True)

        self.loss_function = nn.CrossEntropyLoss()

        self.metric = Accuracy(task="multiclass", num_classes=config.num_classes)

        self.training_outputs = []
        self.validation_outputs = []

        self.training_ground_truths = []
        self.validation_ground_truths = []
        

    def forward(self, inputs):
        return self.model(inputs)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("train_loss", loss)

        self.training_outputs.append(output)
        self.training_ground_truths.append(target)
        
        return loss
    
    def on_train_epoch_end(self):
        all_preds = torch.cat(self.training_outputs, dim=0)
        all_targets = torch.cat(self.training_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("train_accuracy", accuracy)

        self.training_outputs.clear()

    
    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("val_loss", loss)

        self.validation_outputs.append(output)
        self.validation_ground_truths.append(target)
        
        return loss
    
    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.validation_outputs, dim=0)
        all_targets = torch.cat(self.validation_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("val_accuracy", accuracy)

        self.validation_outputs.clear()
        self.validation_ground_truths.clear()

As you can see, I set up my model in the init function and defined how it runs in the forward function, similarly to how you would in a regular PyTorch module. I also set up my loss function and metric in the init.   In the configure_optimizers function, I set up my Adam optimizer.

After that, I defined how the training step and validation step should look.  (I did not add a test step this time because we have no test data loader, but feel free to add one). In addition to that, I also added methods for the end of epochs. These are also PyTorch Lightning methods and you don’t have to add them, but I added them because they make it easy to evaluate metrics at the end of each epoch.

Source Code

As always, I have provided the source code and a docker environment so you can try it out on your own.  The readme has detailed instructions on how to run everything.

Conclusión

This concludes all the primary components of a PyTorch Lightning project. However, I have barely scratched the surface of all the functionality here, so make sure to take a look at the documentation and try out some of the other additional features.

Although it is still possible to make great projects using PyTorch alone, PyTorch Lightning provides a lot of great features that can make your deep learning code both cleaner and more consistent from project to project.  It also contains a number of other useful features that I will plan to write more about in the future.  If this article was useful, please feel free to let me know in the comments, or if you have questions I’ll do my best to answer them!

Comentarios

Deja un comentario

es_ARES