DPO 학습 개요
Alignment tuning은 모델의 출력이 인간의 기대와 선호에 부합하도록 유도하는 과정이다.
이중 DPO는 이러한 강화학습 과정에서 보상 최적화 과정을 단순화하여 이러한 문제를 인간 선호 데이터에 기반한 single-stage policy training 문제로 취급하여 모델을 학습한다.
DPO 프레임워크는 두 가지 핵심 모델인 reference policy( π_ref )과 target policy ( π_tar )을 기반으로 구축된다.
여기서 reference는 일반적으로 사전 학습 및 감독 학습 기반 미세 조정이 완료된 언어 모델로, 학습 동안 고정된 상태로 유지된다. 반면, target policy는 eference policy 로부터 초기화되며, 선호도 기반 피드백을 통해 반복적으로 업데이트되어 인간 판단과의 alignment를 개선하게 된다.
데이터 수집 및 준비
DPO 학습을 위해서는 각각의 프롬프트 x에 대해 샘플링한 여러 후보 답변에 대한 선호를 라벨링한 데이터가 필요하다. 다음은 DPO 학습에 활용할 수 있는 대표적인 preference data인 UltraFeedback 데이터셋이다. (참고: openbmb/UltraFeedback)
instruction | models | completions |
모델에 입력하는 프롬프트 | 모델 후보에서 랜덤하게 샘플링되어 프롬프트에 대해 답변을 생성한 모델 | 4가지 후보 모델에 대해 - 모델별 align할 원칙(예. helpfulness) - custom system prompt - response (답변) - annotations: 평가 기준 (예. instruction-following, honesty, ...) 별 점수 |
I am going to cairo in June of this year, thinking of 4-5 days, what are the best things I could do? Going with family of 5, granddad, grandmom, sister, mom and me. No days can have too much walking but sister, me and mom like adventure sports | "falcon-40b-instruct", "gpt-4", "starchat", "wizardlm-7b" |
[ { "model": "falcon-40b-instruct", "principle": "helpfulness", "custom_system_prompt": "As an AI assistant, ....(중략)", "response": "Cairo is a city that has something for everyone. ,... (중략)", "annotations": { "instruction_following": [ { "Rating": "2", "Rationale": "The text only partially addresses the task goal by providing a general ..(중략)" } ], "honesty": [ { "Rating": "3", "Rationale": "..." } ]...} ... ] |
이렇게 수집한 데이터는 DPO 학습에 사용하기 위해 chosen (선호되는 답변)과 rejected(선호되지 않는 답변) 쌍으로 정제해야 한다. 다음은 DPO 학습을 위해 TRL library에서 UltraFeedback 데이터를 binary한 학습 데이터 예시이다 (참고: trl/dpo_trainer)
chosen | rejected | score_chosen | score_rejected |
선호 답변 | 선호하지 않는 답변 | 선호 답변의 점수 | 선호하지 않는 답변의 점수 |
[ { "content": "Where does candlewood grow? in ocean or desert?", "role": "user" }, { "content": "Candlewood mainly grows in the desert or dry areas, such as the southwestern United States, Mexico, and parts of Central America and South America. However, it can also be found in parts of Florida and Texas in the United States. It is not found in the deeper ocean, but it can grow in marshes and estuaries near the sea.", "role": "assistant" } ] | [ { "content": "Where does candlewood grow? in ocean or desert?", "role": "user" }, { "content": "Hello! I'm here to help answer your questions to the best of my ability. I understand that you're asking about candlewood and its growth habits.\n\nTo be honest, I'm not familiar with a plant called \"candlewood\" that grows in the ocean or the desert. It's possible that it's a rare or fictional plant, or it may be a misspelling or misidentification of a different plant.\n\nIf you have any more information or context about candlewood, I'd be happy to try and help you find the answer. Alternatively, if you have a specific question about a different topic, I'll do my best to assist you with that. Please feel free to ask!", "role": "assistant" } ] | 8.5 | 6 |
🛠️ DPO 학습을 위한 학습 데이터 수집 프로세스
1) 생성 단계: 후보 모델이 각각의 프롬프트 x에 대해 후보 출력을 생성한다.
2) 주석 단계: 인간 평가자가 생성된 출력들을 비교하여 상대적인 선호를 판단한다.
- 이때 인간 평가자는 후보 응답들을 일관성(coherence), 관련성(relevance), 명확성(clarity) 등의 기준에 따라 비교하거나 순위를 매긴다
3) 데이터 포맷 변경: DPO 학습을 위한 라이브러리에서 요구하는 데이터 규격으로 데이터 형식을 변환한다.
주요 고려 사항
✔ DPO를 효과적으로 초기화하기 위해 강력한 reference model을 사용하는 것이 중요하다.
- SFT를 거친 모델은 일반적으로 reference policy( π_ref )에 대해 뛰어난 성능을 제공함으로써, 이후 선호 기반의 업데이트 과정을 통해 기본적인 능력을 습득하는 것이 아니라 세밀하고 정교한 성능 향상에 집중할 수 있게 한다.
✔ 선호 데이터의 품질
- 선호 데이터는 사용자 기대치의 다양한 변동성을 포착할 수 있을 만큼 충분히 다양해야 하며,
이를 통해 모델의 적응성을 높이고 특정 작업에 과적합(overfitting)되는 것을 방지한다
DPO 학습의 변형 알고리즘
생성 최적화를 위한 DPO
토큰 단위의 반복형 DPO 전략은 인간 선호에 대한 정밀하고 연속적인 정렬을 가능하게 한다.
토큰 수준 DPO(token-level DPO)는 이를 밴딧(bandit) 문제로 재정의하고, (S, A, f, r, ρ₀)로 구성된 Markov Decision Process (MDP)을 채택한다. 이 방식은 비선호 토큰에 대해 과도한 KL divergence 문제가 발생하는 것을 완화한다.
- TDPO:
- 각 토큰에 대해 forward KL divergence을 적용하여 모델의 출력 분포를 제어함
- 이때 reverse KL divergence 대신 forward KL divergence를 사용하여 정렬 품질과 생성 다양성을 동시에 향상
- 기존 DPO는 문장 전체를 단위로 선호도를 평가하지만,TDPO는 각 토큰 수준에서 선호도를 평가하여 보다 세밀한 정렬을 가능하게 함
- Iterative DPO:
- 모델이 생성한 응답에 대해 반복적으로 선호도를 평가하고, 이를 기반으로 모델을 점진적으로 개선
- 모델이 생성한 응답을 자체적으로 평가하여 새로운 선호 데이터를 생성하고, 반복적인 학습을 통해 모델의 정렬 성능을 향상 시킴
- Pairwise Cringe Optimization:
- 기존의 Cringe Loss를 확장하여, 쌍(pairwise) 응답 간의 선호도를 기반으로 모델을 align함
- 소프트 마진을 활용하여 탐색(exploration)과 활용(exploitation) 간의 균형을 유지
- Step-wise DPO:
- 복잡한 추론 과정을 단계별로 분해하여, 각 단계에 대한 선호도를 기반으로 모델을 정렬
- 선호 데이터셋을 분할하고, 각 라운드에서 갱신된 정책을 다음 라운드의 기준점으로 삼아 반복적인 업데이트를 수행
- 복잡한 수학적 추론이나 다단계 문제 해결에 효과적
TRL library DPO
DPO 학습을 위해서는 먼저 SFT(Supervised fine-tuning) 모델을 학습하여, DPO 알고리즘이 학습하게 될 데이터가 분포 내(in-distribution)에 있도록 보장해야 한다. 그 다음은 두 단계로 DPO 학습을 수행할 수 있다.
- 데이터 수집: 프롬프트에 대해 긍정 응답과 부정 응답 쌍을 포함하는 선호 데이터셋을 수집한다.
- 최적화: DPO 손실 함수의 로그 가능도(log-likelihood)를 직접 최대화한다.
데이터 요구사항
- DPO 학습을 위해서는 preference 데이터셋이 필요하며, conversational 데이터셋과 standard 데이터셋 모두 지원
- conversational 데이터셋을 입력할 경우, 자동으로 chat template를 적용하여 학습을 구성함
- explicit과 implicit 프롬프트 데이터셋을 모두 지원하지만, explicit prompt 사용을 권장함
- 데이터셋은 preference style 데이터셋으로 구성
## Explicit prompt (recommended)
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
Logging되는 metric
- rewards/chosen: 정책 모델과 참조 모델의 선택된 응답에 대한 로그 확률 차이의 평균을 β로 스케일링하여 계산
- rewards/rejected: 정책 모델과 참조 모델의 거절된 응답에 대한 로그 확률 차이의 평균을 β로 스케일링하여 계산
- rewards/accuracies: 선택된 보상이 해당 거절된 보상보다 더 큰 경우의 비율의 평균을 계산
- rewards/margins: 선택된 보상과 해당 거절된 보상 간의 평균 차이(마진)를 계산
Loss 함수
DPO 알고리즘은 여러 loss 함수를 사용할 수 있으며, loss 함수는 DPOConfig에서 loss_type 파라미터를 셋팅함으로써 조정할 수 있다.
loss_type= | 설명 |
"sigmoid" (default) |
선호 데이터가 주어질 때 Bradley-Terry 모델에 따라 이진 분류기를 학습할 수 있으며, 로그-시그모이드(logsigmoid)를 적용하여 정규화된 우도에 기반한 로지스틱 회귀를 수행![]() 👉 chosen과 rejected 응답의 로그 확률 차이를 sigmoid 함수에 적용하여 선호되는 응답의 확률이 더 높아지도록 모델을 학습시킴. 이때 chosen과 rejected 사이의 차이를 최대한 벌리는 loss로 인해 모델이 positive에 과적합되는 문제가 발생하기도 함 |
"hinge" | SLiC 논문에 기반하여 힌지 손실을 사용하며, 이 경우 β는 마진의 역수로 설정![]() 👉 선호되는 응답과 덜 선호되는 응답의 로그 확률 차이가 특정 임계값(여기서는 1)을 넘지 못하면 손실이 발생하며, 이를 통해 모델이 더 확실하게 선호되는 응답을 선택하도록 학습 |
"ipo" | IPO 논문에서 제안한 손실함수는 선택된 응답과 거절된 응답의 로그 가능도 차이의 간격(gap)의 역수로 β를 설정하며, 평균 기반 손실을 사용![]() 👉 DPO의 일반화된 형태로, 로그 확률 차이가 특정 값(여기서는 1/2β )에 가까워지도록 모델을 학습시키며, 이를 통해 모델이 선호도 데이터를 과도하게 신뢰하지 않도록 정규화 효과를 제공 |
"exo_pair" | EXO 논문에서는 DPO의 로그-시그모이드 손실 대신 역방향 KL을 최소화하는 손실을 제안하며, label_smoothing을 0보다 크게 설정하면 단순화된 EXO 방식으로 동작 |
"nca_pair" | NCA 방식은 상대적인 likelihood 대신 각 응답의 절대 우도를 최적화하는 방식으로 정렬을 수행. |
"robust" | Robust DPO 는 선호 라벨에 노이즈가 존재할 수 있음을 가정하고, 이를 반영하여 레이블 노이즈에 강인한 손실을 추정. label_smoothing 값을 0.0보다 크게 설정해야 함 |
"bco_pair" | BCO 는 이진 분류기를 학습하고, 그 로짓을 보상으로 간주하여 {프롬프트, 선택 응답} 쌍은 1로, {프롬프트, 거절 응답} 쌍은 0으로 매핑하는 분류 기반 정렬을 수행 |
"sppo_hard" | SPPO 는 내쉬 균형(Nash Equilibrium)을 점진적으로 만족시키는 방식으로, 선택된 응답의 보상은 1/2로, 거절된 응답의 보상은 -1/2로 유도함 |
"aot" or loss_type="aot_pair" |
AOT 는 정렬을 확률분포 단위로 수행하는 방식으로, 긍정 샘플의 보상 분포가 부정 샘플보다 1차 확률 지배(stochastic dominance)를 하도록 최적화 - "aot": 선택-거절 쌍이 있는 데이터셋에 사용 - "aot_pair": 쌍이 없는 데이터셋에 사용 |
"apo_zero" or loss_type="apo_down" |
APO 는 기준 응답(anchor)을 설정하는 정렬 방식이며 - "apo_zero"는 이긴 응답의 가능도를 높이고 진 응답의 가능도를 낮춘다. - "apo_down"은 두 응답 모두의 가능도를 낮추되, 진 응답을 더 강하게 억제한다. |
"discopop" | DiscoPOP 논문은 더 효율적인 오프라인 선호 최적화 손실을 탐색하며, 로그 비율 기반 손실을 통해 IMDb, Reddit, Alpaca Eval 등에서 우수한 성능을 달성 |
✔ Label smoothing
cDPO 는 선호 레이블에 일정 확률로 노이즈가 존재한다고 가정하여 DPO 손실을 수정한 기법이다. 이 방법에서는 DPOConfig의 label_smoothing 파라미터를 사용하여 레이블 노이즈 존재 확률을 모델링한다.
보수적인 loss를 적용하려면 label_smoothing 값을 0.0보다 크고 0.5 이하(기본값은 0.0)로 설정한다.
✔ Syncing the reference model
TR-DPO 논문은 DPO 학습 중, 매 ref_model_sync_steps 스텝마다 참조 모델 가중치를 현재 모델과 ref_model_mixup_alpha 비율로 동기화할 것을 제안한다. 이 콜백 기능을 활성화하려면 DPOConfig에서 sync_ref_model=True로 설정한다.
✔ RPO loss
RPO 논문은 선택된 선호에 대해 가중치가 부여된 SFT 손실과 DPO 손실을 결합하여 반복적으로 선호를 튜닝하는 알고리즘을 구현한다. 이 loss를 사용하려면 DPOConfig에서 rpo_alpha를 적절한 값(논문에서는 1.0 추천)으로 설정한다.
✔ WPO loss
WPO 논문은 현재 정책 하에서 선호 쌍의 확률에 따라 가중치를 재조정하여 off-policy 데이터를 on-policy 데이터처럼 만드는 방식을 제안한다. 이 방식을 사용하려면 DPOConfig에서 use_weighting=True로 설정한다.
✔ For Mixture of Experts Models: Enabling the auxiliary loss
MOE 모델은 전문가들 간의 부하가 고르게 분산될 때 가장 효율적으로 작동한다. 선호 튜닝 동안 MOE를 비슷하게 학습시키기 위해 로드 밸런서의 보조 손실(auxiliary loss)을 최종 손실에 추가하는 것이 유익하다. 이 옵션을 활성화하려면 모델 설정(e.g., MixtralConfig)에서 output_router_logits=True로 설정한다. 보조 손실의 기여 비율을 조정하려면 모델 설정에서 router_aux_loss_coef 하이퍼파라미터(기본값: 0.001)를 사용한다.
Unsloth를 사용하여 DPO 학습 가속화하기
- unsloth library를 사용하면 QLoRA / LoRA 기반 미세조정 속도를 2배 빠르게 하고 메모리 사용량을 60% 줄일 수 있음
unsloth는 SFTTrainer와 완벽하게 호환되며, 현재 Llama 계열 모델(Yi, TinyLlama, Qwen, Deepseek 등)과 Mistral 아키텍처만 지원
GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | 1.88x | -11.6% |
Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
PEFT를 사용할 때 reference model 설정하는 방법
- 방법 1: 모델 인스턴스를 두 번 생성하여 각각 어댑터를 로드하는 방법
- 동일한 베이스 모델을 두 번 로드하고, 각각에 어댑터를 적용하여 하나는 학습용, 다른 하나는 참조용으로 사용함
- 장점: 구현이 간단하며, 각 모델이 독립적으로 작동하므로 충돌이 없음
- 단점: 메모리 사용량이 두 배로 증가하여 비효율적
- 방법 2: 어댑터를 베이스 모델에 병합한 후, 새로운 어댑터를 추가하여, ref_model 파라미터를 null로 두는 방법
- 기존 어댑터를 베이스 모델에 병합하고, 새로운 어댑터를 추가하여 학습 하는 방법. 이때 ref_model 파라미터를 null로 설정하면, DPOTrainer가 참조 추론 시 어댑터를 언로드하여 베이스 모델만 사용
- 장점: DPOTrainer가 참조 추론 시 어댑터를 언로드하므로 효율적이며, 메모리 효율성이 높아 단일 모델로 학습과 참조를 모두 수행할 수 있음
- 단점: QLoRA와 같이 양자화된 모델에서는 병합 전에 디퀀타이즈(dequantize) 과정이 필요하며, 이로 인해 메모리 사용량이 증가할 수 있음
- 방법 3: 어댑터를 서로 다른 이름으로 두 번 로드한 다음, 학습 중에 set_adapter를 사용해 DPO용 어댑터와 참조용 어댑터를 전환하는 방법
- 동일한 베이스 모델에 두 개의 어댑터를 서로 다른 이름으로 로드하고, 학습 중에 set_adapter 메서드를 사용하여 활성화된 어댑터를 전환 하는 방법
- 장점: 메모리 사용량이 방법 1보다 적으며, 2번 방법에 비해 약간 비효율적이지만(VRAM에 어댑터 크기만큼 추가 부담) 어댑터 간의 충돌을 방지할 수 있음
✔ QLoRA 병합 후 DPO 적용(방법 2) 시 단점
- Benjamin Marie의 제안에 따르면, QLoRA 어댑터를 병합할 때는 먼저 베이스 모델을 디퀀타이즈(dequantize)한 후 어댑터를 병합하는 것이 가장 좋다.
- 그러나 이 방식으로 병합하면 퀀타이즈되지 않은 베이스 모델이 남게 된다.
- 따라서 DPO에 QLoRA를 사용하려면 다시 퀀타이즈하거나, 메모리 요구량이 큰 상태(비퀀타이즈 상태)로 사용해야 한다.
✔ 방법 3 - Adapter을 두 번 로드하는 방식
- 위의 단점을 피하기 위해, 학습할 모델에 미세조정된 어댑터를 서로 다른 이름으로 두 번 로드한다.
- 그리고 DPOTrainer 설정에서 학습용 어댑터 이름과 참조용 어댑터 이름을 지정하여 사용한다.
주요 참고 자료
- A SURVEY ON POST-TRAINING OF LARGE LANGUAGE MODELS
- https://huggingface.co/docs/trl/main/en/dpo_trainer
'LLM > LLM Customization' 카테고리의 다른 글
Model Compression Recipe - Generalized Knowledge Distillation (GKD) (1) | 2025.05.22 |
---|---|
Model Compression Recipe - Knowledge Distillation (KD) (0) | 2025.05.07 |
LLM 성능 향상을 위한 Post-training 방법론 개요 (0) | 2025.04.29 |
ChatGPT Fine-tuning 예시 | 언제, 어떻게 해야 하는가 (25) | 2023.11.11 |