3줄 요약
- PyTorch 의 torch.max 함수는
- 값뿐만 아니라
- 색인(index) 도 함께 return 한다.
(시작하기 전에 PyTorch 고수분들은 나가주세요, 부끄러우니까.)
배경
Math Word Problem 분야의 최신 논문을 읽고, 그 코드를 활용하고자 코드 분석하는 과정에서 상당히 실력이 늘었다. PyTorch 에 능통한 저자는 forward 함수 하나에 300 줄을 넘게 태우는 분이셨다...
https://github.com/allanj/deductive-mwp
그 분의 코드를 보다가 도저히 이해가 되지 않는 부분이 있었다. 바로 torch.max 함수가 마치 tuple 형태로 return 하는 듯이 2개의 변수에 대입하는 것이었다.
best_temp_logits, best_stop_label = m0_combined_logits.max(dim=-1) ## batch_size, num_combinations/num_m0, num_labels
best_temp_score, best_temp_label = best_temp_logits.max(dim=-1) ## batch_size, num_combinations
best_m0_score, best_comb = best_temp_score.max(dim=-1) ## batch_size
변수 이름 상으론, argmax 를 지칭하는 것 같아서 PyTorch 공식 문서에 들어가봤다. 언제나 공식문서는 옳다. 실력을 키우려면, 최신 논문에 작성된 코드를 통해 무럭무럭 자라는 것이 참 좋은 방법인 것 같다.
(다들 알았던 것인가? 나 아직 멀은 것이었단 말이야?)
실험
import torch
a = torch.randn(4, 4)
"""
tensor([[-2.2851, 0.4925, 0.9951, 0.8681],
[ 1.3086, -0.2662, -0.1281, 1.1681],
[ 0.0693, -1.8767, -0.5615, 2.6629],
[-0.4545, -0.5358, -1.3342, 0.7333]])
"""
i, j = a.max(dim=-1)
"""
i: tensor([0.9951, 1.3086, 2.6629, 0.7333])
j: tensor([2, 0, 3, 3])
"""
'NLP > PyTorch' 카테고리의 다른 글
[PyTorch] Is scheduler always good? (0) | 2022.08.03 |
---|---|
[PyTorch] Auto Mixed Precision (0) | 2022.06.29 |