크기가 1인 차원을 삭제하는 squeeze 함수와 삽입하는 unsqueeze 함수에 대해 알아봅시다.
squeeze 함수
torch.squeeze(input, dim=None) → Tensor
[Parameters]
* input (Tensor): 입력 텐서
* dim (int or tuple of ints, optional): 값이 지정돼 있다면 특정 차원에서 squeeze됨
차원이 1인 차원을 제거해줍니다. 특정 차원을 지정하면, 해당 차원의 크기가 1인 경우만 제거하고 1이 아니라면 그대로 유지합니다.
import torch
x = torch.rand(1,2,3,4,1)
print(x.squeeze().size())
print(x.squeeze(0).size())
print(x.squeeze((0,1,2,3)).size())
print(x.squeeze((1,2,3,4)).size())
"""
[출력]
torch.Size([2, 3, 4])
torch.Size([2, 3, 4, 1])
torch.Size([2, 3, 4, 1])
torch.Size([1, 2, 3, 4])
"""
unsqueeze 함수
torch.unsqueeze(input, dim) → Tensor
[Parameters]
* input (Tensor) – 입력 텐서
* dim (int) – 단일 차원을 삽입할 인덱스
1인 차원을 생성해줍니다. 어느 차원에 생성할 지는 지정해주면 됩니다.
import torch
x = torch.rand(2,3,4)
print(x.unsqueeze(0).size())
print(x.unsqueeze(-1).size())
print(x.unsqueeze(-2).size())
"""
[출력]
torch.Size([1, 2, 3, 4])
torch.Size([2, 3, 4, 1])
torch.Size([2, 3, 1, 4])
"""
Reference
[1] “torch.squeeze — PyTorch 2.0 documentation,” pytorch.org. https://pytorch.org/docs/stable/generated/torch.squeeze.html (accessed Sep. 03, 2024).
[2] “torch.unsqueeze — PyTorch 1.11.0 documentation,” pytorch.org. https://pytorch.org/docs/stable/generated/torch.unsqueeze.html (accessed Sep. 03, 2024).
'AI 구현 > Pytorch' 카테고리의 다른 글
[Pytorch] view, reshape, permute 함수 : 차원 재구성 (0) | 2024.09.08 |
---|---|
[Pytorch] gather 함수 : index에 따라 값을 수집/추출 (0) | 2024.09.03 |
[Pytorch] scatter_ 함수 : index에 따라 특정 위치 값을 직접 삽입/수정 (0) | 2024.09.03 |
[Pytorch] broadcast_tensors 함수 : 텐서 확장 및 연산 (0) | 2024.09.03 |