[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입
본문 바로가기
AI 구현/Pytorch

[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입

by NEWSUN* 2024. 9. 3.

크기가 1인 차원을 삭제하는 squeeze 함수와 삽입하는 unsqueeze 함수에 대해 알아봅시다.

 

squeeze 함수

torch.squeeze(input, dim=None) → Tensor

[Parameters]
* input (Tensor): 입력 텐서
* dim (int or tuple of ints, optional): 값이 지정돼 있다면 특정 차원에서 squeeze됨

 

차원이 1인 차원을 제거해줍니다. 특정 차원을 지정하면, 해당 차원의 크기가 1인 경우만 제거하고 1이 아니라면 그대로 유지합니다.

 

import torch
x = torch.rand(1,2,3,4,1)
print(x.squeeze().size())
print(x.squeeze(0).size())
print(x.squeeze((0,1,2,3)).size())
print(x.squeeze((1,2,3,4)).size())

"""
[출력]
torch.Size([2, 3, 4])
torch.Size([2, 3, 4, 1])
torch.Size([2, 3, 4, 1])
torch.Size([1, 2, 3, 4])
"""

 

unsqueeze 함수

torch.unsqueeze(input, dim) → Tensor

[Parameters]
* input (Tensor) – 입력 텐서
* dim (int) – 단일 차원을 삽입할 인덱스

 

1인 차원을 생성해줍니다. 어느 차원에 생성할 지는 지정해주면 됩니다.

 

import torch
x = torch.rand(2,3,4)
print(x.unsqueeze(0).size())
print(x.unsqueeze(-1).size())
print(x.unsqueeze(-2).size())

"""
[출력]
torch.Size([1, 2, 3, 4])
torch.Size([2, 3, 4, 1])
torch.Size([2, 3, 1, 4])
"""

 

Reference

[1] “torch.squeeze — PyTorch 2.0 documentation,” pytorch.org. https://pytorch.org/docs/stable/generated/torch.squeeze.html  (accessed Sep. 03, 2024).

[‌2] “torch.unsqueeze — PyTorch 1.11.0 documentation,” pytorch.org. https://pytorch.org/docs/stable/generated/torch.unsqueeze.html  (accessed Sep. 03, 2024).