브로드캐스팅 기능을 이용해 두 텐서를 같은 크기로 확장하는 broadcast_tensors 함수에 대해 알아봅시다.
broadcast_tensors 함수
브로드캐스팅은 서로 다른 크기의 텐서를 같은 크기로 확장하여 연산을 가능하게 합니다. 이 함수를 쓰려면, 아래 조건을 만족해야 합니다.
[브로드캐스팅 규칙]
- 비교하는 두 차원 크기가 같거나 두 차원 중 하나의 크기가 1이어야 함
* 브로드캐스팅 O
x = torch.rand(3,6)
y = torch.rand(1,6)
a,b = torch.broadcast_tensors(x,y)
print(a.size())
print(b.size())
"""
[출력]
torch.Size([3, 6])
torch.Size([3, 6])
"""
* 브로드캐스팅 X
x = torch.rand(4,1)
y = torch.rand(6,3)
a,b = torch.broadcast_tensors(x,y)
브로드캐스팅이 안되는 경우를 살펴보겠습니다. x, y의 첫 번째 차원은 서로 일치하지 않고 두 번째 차원은 1,3이기 때문에 브로드캐스팅이 가능합니다. 하지만 첫 번째 차원의 크기가 호환되지 않아 오류가 발생하게 됩니다. 아래는 이를 보완한 코드입니다.
x = torch.rand(4,1,1)
y = torch.rand(6,3)
a,b = torch.broadcast_tensors(x,y)
print(a.size())
print(b.size())
"""
[출력]
torch.Size([4, 6, 3])
torch.Size([4, 6, 3])
"""
Reference
[1] “torch.broadcast_tensors — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.broadcast_tensors.html (accessed Sep. 03, 2024).
'AI 구현 > Pytorch' 카테고리의 다른 글
[Pytorch] view, reshape, permute 함수 : 차원 재구성 (0) | 2024.09.08 |
---|---|
[Pytorch] gather 함수 : index에 따라 값을 수집/추출 (0) | 2024.09.03 |
[Pytorch] scatter_ 함수 : index에 따라 특정 위치 값을 직접 삽입/수정 (0) | 2024.09.03 |
[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입 (0) | 2024.09.03 |