본문 바로가기

AI

[논문리뷰] DeepMind RETRO - 수 조개의 토큰 DB로부터 정보를 검색해 강화된 언어모델

블로그 : https://www.deepmind.com/publications/improving-language-models-by-retrieving-from-trillions-of-tokens

논문 : https://arxiv.org/abs/2112.04426


Motivation

언어 모델이란 '가장 자연스러운 단어 시퀀스를 찾아내는 모델'로 단어의 시퀀스에 확률을 할당(assign) 하는 일을 하는 모델이다. 그리고 이러한 작업을 학습하기 위해 가장 보편적으로 사용하는 방법은 언어모델이 이전 단어들이 주어졌을 때 다음 단어를 예측하도록 훈련시키는 것이다.

 

지난 몇 년 동안 컴퓨팅 자원의 발달에 힘입어 언어모델은 더 큰 파라미터더 많은 데이터에 대해 학습하는 방향으로 발전해왔다. 

2020년 1750억 규모의 GPT-3 모델의 발표 이후 1780억 파라미터의 Jurassic-1, 2800억 파라미터의 Gopher, 5300억 파라미터의 Megatron-LM 등이 등장하며 모델 규모의 확장에 따른 성능 향상의 가능성을 보여주었다.

 

이러한 대규모 언어모델의 특징은 대량의 코퍼스를 수천억 개의 파라미터를 활용해 토큰의 분포, 언어의 의미, 그리고 세상에 대한 상식까지 모델링한다는 것이다. 이때 언어 및 세계에 대한 정보는 모델의 파라미터들이 오롯이 학습하고 기억해야 한다.

 

RETRO는 정보의 저장을 모델에게 모두 할당하는 대신, 모델이 외부 지식 DB에 접근해 필요한 지식을 활용할 수 있는 semi-parametric 접근법을 통해 기존 모델 대비 25배 적은 70억 개의 파라미터만으로 다운스트림 태스크에 있어 비슷한 성능을 보였다. RETRO에서 참조한 DB는 2조 개의 토큰들로 이루어져 있으며, 이들을 고정된 BERT모델을 사용해 인코딩해두고 사용한다.

 

RETRO 모델의 특징은 다음과 같다 : 

1. RETRO는 정보검색(retrieval)을 통해 강화된 autoregressive 언어모델이다. 받아온 텍스트의 정보를 언어모델에 전달하기 위해 chunked cross-attention 모듈을 사용하는데, 이 계산은 받아온 데이터의 양에 대해 선형적인 시간 복잡도를 가진다. 정보 검색에는 고정된 BERT모델을 사용할 수 있다는 것을 실험적으로 밝혔고, 따라서 검색을 위한 네트워크를 별도로 학습하거나 업데이트할 필요가 없다.

 

2. 이 방법은 모델 사이즈 혹은 데이터 규모를 스케일-업할 때 잘 작동한다. 1억 5천만 파라미터부터 70억 파라미터 규모 모델에 대해 RETRO 기법은 계속적은 성능 향상을 보였고, 추론 시 데이터베이스의 규모와 검색해온 정보 양을 늘릴 때 성능 향상을 보였다. 가장 큰 모델은 Wikitext103, Pile 등의 다운스트림 데이터셋에서 SOTA 성적을 얻었고, 질의응답 태스크에 대해  fine-tuning 시에도 성능 향상의 가능성을 보였다.

 

3. 논문에서는 테스트 데이터셋 유출 문제를 해결하기 위해 학습셋과 테스트 문서의 유사성을 평가하는 방법을 제안하였다. 이는 모든 언어모델이 풀어야 하는 문제이기도 한데, 특히 retrieval 방식으로 성능을 강화하는 모델은 평가 중에 학습 데이터에 직접적으로 접근할 수 있기 때문에 중요한 문제이다. 제안한 방법을 사용해 평가해 보았을 때 RETRO의 성능 형상은 명시적으로 찾아온 이웃 문서를 복사하는 것과 일반적인 지식의 추출 둘 다에서 기인한다는 것을 알 수 있다. 

 


Method

정보 검색을 통해 강화된 아키텍처는 수조 개의 토큰들로 이루어진 데이터로부터 정보를 받아올 수 있어야 한다.

저장 공간과 계산을 효율화하기 위해, RETRO는 인풋에 있는 각각의 토큰 별로 정보를 받아오는 대신, 토큰의 chunk 단위로 이웃 문서를 검색해오는 방법을 선택하였다. RETRO 모델의 학습 과정은 요약하자면 아래와 같다.

 

