LLM/LLM Customization

Model Compression Recipe - Generalized Knowledge Distillation (GKD)

LittleFox 2025. 5. 22. 18:39

Generalized Knowledge Distillation 개요

LLM은 대규모 파라미터를 활용하여 다양한 태스크에서의 가능성을 입증해왔으나, 이러한 규모로 인해 추론 비용 등 모델의 배포 관점에서 장벽이 있다. Knowledge Distillation(KD, 지식 증류)는 비교적 작은 학생(student) 모델을 학습하여 추론 비용과 메모리 사용량을 줄이기 위해 교사 모델을 압축하는 기법이다. Auto-regressive 모델에 대한 지식 증류는 ▲ teacher 모델이 생성한 고정된 아웃풋 시퀀스를 활용하거나 (Kim & Rush, 2016) ▲토큰 단위의 확률 분포를 지정함으로써 teacher 모델이 라벨을 지정할 수 있는 방법(Sanh et al., 2019)을 활용하여 이루어졌다. 그러나 이러한 방식은 학습 중에 사용하는 출력 시퀀스와 학생이 추론 중에 생성하는 시퀀스 간의 분포 불일치가 존재한다. 또한, distillation에서의 목적함수는 teacher과 student간의 forward KL을 최소화하는 경우가 일반적인데, 이 방법은 student 모델이 생성한 결과가 teacher가 생성한 것과 다른 결과로 이어질 수 있다. 

 

이러한 문제를 해결하기 위해 고정된 출력 시퀀스에 의존하는 대신, teacher 모델의 피드백을 활용하여 자체 생성된 출력 시퀀스로 student 모델을 학습하는 Generalized Knowledge Distillation(GKD) 방법이 제안되었다.

 

먼저, GKD에서는 auto-regressive sequence model에 대한 지식 증류를 상호작용하는 전문가와의 imitation leatning 문제로 접근한다. 이 관점을 활용하여 GCK에서는 고정된 아웃풋 시퀀스 대신 스스로 생성한 시퀀스를 on-policy로서 학습하게 된다. 더 나아가 GKD에서는 reverse KL과 generalized JSD와 같은 대안의 divergence measure을 사용할 수도 있다.  이는 student모델이 teacher 모델의 분포를 모방하는 능력이 부족한 경우에 유용할 수 있다. 또한 GKD는 언어 모델에 대한 강화학습 기법과도 쉽게 통합이 가능하다.

 

GKD 알고리즘 상세

더보기

 Preliminaries

● Auto-regressive generative sequence models

   - x : 인풋 시퀀스

   - y : 아웃풋 시퀀스

   - $\mathbb{V}$ :  M개의 토큰으로 이루어진 단어

   - $y_{ < n+1} = (y_{1}, y_{2}, ..., y_{n})$ : n번째 토큰까지 생성된 아웃풋 시퀀스

   - $L_y$ : 시퀀스 y의 길이

   - $p(.|y_{<n}) \in (0,1)^M$ : 토큰 단위의 auto-regressive policy, 인풋 x와 n번째 토큰 이전까지의 아웃풋 시퀀스의 조건부

   - $y ∼ p(·|x)$ : 인풋 x가 주어질 때 샘플링된 아웃풋 시퀀스 y

   - $ p(y_n|x) := p(y_n|y_{<n}, x)$ → 표기의 간편성을 위해 이와 같이 표기함

 

● KL-Based Divergence

    - $D_{KL}(P \Vert Q) = \sum_{c \in C} P(c) log \frac{P(c)}{Q(c)}$ : 두 개의 분포 P(C)와 Q(C)간의 KL divergence

    - KL divergence는 비대칭적, $D_{KL}(P \Vert Q)$ 를 forward KL,  $D_{KL}(Q \Vert P)$를 reverse KL로 부름

    - empirical data 하의 forward KL은 지도학습에서 최적화하는 mazimum likelihood에 해당함

    - P(C)를 $Q_\theta(C)$로 근사할 때 reverse와 forward KL을 최소화하는 것은 평균과 최빈값을 찾는 행동으로 이어짐

    - KL divergence는 unbound될 수 있기 때문에, bounded 된 divergence를 위해 generalized JSD(Jensen-Shannon divergence)를 쓸 수 있음.

    - JSD(beta)는 forward와 reverse KL을 0<beta<1의 계수를 사용해 interpolate함

문제 정의

두 개의 서로 다른 용량(capacity)을 가진 auto-regressive시퀀스 모델 p_S (student model)과 p_T(teacher model)이 있다.
학생 모델은 학습 가능한 파라미터 θ를 가지고 있고, pS_θ는 θ에 대해 미분 가능하다. 입력 데이터셋 X가 주어져 있으며, 선택적으로, 입력-출력 시퀀스 쌍 (X, Y)도 주어졌다고 가정한다. 만약 주어지지 않는다면, teacher 모델로부터 시퀀스를 샘플링하여 생성할 수 있다. 어떤 divergence 척도 D에 대해, 교사와 학생 모델의 토큰 수준 분포 간 차이(discrepancy)는 다음과 같이 정의된다:

 

