주어진 index에 따라 원본 텐서의 특정 위치에 값을 직접 수정하는 scatter_ 함수에 대해 알아봅시다.
torch.scatter 함수와 torch.scatter_ 함수 다른점
- torch.scatter 함수 : 원본 텐서를 수정하지 않고, 새로운 텐서를 반환 (out-of-place)
- torch.scatter_ 함수 : 원본 텐서를 직접 수정함 (in-place)
torch.scatter_ 함수
Tensor.scatter_(dim, index, src, *, reduce=None) → Tensor
[Parameters]
* dim: scatter 사용 시, 기준이 되는 축 (0이면 행 방향, 1이면 열 방향)
* index: element들의 index, 숫자를 어떤식으로 옮길지 결정하는 규칙
* src: 옮길 element를 담고 있는 tensor
1) dim=0 (행 방향)인 경우
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2, 0]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src))
"""
[출력]
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
"""
dim=0인 경우, index의 element 숫자들은 어떤 행에 값을 넣어줄지 나타냅니다. 값을 넣어줄 열은 각 element의 열에 해당합니다.
2) dim=1 (열 방향)인 경우
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
print(torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src))
"""
[출력]
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
"""
dim=1인 경우, index의 element 숫자들은 어떤 열에 값을 넣어줄지 나타냅니다. 값을 넣어줄 행은 각 element의 행에 해당합니다.
Reference
[1] “torch.Tensor.scatter_ — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html (accessed Sep. 03, 2024).
'AI 구현 > Pytorch' 카테고리의 다른 글
[Pytorch] view, reshape, permute 함수 : 차원 재구성 (0) | 2024.09.08 |
---|---|
[Pytorch] gather 함수 : index에 따라 값을 수집/추출 (0) | 2024.09.03 |
[Pytorch] broadcast_tensors 함수 : 텐서 확장 및 연산 (0) | 2024.09.03 |
[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입 (0) | 2024.09.03 |