Step 1. 정보를 검색할 key-value 데이터베이스 구축하기

  > 이 DB에서 key는 고정된 BERT 임베딩, value는 텍스트 토큰의 원본 chunk를 가진다.

  > 논문에서는 MassiveText의 다국어 버전을 학습 및 검색용 DB 데이터로 활용하였다. 

 

  > Details :

     - Massive Text는 웹사이트, 뉴스, 깃허브를 포함한 다양한 출처에 대해 여러 가지 언어로 이루어진 텍스트 문서를 포함

     - 원본 데이터는 총 5조 개의 토큰으로 이루어져 있는데, 학습/평가 시에는 일정 비율로 샘플링한 데이터를 구축해 사용

        · 학습 데이터 : 아래와 같은 비율로 샘플링한 데이터셋을 구성해 6천억 개의 토큰으로부터 정보를 받아옴

        · 평가 데이터 : Books 도메인에 대해서만 4%의 샘플링을 적용, 총 1조 7500개의 토큰으로부터 정보를 받아옴

     -  테스트 데이터 유출을 방지하기 위해 13-gram Jaccard similarity를 계산해 평가/테스트 데이터와 0.8 이상 유사한 문서는 삭제하고 테스트

     - 토크나이저는 SentencePiece를 사용해 128,000 토큰으로 이루어진 사전 구축

 

 

Step 2. 학습 시퀀스를 chunk 단위로 쪼개고, 각각에 대해 k-nearest neighbour 받아오기

  > 먼저 학습 시퀀스는 일정한 길이 n을 가지는 토큰의 chunk들로 쪼갠다.

  > 이후 해당 chunk의 임베딩을 DB의 key들과 비교하여 각각 k개의 가장 가까운 이웃 문서(neighbour)를 찾아온다.

 

  > 기술적 구현:

     - T개의 구성요소를 가진 데이터베이스에 대해 쿼리할 시 계산 시간은 O(log T) 복잡도를 가졌다.

     - SCANN 라이브러리를 활용해 구현, 평균적으로 2조 개의 토큰 데이터베이스에 대해 10ms의 시간이 소요되었다. 

 

  > Details:

     - 외부 데이터베이스 구성 :

        · key-value 메모리로 이루어져 있는데, 각각의 value는 2개의 인접 토큰의 chunk로 이루어짐 (표기 :  [ N,F ])

        · N은 이웃 chunk로, key 계산에 활용, F는 원본 문서에서 N에서 이어지는 원본 문서.

        ·  key는 N에 대한 BERT 임베딩으로 계산하며, 시간축에 대해 평균한 값을 사용한다. (표기 : BERT(N))  

     - 각각의 chunk C에 대해 key-value DB로부터 BERT 임베딩에 대한 L2거리를 계산해 가장 가까운 k개의 이웃을 받아온다

    - 모델은 받아온 value들 RET(𝐶) = ( [𝑁 1 , 𝐹1 ], . . . , [𝑁 𝑘 , 𝐹𝑘 ])을 활용하게 되고, 여기서 이웃 chunk(𝑁) 를 활용하는 것과 그의 이어지는 텍스트(𝐹)를 활용하는 것은 모두 의미가 있는 것으로 나타났다.

     - 언어모델 학습에 있어 해당 chunk를 받아오는 것은 학습 인과관계가 깨질 수 있기 때문에, 같은 시퀀스에서 이웃을 받아오는 것은 방지하였다.

     - 실험을 해 보았을 때, 받아오는 이웃 문서는 일반적으로 주어진 아티클에서 2~3단계의 링크로 도달할 수 있는 것으로 나타났다. 이는 랜덤한 문서는 5단계 이상의 링크가 소요되는 것에 대비되는 결과이다.   

 

Step 3. 받아온 이웃 문서의 정보를 학습 시퀀스에 통합하기

  > Encoder-decoder 구조는 chunk별로 받아온 정보를 통합하여 모델의 예측에 사용한다. 

  > Details :

     - 토크나이저로 정수 인코딩된 인풋 시퀀스 토큰 X = (x1, ... , xn) 를 인풋으로 받아

     - 이를 각각 m=n/의 길이를 가지는 시퀀스 chunk ℓ개로 쪼갠다. (C1, ... , C)

       ( 즉, C1 = (x1, ... , xm) , ... , Cℓ = (x_{n-m+1} , ... , xn) )

     - 이후 각각의 chunk Cu는 데이터베이스로부터 받아온 k개의 이웃을 활용하여 강화된다.

     - 모델은 토큰의 likelihood를 계산하는데, 이때 기존의 토큰과 받아온 이웃의 정보를 고려하게 된다. 

     - 논문에서는 인풋 시퀀스 길이를 2,048 , 각 chunk의 길이는 64로 지정하였다. (n=2048 , m = 64)  

 


