본문 바로가기

NLP/논문이해

[논문이해] Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers

논문명: Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers

논문링크: https://arxiv.org/abs/2212.10559

 

Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers

Large pretrained language models have shown surprising in-context learning (ICL) ability. With a few demonstration input-label pairs, they can predict the label for an unseen input without parameter updates. Despite the great success in performance, its wo

arxiv.org

 


요약

  • In Context Learning 이 왜 되는가에 대해 수학적으로 접근한 논문
  • 핵심: Attention mechanism = back propagation 을 보여서, In Context Learning 의 예제가 attention mechanism 을 통과하는 것은 gradient update 를 일으키는 back propagation 과 유사하다는 걸 수식으로 입증한다.
  • 그리고 이 원리를 활용해 momentum 방식도 제안한다.

 

이 논문을 이해하려면 핵심은 크게 2가지라고 볼 수 있다.

  • 1. Attention 과 back propagation 이 수식적으로 유사하다는 걸 증명한다
  • 2. 그걸 토대로, In Context learning 이 back propagation 을 하는 것처럼 학습한다는 걸 증명한다

 

1번을 이해하기 위해선, 해당 논문보다 이 논문의 영감을 준 아래 논문을 읽는 것이 도움이 되는 것 같아 수식을 가져왔다.

 

https://arxiv.org/abs/2202.05798

 

The Dual Form of Neural Networks Revisited: Connecting Test Time Predictions to Training Patterns via Spotlights of Attention

Linear layers in neural networks (NNs) trained by gradient descent can be expressed as a key-value memory system which stores all training datapoints and the initial weights, and produces outputs using unnormalised dot attention over the entire training ex

arxiv.org

 

 

1. Attention Mechanism = Back propagation

  • Q, K, V 의 차원을 명심하자
  • Q: (d-in), K: (d-in X T), V: (d-out X T)
  • 당연히 Q 와 K 를 곱했을 때, key 의 개수만큼 각각의 점수가 나온다
  • 그리고 그걸 weighted average 하듯이, value vector 를 합치는 게 attention mechanism 이다
  • 여기까지 읽었을 때, 이해가 되지 않으면 attention 을 다시 공부하는 게 맞다

 

  • 2번을 증명하기 위해, 3/4번 식을 보면 된다. 여기까지 간단하다.
  • 개인적으로 왜 갑자기 이게 외적(outer product)으로 표기가 가능하지 했는데, 정의를 찾아보니까 맞았다.
  • 그리고 처음에 언급했듯이 unnormalized 이므로 softmax 와 차원의 제곱근으로 나누는 것을 하지 않는다.

 

두 시스템이 같다는 걸 증명하는 건 아주 쉽다.

  • 모든 입력에 대한 결과가 같으면 된다. 그러면 S1, S2 가 어떤 식으로 구성되었든 동치라고 본다고 한다.
  • 이제 linear layer 에서 back propagation 을 할 때와 비교하여 증명하고자 한다.

 

 

  • 우선 V x K 와 W 가 결국은 차원이 같기 때문에 W = V x K 와 동일하다고 볼 수 있다.

 

  • q = x 로 치환해서 보면, 둘이 동치라는 것을 증명할 수 있다.

 

일단, 여기까지 보면 attention mechanism = linear transformation 까진 증명했다. 사실 back propagation 까지 가기 위해서 필요했던 증명이라고 보면 된다. 이제 그걸 증명하러 가보자.

 

  • 먼저, W 를 back propagation 관점에서 2개로 나눌 수 있다.
  • W0: 처음 초기화했던 값
  • 그 옆에 있는 항은 외적으로, error back propagation 에 의해 만들어진 값이다. 즉, update 될 값이다.

 

  • 이에 따라 위 2개 식이 동치임이 밝혀진다.
  • 큰 X 는 예제(mini batch)에 포함된 것을 의미하며, 이는 unnormalized attention mechanism 과 linear back propagation 이 동치임을 증명한다.

 

2. In Context Learning = Back propagation

이제 원래 논문으로 돌아오자. 이 논문에서 애초에 GPT 라고 명시한 이유가 attention mechanism 을 쓰는 모델로 제한하기 위함이다. 당연히 Inference 에서도 attention 이 쓰인다. 그러므로, In Context Learning 역시 parameter update 는 없지만 attention mechanism 이 일어난다. 그러므로 In Context Learning = attention mechanism 이라서 제목을 저렇게 지었다.

 

  • 위 그림이 말하고 싶은 건 이것이다.
  • In Context Learning(이하 'ICL') 에 들어가는 예제들이 attention mechanism 으로 반영되는데, 수식적으로 back propagation in linear transformation 과 유사한 효과를 지닌다. 그래서 일시적으로 gradient update 가 되는 효과가 있으므로, In Context Learning 이 작동하는 것이다.

 

  • X': historical inputs 라고 되어 있는데, demonstration exampes 를 의미한다.
  • X: 입력을 의미하며, Wq 와 곱해지면서 query vector 가 된다.
  • 위 식은 GPT 모델에서 당연한 결과다.

 

  • 먼저 차원의 제곱을 나눠주는 분모, 그리고 softmax 함수를 제거했다. 그래서 근사 기호를 넣었다.
  • 사견: 첫번째 줄에서 두번째 줄로 넘어올 때, 아무리 생각해도 항이 2개나 생략되었다고 생각한다. 이에 대한 언급이 없어서 아쉽다. 근사가 2번 일어나는 셈이므로, 사실 여기서부터는 이 논문을 신중히 바라봐야 한다.

 

  • 이제 위에서 증명했던 걸 토대로, X'(예제 행렬)에 관한 식을 전개하면 ICL의 예제가 back propagation 처럼 작용한다는 걸 수식적으로 보여줬다.

 

실험: ICL vs fine tuning

  • 1) ICL 과 붙으러면, 위 수식처럼 finetuning 이 key matrix, value matrix 에서만 일어나야 한다. 그래서 그렇게만 gradient update 가 일어나도록 변경했다고 한다.
  • 2) 쓰이는 예제도 동일하게 했으며, finetuning 은 예제순서도 맞추고 1개씩 학습하며, 총 1 epoch 만 학습했다고 한다.
  • 3) template 형식으로 학습하기 위해 causal language modeling objective 를 사용했다고 한다.

 

설정

  • GPT fairseq 1.3B, 2.7B 사용
  • optimizer: SGD

  • zero shot 에 비해 모두 성능이 잘 나왔다. 딱 여기까지가 저자가 언급한 내용이다.
  • 그런데 ICL 이 FT 보다 훨씬 뛰어나다는 점이 눈에 띈다. 다만 저자는 이 점에 대해서는 이야기하지 않는다.

 

 

  • 위 지표를 해석하면, 분모는 zero shot 에 비해 fine tuning 이 몇개나 개선했는지를 계산한다.
  • 분자는 ICL은 그걸 얼마나 따라잡았는가를 개수로 센다.
  • 이걸 나누면, 이런 의미와 같다. Fine tuning 을 얼마나 ICL이 따라잡았는가?
  • 성능을 보면, 대체로 평균 85% 이상 In Context Learning 이 Fine tuning 을 따라잡았다는 것을 볼 수 있다.

 

뒤에 새로운 방식 'momentum'을 제안하지만, In Context Learning 의 원리를 이해해보는 것이 목적이므로 여기서 글을 마치겠다.