본문 바로가기

AI

[논문리뷰] ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA: Efficiently Learning an Encoder that Classifies Token Replacements Accurately

생성 모델 대신 "판별모델"을 통해 인코더 사전 학습하기

(Pre-training Text Encoders as Discriminators Rather Than Generators) 

 

논문 - https://openreview.net/pdf?id=r1xMH1BtvB

깃허브 - https://github.com/google-research/electra


생성 모델 대신 "판별모델"을 통해 인코더 사전학습하기

 

대량의 코퍼스에 대해 모델을 사전학습하고, 풀고자 하는 태스크에 대해 fine-tuning 하는 방법은 NLP 과제를 수행하는 데에 있어 성공적이었다. 혁명의 시초였던 BERT에서 시작해서 RoBERTa, AlBERT, 그리고 최근에 발표된 google T5까지, 사전학습 방식에 각각의 노하우를 녹여냈지만, Masked Language Modeling(MLM) 혹은 Corrupted Language Modeling을 사전학습 과제로 사용했다는 공통점이 있다. 

MLM은 인풋 문장의 일부 토큰을 [MASK]와 같은 토큰으로 망쳐놓고, 모델이 이 오염된 자리에 어떤 단어가 왔을지 예측하도록 학습하는 방법이다. 이 사전학습을 수행하는 동안 모델은 자연스럽게 단어뿐만 아니라 인풋 문장의 문맥적인 의미를 포착하는 방법을 학습한다. 그 결과, 다운스트림 자연어 처리 태스크에 대해 모델을 transfer 하면 다량의 코퍼스에서 학습한 노하우를 바탕으로 좋은 성능을 내는 것이다.

하지만 이러한 사전학습 방법은 좋은 결과를 내기까지 대량의 계산이 필요했다. MLM은 마스킹된 일부 토큰에 대해서만 정의되는 태스크이기 때문에, 지식 습득을 위해 수많은 코퍼스가 필요하다.

 

ELECTRA는 이에 대한 대안으로  각 샘플을 더 효과적으로 활용하는 사전학습 태스크인 Replaced Token Detection을 제안한다. 이 접근법에서는 MLM에서 일부 토큰을 [MASK] 처리하듯이, 몇몇 토큰을 골라 작은 generator 네트워크가 샘플링한 가능성 있는 대안 토큰으로 대체한다. 이후 메인 모델은 generator가 오염시킨 인풋의 각 토큰이 대체된 것인지 아닌지를 맞추는, "discriminator"로서 학습된다.

 

논문의 실험 결과, 새로운 사전학습 방법은 MLM보다 효과적이었다. 특히 작은 모델도 강력한 성능을 보였는데, GPU로 4일 간만 학습한 모델이 GPT(30배 많은 계산양으로 학습)보다 GLUE 벤치마크에 대해 성능이 더 좋았다. 모델의 스케일을 키워도 성능은 계속 좋았다. RoBERTa와 XLNet에 대해, 그 계산량의 1/4만 사용해도 비슷한 성능을 보였고, 같은 만큼 계산을 했을 때는 더 뛰어난 성능을 보였다. 논문에서 제안한 사전학습 태스크는 마스킹된  일부 토큰이 아닌 전체 인풋 토큰에 대해 정의되는 태스크이기 때문이다. 

 

[Figure 1]같은 계산량에 대해 Replaced Token Detection으로 사전학습한 모델은 일관적으로 MLM보다 성능이 좋았다.

 

참고로 ELECTRA는 discriminator와 generator 등 GAN(Genarative Adversarial Network)의 아이디어를 차용했으나, adversarial 하게 학습을 한 것은 아니다. 논문에서 generator은 maximum likelihood를 통해 오염 토큰을 생성하는 방법을 학습하였는데, GAN을 자연어 텍스트에 적용하기가 어렵기 때문이다.


Method

사전학습 과제인 Replaced Token Detection의 기본 구조는 아래 그림과 같다.

 

 

ELECTRA는 두 개의 네트워크 - 생성 모델(G)과 판별 모델(D)을 학습시킨다.

 

각 네트워크는 하나의 인코더(Transformer)로 구성되어 있는데, 인풋 토큰  x = [x1, ... , xn]을 문맥적인 의미를 담은 벡터 representation의 시퀀스인 h(x) = [h1, ... , hn]로 매핑하는 역할을 한다. 생성모델 G는 주어진 위치 t에 대해 어떤 토큰 x_t를 생성할 확률을 아웃풋으로 리턴한다. 

 

 

