주어진 index 텐서에 따라 input 텐서의 값을 추출하여 새로운 텐서를 생성하는 gather 함수에 대해 알아봅시다.
gather 함수
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
[Parameters]
* input: 입력 텐서
* dim: gather 사용 시, 기준이 되는 축 (0이면 행 방향, 1이면 열 방향)
* index: 특정 값을 값들을 수집할 위치를 지정하는 index 텐서
(input tensor와 index tensor의 dimension이 동일해야함)
1) dim=0 (행 방향)인 경우
t = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
print(torch.gather(t, 0, index))
"""
[출력]
tensor([[1, 2],
[3, 2]])
"""
dim=0인 경우, index 텐서의 element들은 어떤 행의 값을 추출할지 나타냅니다. 값을 추출할 열은 입력 텐서 t의 각 element의 열에 해당합니다.
2) dim=1 (열 방향)인 경우
t = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
print(torch.gather(t, 1, index))
"""
[출력]
tensor([[1, 1],
[4, 3]])
"""
dim=1인 경우, index 텐서의 element들은 어떤 열의 값을 추출할지 나타냅니다. 값을 추출할 행은 입력 텐서 t의 각 element의 행에 해당합니다.
Reference
[1] “torch.gather — PyTorch 2.4 documentation,” Pytorch.org, 2023. https://pytorch.org/docs/stable/generated/torch.gather.html (accessed Sep. 03, 2024).
[2] “[Pytorch] gather 함수 설명 (특정 인덱스만 추출하기),” All I Need Is Data., Mar. 17, 2021. https://data-newbie.tistory.com/709 (accessed Sep. 03, 2024).
'AI 구현 > Pytorch' 카테고리의 다른 글
[Pytorch] view, reshape, permute 함수 : 차원 재구성 (0) | 2024.09.08 |
---|---|
[Pytorch] scatter_ 함수 : index에 따라 특정 위치 값을 직접 삽입/수정 (0) | 2024.09.03 |
[Pytorch] broadcast_tensors 함수 : 텐서 확장 및 연산 (0) | 2024.09.03 |
[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입 (0) | 2024.09.03 |