Abstract
- Problem
- LLM의 놀라운 능력에도 불구, 부정확한 정보를 포함한 응답 多 → RAG가 그런 이슈를 감소시킴
- (but 기존의 RAG도) 무차별적으로 passage를 검색하고 포함하는 문제 → LM의 다양성↓, 도움이 되지 않는 답변을 생성
- Solution
- Self-RAG: 검색과 자아성찰(self-reflection)을 통해 LM의 퀄리티와 사실성을 강화하는 방법
- 필요에 따라(on-demand) 구절을 적응적으로 검색하는 단일한 임의의 LM을 학습 → reflection token이라는 특수 토큰을 이용하여 검색된 구절과 자체적인 생성물을 생성 및 자아성찰함
- reflection token을 생성하는 것은 추론 단계에서 LM을 통제 가능하게 하고, 다양한 작업 요구사항에 맞게 동작을 조정 가능함
- Contribution
- Self-RAG는 SOTA LLM 및 RAG 모델을 능가함
- 구체적으로는 Open-domain QA의 추론 및 사실 확인 작업에서 ChatGPT와 retrieval-augmented Llama2-chat을 능가하며, 이러한 모델들에 비해 긴 형식의 생성물에 대한 사실성과 인용 정확도에 상당한 이점을 보임
1. Introduction
- Problem
- SOTA LLM은 모델과 데이터 규모가 증가했음에도 불구, factual error로 인해 어려움을 겪고 있음
- RAG는 관련된 구절을 검색하는 것으로 LLM의 input을 증강 → 지식 집약적 작업에서 factual error↓
- 하지만, RAG와 같은 방법은 사실적 근거가 도움이 되는지와 상관없이 무차별적으로 구절을 검색 → LLM 다양성 저해하거나 주제에서 벗어난 구절을 가져와서 저품질 생성으로 이어질 수 있음
- 모델이 제공된 구절의 사실을 활용하고 따르도록 명시적으로 학습되지 않았기 때문에 LM의 출력이 검색된 관련 구절과 일치한다고 보장할 수 없음
- Solution
- Self-RAG: on-demand 검색과 자아성찰(self-reflection)을 통해 다양성을 해치지 않으면서도, LLM의 생성 품질을 향상시키는 방법
- task output과 간헐적인 특수 토큰(즉, reflection token)을 생성하여 task input이 주어지면 자체 생성 과정을 성찰(reflect)하는 방법을 학습하도록 임의의 LM을 E2E 방식으로 학습시킴
- reflection token은 성찰의 필요성과 생성 품질을 나타내기 위해 각각 retrieval token과 critique token으로 분류됨
- 프로세스
- input prompt와 그에 따른 generation이 주어지면, 검색된 구절로 생성을 증강하는 것이 도움이 되는지 판단 만약, 도움이 된다고 판단되면 retrieve model을 호출하는 retrieve token을 출력 (retrieve)
- 검색된 여러 구절을 동시에 처리하여 관련성을 평가 → 해당 output을 생성 (generate)
- critique token을 생성하여 자신의 output을 비평 → 사실성과 전반적인 품질 측면에서 가장 좋은 것을 선택 (critique)
- 이런 프로세스는 검색 필요성과 상관없이 생성을 위해 일정 수의 문서를 지속적으로 검색하고 생성 품질을 재검토하지 않는 기존의 RAG와 다름
- 또한 Self-RAG는 각 부분에 대한 인용문과 함께 output이 검색된 구절에 의해 지지되었는지에 대한 자체 평가를 제공 → 사실 검증을 쉽게 함
- 확장된 모델 어휘에서 다음 토큰의 예측으로 통합시킴으로써 reflection token이 포함된 텍스트를 생성하도록 임의의 LM을 학습시킴
- reflection token과 검색된 구절이 삽입된 다양한 텍스트 모음으로 generator model을 학습시킴
- 강화 학습에 사용되는 보상 모델에서 영감을 얻은 reflection token은 학습된 critic model에 의해 원본 말뭉치에 오프라인으로 삽입됨 → 따라서 학습 중에 critic 모델을 호스트할 필요가 없으므로 overhead가 줄어듦
- critic model은 부분적으로 proprietary LM(즉, GPT-4)에 프롬프팅하여 수집된 input, output 및 해당 reflection token의 데이터셋으로 supervised 학습됨
- Contributions
- 또한 Self-RAG는 reflection token 예측으로 정의되는 하드 또는 소프트 제약 조건을 충족하는 customizable 디코딩 알고리즘을 가능하게 함
- 특히, inference 단계의 알고리즘을 통해 (1) 다양한 다운스트림 애플리케이션에 맞게 검색 빈도를 유연하게 조정하고 (2) reflection token 확률의 가중 선형 합을 세그먼트 점수로 사용하는 세그먼트 수준 beam search를 통해 reflection token을 활용하여 사용자 선호도에 맞게 모델 동작을 커스텀할 수 있음
- 추론 및 long-form 생성을 포함한 6가지 작업에 대한 경험적 결과에 따르면 Self-RAG는 더 많은 매개변수와 더 높은 인용 정확도로 널리 채택된 RAG 접근 방식을 가진 pre-trained 및 instruction-tuned LLM을 훨씬 능가하는 것으로 나타났음
- 특히 Self-RAG는 네 가지 작업에서 검색 증강 ChatGPT보다 성능이 뛰어나며, 모든 작업에서 Llama2-chat 및 Alpaca보다 성능이 뛰어남
- 우리의 분석은 전반적인 성능 개선과 test-time model customization(ex. 인용 예측과 완성도 간의 균형 조정)을 위해 reflection token을 사용한 학습 및 추론의 효과를 입증함
3. Self-RAG: Learning to Retrieve, Generate and Critique
- Self-RAG는 검색과 자기 성찰을 통해 LLM의 품질과 사실성을 향상시키면서도 LLM 고유의 창의성과 다양성을 희생하지 않는 프레임워크임
- 우리의 E2E 학습을 통해 LM은 필요한 경우 검색된 구절의 정보를 바탕으로 텍스트를 생성하고, 특수 토큰 생성 학습을 통해 결과물을 비평할 수 있습니다.
- 이러한 reflection token은 검색의 필요성을 알리거나 output의 relevance, support, completeness를 확정함
- 반면 일반적인 RAG 접근 방식은 인용된 출처의 완전한 지원을 보장하지 않고 무차별적으로 구절을 검색함
3.1. Problem Formalization and Overview
- $\mathcal{M}$: Generator model
- $\mathcal{C}$: Critic model
- $\mathcal{R}$: Retriever model
- $d$: relevant passage, $d \in \mathbf{D}$
- $\mathbf{D}$: relevant passages
- $x$: input
- $y$: output
- $y=[y_1, …, y_T]$
- $y_t$: $t$번째 세그먼트의 토큰 시퀀스
- $y_t$에서 생성된 토큰은 원본 어휘의 텍스트와 reflection token을 포함함
이 논문에서는 한 문장을 segment로 취급함 (하지만 다른 어떤 segment unit도 적용 가능함)
Inference overview
- Input: $x$(Prompt How did US ~) and $y_{<t}$() / Output: $y_t$(US states got ~)
- 주어진 $(x,y_{<t})$로 Retrieve token**(retrieval token)**을 예측 # Retrieve 토큰은 검색이 필요한지 아닌지를 판단
- 만약 Retrieve token 이 ‘Yes’라면 # 검색이 필요하다면
- ▹Retrieve Retriever $\mathcal{R}$은 주어진 $(x, y_{t-1})$로 관련있는 텍스트 구절들($\mathbf{D}$)을 검색
- ▹Generate Model $\mathcal{M}$은 각 $d$에 대해 주어진 $x, d$로 IsRel token(critique token) 토큰을 예측하고 # 검색된 구절의 연관성을 평가 Model $\mathcal{M}$은 각 $d$에 대해 주어진 $x, d, y_{<t}$로 $y_t$를 예측
- ▹Critique Model $\mathcal{M}$은 각 $d$에 대해 주어진 $x, y_t, d$로 IsSup token**(critique token)을 예측하고 # 응답의 정보가 검색된 구절에 의해 지지되는지 평가 Model $\mathcal{M}$은 각 $d$에 대해 주어진 $x, y_t, d$로 IsUse token(critique token)**을 예측 # 응답의 전반적인 유용성을 평가
- IsRel, IsSup, IsUse token에 기반하여 $y_t$를 rank
- 만약 Retrieve token이 ‘No’라면 # 검색이 필요하지 않다면
- ▹Generate Model $\mathcal{M}$은 주어진 $x$로 $y_t$를 예측(생성) # 기존의 LM이 하는 방식
- ▹Critique Model $\mathcal{M}$은 주어진 $x, y_t$로 IsUse token**(critique token)**을 예측
Training overview
- 확장된 모델 어휘(즉, 원본 어휘 + reflection tokens)에서 reflection token을 다음 토큰 예측으로 통합함으로써 Self-RAG는 임의의 LM이 reflection token이 포함된 텍스트를 생성할 수 있도록 함
- 구체적으로, retriever $\mathcal{R}$이 검색한 interleaving passage와 critic model $\mathcal{C}$가 예측한 reflection token으로 큐레이션된 말뭉치를 generator model $\mathcal{M}$에 학습시킴
- (검색된 구절과 주어진 output)의 품질을 평가하기 위한 reflection token을 생성하도록 $\mathcal{C}$를 학습시킴 (Section 3.2.1)
- 오프라인에서 $\mathcal{C}$를 사용하여, output에 reflection token을 삽입하는 방식으로 학습 말뭉치를 업데이트함
- 그 후, inference 시점에는 $\mathcal{C}$에 의존하지 않고 $\mathcal{M}$이 스스로 reflection token을 생성할 수 있도록 기존의 LM objective(Section 3.2.2)를 사용하여 최종적인 $\mathcal{M}$을 학습함
3.2. Self-RAG Training
3.2.1. Training the Critic Model
- Data collection for critic model
- 각 segment의 reflection token에 대한 수동 annotation은 비쌈 → LLM으로 피드백을 받을 수 있음
- GPT-4가 reflection token을 생성하도록 프롬프팅하여 supervised data를 생성 → 내부의 $\mathcal{C}$에 그 지식을 증류함(distill)
- Table 1.과 같이 reflection token 그룹마다 고유한 정의와 입력이 있으므로, 각 그룹에 대해 서로 다른 명령 프롬프트를 사용함
- 아래는 Retrieve token 예시
-
- 수동 평가 결과 위와 같이 GPT-4의 reflection token 예측은 사람의 평가와 높은 일치도를 보임
- 각 타입별 4k-20k supervised training data를 수집하고, 이를 결합하여 $\mathcal{C}$용 학습 데이터를 형성함
- Critic learning
- $\mathcal{C}$를 pre-trained LM으로 초기화하고 수집한 학습 데이터 $\mathcal{D}_{critic}$를 이용하여 학습함, 다음과 같은 obejective:
- pre-trained LM으로는 generator LM과 같은 Llama 2-7B 사용
3.2.2. Training the Generator Model
- Data collection for generator
-
- input-output pair $(x,y)$가 주어지면, $\mathcal{R}$과 $\mathcal{C}$를 사용하여 원래의 output $y$를 증강하여 supervised data $\mathcal{D}_{gen}$을 만듦
- 위 그림처럼, inference 프로세스와 같은 방식으로 데이터를 증강함
- (augmented output w/ reflection token, original input) 쌍이 $\mathcal{D}_{gen}$에 추가됨Table 4: List of the training examples
- Generator learning
- $\mathcal{M}$을 수집한 학습 데이터 $\mathcal{D}_{critic}$를 이용하여 학습함, 다음과 같은 obejective:
- $\mathcal{C}$와 달리 $\mathcal{M}$은 output 뿐만 아니라 reflection token 또한 예측하도록 학습시킴
- 학습 중, loss 계산을 위해 검색된 텍스트 청크(Figure 2.에서 <p>와 </p>로 둘러싸인 부분)를 마스킹하고 원본 어휘 $\mathcal{V}$를 일련의 reflection token {critique, retrieve}로 확장함
- Connections to prior work on learning with critique
- PPO를 통한 RLHF 등의 최근 연구들은 학습 중에 추가적인 critique(feedback)을 포함함
- RLHF, PPO
- RLHF(Reinforcement Learning from Human Feedback)는 사람의 피드백을 통해 학습한 리워드 모델을 이용하여, 생성 모델이 생성한 답변에 대해서 좋은 답변의 경우 긍정적인 신호를, 올바르지 않은 답변의 경우는 부정적인 신호를 주면서 강화학습을 하게 되고, 이를 통해 좀 더 안전하고 유용한 답변을 할 수 있도록 모델을 fine-tuning
- PPO(Proximal Policy Optimization) 등의 강화학습 기반으로 학습을 진행하는 RLHF 방법론은 실제 학습을 하는 생성 모델(Actor) 뿐만 아니라 리워드 모델, Critic 모델, Reference 모델까지 총 4개의 모델이 필요함 → LLM 모델 자체의 크기가 이미 큰데, 4개의 모델이 필요하기 때문에 훨씬 더 많은 GPU 리소스 필요
- RLHF, PPO
- PPO가 별도의 학습 중에 reward 모델에 의존해야하는 것에 비해, Self-RAG는 critique를 오프라인으로 학습(별도로 학습)하고 이를 훈련 말뭉치($\mathcal{D}_{gen}$)에 직접 삽입하여 $\mathcal{M}$이 표준 LM objective로 학습함 → PPO 대비 학습 비용을 상당히 줄일 수 있음
- Self-RAG는 각 segment가 생성된 후 모델 자체의 예측(출력)을 평가하기 위해 special token을 생성하는 방법을 학습하여, inference 단계에서 soft re-ranking 메커니즘 또는 hard constraints를 가능하게 함
- PPO를 통한 RLHF 등의 최근 연구들은 학습 중에 추가적인 critique(feedback)을 포함함
3.3. Self-RAG Inference
- 모델 자체의 output을 평가하기 위해 reflection token을 생성하면 inference 단계에서 Self-RAG를 통제 가능하게 하여, 다양한 작업 요구사항에 맞게 동작을 조정할 수 있음
- 사실적 정확성이 요구되는 작업의 경우, 모델이 구절을 더 자주 검색하여 output이 evidence와 밀접하게 일치하도록 하는 것을 목표로 함
- 반대로, 에세이 작성과 같은 open-ended 작업의 경우, 검색 횟수를 줄이고 전반적인 창의성 또는 유용성 점수를 우선시하는 방향에 중점을 둠
- 이 섹션에서는 inference 과정에서 이러한 목표를 달성하기 위해 control을 강화하는 접근 방식에 대해 설명함
Adaptive retrieval with threshold
- Self-RAG는 Retrieve token을 예측하여 언제 구절을 검색할 지 동적으로 결정함
- 구체적으로, Retrieve token의 모든 output 토큰에 대해 정규화된 'Retrieve token=Yes token' 생성 확률이 지정된 임계값을 초과하면 검색을 실행함
Tree-decoding with critique tokens
- 각 segment step $t$에서 hard 또는 soft 조건에 따라 검색이 필요한 경우 $\mathcal{R}$은 $K$개의 구절을 검색하고, $\mathcal{M}$은 각 구절을 병렬로 처리하여 $K$개의 다른 continuation candidates(구절별 output)를 출력함
- segment-level beam search(w/ beam size=$B$)를 수행하여 각 타임스탬프 $t$에서 segment continuation을 얻고, 생성 종료 시 best sequence를 반환함
- 구절 $d$에 대한 각 segment $y_t$의 점수는 각 critic score $\mathcal{S}$로 업데이트됨 (critic score $\mathcal{S}$는 각 Critique token 타입의 정규화된 확률의 선형 가중치 합계)
- 각 Critique token 그룹 $G$(ex. IsRel token)에 대해 타임스탬프 $t$에서 해당 점수를 $s_t^G$로 표시하고, 다음과 같이 segment score를 계산함:
- $w^G$: inference 시점에 조정할 수 있는 하이퍼파라미터로 맞춤형 동작을 가능하게 함
- 예를 들어, 결과 $y$가 evidence에 의해 support되도록 하려면 IsSup 점수에 대한 가중치를 높게 설정하고 다른 측면에 대한 가중치는 상대적으로 낮출 수 있음
4. Experiments
4.1. Tasks and Datasets
Closed-set tasks
- PubHealth: fact verification dataset about public health
- ARC-Challenge: a multiple-choice reasoning dataset created from scientific exams
Short-form generations tasks
- Open-domain Question Answering (QA) datasets
- PopQA
- TriviaQA-unfiltered
4.2. Baselines
Baselines without retrievals
- Llama2 7B, 13B: strong publicly available pre-trained LLMs
- Alpaca 7B, 13B: instruction-tuned models (replication based on Llama2)
- Chat-GPT: models trained and reinforced using private data
- Llama2-chat 13B: models trained and reinforced using private data
- CoVE 65B: introduces iterative prompt engineering to improve the factuality of LLM generations
Baselines with retrievals
- Standard RAG baselines
- Llama2, Alpaca, Llama2-FT
- Ret-ChatGPT
- Ret-Llama2-chat
- perplexity.ai
- concurrent methods that are trained with retrieved text passages
- SAIL
- Toolformer
4.3. Experimental Settings
Training data and settings
- Open-Instruct processed data와 knowledge-intensive datasets에서 인스턴스를 샘플
- 150k instruction-output pairs
- Generator model $\mathcal{M}$: Llama2 7B, 13B
- Critic model $\mathcal{C}$: Llama2 7B
- Retriever model $\mathcal{R}$: Contriever-MS MARCO
Inference settings
- Weight term
- IsRel = 1.0
- IsSup = 1.0
- IsUse = 0.5
- Retrieval Threshold = 0.2 for most tasks
- vllm: speed up inference
- 각 segment level에서 beam width = 2
- token level generation에 대해 greedy decoding 적용
- Contriever-MS MARCO에서 기본값으로 top 5 document 사용
5. Results and Analysis
Comparison against baselines without retrieval
- Table 2.의 윗부분은 retrieval을 제외한 baseline
- Self-RAG는 모든 작업에서 fine-tuned LLM에 비해 상당한 성능을 보여줌 심지어 PubHealth, PopQA, biograph generations, ASQA(Rouge 및 MAUVE)에서 ChatGPT보다 성능이 뛰어남
Comparison against baselines with retrieval
- 많은 작업에서 non-proprietary LM-base 모델들 중 가장 우수한 성능