빅웨이브에이아이 기술블로그

RLHF - 어떻게 LLM의 성능을 향상시킬 수 있을까? 본문

기술 블로그

RLHF - 어떻게 LLM의 성능을 향상시킬 수 있을까?

빅웨이브 이현상 2024. 1. 24. 16:25
빅웨이브에이아이 이원석 님의 리뷰입니다.

 

LLM 모델인 GPT-4, PaLM, LLama 등은 범용적인 목적에 맞게, 매우 큰 모델 사이즈와 매우 방대한 양의 데이터로 사전 학습이 수행됨

일반적인 LLM의 경우 방대한 양의 데이터로 부터 매우 다양한 도메인 지식을 습득

But, 사전 학습 데이터에서 욕설, 편향적인 정보, 부정확한 정보를 담은 문서 등 적절치 못한 데이터를 다수 포함

데이터 클렌징 및 필터링 등 방대한 양의 데이터를 사람이 전부 처리하는 것은 한계가 존재

이에 따라, 모델이 부적절한 문장이나 단어를 선택하여 다음 문장을 생성하는 일이 빈번하게 발생

생성 모델 자체도 Next-token prediction 방식으로 학습 되기 때문에 최대한 확률적으로 높은 문장을 생성하는 것, 이로인한 환각 및 비윤리적 답변이 문제가 됨

 

그림 출처:  https://tech.scatterlab.co.kr/luda-rlhf/

 

LLM의 능력을 적절하게 활용하기 위해서 사람의 의도와 방향에 맞게 생성 모델을 통제할 수 있어야 하며 최근의 LLM들은 Supervised Fine-tuning (SFT) 방식과 Reinforcement Learning from Human Feedback (RLHF) 방식을 통해 생성 모델의 성능을 고도화함

그림 출처:  twitter.com/anthrupad

 

Unsupervised Learning (Pre-training)

  • 사전 학습을 통해서 생성 모델을 학습하는 단계로, 대형 생성 모델은 강력하지만 사람이 원하는 의도대로 동작하기 어려우며 서비스에 바로 적용하기 힘든 상태
  • 사전 학습된 Domain 지식을 통해서 모델의 성능이 크게 좌지우지 될 수 있으며 fine-tune 단계는 생성 모델이 답변을 잘 낼 수 있도록 보조하는 역할 정도로 생각할 수 있음

Supervised Fine-tuning(SFT)

  • 특정 도메인 혹은 크라우드 소싱 등으로 구축된 양 질의 데이터 Instruct, Input, Reponse 쌍을 통해서 fine-tune 하는 과정이며 instruct-tuning이라고도 함
  • 해당 과정을 통해 생성 모델이 사람의 의도에 맞는 문장을 생성하는 법을 익히며 사전 학습 단계에서 지식 습득이 제대로 되었다면 SFT 단계에서 적은 수의 양질의 데이터 만으로도 충분한 성능을 낼 수 있음**(LIMA : Less is more Alignment 참고!)**
  • 아래와 같이 지시사항에 대한 답변을 생성해낼 수 있도록 예시 데이터를 주고 학습을 수행하는 것
<|system|>
귀하는 지시를 매우 잘 따르는 인공지능 비서입니다. 최대한 많이 도와주세요.</s>
<|user|>
다음 질문에 답해 주세요: 문맥: 오스틴은 도시의 밤 문화를 좋아해서 외출할 준비를 했습니다.  
질문: 오스틴은 왜 이렇게 했을까요?  
다음 중 문맥에 따라 이 질문에 가장 잘 맞는 답은 무엇인가요?  
A: 최고의 모습을 보이다 
B: 클럽에서 어울리다 
C: 파티에 나가다
정답</s>
<|assistant|>
B: 클럽에서 놀기</s>

 

Reinforcement learning from human feedback(RLHF)

  • SFT 모델에 추가적으로 사람의 피드백을 보상으로 한 강화 학습을 적용하여 사람의 의도에 부합하는 답변을 생성 하도록 하는 과정

SFT의 경우 주어진 문맥에 대해서 모델이 모범 답안을 주어서 올바른 답변을 생성하도록 학습이 진행된다면, RLHF 의 경우 “너는 그렇게 답변해서는 안된다!” 를 알려주는 역할 또는 A 문장 보다 B 문장이 좋다고 알려주는 역할을 수행

위와 같이 답변에 대한 긍/부정적 평가 혹은 순위 정보를 피드백 하여 사람의 선호도를 학습하는 방법론을 Learning from Human Feedback 혹은 Human Preference Alignment 라함

가장 대표적인 방법이 RLHF

 

RLHF

강화 학습은 주어진 환경(Enviroment)에 대한 상태(State)에 따라 Policy 모델이 행동(Action)을 하게 되고, 그 일련의 과정에서 얻은 보상(Reward)를 기반으로 각 상태 또는 행동에 대한 가치를 평가하여 학습하는 방법론

