Torch 연산
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor >>> t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) (계속 정리...)
더보기