From segmenting molecules on biomedical images to creating green screen videos to self-driving cars, UNet Architecture has a wide range of applications.
UNets are mainly used in segmentation, such as in Self Driving cars, the model takes the camera footage and segments them into classes like Road, Lane, Traffic lights, person, etc. In Cinematography, UNets can be used to create a mask around an object or a person that can then be edited to fine-tune. In the Biomedical industry, UNets can be trained to take in images from a microscope and segment different parts it sees. This could help scientists to uncover things that could not be seen with the naked eye. It can also be used in Satellite Imagery or in military drones. So, the possibilities are endless.
In this post, we will develop a model using UNet architecture to mask out the birds from a given image. First, let's talk about the dataset we will be using.
Dataset
We will be using the Caltech-UCSD Birds-200-2011 dataset released by Caltech, it contains around 11,700 images of birds belonging to 200 species along with their masked images.
You can download the image dataset and the segmentations from this link.
In the above screenshot, you can see the image of the bird and its mask pair.
Before starting building the model, let's implement the data pipeline from which we get appropriate input for our model.
Building Data Pipeline
We will be using Pytorch and Torchvision for the purposes of this blog.
Let's look at some metadata we got and deduced by looking at the downloaded dataset folders.
-
The Bird images and their corresponding mask has the same name but the image is in jpg format and the mask is in png format.
-
Each species is separated into its respective folders, but Caltech researchers were kind enough to provide a text file that contains a list of paths to all the images.
-
There are other metadata such as bounding boxes and classes but we do not need those data to build this model.
-
The Bird images (1.1 GB) and their segmentations (37 MB) comes in two separate archive files, so I extracted them and place them in a single folder.
This is how the final folder structure looks like.
We only care about the images folder, segmentations folder and images.txt file.
First, let's create a file called dataset.py and import the necessary libraries.
from torch.utils.data import Dataset
from PIL import Image
import os
Then create a class "BirdDataset" which extends the Pytorch's Dataset class.
In its init method, we will bring in image_paths which is the path to images.txt file that contains all the path to bird images, image_dir and segmentation_dir we need this because images.txt file does not have paths with root folders, so we need to join them accordingly.
Finally, we will need the image and mask transforms to apply the resize and normalise transformations.
class BirdDataset(Dataset):
def __getitem__(self, index):
image_name = ".".join(self.images_paths[index].split('.')[:-1])
image = Image.open(os.path.join(self.image_dir, f"{image_name}.jpg")).convert("RGB")
seg = Image.open(os.path.join(self.segmentation_dir, f"{image_name}.png")).convert("L")
image = self.transform_image(image)
seg = self.transform_mask(seg)
return image, seg
def __init__(self, image_paths, image_dir, segmentation_dir, transform_image, transform_mask):
super(BirdDataset, self).__init__()
self.image_dir = image_dir
self.segmentation_dir = segmentation_dir
self.transform_image = transform_image
self.transform_mask = transform_mask
with open(image_paths, 'r') as f:
self.images_paths = [line.split(" ")[-1] for line in f.readlines()]
def __len__(self):
return len(self.images_paths)
As you can see in the init method, we are reading the paths, line by line and create a list out of them. For the len method, we simply return the length of image paths and in the get item method, we first extracted the name of the file so that we can append the extension as needed (jpg for images and png for masks) then opened them using PIL, applied transformations and return those as a tuple.
Next, we need to create a function that will take this dataset and split them into training and validation set and supply them in batches.
# utils.py file
from torch.utils.data import DataLoader
from dataset import BirdDataset
import torch
def load_data_set(image_paths, image_dir, segmentation_dir, transforms, batch_size=8, shuffle=True):
dataset = BirdDataset(image_paths,
image_dir,
segmentation_dir,
transform_image=transforms[0],
transform_mask=transforms[1])
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [11772, 16])
return DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=shuffle
), DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=shuffle
)
Pytorch's DataLoader class helps us to create batches and randomly shuffle the data.
Building UNet Architecture
TL;DR Scroll down to the section that says "Complete Code" to copy the entire model architecture.
Finally, we can proceed to build the UNet architecture but before that, let's write a quick test that will ensure our future model returns the output with expected dimensions.
def test():
image = torch.randn((32, 3, 161, 161))
model = UNet(in_channels=3)
out = model(image)
print(image.shape, out.shape)
assert out.shape == (32, 1, 161, 161)
We will be building a model that takes an RGB image and returns a Black and White mask image of the same height and width, exactly like it is in the dataset.
We will follow this digram of UNet architecture.
You can see why this network got this name, it is kind of shaped like the letter "U".
If you look closely, this network can be split in 3 parts, Down layes (Blue), Up Layers (Red) and a Bottleneck (Green).
We will deal with them one by one, but before that, you can notice that each layer has one thing in common, there is an input layer that is followed by two same convolution layers, which means that the convolution does not change the height and width of the image only the number of channel is changed. So, we will implement that double convolution module.
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernal_size, strides, padding):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernal_size, strides, padding, bias=False),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernal_size, strides, padding),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
This module is built in a general pattern so that I can copy this module while building some other architecture. As you can see, we have a normal Convolution layer with bias disabled because different from the original paper we will be adding a BatchNorm layer so this will cancel out the bias anyway.
After that, we added the second conv layer with ReLU activation and batchnorm.
Next, we create the class UNet that will contain our crazy-looking Network. in the init method we take in the in_channels which will be 3 in our case. And then we need the number of segmentation (which is 1) and a list of features in each layer of ups and downs.
class UNet(nn.Module):
def __init__(self, in_channels, num_segmentations=1, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
Next, we declare the module list that will hold the down layers and up layers, ModuleList can be indexed like a regular Python list, but modules it contains are properly registered and will be visible by all Module methods. (Source: https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html)
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
Let's also define other layers, bottleneck, max pool and output layer before we start populating the ups and downs list.
self.bottleneck = DoubleConv(
in_channels=features[-1],
out_channels=features[-1]*2,
kernal_size=3,
strides=1,
padding=1
)
self.output = nn.Conv2d(
in_channels=features[0],
out_channels=num_segmentations,
kernel_size=1
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
By looking at the diagram, we can deduce a general pattern that we can implement in a loop.
For the downward layers, you can just loop through the features and append our DoubleConv module because the first layer will come from the input image or from the previous max pool layer.
in_channels_iter = in_channels
for feature in features:
self.downs.append(DoubleConv(
in_channels=in_channels_iter,
out_channels=feature,
kernal_size=3,
strides=1,
padding=1
))
in_channels_iter = feature
As for the up layers, we will be using a Transpose Convolution layer to upsample the tensors and after that, we will add the DoubleConv layer. But in this case, we have to loop through the feature in reverse order.
for feature in reversed(features):
up = nn.Sequential(
nn.ConvTranspose2d(
in_channels=feature*2,
out_channels=feature,
kernel_size=2,
stride=2,
padding=0
),
DoubleConv(
in_channels=feature*2,
out_channels=feature,
kernal_size=3,
padding=1,
strides=1
)
)
self.ups.append(up)
Notice that if you run this sequential layer it will not run and raises an error that input_channels does not match the expected input_channels. This is because we need to do an intermediate step in the forward method which is represented by those grey arrows in the diagram.
Now let's write the forward method.
First, we send the image down the UNet through Downward layers and we saved the output of DoubleConv before applying the max pool.
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
After that, we send in the tensor through the bottleneck layer and prepare the skip_connections to be concatenated with the up layers by reversing it.
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
As discussed earlier, we need to do some additional operations to the output of ConvTranspose before we passed it through the DoubleConv layers.
for i in range(len(self.ups)):
x = self.ups[i][0](x) # Pass through ConvTranspose first
skip_connection = skip_connections[i]
# If the height and width of output tensor and skip connection
# is not same then resize the tensor
if x.shape != skip_connection.shape:
# TF => import torchvision.transforms.functional as TF
x = TF.resize(x, size=skip_connection.shape[2:])
# Concat the output tensor with skip connection
concat_x = torch.cat((skip_connection, x), dim=1)
# Pass the concatinated tensor through DoubleCOnv
x = self.ups[i][1](concat_x)
Then why we used the Sequential layer, we could have just appended them separately. Yes, we could, the reason I did it that way because I didn't want to deal with skipping every other layer in the for a loop. I think this is much cleaner.
And finally, we just pass it through the output layer.
return self.output(x)
Complete Code (UNet Architecture)
# unet.py file
from torch import nn
import torch
import torchvision.transforms.functional as TF
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernal_size, strides, padding):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernal_size, strides, padding, bias=False),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernal_size, strides, padding, bias=False),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels, num_segmentations=1, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.bottleneck = DoubleConv(
in_channels=features[-1],
out_channels=features[-1]*2,
kernal_size=3,
strides=1,
padding=1
)
self.output = nn.Conv2d(
in_channels=features[0],
out_channels=num_segmentations,
kernel_size=1
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
in_channels_iter = in_channels
for feature in features:
self.downs.append(DoubleConv(
in_channels=in_channels_iter,
out_channels=feature,
kernal_size=3,
strides=1,
padding=1
))
in_channels_iter = feature
for feature in reversed(features):
up = nn.Sequential(
nn.ConvTranspose2d(
in_channels=feature*2,
out_channels=feature,
kernel_size=2,
stride=2,
padding=0
),
DoubleConv(
in_channels=feature*2,
out_channels=feature,
kernal_size=3,
padding=1,
strides=1
)
)
self.ups.append(up)
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for i in range(len(self.ups)):
x = self.ups[i][0](x) # Pass through ConvTranspose first
skip_connection = skip_connections[i]
# If the height and width of output tensor and skip connection
# is not same then resize the tensor
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
# Concat the output tensor with skip connection
concat_x = torch.cat((skip_connection, x), dim=1)
# Pass the concatinated tensor through DoubleCOnv
x = self.ups[i][1](concat_x)
return self.output(x)
def test():
image = torch.randn((32, 3, 161, 161))
model = UNet(in_channels=3)
out = model(image)
print(image.shape, out.shape)
assert out.shape == (32, 1, 161, 161)
if __name__ == "__main__":
test()
Training Step
Create a file called train.py and import the necessary libraries.
import torch
from unet import UNet
from utils import load_data_set
from torchvision.transforms import transforms
from tqdm import tqdm
import torchvision
Next, we'll define the Hyperparameters and configs
config = {
"lr": 1e-3,
"batch_size": 16,
"image_dir": "CUB_200_2011/CUB_200_2011/images",
"segmentation_dir": "CUB_200_2011/CUB_200_2011/segmentations",
"image_paths": "CUB_200_2011/CUB_200_2011/images.txt",
"epochs": 10,
"checkpoint": "checkpoint/bird_segmentation_v1.pth",
"optimiser": "checkpoint/bird_segmentation_v1_optim.pth",
"continue_train": False,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
I'm going with a learning rate of 0.001 and 10 epochs, trained on an RTX 2080 Super 8GB VRAM.
transforms_image = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])
transforms_mask = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.,), (1.,))
])
I am also resizing all images to 256px by 256px and training with a batch size of 16.
Next, we will load the data using a helper function I wrote earlier called load_data_set.
train_dataset, val_dataset = load_data_set(
config['image_paths'],
config['image_dir'],
config['segmentation_dir'],
transforms=[transforms_image, transforms_mask],
batch_size=config['batch_size']
)
print("loaded", len(train_dataset), "batches")
Now, let's define the model object and Adam optimiser.
model = UNet(3).to(config['device'])
optimiser = torch.optim.Adam(params=model.parameters(), lr=config['lr'])
If you want to import a pre-trained model then for that the below snippet of code.
if config['continue_train']:
state_dict = torch.load(config['checkpoint'])
optimiser_state = torch.load(config['optimiser'])
model.load_state_dict(state_dict)
optimiser.load_state_dict(optimiser_state)
We will be using the Binary Crossentropy loss because we are dealing with binary segmentation here. And also, we will be using Pytorch's Automated Mixed Precision library to automatically set the precision of the gradients. This will reduce the VRAM consumed and also sped up the learning process.
We need to use BCE with Logits instead of normal BCE error because we are not using a sigmoid in the model architecture. BCE with Logits will pass the tensor through a sigmoid function before calculating the loss.
In case you want to segment more than one items like in a photo you want to segment people, cars, trees, sky, etc. then you can simply change the num_segmentations parameter in the UNet class and change the loss function to Cross-Entropy Loss torch.nn.CrossEntropyLoss.
loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()
model.train()
The train function is pretty straight forward, we create two loops on for epochs and the other for the batches. We use the autocast wrapper to automatically cast the gradients to float16 or float32 as required and we just update the weights using Pytorch abstraction.
def train():
for epoch in range(config['epochs']):
loop = tqdm(train_dataset)
for image, seg in loop:
image = image.to(config['device'])
seg = seg.float().to(config['device'])
with torch.cuda.amp.autocast():
pred = model(image)
loss = loss_fn(pred, seg)
optimiser.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimiser)
scaler.update()
loop.set_postfix(loss=loss.item())
check_accuracy_and_save(model, optimiser, epoch)
check_accuracy_and_save is a helper function that will check the accuracy of the model against the validation set after every epoch and saves the states of the model and optimiser.
def check_accuracy_and_save(model, optimiser, epoch):
torch.save(model.state_dict(), config['checkpoint'])
torch.save(optimiser.state_dict(), config['optimiser'])
dice_score = 0
model.eval()
with torch.no_grad():
for x, y in val_dataset:
x = x.to(config['device'])
y = y.to(config['device'])
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
dice_score += (2 * (preds * y).sum()) / (
(preds + y).sum() + 1e-8
)
torchvision.utils.save_image(preds, f"test/pred/{epoch}.png")
torchvision.utils.save_image(y, f"test/true/{epoch}.png")
print(
f"Dice Score = {dice_score/len(val_dataset)}"
)
model.train()
For a measure of accuracy we will be using Dice Score also know as F1 Score because if we use normal pixel accuracy since 80% of the mask image is black, the model can get an accuracy of 80% just by generating a black screen every time.
The Dice score is not only a measure of how many positives you find, but it also penalizes for the false positives that the method finds, similar to precision. so it is more similar to precision than accuracy. The only difference is the denominator, where you have the total number of positives instead of only the positives that the method finds. So the Dice score is also penalizing for the positives that your algorithm/method could not find.
Source: https://stats.stackexchange.com/a/195037
Complete Code (Training Step)
# train.py file
import torch
from unet import UNet
from utils import load_data_set
from torchvision.transforms import transforms
from tqdm import tqdm
import torchvision
config = {
"lr": 1e-3,
"batch_size": 16,
"image_dir": "CUB_200_2011/CUB_200_2011/images",
"segmentation_dir": "CUB_200_2011/CUB_200_2011/segmentations",
"image_paths": "CUB_200_2011/CUB_200_2011/images.txt",
"epochs": 10,
"checkpoint": "checkpoint/bird_segmentation_v1.pth",
"optimiser": "checkpoint/bird_segmentation_v1_optim.pth",
"continue_train": False,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
print(f"Training using {config['device']}")
transforms_image = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])
transforms_mask = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.,), (1.,))
])
train_dataset, val_dataset = load_data_set(
config['image_paths'],
config['image_dir'],
config['segmentation_dir'],
transforms=[transforms_image, transforms_mask],
batch_size=config['batch_size']
)
print("loaded", len(train_dataset), "batches")
model = UNet(3).to(config['device'])
optimiser = torch.optim.Adam(params=model.parameters(), lr=config['lr'])
if config['continue_train']:
state_dict = torch.load(config['checkpoint'])
optimiser_state = torch.load(config['optimiser'])
model.load_state_dict(state_dict)
optimiser.load_state_dict(optimiser_state)
loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()
model.train()
def check_accuracy_and_save(model, optimiser, epoch):
torch.save(model.state_dict(), config['checkpoint'])
torch.save(optimiser.state_dict(), config['optimiser'])
num_correct = 0
num_pixel = 0
dice_score = 0
model.eval()
with torch.no_grad():
for x, y in val_dataset:
x = x.to(config['device'])
y = y.to(config['device'])
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
num_pixel += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / (
(preds + y).sum() + 1e-8
)
torchvision.utils.save_image(preds, f"test/pred/{epoch}.png")
torchvision.utils.save_image(y, f"test/true/{epoch}.png")
print(
f"Dice Score = {dice_score/len(val_dataset)}"
)
model.train()
def train():
step = 0
for epoch in range(config['epochs']):
loop = tqdm(train_dataset)
for image, seg in loop:
image = image.to(config['device'])
seg = seg.float().to(config['device'])
with torch.cuda.amp.autocast():
pred = model(image)
loss = loss_fn(pred, seg)
optimiser.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimiser)
scaler.update()
loop.set_postfix(loss=loss.item())
step += 1
check_accuracy_and_save(model, optimiser, epoch)
if __name__ == "__main__":
train()
After training for just two epoch the model was able to produce pretty promising results. Predicted Masks
Ground Truth
This post is inspired by this video.
With this, we have implemented the UNet architecture to generate masks for bird images. You can find the complete code in this Github repo.
Leave a comment if you have any concerns or query, I will get back to you as soon as possible.