[Pytorch] broadcast_tensors 함수 : 텐서 확장 및 연산
본문 바로가기
AI 구현/Pytorch

[Pytorch] broadcast_tensors 함수 : 텐서 확장 및 연산

by NEWSUN* 2024. 9. 3.

브로드캐스팅 기능을 이용해 두 텐서를 같은 크기로 확장하는 broadcast_tensors 함수에 대해 알아봅시다.

 

broadcast_tensors 함수

브로드캐스팅은 서로 다른 크기의 텐서를 같은 크기로 확장하여 연산을 가능하게 합니다. 이 함수를 쓰려면, 아래 조건을 만족해야 합니다.

 

[브로드캐스팅 규칙]

  • 비교하는 두 차원 크기가 같거나 두 차원 중 하나의 크기가 1이어야 함

 

* 브로드캐스팅 O

 

x = torch.rand(3,6)
y = torch.rand(1,6)
a,b = torch.broadcast_tensors(x,y)
print(a.size())
print(b.size())

"""
[출력]
torch.Size([3, 6])
torch.Size([3, 6])
"""

 

 

* 브로드캐스팅 X

 

x = torch.rand(4,1)
y = torch.rand(6,3)
a,b = torch.broadcast_tensors(x,y)

 

브로드캐스팅이 안되는 경우를 살펴보겠습니다. x, y의 첫 번째 차원은 서로 일치하지 않고 두 번째 차원은 1,3이기 때문에 브로드캐스팅이 가능합니다. 하지만 첫 번째 차원의 크기가 호환되지 않아 오류가 발생하게 됩니다. 아래는 이를 보완한 코드입니다.

 

x = torch.rand(4,1,1)
y = torch.rand(6,3)
a,b = torch.broadcast_tensors(x,y)
print(a.size())
print(b.size())

"""
[출력]
torch.Size([4, 6, 3])
torch.Size([4, 6, 3])
"""

 

Reference

[1] “torch.broadcast_tensors — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.broadcast_tensors.html (accessed Sep. 03, 2024).