여기서 중요한 점은 어떤 행동이 정답인지 레이블로 존재하는 supervised learning과는 달리, 모델이 한 행동이 적절한지 아닌지를 알려주게됨

바둑의 예시로 현재 상황에서 너가 둬야하는 적절한 수의 위치는 (6, 15)야라고 정답을 말해주는 supervised learning과는 달리, 강화 학습은 모델이 한 행동에 대해 너가 둔 수는 -1만큼 보상을 받아

라고 알려주는 것

 

위 과정을 LLM 학습에 빗대어 위 그림과 같이 표현 가능

instruct-tune을 통해 학습된 LLM이 입력된 prompt로 context를 생성하는데 context를 생성 한 것을 토대로 reward를 주는데 이 reward에 대한 부분을 사람이 직접 feedback을 의미함

예를 들어,

Human : 오늘 뭔가 아무것도 하기 싫은걸?

  1. Assistant : 그럼 때려쳐!
  2. Assistant : 왜 그래? 무슨일이야?!

위의 문맥에서 1번 보다는 2번이 좋다라는 피드백을 주는 것으로 1번은 부정적인 보상, 후자는 긍정적인 보상을 주어 모델을 학습하는 것

RLHF는 SFT가 완료된 모델에 대해서 수행하는 것을 전제로 한다. 그러나 Reward를 주는 행위를 전부 사람이 수행하는 것이 가능한가? 시간과 비용이 많이 들 것!

따라서 Reward Model을 통해서 이를 대신 수행

Reward Model

Reward Model은 모델이 한 행동(모델이 생성한 문장)에 대해서 사람이 매번 리워드를 평가 할 수 없기에 모델이 생성한 문장에 대해서 자동으로 평가할 수 있는 모델을 학습한 것

 

간단하게 생각하면 Reward Model 또한 지도 학습을 통해서 학습되는 것

따라서, Reward Model 학습을 위해서 초기에 Reward에 대해 사람이 레이블링 한 데이터셋이 필요

Reward 에 대한 레이블을 사람이 직접 매긴 점수의 형태로도 사용할 수 있으나 점수의 경우 사람마다 그에 대한 척도가 달라 편향이 발생할 수 있음

많이 채택하는 방식으로 (Prompt, Chosen Response, Reject Response) 의 형태로 데이터를 구성

초기 SFT로 학습 된 모델을 통해 각 Prompt로 부터 여러 가지 답변을 생성 후 생성 된 답변들 중 positive한 답변과 negative한 답변을 나누어 위와 같이 데이터 구성

이때 Bradley-Terry 모델이라는 두 후보 우위에 대한 확률을 계산하는 모델로 Chosen Response가 Reject Response 보다 답변이 더 좋을 확률을 계산

주어진 Prompt 에 대해

P(Chosen Response > Reject Response | Prompt)

Chosen Response가 좋을 확률을 높이는 방향으로 학습 ⇒ negative 답변에 대한 logit은 작아지고, positive 답변에 대한 logit은 커지게 학습되게 됨

학습이 완료된 Reward Model은 위 그림처럼 각 Response에 대한 응답 점수를 낼 수 있음

Fine Tuning with RLHF(PPO & KL Divergence)

학습된 Reward Model을 활용하여 RLHF 학습하는 과정은 아래와 같다.

  1. 준비된 Prompt 데이터 셋 instruct-LLM 입력
  2. Prompt 가 완성된 결과를 Reward Model 입력 및 Reward를 추출하고 강화학습 알고리즘으로 전달
  3. 강화학습에 사용하는 알고리즘은 Proximal Policy Optimization 으로 보상 점수에 대한 로스를 계산하고 LLM 높은 Reward를 산출할 수 있도록 최적화를 수행

그러나, 위의 학습 과정에서는 취약점이 존재

모델이 지나치게 높은 리워드를 내는 것에 목표를 가지다보니 리워드 모델의 취약점 및 편향을 찾아내고 그 부분만을 집중적으로 공략하여 높은 리워드 점수를 내도록 할수가 있는데 이를 리워드 해킹(Reward Hacking) 또는 Mode Collapse 라 함

이를 방지하기 위한 Regulrization term 이 추가되는 이 항이 KL divergence term ⇒ KL Penalty

KL divergence term의 경우 레퍼런스 모델을 기준으로 하여 학습되는 모델이 레퍼런스 모델의 분포로부터 너무 크게 벗어나지 않도록 하는 역할을 수행

레퍼런스 모델의 경우 전체 weight에 대한 freeze 후에 사용되는 형태로 보상을 높이다 모델의 생성형태가 환각의 형태에 가까울때 그에 벗어나지 않도록 중재해주는 역할을 수행할 수 있는 것!

 

