Using PyTorch Lightning For Image Classification
The process of using PyTorch Lightning Image Classification builds can reduce the amount of friction between research and development. There is both an art and science between creating machine learning modules that can identify and tag images appropriately.
PyTorch Lightning can be used to separate out the research from the development just a bit further, so users can focus on one and then the other.
In this guide, we will walk through understanding the role of image classification, how PyTorch Lightning can be used with images specifically, and wrap up by diving into a PyTorch Lightning example for image classification.
What Is Image Classification?
Image Classification is the task of collecting information classes, or clusters, from a multiband raster picture. In the most basic terms, the process of image classification is collecting information from an image made up of pixels, as opposed to vectors, in order to categorize and understand the image.
From there, images can be utilized to build out a map of overlapping pieces of information to create more detailed images or data. These created images can be classified into two categories: supervised and unsupervised, based on the interaction between the analyst and the computer during categorization.
Strong classification allows users to increase the output of tagging images and extrapolating usable data from images for other purposes. A lot of the manual time spent performing these mundane tasks. With properly trained models, manual efforts can be reduced to near zero!
What Is the Purpose of Image Classification?
The question now is, "Why does PyTorch Lightning image classification matter, and what purpose does PyTorch solve?" At the most basic level, it largely helps reduce the time, effort, and unnecessary friction of mundane and manual tasks in order to speed up the process of various machine learning applications as stated before.
However, taking it one step further, there are many ways to use image classification directly for machine and deep learning applications as an end result. One image classification example would be to render data taken from images, overlay them, and even be able to create complex visual models of what the various images and their data represent.
For example, to filter through thousands of images of various dogs, sort out which dogs had long, black fur, and compile a composite image of what the average dog with long, black fur looks like, then image classification is a simple way to create an end result.
This example is a slight simplification, but it is similar to what users interact with every day on social media applications that use augmented reality face filters. Image classification is the basis for more complicated and intricate image-based machine learning operations.
Can PyTorch Be Used For Image Processing?
Yes, of course! PyTorch Lightning is a great framework and is highly recommended for image classification. PyTorch is an open-source, high-level framework for machine and deep learning models rooted in both Python and the Torch library. Most importantly, PyTorch Lightning makes machine learning scalable for those who use it and want to continue to iterate on their learning models and applications.
For a more in-depth look at how to use PyTorch Lightning with examples, then we suggest checking out our blog post on how to get started with PyTorch Lightning.
PyTorch Lightning for image classification, though, can be incredibly simple and hugely beneficial for those who are interested in scaling their machine learning models and creating multiple iterations of various image classification models.
In the later half of this article, we will dive into exactly how to start using PyTorch Lightning for image classification next.
How to Use PyTorch For Image Classification
In order to start using PyTorch Lightning, you are going to need to make sure you have everything installed and set up correctly. Check out our guide to getting started with PyTorch Lightning for an easy reference on how to do this.
Next, we will need to set up a DataModule. There are many ways to go about creating a data pipeline to make the entire process easier, but a simple way to get a DataModule going is to:
- Use PyTorch Lightning’s Data Module and use the __init__ method to pass hyperparameters, then define the data pipeline.
- Prepare the data by setting up all the logic needed in order to download your dataset and images, in this instance.
- Load up the data, prepare tensor datasets, and set up splits for any reproduction you want to try for later on.
- Using PyTorch DataLoader, import all of your datasets and images used in your PyTorch Lightning image classification model.
While there is a lot of nuance in the exact coding used for these simple processes as well as the datasets and frameworks you may use in your specific instance, this is the general process used to begin setting up PyTorch Lightning to work within the confines of your machine learning model for image processing and classification.
Afterward, you can begin to use LightningModule to configure your system. This is where you will be able to define and run computations and start the training loop for your image classification machine learning model.
Once you have gone through the training, validation, and test loops you can then truly optimize your PyTorch Lightning example for any specific needs and evaluations used for your particular machine learning model.
PyTorch Lightning Example of Image Classification
Let’s break down a basic PyTorch Lightning example for image classification that utilizes a PyTorch Lightning model called Lightning Flash. To begin with, the CIFAR10DataModule subclasses from the PyTorch Lightning's LightningDataModule is used. We will pass in the hyperparameters required for our data pipeline using the __int__method.
class CIFAR10DataModule(pl.LightningDataModule): def __init__(self, batch_size, data_dir: str = './'): super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.dims = (3, 32, 32) self.num_classes = 10
Next, we can setup the logic to download data
def prepare_data(self): # download CIFAR10(self.data_dir, train=True, download=True) CIFAR10(self.data_dir, train=False, download=True)
Next, we will load data from the file and prepare PyTorch tensor datasets for each split making the data split reproducible. Every GPU-bound data operation is defined here including the transformation to the PyTorch tensor.
def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform) self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
Setting up model architecture is easy and intuitive.
class LitModel(pl.LightningModule): def __init__(self, input_shape, num_classes, learning_rate=2e-4): super().__init__() # log hyperparameters self.save_hyperparameters() self.learning_rate = learning_rate self.conv1 = nn.Conv2d(3, 32, 3, 1) self.conv2 = nn.Conv2d(32, 32, 3, 1) self.conv3 = nn.Conv2d(32, 64, 3, 1) self.conv4 = nn.Conv2d(64, 64, 3, 1) self.pool1 = torch.nn.MaxPool2d(2) self.pool2 = torch.nn.MaxPool2d(2) n_sizes = self._get_conv_output(input_shape) self.fc1 = nn.Linear(n_sizes, 512) self.fc2 = nn.Linear(512, 128) self.fc3 = nn.Linear(128, num_classes) def _get_conv_output(self, shape): batch_size = 1 input = torch.autograd.Variable(torch.rand(batch_size, *shape)) output_feat = self._forward_features(input) n_size = output_feat.data.view(batch_size, -1).size(1) return n_size def _forward_features(self, x): x = F.relu(self.conv1(x)) x = self.pool1(F.relu(self.conv2(x))) x = F.relu(self.conv3(x)) x = self.pool2(F.relu(self.conv4(x))) return x def forward(self, x): x = self._forward_features(x) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.log_softmax(self.fc3(x), dim=1) return x
Lightning automates most of the training process including the epoch and batch iterations. We need to define the training logic including the return of the calculated loss metric.
def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) # training metrics preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True) self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True) return loss
Validation and test loop functions can be defined similarly. Finally, an optimizer function needs to be defined as well.
def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer
Running this whole pipeline is straightforward now.
dm = CIFAR10DataModule(batch_size=32) dm.prepare_data() dm.setup() model = LitModel(dm.size(), dm.num_classes) trainer = pl.Trainer(max_epochs=50, progress_bar_refresh_rate=20, gpus=1,) # Train the model trainer.fit(model, dm) # Evaluate the model on the held-out test set trainer.test()
Setting up PyTorch Lightning for image classification truly is that simple. The PyTorch framework allows it to be infinitely scalable and reproducible with minimal code and datasets. From there, there are endless possibilities of what can be achieved with machine learning models for image classification and the data it extrapolates and creates.
Interested In Learning More About PyTorch Lightning?
Image classification is nothing new in the world of machine and deep learning. However, as more work is added to the pool of guides available in PyTorch Lightning for image classification specifically, it is an obvious choice for anyone building out these kinds of machine learning models.
Did we miss anything, though? We would love to make sure we cover any questions or areas of interest we may have overlooked.
Feel free to contact us for any questions or see other articles related to PyTorch, PyTorch Lightning, or a plethora of other topics over on our blog.