본문 바로가기

NLP/논문이해

[논문이해] training language models to follow instructions with human feedback

 

논문을 이해하고 싶다면 아래 글을 읽으세요. 너무 잘 써서 이것보다 더 잘 쓸 자신이 없어요.

 

https://taeyuplab.tistory.com/10

 

[논문 리뷰] InstructGPT: Training language models to follow instructions with human feedback

이 글에서는 InstructGPT를 제안한 논문인 Training language models to follow instructions with human feedback에 대해 살펴볼 것이다. 본 논문은 GPT-1, GPT-2, GPT-3 논문을 발표한 OpenAI로부터 2022년 NeurIPS에 발표되었다.

taeyuplab.tistory.com

 

근데 논문만 읽고서는 이해가 잘 안되기도 합니다. 저 같은 경우는 2가지가 와닿지 않았어요.

 

1) Reward Model

: 그래서 구현을 어떻게 했나?

 

https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py

 

trlx/examples/summarize_rlhf/reward_model/reward_model.py at main · CarperAI/trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) - CarperAI/trlx

github.com

 

 

    def forward(
        self,
        ...
    ):
        loss = None
        
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        hidden_states = transformer_outputs[0]

        rewards = self.v_head(hidden_states).squeeze(-1)
  • transformer 통과한 임베딩에 self.v_head 를 통과시키니 rewards 가 나온다고 함
  • 그렇다면 self.v_head 는 뭐냐?

 

self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
  • 그냥 linear 였다고 한다. linear layer 로 scalar value 를 측정하는 것으로 보임.

 

2) 손실함수

: 대체 아래 식은 코드로 구현하면 어떤 느낌인걸까.

(우선 이 식이 무슨 의미인지 모르겠다면 맨위에 언급한 블로그를 다시 읽을 것. 특히 log(sigmoid)도 직접 그려볼 것.)

  • 여기서 기댓값이라는 게 평균일텐데, 무엇에 대한 평균인지 궁금했다.

 

# Split the inputs and rewards into two parts, chosen and rejected
assert len(input_ids.shape) == 2
bs = input_ids.shape[0] // 2
chosen = input_ids[:bs]
rejected = input_ids[bs:]
chosen_rewards = rewards[:bs]
rejected_rewards = rewards[bs:]

 

  • 일단 위 식으로 유추할 수 있는 건, (사람이 선호하는 답, 선호하지 않는 답) 이렇게 1쌍으로 넣은 것 같다
  • chosen: 사람이 선호하는 답
  • rejected: 사람이 거절하는 답

 

loss = 0
inference = False
for i in range(bs):
    if torch.all(torch.eq(chosen[i], rejected[i])).item():
        c_inds = (chosen[i] == self.PAD_ID).nonzero()
        c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
        chosen_end_scores.append(chosen_rewards[i, c_ind - 1])
        inference = True
        continue

 

  • 조건문의 의미가 뭐냐면, 인간이 선호하는 답과 거절하는 답이 완전히 동일한 경우를 의미한다
  • 즉, 모델이 생성한 결과가 아예 똑같은 경우엔 넘어간다고 합니다.

 

    # Check if there is any padding otherwise take length of sequence
    c_inds = (chosen[i] == self.PAD_ID).nonzero()
    c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
    r_inds = (rejected[i] == self.PAD_ID).nonzero()
    r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1]
    end_ind = max(c_ind, r_ind)

 

  • if 문을 무사히 넘겼으니, 일단 두 답변이 동일하지 않다는 건 확인이 됐다
  • 이때 padding 을 제외한 실제 답변이 끝나는 색인(index)을 찾는다

 

    # Retrieve first index where trajectories diverge
    divergence_ind = (chosen[i] != rejected[i]).nonzero()[0]
    assert divergence_ind > 0

    # Index into the correct rewards
    c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind]
    r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind]
  • 두 답변을 비교하여 처음으로 답이 달라지는 토큰의 색인을 찾는다
  • 그리고 난 뒤, 달라진 색인부터 끝까지 보상값을 잘라서 저장한다
  • 앞부분의 답변이 동일한 동안, 어차피 보상값도 동일하니까 굳이 더하지 않는 것으로 보인다

 

    # Append the last rewards to the list of end scores
    chosen_end_scores.append(c_truncated_reward[-1])
    rejected_end_scores.append(r_truncated_reward[-1])

    # Compute loss based on truncated rewards (ignore padding)
    loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
loss = loss / bs
  • 그렇게  아까 말했던 log(sigmoid)에 삽입한다
  • 기댓값이라고 표현한 것은 각 토큰마다 생성한 보상값을 모두 넣어서 mean 함수로 평균내기 때문인 것 같다

 

(근데 의문점이 하나 있다면, padding 이 모두 지워지는 게 맞나? 예컨대, 선호하는 답의 길이가 20, 거절한 답의 길이가 10이라고 하자. 그런데 end_ind 가 20이니까 r_trucated_reward 에서 11번째부터는 padding 토큰이 만들어낸 보상값 아닌가? 흠....)