PyTorch Lighting: Una Manera Fácil Para Organizar Tu Proyecto PyTorch

As an engineer in perception and computer vision, I tend to work with PyTorch a lot.  It’s a very flexible framework and there are already many useful components built into it that prevent you from having to build it yourself.  However, if you have already built a few projects with it, you might have noticed there are always some parts in common in each project.

Por eso podemos usar PyTorch Lightning. PyTorch Lightning es un framework hecho encima de PyTorch que provee un poco de estructura a los componentes que se usa en frecuentemente y además tiene unas implementaciones de cosas que existen en la mayoridad de proyectos de aprendizaje profundo. Un ejemplo bueno es el modelo. Cada vez que vos hacés un modelo, necesitás una loss function, un optimizer y quizás un scheduler para el learning rate. También es posible que vas a necesitar unas capas a medida para tu modelo. PyTorch Lightning pone todos esos compoentes en un clase que se llama Lightning Module y tiene un lugar especifico para todos.

It also provides things like a trainer component that runs a training loop for you, allows you to set which compute to use, for example a GPU or TPU, and takes some callbacks for things like logging and creating copies of your model at various training steps.  It would be impossible to cover all the features they provide in one article, but take my word for it that using PyTorch Lightning can make your life much easier when developing deep learning models.

Un Ejemplo de Un Script de Entrenamiento

To get started, let’s take a look at what a PyTorch Lightning training script would look like.  You might notice that I am using Hydra for my configurations.  If you are unfamiliar with it, don’t worry, I will cover it in a future post.

Python
import hydra
from lightning.pytorch import Trainer

from fasion_mnist_dm import FashionMNISTDataModule
from resnet_lm import ResnetLightningModule


@hydra.main(version_base=None, config_path="config", config_name="config")
def train(config):

    datamodule = FashionMNISTDataModule(config.datamodule)

    model = ResnetLightningModule(config.model)

    trainer = Trainer()
    trainer.fit(model, datamodule)


if __name__ == "__main__":
    train()

Como ves, los componentes son pocos. Solo contiene estos componentes:

  1. Datamodule: Contiene la lógica para prepara los datasets y también los dataloaders para entrenamiento, validación y testing. Además, si querés separar un dataset entre entrenamiento y validación, lo podés hacer en el Datamodule.
  2. Model: En PyTorch Lightning se llama LightningModule, pero además de contener el código para el modelo, también prepara el optimizer, la función de loss y los metrics, y provee funciones para definir el funcionamiento del loop de entrenamiento.
  3. Trainer: Esto es un componente que no tenés que hacer por tu mismo. Recibe un modelo y un datamodule como entrada y ejecuta el modelo en el loop de entrenamiento usando los dataloaders de entrenamiento y validación en el datamodule. Además, tiene una funciona para ejecutar el modelo entrenado con el dataloader test.

Preparando Un Datamodule

Como ejemplo, hacemos un datamodule usando el FashionMNIST dataset from PyTorch.  The important point here is that you can use whatever dataset you want, including a custom dataset.

Python
import lightning as L

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


class FashionMNISTDataModule(L.LightningDataModule):

    def __init__(self, config):
        super().__init__()

        self.config = config

        self.prepare_data_per_node = True

    def prepare_data(self):
        # download
        FashionMNIST("data", train=True, download=True)
        FashionMNIST("data", train=False, download=True)

    def setup(self, stage):

        self.train_dataset = FashionMNIST("data", train=True, transform=ToTensor())

        self.val_dataset = FashionMNIST("data", train=False, transform=ToTensor())

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

    def val_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

Now that we have an idea of what it does, let’s take a look at the internals of a datamodule.  It’s pretty basic.  Other than the standard init function, there’s a setup function and a function for each data loader: train, val and test.  In my example, I have left out test to keep things simple, but in general it’s better to have a test dataset than not.

Preparando un LightningModule (Model)

Now that we have a data loader, I will set up a LightningModule.   Since this is just a simple classification model, I’ll use some of the standard stuff: a Resnet18, Adam optimizer, a cross entropy loss function and accuracy for a metric.

Python
import torch
import torch.nn as nn
import lightning as L
from torchvision.models import resnet18

from torchmetrics import Accuracy


class ResnetLightningModule(L.LightningModule):

    def __init__(self, config):
        super().__init__()

        self.model = resnet18(weights='IMAGENET1K_V1')

        # Change the first layer to use only 1 channel for FashionMNIST
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

        # Change the fully connected layer to output 10 features, one for each MNIST class
        self.model.fc = nn.Linear(in_features=512, out_features=config.num_classes, bias=True)

        self.loss_function = nn.CrossEntropyLoss()

        self.metric = Accuracy(task="multiclass", num_classes=config.num_classes)

        self.training_outputs = []
        self.validation_outputs = []

        self.training_ground_truths = []
        self.validation_ground_truths = []
        

    def forward(self, inputs):
        return self.model(inputs)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("train_loss", loss)

        self.training_outputs.append(output)
        self.training_ground_truths.append(target)
        
        return loss
    
    def on_train_epoch_end(self):
        all_preds = torch.cat(self.training_outputs, dim=0)
        all_targets = torch.cat(self.training_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("train_accuracy", accuracy)

        self.training_outputs.clear()

    
    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self.model(inputs)
        loss = self.loss_function(output, target)

        self.log("val_loss", loss)

        self.validation_outputs.append(output)
        self.validation_ground_truths.append(target)
        
        return loss
    
    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.validation_outputs, dim=0)
        all_targets = torch.cat(self.validation_ground_truths, dim=0)

        accuracy = self.metric(all_preds, all_targets)

        self.log("val_accuracy", accuracy)

        self.validation_outputs.clear()
        self.validation_ground_truths.clear()

As you can see, I set up my model in the init function and defined how it runs in the forward function, similarly to how you would in a regular PyTorch module. I also set up my loss function and metric in the init.   In the configure_optimizers function, I set up my Adam optimizer.

After that, I defined how the training step and validation step should look.  (I did not add a test step this time because we have no test data loader, but feel free to add one). In addition to that, I also added methods for the end of epochs. These are also PyTorch Lightning methods and you don’t have to add them, but I added them because they make it easy to evaluate metrics at the end of each epoch.

Source Code

Como siempre, he proveído el source code y el ambiente docker so you can try it out on your own.  (I will be updating the code so be sure you checkout the “using_lightning” tag). The readme has detailed instructions on how to run everything.

Conclusión

Aquí concluimos una descripción de los componentes primarios de un proyecto de PyTorch Lightning. Sin embargo, todavía hay muchos partes de la funcionalidad de que no he hablado en este articulo. Es mejor que ves la documentación y probas las otras partes de la librería.

Although it is still possible to make great projects using PyTorch alone, PyTorch Lightning provides a lot of great features that can make your deep learning code both cleaner and more consistent from project to project.  It also contains a number of other useful features that I will plan to write more about in the future.  If this article was useful, please feel free to let me know in the comments, or if you have questions I’ll do my best to answer them!

Comentarios

Deja un comentario

es_ARES