[Pytorch] scatter_ 함수 : index에 따라 특정 위치 값을 직접 삽입/수정
본문 바로가기
AI 구현/Pytorch

[Pytorch] scatter_ 함수 : index에 따라 특정 위치 값을 직접 삽입/수정

by NEWSUN* 2024. 9. 3.

주어진 index에 따라 원본 텐서의 특정 위치에 값을 직접 수정하는 scatter_ 함수에 대해 알아봅시다.

 

torch.scatter 함수와 torch.scatter_ 함수 다른점

  • torch.scatter 함수 : 원본 텐서를 수정하지 않고, 새로운 텐서를 반환 (out-of-place)
  • torch.scatter_ 함수 : 원본 텐서를 직접 수정함 (in-place)

 

torch.scatter_ 함수

Tensor.scatter_(dim, index, src, *, reduce=None) → Tensor

[Parameters]
* dim: scatter 사용 시, 기준이 되는 축 (0이면 행 방향, 1이면 열 방향)
* index: element들의 index, 숫자를 어떤식으로 옮길지 결정하는 규칙
* src: 옮길 element를 담고 있는 tensor

 

 

1) dim=0 (행 방향)인 경우

 

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2, 0]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src))

"""
[출력]
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]]) 
"""

 

 

dim=0인 경우, index의 element 숫자들은 어떤 행에 값을 넣어줄지 나타냅니다. 값을 넣어줄 열은 각 element의 열에 해당합니다. 

 

 

2) dim=1 (열 방향)인 경우

 

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src))

"""
[출력]
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])
"""

 

 

dim=1인 경우, index의 element 숫자들은 어떤 열에 값을 넣어줄지 나타냅니다. 값을 넣어줄 행은 각 element의 행에 해당합니다. 

 

Reference

[1] “torch.Tensor.scatter_ — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html (accessed Sep. 03, 2024).