판별 모델 D는 이제 위치 t가 주어질 때 판별 모델은 토큰 x_t가 "real(실제)", 즉 생성 모델로 인해 대체된 토큰이 아닌지, 혹은 "replaced(대체)"인지 구별해내야 한다. 이는 sigmoid 아웃풋 레이어로 수행한다. 

 

 

구체적으로 생성 모델은 MLM을 수행하는 방식으로 학습된다. MLM에서는 인풋 x가 주어질 때 랜덤한 위치들을 마스킹하고 (m = [m1, ... , m_k]) 마스킹된 토큰들은 [MASK] 토큰으로 대체한다. 이렇게 마스킹된 인풋은 x^{masked} = REPLACE(x, m, [MASK])로 표기한다. 이제 생성 모델은 마스킹된 토큰들에 대해 원래 어떤 토큰이 있었을지 예측하는 방법을 학습한다. 생성 모델이 만들어낸 토큰으로 대체한 인풋을 x^{corrupted}라고 표기하고, 이제 판별 모델은 이 오염된 인풋의 각 토큰이 원본인지, 대체된 것인지 판별해내야 한다. 일반적으로 k = [0.15n] 즉, 인풋의 15%를 마스킹하고, 이 과정을 수식으로 쓰면 다음과 같다. 

 

 

Loss 함수는 다음과 같다.

 

 

이는 GAN의 훈련 목적함수와 비슷해 보이지만, 몇 가지 중요한 차이점이 있다.

1. 생성 모델이 원본 토큰을 생성하는 데 성공해내면, 그 토큰은 'fake'가 아닌 'real'으로 간주한다.

2. 생성 모델은 판별 모델을 속이려는 적대적인 방식으로 학습하는 것이 아니라 Maximum likeligood를 통해 학습한다.

     -> 생성 모델이 샘플링한 결과에 대해 back-propagation이 어렵기 때문에, 텍스트에서 adversarial training이 어렵기 때문.

3. GAN에서는 생성 모델에 noise 벡터를 인풋으로 주는 반면, 본 논문에서는 그렇지 않다.

 

대량의 텍스트 Chi에 대해 위에서 정의한 두 loss를 하이퍼파라메터 lambda로 가중합한 손실 함수를 최소화한다.

 

 

위 식에서 알 수 있듯이 판별 모델의 loss를 생성 모델로 back-propagate하지는 않는다. (샘플링 때문에 어차피 이는 불가능하기 때문) 사전학습이 끝나면, 생성 모델은 더 이상 사용하지 않고, 판별 모델을 다운스트림 태스크에 대해 fine-tuning 한다.

 


Experiments

본 논문에서 모델과 대부분의 하이퍼 파라미터, 다운스트림 태스크(GLUE, SQuAD  등)에 대한 fine-tuning 기법 등은 BERT와 같은 것을 사용했다.

 

Model Extensions

본 논문에서는 모델링 측면에서 몇 가지 개선사항을 더하였다. 모델 개선 실험에서 특별한 언급이 없으면 모델 크기와 학습 데이터는 BERT-Base와 동일하다.

 

Weight Sharing : 사전학습 효율성을 증대하기 위해 생성 모델과 판별 모델의 가중치를 sharing 하는 방식으로 학습했다. G와 D의 모델 크기가 같으면 Transformer의 모든 weight를 묶어서 사용할 수는 있다. 그러나 연구 결과 생성 모델은 더 작은 것을 사용하는 것이 효과적이었기 때문에, 임베딩 (토큰 임베딩 & 위치 임베딩) 웨이트만을 셰어 하였다. 임베딩 크기는 판별 모델의 히든 차원과 같은 것을 사용했고, BERT와 마찬가지로 생성 모델에서 인풋과 아웃풋의 토큰 임베딩은 tie 하여 학습하였다.

 * G와 D의 크기를 같게 하고 5천만 스텝 동안 훈련한 결과

     -> GLUE에서 no tie = 83.6 / share embedding = 84.3 / share all = 84.4

