일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |
- domain-adapted pre-training
- Fine-tuning
- Text-to-Image
- langchain
- ViT
- llm
- PEFT
- backbone
- llm tuning
- prompt
- instruct pre-training
- CPT
- continued pre-train (cpt)
- Mac
- transformer
- cross-document attention
- full fine-tuning (fft)
- lora+
- instruction tuning
- gemma2
- sfttrainer
- continual pre-training
- diffusion
- ubuntu
- instruct-pt
- error: mkl-service + intel(r)
- continued pre-training
- instruction tuning (it)
- glibcxx
- Lora
- Today
- Total
꾸준하게
[논문리뷰] Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum 본문
[논문리뷰] Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
yeonsikc 2024. 9. 16. 16:30arXiv 2024 [paper]
Hadi Pouransari, Chun-Liang Li, Jen-Hao Rick Chang, Pavan Kumar Anasosalu Vasu, Cem Koc, Vaishaal Shankar, Oncel Tuzel
Apple
submitted on 21 May 2024
Abstract
본 논문에서는 고정된 길이의 토큰 시퀀스로 구성된 데이터셋을 사용하는 기존 방식의 한계를 지적하며, 이를 해결하기 위해 '데이터셋 분해'라는 새로운 가변 시퀀스 길이 학습 기법을 제안하였다. 기존 방법은 다양한 길이의 문서를 무작위로 연결한 후 고정된 길이로 잘라내는 방식인데, 이로 인해 시퀀스 내에서 문서 간 attention이 발생하는 비효율성과 계산 비용이 증가하는 문제가 있다. 특히 긴 시퀀스를 학습하는 것은 주의 계산 비용이 기하급수적으로 증가하기 때문에 매우 비효율적이다.
이에 비해, 제안된 '데이터셋 분해' 방식은 동일한 크기의 시퀀스를 가진 버킷(bucket)으로 데이터셋을 분해하고, 각 버킷에서 동시에 샘플링하여 가변 시퀀스 길이와 배치 크기를 사용해 학습한다. 이를 통해 attention 비용이 실제 문서 길이에 비례하게 되어 학습 시간이 크게 단축된다. 8k 컨텍스트 길이의 10억 개 파라미터 모델을 2k 컨텍스트 길이 모델과 동일한 비용으로 학습할 수 있었으며, 웹 스케일 코퍼스 실험에서 제안된 방법이 기존 방법보다 3배 빠르게 목표 정확도에 도달하는 등 성능이 크게 향상되었다. 이 방법은 긴 시퀀스에서의 효율적인 pre-train을 가능하게 하고, 데이터셋 크기에 따라 효과적으로 확장된다.
Introduction
LLM 학습을 위해선 매우 큰 규모의 데이터 셋이 필요하며 이는 메신저처럼 짧거나 책처럼 긴 텍스트가 포함된다. 배치 내 시퀀스 킬이는 한정적이기 때문에 매우 긴 텍스트를 여러 텍스트로 청킹하여 학습하게 된다. 최근 연구에서는 각기 다른 길이를 가진 문서들을 청킹하고 concat(concat-and-chunk)하는 전략을 취하고 있다. 이는 학습 전에 문서를 청킹 후 랜덤하게 concat하게 된다. 연이어 이어진 문서들은 문서 경계를 파악하기 위해 <EOT>라는 스페셜 토큰으로 구분하게 된다. 그 후, 각 모델마다 정해진 고정 시퀀스 길이에 맞춰 학습이 이루어진다. (Llama-1 : 2048, Lama-2 : 4096)
이러한 concat-and-chunk 전략은 여러 문제가 존재한다.
- 랜덤하게 concat하게 된다면 모델은 전혀 상관없는 문서의 컨텍스트를 참고하여 다음 토큰을 예측하게 된다.
- 문서 간 cross-attention은 학습을 촉진하지 않는 쓸데없는 토큰에 불필요한 연산을 수행토록 한다.
- 문서가 타겟 시퀀스 길이보다 짧아도 두 시퀀스 사이의 경계에 있다면 두 개의 청크로 분리될 수 있다
본 논문에서는 위의 여러 문제들을 해결하기 위해 데이터를 길이에 따라 분해하고 가변 시퀀스 길이(variable sequence length; VSL)로 학습하는 새로운 접근 방식인 dataset decomposition (DD)라는 기술을 소개한다. DD는 가변의 길이를 가진 문서들이 있는 데이터셋(dataset)을 각기 고정된 길이를 가진 데이터셋들(datasets)/버켓(bucket) $\cup _i \mathcal{D}_i$로 통합한다. 각 버켓 $\mathcal{D}_i$는 $2^i$ 길이를 갖는 시퀀스를 가지며 각각은 고유한 문서에서 추출된다. (즉 한 버켓에는 하나의 문서에서 추출된 고정된 길이의 시퀀스들만 담김)
VSL 환경에서 학습하는 동안, optimizer 프로세스의 모든 step에서 (커리큘럼에 따라) i를 샘플링하여 버켓 $\mathcal{D}_i$에서 $b/2^i$ 시퀀스로 배치를 형성하고, 샘플링된 $\mathcal{D}_i$에 관계없이 배치의 총 토큰 수를 =b로 일정하게 유지한다.
이러한 전략은 앞서 말한 concat-and-chunk의 이슈를 해결하면서 여러 이점을 제공한다.
- DD는 간단하고, 데이터 준비 단계에서 계산 오버헤드가 거의 없어 큰 데이터셋에도 적용할 수 있다.
- 모든 버켓의 각 시퀀스에 있는 토큰은 구조상 동일한 문서에서 가져온 것이므로 cross-document attention를 피할 수 있다.
- VSL 학습 전략은 학습 시간을 단축한다. $\rightarrow$ 작은 $i$를 가진 $\mathcal{D}_i$에서 샘플링 되었을 때, 한 번의 optimizer step 시간이 단축됨(attention의 quadratic complexity 때문)
Contributions
- 가변길이의 데이터셋을 고정된 길이의 시퀀스 데이터셋 버켓으로 분해하는 방법을 제안하였으며 이는 효과적이고 robust한 학습을 가능케 함
- 다양한 모델, 데이터셋, 평가에서 대규모 실험을 수행하였으며 데이터 효율성(2배 이상)과 컴퓨팅 효율성(11%~45%)이 크게 향상되어(Fig. 1 참고) LLM 사전 학습 가속화가 최대 3배(기준선 대비 특정 정확도에 도달하는 시간) 빨라진 것으로 확인 하였음
- 자연어와 긴 문맥의 작업에 대한 pre-train 시 시퀀스 길이 분포와 혼합의 중요성을 연구하였으며, 시퀀스 길이를 합성적으로 변경하기 위한 연결 및 청킹(concat and chunking) 작업의 효과를 보여줌
Method
Dataset decomposition
토큰화된 문서 $\{d_1, d_2, ... , d_n\}$의 데이터셋 $\mathcal{D}$가 주어졌을 때, dataset decomposition (DD)의 목표는 다음 조건에 맞는 여러 버켓들로 $\mathcal{D}$을 재구성하는 것이다. 이는 결과적으로 각 버켓에는 하나의 문서의 시퀀스만 담겨있으며, cross-document attention 문제를 방지한다. 추가적으로, 버켓 $\mathcal{D}_i$의 시퀀스들은 같은 시퀀스 길이$\_i$를 가지므로 효율적인 batch가 가능하다.
- 각 버켓 $\mathcal{D}_i$는 $l_i$ 길이를 가진 토큰들의 시퀀스로 구성한다.
- 각 시퀀스 $s \in \mathcal{D}_i$는 하나의 문서 $d \in \mathcal{D}$의 서브시퀀스이다.
- $\mathcal{D}$의 각 토큰은 정확히 하나의 $\mathcal{D}_i$에서 나타난다.
위 DD는 유니크하지 않다. 저자는 원본 문서 시쿼스 길이 분포를 최적으로 유지하면서 효율적인 batch pre-train을 가능하게 하기 위해 $l_i = 2^i$로 하는 특별한 decomposition을 제안하였다. 저자는 문서 수준에서 분해 방식을 적용하기 때문에 기존 데이터 준비 파이프라인(모델 학습 전 단계)에 매우 쉽게 통합할 수 있으며 대규모 데이터셋에도 확장할 수 있다고 한다. 저자는 길이가 $l$($l = 2^{i_1} + 2^{i_2} + ... + 2^{i_k}$ : binary decomposition)로 토큰화된 문서 $d \in \mathcal{D}$의 경우, $d$를 $2^{i_1}, ..., 2^{i_k}$ 길이를 가진 인접 시퀀스$s_1, ..., s_k$로 각각 나누었다. 길이 $2^{i_j}$의 각 시퀀스 $s_j$는 버켓 $D_{ij}$로 할당된다.
저자가 제안한 DD를 사용하면 각 버킷 $D_i$에는 원본 문서 d에서 추출한 시퀀스가 포함되며, d의 길이는 최소 $2^i$이다. (토큰 수가 8개($2^3$) 이상 16개($2^4)$ 미만이라면 $D_3$에 소속됨) 아래 Fig. 3-b.에서는 여러 버켓에 대한 RefinedWeb 데이터 셋 토큰의 분포를 보여준며 D9(길이 512의 시퀀스)가 가장 많은 토큰을 보유하고 있음을 알 수 있다. 버켓 $D_i$에 있는 대부분의 토큰은 길이가 $2^i \leq l < 2^{i+1}$인 문서에서 추출되며, 일부 토큰은 길이가 $2^{i+1}$ 이상인 문서에서 롤오버된다. 이는 원본 문서 길이를 유지하는데 있어 이 방법이 특히 희소성이 있는 긴 문서에 효과적이라는 것을 보여준다.
Variable sequence length training
한 번의 optimizer step 당 사용되는 토큰 수를 타겟 배치사이즈 $b$라고 가정하자. variable sequence length (VSL) 학습에서, 매 step마다 버켓 $\mathcal{D}_i$로부터 $b/2^i$ 시퀀스의 샘플$i$를 추출한다. $\mathcal{D}_i$가 길이 $2^i$의 시퀀스로 구성되어 있기에, optimizer step당 사용되는 토큰 수는 어떤 $i$를 선택하는지와는 독립적으로 여전히 $b$로 남게된다. VSL 알고리즘을 사용한 LLM 학습은 여러 이점을 가져온다.
- optimizer step당 사용되는 토큰 수는 변하지 않는다.
- 고정된 $b$(스텝 당 토큰)한번의 optimizer step(forward+backward)를 완료하는 시간이 attention의 quadratic cost 때문에 시퀀스 길이에 따라 달라진다. VSL 학습을 하게되면 매 optimizer step당 cost는 버켓 $\mathcal{D}_i$에 의존된다. 따라서, 보다 고비용의 step(긴 시퀀스)는 저비용의 step(짧은 시퀀스)으로 보상된다. (긴 시퀀스 step 할 때 시간이 오래걸려도, 고정적으로 짧은 시퀀스 버켓도 있기 때문에 짧은 시퀀스 step 시 그만큼 시간이 짧게걸린다.)
- VSL의 샘플링 component를 사용하면 시퀀스 길이에 따른 다양한 커리큘럼이 가능하다.
Experiments and analysis
저자는 다양한 규모의 데이터셋에 대해 LLM을 tuning하는 실험을 진행하였다.
- Dataset : RefinedWeb (Common Crawl의 필터링 버전)
- 525B tokens
- Tokenizer : EleutherAI/gpt-neox tokenizer
- Vocab size : 50,432
- Model, Train Code : OpenLM
- # of params : 1B
- context length : 8k
- Positional encoding : Rotary Positional Embedding (RoPE)
- 일반적인 base frequency $f_b = 10,000$이지만, 최근에 pre-trained model을 더 긴 시퀀스로 fine-tuning 할 때에 $f_b$를 높이는게 효과적이라는 연구 결과가 나왔기에 $f_b = 100,000$까지 늘려서 실험을 진행하였다. 저자는 from scratch에서도 $f_b$를 높이는 것이 효과적이었다고 주장하였다.
- Evaluation
- 총 14개의 언어모델 벤치마크의 평균을 계산
- Commonsense Reasoning (CSR) : PIQA-0-shot, COPA-0-shot, OpenBookQA-10- shots
- Language Understanding (LU) : Lambada-OpenAI, Hellaswag-0-shot, Winograd-3- shots, WinoGrande-5-shots
- Reading Comprehension (RC) : SQuAD-3-shots, BoolQ-0-shot, CoQA-0-shot
- World Knowledge (WK) : Jeopardy-3-shots, ArcEasy-3-shots, ArcChallenge-3- shots, WikiDataQA-3-shots
- 긴 컨텍스트 태스크 평가를 위해 다음의 real-world 벤치마크를 추가하였다.
- Multi-Document Question Answering (MDQA)
- TOEFL
- QuALITY
- 총 14개의 언어모델 벤치마크의 평균을 계산
Training efficiency
먼저, VSL 학습이 throughput을 향상시키는 지에 대한 실험이다.
- 모델 사이즈 : OpenLM-1B/3B/7B
- 컨텍스트 길이 : $2^6 ~ 2^{13}$
- 총 step 수 : 100
- global batchsize b : $8 \times 8192$
- GPU : H100 $\times$ 8 (single node)
위 Fig. 4.는 일반적인 베이스라인(concat-and-chunk)의 평균 step 시간을 측정한 결과다. 타겟 컨텍스트 길이(2048 ~ 8192)에 따라 1B 모델 기준 243ms ~ 304ms 시간이 소요된다. 이에반해, VSL은 위 Table 1.에서 볼 수 있듯이 같은 성능(Regular Avg. : 54.0)을 베이스라인이 타겟 컨텍스트 길이 2048일때와 거의 같은 시간(243ms)인 244ms를 기록한 것을 알 수 있다.
Sequence length bias
다음으로, 사전학습용 데이터 시퀀스 길이와 모델 성능 간의 관계성에 대한 실험이다. 전체 토큰 수 $2^{34}$, 시퀀스 길이 $2^i$일때의 단일 버켓 $\mathcal{D}_i$을 데이터 셋을 사용한다고 가정하자. 이때, optimization step당 토큰 수는 시퀀스 길이와 관계없이 256으로 고정되어있다. 모든 실험 간에 학습 하이퍼파라미터는 동일하게 적용하였다.(결론적으로, Appendix C.2에서 하이퍼파라미터 선택에 의존되지 않다는 것을 확인할 수 있다.) 통계적 오차를 줄이기 위해 각 모델은 from scratch로 각기 다른 seed에서 두 번씩 학습한 뒤 각 벤치마크에 평균낸 결과를 사용하였다(관측 표준편차는 regular benchmarks에서 ~0.3, multi-document QA에서 ~1.6을 기록하였다.).
위 실험에서 사전 학습에서의 시퀀스 길이가 모델 성능에 큰 영향을 끼친다는 것을 알 수 있다. 특히, commonsense reasoning, language understanding, world knowledge 벤치마크에서 $\cap$ 형태인 것을 확인할 수 있다. 이는 학습셋과 평가셋의 시퀀스 길이의 분포가 연관된 결과일 수 있기 때문에 평가셋의 시퀀스 길이 분포를 시각화 하였다(위 Fig. 5-b.). 그 결과를 보니, 평가셋의 시퀀스 길이 분포와 시퀀스 길이 별 모델 성능간에 유의한 상관관계가 있는 것으로 드러났다. 특히, 시퀀스 길이 분포가 2k, 4k, 6k인 MDQA 데이터셋의 경우, 그보다 짧은 시퀀스 길이로 학습할 경우 테스트 정확도가 0으로 드러났다. 이를통해 문서의 내용이 시퀀스 길이에 달라질 수 있기 때문에 저자는 이를 Sequence length bias라고 지칭하였다. 저자는 기존의 $\mathcal{D}_{13}$과 $\mathcal{D}_7$을 각각 $\mathcal{D}_{13 \rightarrow 10}$(하위 8개의 서브 시퀀스로 청킹 후 global shuffle을 통해 $2^{10}$길이의 시퀀스를 사용), $\mathcal{D}_{7 \rightarrow 10}$ ($\mathcal{D}_7$에서 랜덤으로 8개의 시퀀스를 concat하여 시퀀스 길이 $2^{10}$를 사용)로 chunking, concat하여 추가 실험을 진행하였다. 위 그림 Fig. 5-c.를 보면 $\mathcal{D}_{13 \rightarrow 10}$일때 $\mathcal{D}_{13}$ 대비 2.6points 성능 향상이 있는 것을 보았을 때, 사실상 $\mathcal{D}_{13}$의 내용은 같으므로 순전히 시퀀스 길이가 모델 성능에 영향을 끼친다는 것을 의미한다고 볼 수 있다. 또한, $\mathcal{D}_{13}$보다 $\mathcal{D}_{13 \rightarrow 10}$이 0.9points 낮은 성능을 보였는데, 이는 벤치마크에 긴 문장보다 짧은 문장이 더 있었다는 것을 의미한다. 마지막으로, concat은 chunking과 달리 시퀀스 길이 상관관계에 영향을 끼치지 않은 것을 알 수 있다. (저자는 이는 $\mathcal{D}_{7 \rightarrow 10}$ $\mathcal{D}_{10}$과 점수가 동일하지만 여전히 $\mathcal{D}_{10}$보다 나쁘다는 것을 의미한다고 하였다.)
Data mixture
위 Table 1.은 시퀀스 길이 분포에 따른 모델 성능을 나타낸다. 시퀀스 길이가 짧은 mixture는 긴 문맥 이해가 필요한 MDQA에서 성능이 좋지 않으며 평균 문맥이 클수록 RC 능력과도 양의 상관관계가 있다(이는 Fig. 5-a.의 내용과 일치하지만 학습 단계가 길어진다는 손해가 있다.).
또한, 앞선 실험(Fig. 5-c.)에서 가장 좋은 시퀀스 길이(1024; $\mathcal{D}_{10}$)만 사용하여 학습하는 1k-only는 regular eval 특히, language understanding, world knowledge 태스크에서 좋은 성능을 보이지만 긴 맥락 과제에서는 성능이 떨어진다. 마지막으로, "Nature" Mixture가 regular 및 MDQA 모두에서 거의 최적의 성능을 얻는 것을 관측하여, 저자가 제안한 접근법이 large dataset으로의 확장됨을 입증하였다.
Length-based curriculum
학습 과정에서 다양한 하이퍼파라미터(lr, lr_scheduler, weight decay 등)가 존재하기에 length-based curriculum은 잠재적 편향을 초래할 수 있다. 예를들어, 학습이 끝나갈 때 쯤에만 긴 시퀀스를 보게되면 학습 속도가 너무 작게되는 문제가 발생할 수 있다. 이를위해 cyclic learning rate schedule과 유사하게 커리큘럼을 주기적으로 적용하는 주기적 커리큘럼도 존재한다. 실험 결과, 주기적인 'Grow-P2' 커리큘럼은 다양한 지표에서 최적에 가깝게 나타났다. 어느 연구에서 긴 시퀀스가 특히 학습 초기에 극단적인 경사도 편차를 유발하여 불안정성을 초래한다는 사실을 발견하였다. 또한, 커리큘럼을 사용하여 제안한 접근 방식이 더 안정적인 학습 역학을 가져와 더 큰 배치 크기와 학습 속도로 더 효율적인 학습이 가능하다는 것을 관찰하였다고 한다. (Appendix E 참고)
limitations
저자는 끝으로, 타겟 시퀀스의 길이가 클 때만 학습 속도의 이점을 볼 수 있을 것이라고 언급하였다. 그 이유는, 주로 병목이 발생하는 부분은 attention 연산 부분인데 타겟 시퀀스 길이가 짧다면 이에 대한 cost 개선이 크게 이루어지지 않기 때문이다.