Architecture

RETRO는 기본적으로 encoder-decoder 트랜스포머 구조를 가지며, cross-attention을 통해 받아온 문서의 정보를 통합한다.

 

전체 구조 :

아래 그림은 인풋 문장을 각각 4개 토큰을 가지는 Chunk로 나누고,

각각의 Chunk에 대해 5토큰으로 이루어진 이웃 문서를 2개를 받아와 CCA를 통해 augment하여 결과를 내는 도식이다.  

 

 

Step 1. Chunk에 대해 받아온 토큰들 RET(C)는 인코더 트랜스포머에 입력되어 벡터화된 이웃 집합 E로 인코딩된다. 

Step 2. 중간 activation을 H로 표기할 때, 디코더는 RETRO 블록 RETRO(H,E)와 일반적인 트랜스포머 블록 LM(H)를 교차하여 배치해 구성한다. 이때 어떤 위치에서 RETRO 블록을 사용할지는 하이퍼파라미터 P ⊆ [1 , L]를 통해 조절한다. 

Step 3. 이러한 블록들은 아래의 세 가지의 residual operator을 조합하여 구성된다 : 

       1) 일반적인 시퀀스 레벨의 self-attention 레이어 - ATTN

       2) Chunked cross-attention layer - CCA ( · , E) : retrieval encoder에서 나온 정보를 통합하는 역할 수행

       3) Fully-connected layer - FC layer 

RETRO 알고리즘 전체 구조

Hyperparam : 𝑃 와 𝑃_enc - 디코더/인코더에서 cross-attention을 사용하는 레이어의 인덱스

Hyperparam : 𝐿 과 𝐿enc - 디코더의 레이어 수와 인코더의 레이어 수 의미

Input : 𝑋 ∈ 𝕍^𝑛 - 토큰의 시퀀스 , (Ret(𝐶𝑢)) : 받아온 neighbour

Output : 𝑂 ∈ ℝ𝑛× |𝕍 | - 아웃풋 로짓

 

 

 

아키텍처 디테일 :

DETAIL 1. Retrieval neighbor 인코딩하기

- 각각의 chunk에 대해 k개의 이웃 Ret(𝐶𝑢)는 bi-directional transformer 인코더를 통해 𝐸𝑢 로 인코딩된다.

   : 𝐸𝑢^𝑗  =  Encoder(Ret(𝐶𝑢)^𝑗 , 𝐻𝑢) ∈ ℝ𝑟×𝑑'

     > j는 각각의 이웃의 인덱스를 나타냄

     > 이웃 문서 내 r개 토큰에 대해 d' 차원의 벡터를 만들어냄

- retrieval encoder은 non-casual한 트랜스포머로, Cross-attention layer을 통해 𝐻𝑢에 조건부로 계산됨.

- 이 과정으로 인해 retrieval encoder이 생성한 representation은 받아온 chunk에 따라 미분 가능한 방식으로 달라질 수 있음

- 즉, u번째 chunk의 j번째 이웃에 대한 인코딩은 min(P)레이어에서 attended된 activation 𝐻𝑢에 의존함

- 모든 chunk에 대한 이웃은 병렬적으로 계산되어 아래의 전체 인코딩된 텐서를 계산해냄

- 여기서 𝐸𝑢 ∈ ℝ𝑘×𝑟×𝑑' 를 chunk u에 대해 인코딩된 이웃이라고 표기

- 즉, 최종적으로 l개 chunk에 대해 k개 이웃의 r개 토큰들에 대한 d' 차원의 representation을 얻어냄

 

 

DETAIL 2. Chunked cross-attention

 

 

- CCA 연산을 하기 위해, 먼저 중간 activation인 𝐻 ∈ ℝ𝑛×d 을 l-1개의 attending chunk로 분할한다.  

- 𝐻𝑢+와 𝐸𝑢 사이의 cross-attention을 계산하는데, 이 작업은 시간축과 neighbor 축에 대해 동시에 이루어지기 때문에 𝐸𝑢에 있는 이웃과 시간 축을 통합하여 계산한다. 물론 여기서 데이터 chunk 및 chunk 내 토큰 순서에 대해 정렬이 맞아야 하기 때문에 relative positional encoding을 적용하였다. 

- 최종적으로는 per-chunk cross-attention의 결과로 얻게 되는 (mxd)차원의 (l-1)개 아웃풋을 concat하고, 모자란 부분은 pad하여 아웃풋 차원을 맞추었다.  

 

