'분류 전체보기' 카테고리의 글 목록 (4 Page)
[Pytorch] gather 함수 : index에 따라 값을 수집/추출
주어진 index 텐서에 따라 input 텐서의 값을 추출하여 새로운 텐서를 생성하는 gather 함수에 대해 알아봅시다. gather 함수torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor[Parameters]* input: 입력 텐서* dim: gather 사용 시, 기준이 되는 축 (0이면 행 방향, 1이면 열 방향)* index: 특정 값을 값들을 수집할 위치를 지정하는 index 텐서(input tensor와 index tensor의 dimension이 동일해야함) 1) dim=0 (행 방향)인 경우 t = torch.tensor([[1, 2], [3, 4]])index = torch.tensor([[0, 0], [..
2024. 9. 3.
[Pytorch] squeeze,unsqueeze 함수 : 차원 삭제, 차원 삽입
크기가 1인 차원을 삭제하는 squeeze 함수와 삽입하는 unsqueeze 함수에 대해 알아봅시다. squeeze 함수torch.squeeze(input, dim=None) → Tensor[Parameters]* input (Tensor): 입력 텐서* dim (int or tuple of ints, optional): 값이 지정돼 있다면 특정 차원에서 squeeze됨 차원이 1인 차원을 제거해줍니다. 특정 차원을 지정하면, 해당 차원의 크기가 1인 경우만 제거하고 1이 아니라면 그대로 유지합니다. import torchx = torch.rand(1,2,3,4,1)print(x.squeeze().size())print(x.squeeze(0).size())print(x.squeeze((0,1,2,3))...
2024. 9. 3.