RLHF 학습 불안정성 및 한계

RLHF는 생성 모델, 리워드 모델, 레퍼런스 모델들의 상호작용을 통해서 강화학습 수행

이에 따라 학습이 불안정하여 쉽게 over-fitting이 발생하고, 하이퍼 파라미터에 상당히 민감한 특징을 가짐

특히, Reward Model robust 하지 않으면 학습 성능에 치명적인 악영향을 초래할수 있음

이를 방지하기 위해 Reward Model 학습 시 negative sample을 증강 하거나 원래 모델 분포를 벗어나지 않기 위해 pre-training에 사용했던 데이터를 다시 넣어서 학습하는 Pre-training-mix 방법론 등을 사용

또한, 상호작용을 위한 모델의 수가 최소 3개로 학습 시 필요한 일반 학습보다 컴퓨팅 자원의 수가 더 많이 필요

 

DPO(Direct Preference Optimization)

RLHF의 구현 복잡성 및 학습 불안정성의 문제를 보완하기 위한 대체 방법론으로 아래와 같은 특징을 가짐

  • 학습 시 RLHF 보다 더 적은 수 모델을 활용하여 GPU 자원 사용량이 적으며 학습 속도가 빠름
  • 간단한 학습 방식과 더 적은 하이퍼 파라미터 튜닝으로 안정적인 학습
  • 기존 RLHF보다 더 좋은 성능

해당 방법론의 특징은 리워드 모델 학습용 데이터 셋을 직접적으로 사용하여 positive 답변에 대한 확률은 높아지게, negative 확률은 낮아지도록 학습하는 방식을 제안

학습 시 레퍼런스 모델이 계산한 확률 대비 학습하는 모델이 계산한 확률을 리워드 점수로 하며 이를 implict Reward라 정의

식을 보면 간단히 알 수 있는데 y_w을 positive 답변, y_l을 negative 답변이라 할 때

학습 중인 모델과 레퍼런스 모델의 positive 답변 생성에 대한 확률 비율이 negative 답변 생성에 대한q비율 보다 커지도록 학습하는 것

정리하면, DPO는 Reward Model 학습용 데이터를 직접 활용하여 수행될 수 있으며 학습 중 레퍼런스 모델은 필히 필요함

 

https://aihub.or.kr/leaderboard/view.do?currMenu=500&topMenu=102

aihub에서 진행중인 LLM 성능 리더보드에서도 DPO 방법론을 활용한 모델의 성능이 상위권에 위치하는 것을 확인가능

 

Rejection Sampling Fine-tuning(Best of N)

 

해당 방법론은 먼저 생성 모델이 각 컨텍스들 별로 모델 분포에 따라 답변 후보 문장 N개를 생성 후 답변 후보 문장에 대해서 리워드 스코어를 계산

특정 스코어 이상의 답변 후보들을 최종 답변으로 채택, 이때 특정 스코어 이상의 답변 후보를 채택하는 대신 가장 높은 스코어의 답변 한개를 고를 수 도 있음

이에 따라 Best-of N Samling이라 함

위에 방법대로 Sampling한 답변들을 정답 답변 레이블로 활용하여 SFT 방식과 동일하게 학습 수행

이를 통해 리워드가 높은 문장이 생성될 확률이 높아지는 모델이 학습

학습 과정 중 리워드 모델이나 레퍼런스 모델이 필요없는 간단한 방식임에도 매우 높은 성능을 보여줌!

하지만 Negative Sample에 대한 학습 기회가 적다는 단점이 존재

이애 따라 LLama-2-Chat은 1) Rejection Sampling fine-tunning 2) RLHF fine-tunning을 추가 진행하는 방식으로 결합하여 더 높은 성능을 추구하는 방식으로 학습

 

RLHF 코드

huggingface에 RLHF를 쉽게 수행할 수 있도록 TRL(Transformer Reinforcement Learning) 를 공개함

https://github.com/huggingface/trl

위 패키지를 레퍼런스로 학습을 수행할 수 있는 코드 레파지토리를 첨부

  • 추후업데이트 예정 github

 

참고자료

https://tech.scatterlab.co.kr/alt-rlhf/

https://tech.scatterlab.co.kr/luda-rlhf/

https://github.com/huggingface/trl

https://medium.com/@madhur.prashant7/rlhf-reward-model-ppo-on-llms-dfc92ec3885f

https://moon-walker.medium.com/리뷰-meta-ai의-논문-lima-less-is-more-for-alignment-결국-llm의-pre-training이-가장-중요하다-f3c9ea885f5a

https://github.com/huggingface/alignment-handbook/tree/main?tab=readme-ov-file

https://towardsdatascience.com/fine-tune-better-chat-models-with-distilled-identity-preference-optimization-ipo-99cddc819a48 (DPO 개선을 위한 IPO 로스)

Comments