-> ELECTRA는 토큰 임베딩 부분을 셰어함으로써 성능이 개선되었는데, MLM 태스크가 토큰 임베딩을 학습하는 데에 효과적인 방법이기 때문이다. 판별 모델은 인풋 토큰들에 대해서만 업데이트되는 반면 생성 모델은 단어 사전의 softmax에 대해 전반적으로 업데이트되기 때문이다. 반면에 모든 encoder weight를 셰어 하는 것은 미미하게 성능을 향상시켰으나, 생성 모델을 판별 모델 크기와 같게 해야 한다는 큰 단점이 있었다. 

-> 실험의 인사이트에 의해, 이후 실험에서는 G와 D의 임베딩 레이어만 쉐어하는 모델을 사용한다. 

 

Smaller Generators : 생성 모델은 판별 모델만큼 사이즈가 클 필요가 없을 수 있다. 따라서 레이어 개수를 줄이며 생성 모델의 크기를 줄이는 실험을 진행해 보았다. 또한, 극단적으로 간단한 "unigram 생성 모델"도 실험해 보았는데, 훈련 코퍼스에서 나온 토큰의 빈도에 기반해 토큰을 생성하는 방식이다. 

 

 

생성/ 판별 모델의 히든 차원에 따른 GLUE 점수를 살펴본 결과, 흥미롭게도 생성 모델의 크기를 판별 모델보다 작게 가져갈 때 오히려 성능이 더 좋았다. (그래프에서 생성 모델의 히든이 일정 개수(예. 512)를 넘어가면 GLUE 점수가 오히려 감소함) 

-> 일반적으로, 생성 모델의 크기를 판별 모델의 1/4~1/2의 크기로 설정할 때 성능이 가장 좋았다.

-> 너무 뛰어난 생성 모델은 판별 모델에게 지나치게 도전적인 과제를 부여하고, 오히려 학습 효율을 떨어뜨리는 것.

 

Training Algorithms : ELECTRA 학습에 있어 다른 학습 방법도 적용해 보았지만, 성능 향상에 도움이 되지는 않았다. 기본적으로는 생성 모델과 판별 모델을 같이 학습하는 방법을 사용하는데, 다음과 같은 두 단계의 학습 방법도 시도해 보았다.

  1. MLM loss를 이용해 n step동안 생성 모델만 학습하기 

  2. 판별 모델의 weight를 생성 모델의 것으로 초기화하고, 생성 모델의 weight는 고정하고 판별 모델만 n 스텝 학습하기

 

또한 강화 학습을 이용하여 생성 모델을 GAN과 같이 적대적인 방법으로 학습하는 방법도 시도해 보았다.

 

 

실험 결과, 생성 모델 대에서 판별 모델에 대한 목적함수로 옮겨갈 때 다운스트림 태스크에 대한 성능이 크게 증가한 것을 볼 수 있다. 하지만 G와 D를 함께 (jointly training, 파란 선) 학습한 것보다 성능이 좋지는 않았다. adversarial training을 한 것도 원래 방법보다 좋지는 않았지만, 모든 방법에 있어 BERT보다는 성능이 좋았던 것을 볼 수 있다.

참고로 adversarial 생성 모델은 MLM에서 정확도가 58%로, MLE 학습 정확도가 65%에 비해 좋지 않았다. 저자는 이 이유가 적대 방법으로 학습할 때 sample efficiency가 떨어졌기 때문인 것으로 분석한다. 텍스트를 생성한다라는 아주 큰 공간(action space)에서 강화학습을 통해 토큰을 생성하는 방법을 학습하는 것이 비효율적이었던 것이다. 또한, 적대적으로 학습된 생성모델은 아웃풋 분포에 있어 엔트로피가 더 낮았다. 이 두 가지 문제는 텍스트에 GAN을 적용하는 기존의 연구에서도 발견된 문제점들이었다.

 

Small Models

논문의 목표는 사전학습의 효율성을 높이는 것이기 때문에, 하나의 GPU에서 빠르게 학습할 수 있는 작은 모델을 실험해 보았다. BERT-base에 해당하는 하이퍼 파라미터부터 시작하여 시퀀스 길이를 줄이고 (512->128), 배치 크기를 줄이고(256->128) 모델의 히든 차원을 줄였으며(768->256) 토큰 임베딩 크기도 줄였다(768->128). 이 작은 모델(BERT-small)을 150만 스탭 동안 학습하였고, 학습 FLOPs를 유지하여 ELECTRA-small을 1백만 스텝 동안 학습하였다. 

 

 

