본문 바로가기

NLP/PyTorch

[PyTorch] torch.max returns also indices

 

3줄 요약

 

  • PyTorch 의 torch.max 함수는
  • 값뿐만 아니라
  • 색인(index) 도 함께 return 한다.

 

(시작하기 전에 PyTorch 고수분들은 나가주세요, 부끄러우니까.)

 

 

 


 

배경

 

Math Word Problem 분야의 최신 논문을 읽고, 그 코드를 활용하고자 코드 분석하는 과정에서 상당히 실력이 늘었다. PyTorch 에 능통한 저자는 forward 함수 하나에 300 줄을 넘게 태우는 분이셨다...

 

https://github.com/allanj/deductive-mwp

 

GitHub - allanj/Deductive-MWP

Contribute to allanj/Deductive-MWP development by creating an account on GitHub.

github.com

 

그 분의 코드를 보다가 도저히 이해가 되지 않는 부분이 있었다. 바로 torch.max 함수가 마치 tuple 형태로 return 하는 듯이 2개의 변수에 대입하는 것이었다.

 

 best_temp_logits, best_stop_label = m0_combined_logits.max(dim=-1)  ## batch_size, num_combinations/num_m0, num_labels
 best_temp_score, best_temp_label = best_temp_logits.max(dim=-1)  ## batch_size, num_combinations
 best_m0_score, best_comb = best_temp_score.max(dim=-1)  ## batch_size

 

 

변수 이름 상으론, argmax 를 지칭하는 것 같아서 PyTorch 공식 문서에 들어가봤다. 언제나 공식문서는 옳다. 실력을 키우려면, 최신 논문에 작성된 코드를 통해 무럭무럭 자라는 것이 참 좋은 방법인 것 같다.

 

 

(다들 알았던 것인가? 나 아직 멀은 것이었단 말이야?)

 

 


 

실험

 

import torch

a = torch.randn(4, 4)

"""
tensor([[-2.2851,  0.4925,  0.9951,  0.8681],
        [ 1.3086, -0.2662, -0.1281,  1.1681],
        [ 0.0693, -1.8767, -0.5615,  2.6629],
        [-0.4545, -0.5358, -1.3342,  0.7333]])
"""

i, j = a.max(dim=-1)

"""
i: tensor([0.9951, 1.3086, 2.6629, 0.7333])
j: tensor([2, 0, 3, 3])
"""

 

'NLP > PyTorch' 카테고리의 다른 글

[PyTorch] Is scheduler always good?  (0) 2022.08.03
[PyTorch] Auto Mixed Precision  (0) 2022.06.29