본문 바로가기

AI_Paper/NLP

REALM: Retrieval-Augmented Language Model Pre-Training 논문리뷰

[ Abstract ] 

REALM은 LM(language model) pretrain 과정에 retriever를 결합해, 모델이 Wikipedia 같은 대규모 말뭉치에서 직접 문서를 검색·활용할 수 있도록 설계되었습니다.

이를 통해 지식을 파라미터에만 저장하던 기존 한계를 극복하고, Open-domain QA에서 기존 모델 대비 4~16% 정확도 향상 

 

[ Introduction ] 

REALM은 pre-train 단계부터 retrieval을 통합해, 언어 모델이 외부 지식을 검색·활용하며 학습하도록 설계된 최초의 프레임워크로, Open-QA에서 성능과 해석 가능성을 동시에 개선했습니다. 특히 지식을 파라미터에만 의존하지 않고 필요할 때 검색해 쓰는 구조 덕분에, 모델의 지식 확장성과 업데이트 용이성이 크게 향상되었으며, 복잡한 질문에 대해서도 보다 신뢰할 수 있는 답변을 제공할 수 있게 되었습니다.

 

📌 모델 아키텍처 (Model Architecture)

1) Knowledge Retriever

  • 목표: p(z∣x), 즉 어떤 문서가 입력 xx와 가장 관련 있는지 확률로 계산
  • 구현:
    • 입력 xx와 문서 zz를 각각 BERT-style Transformer에 넣어 [CLS] 벡터를 뽑음
    • 선형 변환 후 **내적(inner product)**으로 유사도 측정
    • 소프트맥스 → 검색 확률 분포

2) Knowledge-Augmented Encoder

  • 목표: p(y∣z,x), 즉 “검색한 문서 + 입력”을 보고 답변 생성
  • 구현:
    • xxzz를 하나의 시퀀스로 합쳐 Transformer에 입력
    • cross-attention을 통해 문맥과 지식이 결합됨
  • Pre-training (MLM):
    • [MASK] 위치의 토큰을 맞히는 전통적인 MLM loss 사용
  • Fine-tuning (Open-QA):
    • 답변이 문서 zz 내 연속된 토큰(span)이라고 가정
    • Transformer 출력에서 시작 토큰(start) & 끝 토큰(end) 위치를 예측

REALM의 Training: 검색기(Retriever)는 어떻게 학습될까?

REALM은 사전학습과 파인튜닝 모두에서 “정답 yy가 나올 확률 p(y∣x)”을 최대화하는 방식으로 학습됩니다. 여기서 중요한 점은, 단순히 답을 잘 맞히는 것뿐 아니라 어떤 문서를 검색해야 도움이 되는지까지 학습된다는 점입니다.

실제 계산에서 모든 문서를 다 합산하기는 어렵기 때문에, Retriever가 뽑은 상위 k개 문서만을 대상으로 확률을 계산합니다. 이때 문서와 입력의 유사도는 내적(inner product) 기반으로 계산되고, 빠른 검색을 위해 MIPS(Maximum Inner Product Search) 알고리즘을 사용합니다.

훈련 중에는 Retriever의 파라미터가 계속 바뀌기 때문에, 미리 만들어둔 인덱스가 금방 낡아집니다. REALM은 이를 해결하기 위해 비동기 인덱스 갱신(asynchronous refresh) 방식을 사용합니다. 즉, 한쪽에서는 모델이 계속 학습하고, 다른 한쪽에서는 최신 파라미터로 문서 임베딩과 인덱스를 재구축해 교체하는 구조입니다.

재미있는 점은 Retriever가 학습되는 방식입니다. 모델이 특정 문서를 참조했을 때 정답 예측 확률이 올라가면 그 문서의 점수를 올려주고, 도움이 안 되면 점수를 깎아버립니다. 즉, 예측 정확도 향상에 기여하는 문서만 살아남는 식으로 Retriever가 점점 더 똑똑해지는 구조입니다.