그 결과 ELECTRA-small은 같은 계산량의 BERT-small보다 성능이 좋았고, 훨씬 계산이 많이 필요했던 GPT보다도 좋았다.

 

Large Models

사전학습에서 Replaced Token Detection 과제가 얼마나 효율적이었는지 분석하기 위해 큰 모델에 대해서도 실험해 보았다. ELECTRA-Large 모델은 BERT-large와 같은 모델 크기를 가지지만, 더 오래 학습하였다. 모델을 40만 스텝 동안 학습 (이는 RoBERTa의 약 1/4에 해당)해보고, 더 오래 (175만 스텝, 이는 RoBERTa와 동일) 학습한 결과를 비교해 보았다. XLNet에서 사용한 사전학습 데이터를 사용해 2048배치로 학습하였다. 

 

<GLUE dev set에 대한 결과>
<GLUE test set에 대한 결과>

 

큰 모델을 학습시켜 보았을 때, ELECTRA는 XLNet이나 RoBERTa 사전학습에 필요한 계산량의 1/4만 사용해도 비슷한 성능이 나왔다.

 

Efficiency Analysis

논문에서는 인풋의 일부 토큰을 마스킹하고, 이 작은 부분집합에 대해서만 학습 목적함수를 적용하는 것은 효율성이 떨어진다고 주장하였다. 하지만, 이 주장이 사실인지 확실하지 않고, 모델은 일부 토큰만을 예측하더라도 수많은 인풋을 받아들이기 때문에, MLM이 정말 비효율적이라는 증거는 없다. ELECTRA의 효율성이 어디서 오는지 분석하기 위해, BERT와 ELECTRA 방법의 중간에 있는 방법들을 실험해 봄으로써 사전학습 과제의 효과를 검증해 보았다.

1. ELECTRA 15% : 원래 ELECTRA와 동일하나, 판별 모델의 loss는 마스킹된 15%의 토큰에서 온 것만을 사용한다.

2. Replace MLM : MLM과 비슷하지만, 마스킹된 토큰 [MASK]를 인풋으로 받는 대신, 생성 모델이 만들어낸 토큰으로 대체하여 MLM을 진행한다. MLM 사전학습에서는 [MASK]라는 토큰이 나오는데, 다운스트림 태스크에서는 이 토큰을 절대 볼 일이 없다는 점이 성능 하락을 야기한다는 논의가 있었다. 이 실험을 통해 ELECTRA의 효과가 사전학습 단계에서 모델이 [MASK]토큰을 보지 않게 하는 점에 있는 것인지를 검증한다.

3. All-Tokens MLM : 마스킹된 토큰을 생성 모델이 만든 샘플로 대체하고, 모델은 오염된 토큰뿐만 아니라 모든 토큰에 대해 원본 토큰을 맞추는 MLM을 진행한다. 이 실험에서 각 토큰에 대해 원래 토큰을 copy할 것인지 결정하는 시그모이드 레이어를 도입하면 성능이 더 좋아진다는 것을 발견했다. 이 모델은 BERT와 ELECTRA의 혼종(?)으로, 생성 모델이 만들어낸 토큰이 과제를 더 어렵게 만들었다.

 

 

-> ELECTRA에서 모든 토큰에 대해 loss를 계산하는 것이 확실히 효과적이었다. (vs ELECTRA 15%)

-> BERT 성능은 [MASK] 토큰의 존재로 인해 성능이 약간 하락했을 수 있다. (BERT vs Replaced MLM)

-> All-Tokens MLM의 결과를 볼 때, BERT와 ELECTRA의 중간에 있던 이 모델의 성능이 ELECTRA에 가장 근접하게 나왔다.

 

 

추가적으로 ELECTRA는 학습도 더 빠른 것을 확인할 수 있다!

 


Conclusion

- 새로운 Self-supervised learning 사전 과제로 Replaced Token Detection을 제안

- 핵심 아이디어는 작은 생성 네트워크가 만들어낸 고품질의 negative sample을 사용해 텍스트 인코더가 인풋 토큰과 가짜 토큰을 구분할 수 있도록 학습시키는 것. 

- MLM과 비교할 때, 본 논문에서 제안한 방법은 계산적으로 효율적이면서 다운스트림 태스크에서 성능도 좋았음.

- 성능뿐만 아니라 모델의 효율성까지 고려한 방법이다!!!