여기서 x는 입력, y는 출력 시퀀스, Ly는 시퀀스의 길이다. 예를 들어, 위 식에서 발산 척도로 JSD(β)를 사용하면 다음과 같다:

 

◆ Supervised Fine-tuning

: teacher policy에 대한 접근이 불가하고, 정답 출력 시퀀스가 있는 고정된 데이터셋만 존재하는 경우,
  student model은 Negative log-likelihood를 최소화하는 방식으로 학습된다.

 

◆ Sequence level Knowledge Distillation

: Kim & Rush, 2016은 teacher 모델이 생성한 확률이 높은 시퀀스에 대해 likelihood를 최대화하는 방식으로 시퀀스 단위의 지식증류를 제안했다. 이는 teacher 모델의 출력에 대한 supervised fine-tuning이라고도 볼 수 있다.

 

◆ Supervised Knowledge Distillation

: Hilton이 제안하여 널리 사용되는 기법으로, student 모델 p_S는 teacher 모델의 토큰 수준의 확률분포를 모방하도록 학습된다. 이때 손실함수는 teacher의 전체 토큰 수준의 분포를 활용하여 풍부한 학습 신호를 제공하게 된다.

 

 Generalized Knowledge Distillation (GKD)

✔ On-policy Knowledge Distillation

GKD는 On-policy imitation learning을 지식 증류로 확장한다. 지식 증류에 On-policy 데이터를 사용할 경우, 학생 모델은 자신이 생성한 출력 시퀀스에서 잘못된 토큰들에 대해 teacher 모델의 로짓으로부터 토큰별 피드백을 받게 된다. 이 과정은 강화학습에서 관찰되는 피드백 루프와 유사하며, 학습과 추론에서의 분포의 불일치를 줄이는 데에 도움이 된다. 또한, 학습을 통해 student가 점점 발전함에 따라 student 모델이 생성하는 데이터의 품질 또한 점차 향상되게 된다. 인풋 x가 주어질 때, 학생 모델은 출력 시퀀스 y를 생성하고, intermediate state인 $y_{<n}$에 대해 teacher 모델의 토큰 단위의 분포 P_T(y_x|x)를 흉내내게 된다. 이때 on-policy loss L_{OD}는 아래와 같이 적용한다:

여기서 student의 sampling 분포 pS(·|x)에 대해서는 backprogate하지 않는다. 이는 학습을 안정적으로 만들며, 계산적으로도 효율적이다. 

 

학습 중, temperature=1로 설정함으로서 학생 모델이 다양한 시퀀스를 생성할 수 있도록 장려한다. 또한, 레이블이 주어지지 않은 인풋 프롬프트에 대해서 student를 사용해서 시퀀스를 생성하는 것은 teacher 모델을 사용하는 것보다 계산 비용이 적게 들 수 있다.

 

✔ Generalized Knowledge Distillation

이러한 on-policy KD를 기반으로, supervised learning과 on-policy 방식을 통합하여 보다 일반적인 방법인 Generalized KD(GKD)로 확장한다. GKD에서는 최적화할 목적함수인 divergence와 더불어 학습할 출력 시퀀스를 선택할 수 있다. Teacher과 student간의 토큰간의 divergence를 측정하는 어떠한 방법도 선택이 가능하다. 출력 시퀀스에 대해, GKD는 teacher 모델이 생성했거나, ground-truth이거나, student가 생성한 on-policy 시퀀스 중 사용이 가능하다. 

 

Algorithm:

 1: Given: teacher 모델 $p_T$, student 모델 $P_S^{\theta}$, (인풋, 아웃풋 쌍이 포함된) 데이터셋 (X,Y)이 주어질 때

 2: Hyperparameters: Student 데이터 비율 $\lambda \in [0,1]$, Divergence D, 학습률 η 에 대해

 3: for 각 스텝 k = 1, ..., K동안 do

 4:    Uniform(0,1)분포로부터 랜덤한 값 u를 생성한다

 5:    if 만약 u가 $\lambda$ 보다 작거나 같다면 ( u ≤ λ ), then

 6:          X에서 인풋들 x를 샘플링하여 student model로부터 아웃풋 y ∼ p_S^θ (·|x)를 생성하여 B= {x_b,y_b}를 구성한다

 7:    else

 8:          (X, Y)로부터 인풋과 아웃풋을 샘플링하여 배치 데이터  B = {(xb, yb)}를 구성한다

 9:    end if

10:   Loss 함수를 극소화하기 위해 파라미터 theta를 업데이트한다:

11: end for

 


이때 student는 teacher 모델이 피드백을 제공할 수 있는 수준의, 어느정도 퀄리티 있는 시퀀스를 생성할 수 있는 모델을 사용한다. GKD 논문에서는 supervised Fine-tuning을 거친 student model에서 시작한다.

 

✔  GKD에서 divergence의 선택:

