An Overview of PyTorch Lightning with a Simple Code Walkthrough
As you continue your research into the Python coding language and explore the various frameworks it provides you may come across PyTorch Lightning. You may even ask yourself, “Why use PyTorch Lightning when there are so many other frameworks out there, especially if I am already using PyTorch?”
It is a fair question, and we will help you understand why PyTorch Lightning could be effective for your AI project and how you can get started using it. We will go over 3 key points:
- Understanding whether or not PyTorch Lightning is better than anything else PyTorch has to offer
- Clarify how the switch from PyTorch to PyTorch Lightning will help you succeed in your workload
- How to use and get started with PyTorch Lightning so you can make your next AI project amazing
What is the difference between PyTorch and PyTorch Lightning?
PyTorch is an open-source framework for machine learning. It is based on the Torch library used for AI models like computer vision and natural language processing. PyTorch offers flexibility and a low-level interface making it suitable for custom models and granular control.
PyTorch Lightning is a PyTorch-based high-level Python framework that aims to simplify the training and deployment of models by providing a lightweight and standardized interface. It was built and designed with academics in mind so they could experiment with novel deep learning and machine learning models by abstracting away the boilerplate code and repetitive tasks and encouraging a more structured and organized approach to development.
With this in mind, an academic interested in heavy amounts of research and AI development would benefit from using PyTorch Lightning with its cleaner and modular formatting, reproducibility, and scalability. It gives you access to the ability to experiment and further your research far more quickly than many other PyTorch frameworks. There are use cases for non-academic pursuits too!
Is PyTorch Lightning Better Than PyTorch?
If your project values code organization, reproducibility for experimentation, and a high degree of scalability, PyTorch Lighting can significantly simplify the development process. If your project values flexibility and fine-grained control, you may want to stick with PyTorch. There is no definitive answer and your deployment, training, and preferences will help you pick what's best for your model.
Switching from PyTorch to PyTorch Lightning
Switching from PyTorch to PyTorch Lightning can feel tricky. However, the core of what PyTorch Lightning does is simply streamline and clean up the coding process.
It might feel like magic, but PyTorch Lightning simply manipulates the boilerplate of your PyTorch code so that your code is structured rather than arbitrary. Most importantly, it can do this for every loop of the machine learning model training process.
All the code that will go unchanged throughout the AI model training process is reorganized and abstracted so the code looks cleaner, becomes easier to read, and easier to track down flaws or errors. It also speeds up the process for others to iterate from one AI model to another.
Switching from PyTorch to PyTorch Lightning is simply a matter of getting used to seeing boilerplate code structured into simplified lines throughout your code. If you get all that boilerplate done correctly the first time, then the rest of the process should go much smoother.
One other benefit of switching from PyTorch to PyTorch Lightning is that it comes with a suite of free features such as progress bars and checkpointing. Why use PyTorch Lightning? For simplicity!
How to Use PyTorch Lightning
Figuring out how to use PyTorch Lightning is simple with only a few steps that will save you a lot more down the road. PyTorch Lightning, like many other Python projects, installs with pip (the package installer for Python).
For this, we recommend choosing a favorite virtual environment manager to handle installs and dependencies without clogging up your main Python installation. Once installed, running PyTorch Lightning is fairly straightforward.
Using PyTorch Lightning is similar to using raw PyTorch. The main difference, as we have mentioned, is the altering of boilerplate code becomes unnecessary. Other than that, all you have to do is inherit the LightningModule instead of the nn.module. PyTorch Lightning handles all of the critical components of deep learning network modeling.
Another important piece of information to consider when using PyTorch Lightning is that it is hardware agnostic. Based on what kind of model you are building and the research you are performing, you can choose whether to run your AI models off of CPU or GPU or anything else you can make work. This makes using PyTorch one of the most flexible options for creating reproducible AI models.
If you are wondering how to use PyTorch, then the good news is that it is as simple as using raw PyTorch. If you are still wondering why to use PyTorch Lightning instead of raw PyTorch, or any other framework for that matter, then the primary reason is for speed, efficiency, and reproducibility.
PyTorch Lightning Example Walkthrough
Let's first install Lightning:
pip install pytorch-lightning
PyTorch should also be installed in your system/environment as well. As mentioned above, the key to organizing code with Lightning is to use the class LightningModule. You have to define a class that inherits from this class and build on that.
The initialization may look like the following:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
class LitMNIST(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, height, width)
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)
Here we are building three layers of the neural net:
- one input layer which is ready to accept a 28x 28 dimension vector (e.g., an MNIST image)
- one hidden layer with 256 neurons
- one output layer with 10 classes
Define the forward propagation method, just like in PyTorch:
def forward(self, x):
batch_size, channels, height, width = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)
x = F.log_softmax(x, dim=1)
return x
Now we define the training step:
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
And, lastly, the optimizer for training. Here, there is a clear difference from PyTorch. In Lightning, you use the configure_optimizer method to define the optimizer. For example, to introduce the famous Adam optimizer:
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
How to handle and load data for training? It is always advisable to use the DataLoader class. Here is the code to get the data:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
# prepare transforms standard to MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)
Note that here mnist_train is an instance of the DataLoader class. Finally, you can pass this object to the fit process of Lightning like this to start training!
model = LitMNIST()
trainer = Trainer()
trainer.fit(model, mnist_train)
For more details and examples, please see the official documentation here.
Looking For More Information On PyTorch and PyTorch Lightning?
As you can see, we are big fans of PyTorch Lightning, especially for those who are interested in experimentation and research to see what all machine learning and deep learning models can do. For those interested in research or purely academic pursuits you might find value in switching from PyTorch to PyTorch Lightning.
Interested in more PyTorch Lightning tutorials? Check out the PyTorch Lightning website for more great walkthroughs. You can also see how the growing community is using it here.
Feel free to contact us if you have any questions or take a look at our Deep Learning Solutions if you're interested in a workstation or server to run PyTorch/PyTorch Lightning on!