Understanding Squeeze and Unsqueeze in PyTorch


When working with tensors in PyTorch, you may encounter situations where you need to modify the shape of your tensors by removing or adding dimensions. Two useful functions for these tasks are squeeze() and unsqueeze(). In this blog post, we will discuss these functions, their use cases, and provide examples to help you understand how to use them effectively.

Squeeze

The squeeze() function in PyTorch is used to remove dimensions of size 1 from a tensor. This can be helpful when you want to remove unnecessary dimensions from your tensor, making it more compact and easier to work with.

Here’s an example of how to use squeeze() in PyTorch:

  1. Import the necessary libraries:
import torch
  1. Create a tensor with dimensions of size 1:
# Create a tensor of size (2, 1, 3, 1)
tensor = torch.tensor([[[[1], [2], [3]]], [[[4], [5], [6]]]])
print("Original tensor:")
print(tensor)
print("Shape:", tensor.shape)

Output:

Original tensor:
tensor([[[[1],
[2],
[3]]],


[[[4],
[5],
[6]]]])
Shape: torch.Size([2, 1, 3, 1])
  1. Use squeeze() to remove dimensions of size 1:
# Remove dimensions of size 1
squeezed_tensor = tensor.squeeze()
print("Squeezed tensor:")
print(squeezed_tensor)
print("Shape:", squeezed_tensor.shape)

Output:

Squeezed tensor:
tensor([[1, 2, 3],
[4, 5, 6]])
Shape: torch.Size([2, 3])

As you can see, the dimensions of size 1 have been removed, and the resulting tensor has a shape of (2, 3).

Unsqueeze

The unsqueeze() function in PyTorch is used to add a dimension of size 1 at a specified position in a tensor. This can be helpful when you want to add a singleton dimension to match the shape of another tensor or to perform certain operations that require specific tensor shapes.

Here’s an example of how to use unsqueeze() in PyTorch:

  1. Import the necessary libraries:
import torch
  1. Create a tensor without dimensions of size 1:
# Create a tensor of size (2, 3)
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("Original tensor:")
print(tensor)
print("Shape:", tensor.shape)

Output:

Original tensor:
tensor([[1, 2, 3],
[4, 5, 6]])
Shape: torch.Size([2, 3])
  1. Use unsqueeze() to add dimensions of size 1:
# Add a dimension of size 1 at the end (dim=-1) and then at position 1 (dim=1)
unsqueezed_tensor = tensor.unsqueeze(-1).unsqueeze(1)
print("Unsqueezed tensor:")
print(unsqueezed_tensor)
print("Shape:", unsqueezed_tensor.shape)

Output:

Unsqueezed tensor:
tensor([[[[1],
[2],
[3]]],


[[[4],
[5],
[6]]]])
Shape: torch.Size([2, 1, 3, 1])

As you can see, the original tensor had a shape of (2, 3). After applying unsqueeze(-1).unsqueeze(1), the resulting tensor has a shape of (2, 1, 3, 1). A dimension of size 1 was added at the end (dim=-1) and then at position 1 (dim=1).

Conclusion

In this blog post, we discussed the squeeze() and unsqueeze() functions in PyTorch and provided examples to demonstrate their usage. These functions are essential when working with tensors, as they allow you to manipulate the shape of your tensors to match the requirements of specific operations or to make them compatible with other tensors. By understanding how to use these functions effectively, you can improve the efficiency and readability of your PyTorch code.


Author: robot learner
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source robot learner !
  TOC