basics of torch image tensors: how to quickly get your forward pass working
Tensor dimensions, channels, and batching explained. Simply.
A very common difficulty when starting deep learning projects is understanding how to format your input tensors. This task feels like it should be relatively straightforward, but can be a massive barrier for beginners trying to train even a simple model on a dataset they created.
Let’s start with a simple task: Inputting a single image into an image classifier in torch. We’ll assume you read the image in with torchvision.
Tensor dimensions are confusing, but they don’t have to be
Let’s think of tensor “dimensions” as slices, since that’s how we access them. When you read in a torch tensor, it has a “shape” specified by the shape attribute. Every tensor has a shape that can be defined as a list. Each number in that list simply means the number of elements that exist in that dimension.
How are colored images structured?
Typically, a single image is be defined as a tensor of shape [3, 64, 64]
, since in torch images need to be of the format [channel, width, height]
. This means that indexing like [0, :, :]
(or, [0, …]
) gives you the entire matrix in the first channel. In this case, a 64x64 image with only one color.
How are greyscale images structured?
A greyscale image might be read in as [width, height]
since it is just a 2d matrix of numbers. However, torch expects the first dimension to be the channels. In this case, we need to artificially add a dimension of size 1 to our greyscale image, so it becomes “a list of length 1 of 2d matrices”. To do this, we use the squeeze operation along the zeroth dimension to get a tensor of shape [1, width, height]
.
squeeze() and unsqueeze()
Although the names are confusing, the operations are straightforward. The torch unsqueeze() operation adds a dimension of length 1 to your tensor (read: turns an tensor into a length-one list of tensors). For example, to turn a 2d greyscale image into a valid torch shape, we can use unsqueeze(dim=0) to go from [width, height]
to [1, width, height]
, i.e. a image with only one channel.
So to input a single image, since we need the shape [batch, channel, width, height]
, we take our [channel, width, height]
tensor and unsqueeze(dim=0) to get [1, channel, width, height]
. Your image is now ready to be input into any standard image model!
Similarly, the squeeze() operation removes all tensor dimensions of size 1 (read: removes the extra layer of “list-ness” if the list is length 1) if no tensor dimension is specified, otherwise it removed the singleton dimension at the given index. For example, if we input our single image of shape [1, channel, width, height]
into the model and get an output of [1, classes]
, we can use unsqueeze() to get a 1d list of class logits of shape [classes]
.
How are images structured for model input?
However, torchvision (and other api) models do not expect a single image as input in this format. Rather they expect a 4d tensor (remember, this just means 4 numbers in the shape attribute) of shape [batch, channel, width, height]
. The “batch dimension” just refers to how many images we want to input to the model. Unrolled, a batch size of N would look like [[channel, width, height]_1, … [channel, width, height]_N]
. So slicing the tensor like [0, …]
would give back a single image, [channel, width, height]
.
Models generally take this format instead of [channel, width, height]
since we train with mini-batches of size > 1.
This means if you want to input a single image of shape [3, 64, 64]
into your model, you must first unsqueeze the tensor along the zeroth dimension to get a tensor of size [1, 3, 64, 64]
.
TLDR: [batch, channel, width, height]
. Image models expect the first dimension of your tensor to be the “batch dimension”, i.e. the number of images in your list of images. The second dimension is the “channel dimension”, i.e slicing along each index in the channel dimension gives you the values for that particular color channel across a single image or a batch of images. The last two are height and width, which you don’t have to worry about.