지식 증류에서는 forward KL을 주로 사용하지만, 이는 student 모델이 teacher의 토큰 단위 분포의 전체를 커버할 수 있어야 한다. 이 과정에서 student는 teacher 분포 하에서 확률분포가 낮은 토큰들에도 확률 질량을 할당하고, 이로 인해 hallucination이나 low-quality 출력을 생성하는 결과를 낳을 수 있다. 특히 학생 모델의 능력이 teacher에 비해 많이 떨어질 때, temperature sampling 시 이러한 문제가 더 많이 발생할 수 있다.

이러한 문제에 대한 대안으로 reverse KL과 같은 mode-seeking divergence의 경우, teacher 모델이 높은 확률을 부여하는 토큰을 우선시하며, 저품질의 생성을 피하게 할 수 있지만, 대신 인풋이 주어졌을 때 생성의 다양성이 떨어지는 결과가 생길 수 있다.

이에 실험적으로 optimal한 divergence는 수행하고자 하는 태스크에 따라 다를 수 있음이 나타난다.

 

 

실험 결과

🔍 요약 태스크에서 Divergence 종류가 성능과 다양성에 미치는 영향

 

✔ 실험 개요:

다양한 발산(Divergence)을 활용한 온-폴리시 GKD(On-policy Generalized Knowledge Distillation) 실험을 통해 학생 모델이 생성한 결과물의 품질과 다양성 사이의 트레이드오프를 평가한다. 이때, 샘플링 temperature를 조절하여 영향을 관찰한다. 다양성(Diversity)은 Self-BLEU(Zhu et al., 2018)로 측정하며,

  • 점수 100은 **완전히 결정적인 출력(즉, 다양성 없음)**을 의미하고,
  • 점수 0은 최대의 다양성을 뜻한다.

✔ 실험 결과:

Forward KL에서 Reverse KL로 일반화된 JSD(Generalized Jensen-Shannon Divergence를 거쳐 전환할수록, 출력의 다양성이 감소하는데, 이는 divergence가 점점 더 mode-seeking 성향을 띠기 때문이다. 이러한 mode-seeking divergence는 특히 샘플링 온도가 높을 때(γ = 1) 일반적으로 더 우수한 품질을 보이는 경향이 있다. 한편, 온도를 낮추면(temperature ↓) 다양성이 줄어드는 동시에, divergence 선택 간 간 성능 차이도 줄어든다.

 

Huggingface TRL: GKD Trainer

Documentation: https://huggingface.co/docs/trl/en/gkd_trainer

 

Generalized Knowledge Distillation Trainer

class trl.GKDConfig < source > ( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str]

huggingface.co

 

 

GKDTrainer는 SFTTrainer 클래스를 감싸는 래퍼(wrapper)로, teacher model 인자를 추가로 받는다. 이 트레이너는 GKDConfig를 통해 설정되는 세 가지 주요 파라미터가 필요하다:

 

lmbda

  • 학생 데이터 비율을 조절하는 파라미터로,on-policy로 학생이 생성한 출력이 전체 학습 데이터에서 차지하는 비율
  • lmbda = 0.0일 때: 손실 함수는 supervised JSD가 되며, student는 teacher의 token-level probabilities로 학습됨
  • lmbda = 1.0일 때: 손실 함수는 on-policy JSD가 되며, student가 출력 시퀀스를 생성하고, 이에 대해 교사로부터 토큰별 피드백을 받는다.
  • 0과 1 사이의 값일 때: 각 배치마다 확률적으로 두 방식 중 하나를 선택하게 되며, 이 확률은 lmbda 값에 따라 결정된다.

seq_kd

  • 시퀀스 수준 지식 증류(Sequence-Level KD) 수행 여부를 조절한다. 이는 teacher 모델이 생성한 출력에 대해 student model을 supervised fine-tuning하는 것으로 볼 수 있다.
  • seq_kd = True이고 lmbda = 0.0일 경우: 손실 함수는 지도 방식 JSD가 되며, 교사가 출력 시퀀스를 생성하고, 학생은 이에 대해 토큰별 피드백을 받는다.

beta

  • generalized Jensen-Shannon divergenve에서의 interpolation을 조절한다.
  • beta = 0.0이면: 손실 함수는 정방향 KL 발산(forward KL divergence)에 근사된다.
  • beta = 1.0이면: 손실 함수는 역방향 KL 발산(reverse KL divergence)에 근사된다.
  • 0과 1 사이의 값일 때: 두 KL 발산 사이를 보간(interpolate)하게 된다.

🔍 참고 결과

  • 저자들은 온-폴리시 데이터 사용 비율이 높은 경우(lmbda가 큰 값) 성능이 더 좋다고 보고한다.
  • 최적의 beta 값은 작업(task)과 평가 방법(evaluation method)에 따라 달랐다고 한다.

참고자료

ON-POLICY DISTILLATION OF LANGUAGE MODELS: LEARNING FROM SELF-GENERATED MISTAKES

- Huggingface: Generalized Knowledge Distillation Trainer