[Pytorch] gather 함수 : index에 따라 값을 수집/추출
본문 바로가기
AI 구현/Pytorch

[Pytorch] gather 함수 : index에 따라 값을 수집/추출

by NEWSUN* 2024. 9. 3.

주어진 index 텐서에 따라 input 텐서의 값을 추출하여 새로운 텐서를 생성하는 gather 함수에 대해 알아봅시다.

 

gather 함수

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

[Parameters]
* input: 입력 텐서
* dim: gather 사용 시, 기준이 되는 축 (0이면 행 방향, 1이면 열 방향)
* index: 특정 값을 값들을 수집할 위치를 지정하는 index 텐서
(input tensor와 index tensor의 dimension이 동일해야함)

 

 

1) dim=0 (행 방향)인 경우

 

t = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
print(torch.gather(t, 0, index))

"""
[출력]
tensor([[1, 2],
        [3, 2]])
"""

 

 

dim=0인 경우, index 텐서의 element들은 어떤 행의 값을 추출할지 나타냅니다. 값을 추출할 열은 입력 텐서 t의 각 element의 열에 해당합니다.

 

 

2) dim=1 (열 방향)인 경우

 

t = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
print(torch.gather(t, 1, index))

"""
[출력]
tensor([[1, 1],
        [4, 3]])
"""

 

 

dim=1인 경우, index 텐서의 element들은 어떤 열의 값을 추출할지 나타냅니다. 값을 추출할 행은 입력 텐서 t의 각 element의 행에 해당합니다.

 

Reference

[1] “torch.gather — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.gather.html (accessed Sep. 03, 2024).

‌[2] “[Pytorch] gather 함수 설명 (특정 인덱스만 추출하기),” All I Need Is Data., Mar. 17, 2021. https://data-newbie.tistory.com/709 (accessed Sep. 03, 2024).