- 수식적으로는 각각의 chunk Cu와 그 안의 각 토큰 𝑖 ∈ [1, 𝑚] 에 대해 Chunked Cross Attention은 다음과 같이 계산한다 :

  > 여기서 CA는 인코딩된 이웃에 대해 수행되는 cross-attention residual operator이다. 

  > 위 식에서 softmax는 2번째 차원에 대해 수행되며 multi-head cross-attention을 적용하고, softmax에 positional encoding을 더하여 사용한다. 

  > 첫 (m-1)개의 토큰은 어떠한 이웃에도 접근할 수 없기 때문에 이들에 대해서는 CCA의 결과물은 원래의 representation을 사용한다 (identity setting)

  > 마지막 토큰의 경우에는 그림에는 표기되어 있지 않지만, 마지막 chunk의 이웃에 attend한다. 

  > 이러한 CCA 연산은 autoregressive한 특성을 가진다. 비록 각 chunk에서 attention은 바로 직전의 chunk의 이웃에 대해서만 이루어지지만, self-attention 연산을 통해 정보가 전파되기 때문에 지나친 계산 비용의 증가 없이 이전 chunk들에서 받아온 이웃의 정보에 접근할 수 있게 되는 특성을 가진다. 

 

DETAIL 3. Sampling

Sampling 시에는 u번째 chunk의 마지막에 SCaNN을 사용해 BERT 임베딩 Bert(𝐶𝑢)에 기반하여 이웃 Ret(𝐶𝑢)을 받아온다. 

이후 인코딩된 이웃 𝐸𝑢 = Encoder(Ret(𝐶𝑢))은 다음 chunk를 생성할 때에 조건부로 사용된다. 

이 과정은 incremental하게 수행될 수 있으며, 샘플링에 드는 비용은 일반적은 Transformer와 마찬가지로 샘플링된 시퀀스의 크기에 quadratic하게 증가한다. 여기에 retrieval에 따른 비용은 chunk의 개수는 l에 대해 선형적으로 증가하기 때문에, 실질적으로는 무시할 수 있는 숫자가 된다. 

 

 


Experimental Results

(language modeling 관련 실험 결과는 정리 생략)

 

RETRO를 사용하여 문장을 샘플링할 때, 주어진 프롬프트에 대해 더 일관된 문장을 생성하였다. 

위의 예시에서 프롬프트로 <비버는 강가에 사는 흥미로운 동물이다. 그들은 >이라는 문장을 주었을 때,

RETRO를 사용하지 않은 1열의 결과에서는 생성의 결과가 여러 주제로 발산하는 것을 볼 수 있다.

(개구리, 골든리트리버 등등 언어모델이 알고 있는 다양한 정보에 대해 나열함)

 

RETRO를 사용한 2열의 결과에서는 주어진 프롬프트와 관련이 있는 문장을 지속해서 생성한다는 것을 볼 수 있다. 

 


의견🦊

세상에 있는 수많은 코퍼스를 <언어모델의 파라미터 규모>에 의존하여 모든 지식을 학습시키려던 기존의 시도(GPT-3, Gopher, Jurrasic-1 등)는 아키텍처 및 하드웨어 엔지니어링 관점에서 흥미로운 부분이 많았고, 괄목할 성적을 보였지만 한편으로는 불편한 구석이 있었다. 언어모델이 무엇을 학습했는지 알 수 없으며, 때때로 적당히 아는 척하는 말을 생성하는 등 신뢰성이 떨어졌기 때문이다.RETRO에서 사용하는 neighbor retrieval 및 CCA 모듈에서는 어떤 정보를 참고하여 토큰을 생성했는지 트래킹할 수 있다. 이는 모델이 이상한 말을 생성했을 때 어떤 정보를 사용했는지 확인해볼 수 있다는 점에서 모델의 설명 가능성이 높아 보인다. 뿐만 아니라 정보를 받아올 DB를 정제하여 양질의 정보만을 사용하거나 최신의 정보로 업데이트하는 방식으로 문장 생성의 퀄리티를 높일 수 있지 않을까 하는 생각도 들었다. 다운스트림 태스크 수행에 도움이 되는 정보를 제공할 수 있는 DB를 잘 구축하고, 정보를 활용할 수 있도록 CCA 모듈을 pre-finetuning한다면 대화모델, 오픈 도메인 질의응답, 멀티-홉 질의응답, 코드 생성 등 다양한 태스크에 활용하고 성능 개선을 할 수 있지 않을까 하는 기대감이 든다.