논문명: Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers
논문링크: https://arxiv.org/abs/2212.10559
요약
- In Context Learning 이 왜 되는가에 대해 수학적으로 접근한 논문
- 핵심: Attention mechanism = back propagation 을 보여서, In Context Learning 의 예제가 attention mechanism 을 통과하는 것은 gradient update 를 일으키는 back propagation 과 유사하다는 걸 수식으로 입증한다.
- 그리고 이 원리를 활용해 momentum 방식도 제안한다.
이 논문을 이해하려면 핵심은 크게 2가지라고 볼 수 있다.
- 1. Attention 과 back propagation 이 수식적으로 유사하다는 걸 증명한다
- 2. 그걸 토대로, In Context learning 이 back propagation 을 하는 것처럼 학습한다는 걸 증명한다
1번을 이해하기 위해선, 해당 논문보다 이 논문의 영감을 준 아래 논문을 읽는 것이 도움이 되는 것 같아 수식을 가져왔다.
https://arxiv.org/abs/2202.05798
1. Attention Mechanism = Back propagation
- Q, K, V 의 차원을 명심하자
- Q: (d-in), K: (d-in X T), V: (d-out X T)
- 당연히 Q 와 K 를 곱했을 때, key 의 개수만큼 각각의 점수가 나온다
- 그리고 그걸 weighted average 하듯이, value vector 를 합치는 게 attention mechanism 이다
여기까지 읽었을 때, 이해가 되지 않으면 attention 을 다시 공부하는 게 맞다
- 2번을 증명하기 위해, 3/4번 식을 보면 된다. 여기까지 간단하다.
- 개인적으로 왜 갑자기 이게 외적(outer product)으로 표기가 가능하지 했는데, 정의를 찾아보니까 맞았다.
- 그리고 처음에 언급했듯이 unnormalized 이므로 softmax 와 차원의 제곱근으로 나누는 것을 하지 않는다.
두 시스템이 같다는 걸 증명하는 건 아주 쉽다.
- 모든 입력에 대한 결과가 같으면 된다. 그러면 S1, S2 가 어떤 식으로 구성되었든 동치라고 본다고 한다.
- 이제 linear layer 에서 back propagation 을 할 때와 비교하여 증명하고자 한다.
- 우선 V x K 와 W 가 결국은 차원이 같기 때문에 W = V x K 와 동일하다고 볼 수 있다.
- q = x 로 치환해서 보면, 둘이 동치라는 것을 증명할 수 있다.
일단, 여기까지 보면 attention mechanism = linear transformation 까진 증명했다. 사실 back propagation 까지 가기 위해서 필요했던 증명이라고 보면 된다. 이제 그걸 증명하러 가보자.
- 먼저, W 를 back propagation 관점에서 2개로 나눌 수 있다.
- W0: 처음 초기화했던 값
- 그 옆에 있는 항은 외적으로, error back propagation 에 의해 만들어진 값이다. 즉, update 될 값이다.
- 이에 따라 위 2개 식이 동치임이 밝혀진다.
- 큰 X 는 예제(mini batch)에 포함된 것을 의미하며, 이는 unnormalized attention mechanism 과 linear back propagation 이 동치임을 증명한다.
2. In Context Learning = Back propagation
이제 원래 논문으로 돌아오자. 이 논문에서 애초에 GPT 라고 명시한 이유가 attention mechanism 을 쓰는 모델로 제한하기 위함이다. 당연히 Inference 에서도 attention 이 쓰인다. 그러므로, In Context Learning 역시 parameter update 는 없지만 attention mechanism 이 일어난다. 그러므로 In Context Learning = attention mechanism 이라서 제목을 저렇게 지었다.
- 위 그림이 말하고 싶은 건 이것이다.
- In Context Learning(이하 'ICL') 에 들어가는 예제들이 attention mechanism 으로 반영되는데, 수식적으로 back propagation in linear transformation 과 유사한 효과를 지닌다. 그래서 일시적으로 gradient update 가 되는 효과가 있으므로, In Context Learning 이 작동하는 것이다.
- X': historical inputs 라고 되어 있는데, demonstration exampes 를 의미한다.
- X: 입력을 의미하며, Wq 와 곱해지면서 query vector 가 된다.
- 위 식은 GPT 모델에서 당연한 결과다.
- 먼저 차원의 제곱을 나눠주는 분모, 그리고 softmax 함수를 제거했다. 그래서 근사 기호를 넣었다.
- 사견: 첫번째 줄에서 두번째 줄로 넘어올 때, 아무리 생각해도 항이 2개나 생략되었다고 생각한다. 이에 대한 언급이 없어서 아쉽다. 근사가 2번 일어나는 셈이므로, 사실 여기서부터는 이 논문을 신중히 바라봐야 한다.
- 이제 위에서 증명했던 걸 토대로, X'(예제 행렬)에 관한 식을 전개하면 ICL의 예제가 back propagation 처럼 작용한다는 걸 수식적으로 보여줬다.
실험: ICL vs fine tuning
- 1) ICL 과 붙으러면, 위 수식처럼 finetuning 이 key matrix, value matrix 에서만 일어나야 한다. 그래서 그렇게만 gradient update 가 일어나도록 변경했다고 한다.
- 2) 쓰이는 예제도 동일하게 했으며, finetuning 은 예제순서도 맞추고 1개씩 학습하며, 총 1 epoch 만 학습했다고 한다.
- 3) template 형식으로 학습하기 위해 causal language modeling objective 를 사용했다고 한다.
설정
- GPT fairseq 1.3B, 2.7B 사용
- optimizer: SGD
- zero shot 에 비해 모두 성능이 잘 나왔다. 딱 여기까지가 저자가 언급한 내용이다.
그런데 ICL 이 FT 보다 훨씬 뛰어나다는 점이 눈에 띈다. 다만 저자는 이 점에 대해서는 이야기하지 않는다.
- 위 지표를 해석하면, 분모는 zero shot 에 비해 fine tuning 이 몇개나 개선했는지를 계산한다.
- 분자는 ICL은 그걸 얼마나 따라잡았는가를 개수로 센다.
- 이걸 나누면, 이런 의미와 같다. Fine tuning 을 얼마나 ICL이 따라잡았는가?
- 성능을 보면, 대체로 평균 85% 이상 In Context Learning 이 Fine tuning 을 따라잡았다는 것을 볼 수 있다.
뒤에 새로운 방식 'momentum'을 제안하지만, In Context Learning 의 원리를 이해해보는 것이 목적이므로 여기서 글을 마치겠다.