RegMix: Data Mixture as Regression for Language Model Pre-training
Still following your human intuition to mix corpora from different sources for pre-training ๐ง ?
Everyone says that data mixture has a big impact on model performance, but how - and why๐ต๏ธ?
Did you know that web corpora are actually highly impactful for downstream tasks ๐?
Check out our preprint "RegMix: Data Mixture as Regression for Language Model Pre-training" ๐
๐ฌ In this paper, we've proposed an automatic data mixture method RegMix that achieves a 6.3% improvement over human selection on the widely used HellaSwag benchmark - and it only needs a 2% extra training FLOPs! ๐
- ๐ฎ Demo: https://huggingface.co./spaces/sail/RegMix
- ๐ Paper: https://huggingface.co./papers/2407.01492
- ๐ป Code: https://github.com/sail-sg/regmix
- ๐ Model & Data: https://huggingface.co./collections/sail/regmix-data-mixture-as-regression-6682b6caab37b9442877f0ce
Data Mixture is Important, but Challenging
๐ค๐ Large Language Models (LLMs) are powered by vast, diverse datasets from the Internet, including academic papers, books, and various online sources (Gao et al. 2020). As LLMs grow in scale and complexity, the composition of their training data becomes increasingly crucial. The importance of data mixture was recognized early on by the creators of GPT-3, one of the pioneering LLMs. They deliberately chose to upsample Wikipedia content due to its perceived high quality.
๐งฉ The challenge: As the volume and diversity of data used in LLM pre-training continue to expand, the task of determining the ideal data mixture becomes increasingly complex. And the manual approach to data selection may result in suboptimal choices.
๐ฌ Key research question: How can we decide on a high-performing data mixture for training LLMs in a scalable and automatic manner?
Gao et al. 2020. The Pile: An 800GB Dataset of Diverse Text for Language Modeling, https://arxiv.org/abs/2101.00027
Core idea: small to large generalization
๐กWith the challenge of selecting the optimal data mixture in mind, our core idea is straightforward: train and identify the best-performing small-scale models using different data mixtures, and then directly generalize those findings to large-scale model training.
RegMix: Data Mixture as Regression
Concretely, our method RegMix treats data mixture selection as a regression task. Here's how it works:
- Train some small-scale proxy models on various data mixtures for few tokens ๐ฃ
- Fit a regression model using these results ๐
- Use the regression model to predict the best mixture for large-scale training ๐ฎ
- Train the large-scale model on this optimized mixture ๐
The procedure of small-scale proxy model training requires only ~2% of the computational cost (in FLOPs) of the final large-scale model training.
To visualize the procedure, we provide a concrete example using Hacker News, GitHub, and PhilPapers as the training domain. The validation loss on StackExchange is used as the target metric to optimize during the proxy model training phase.
Regression Works Well Across Model Scales
๐๏ธ What's particularly exciting about RegMix is its efficiency. It allows you to explore a vast space of potential mixtures (even with 40+ domains) by training only a small number of models.
Specifically, training models on 1M models with 1B tokens can predict the performance of 256x 1M models trained on unseen data mixtures with 98.45% correlation.
Moreover, RegMix can automatically identify the best-performing data mixture among 64x 1B models with 25B tokens before actually training them๐ก๐ฐ.
Insight 1: Data mixture significantly impacts downstream performance
We experiment with 64 models, each with 1B parameters trained on different data mixtures, and evaluate their performance across various benchmarks. The results show that data mixture significantly impacts downstream performance - up to 14.6% difference on some tasks! ๐ฎ
Insight 2: Web corpora benefits downstream performance the most
Web corpora like CommonCrawl ๐ surprisingly show the strongest positive correlation with downstream performance for language models, even more than curated sources like Wikipedia! ๐ This pattern holds across most web domains, suggesting the diversity of CommonCrawl drives today's LM success. ๐
Moreover, whether it's gaming sites like IGN ๐ฎ or YouTube ๐บ, they exhibit similar patterns. But http://patents.google.com ๐ and http://springer.com ๐ seem to follow different trends.
Insight 3: Domain interactions are challenging for humans to understand
Domain interactions are complex and often counterintuitive, highlighting the need for automated approaches like RegMix. ๐งฉ
For example, the PhilPapers domain appears to provide gains for all other domains under linear regression modeling, which challenges intuitive human understanding. ๐คฏ๐ So, what is PhilPapers? It is a database for philosophy โฆ
RegMix considers the token availability
๐Previous data mixture methods struggle to balance token availability and usefulness. However, RegMix can easily control token availability by controlling the simulation space - especially considering the 4 epoch practise by Niklas et al. 2023.
๐ฌFor example, you can easily set the maximum weight of HackerNews to 12% in the simulation if you can afford to repeat it for 4 epochs and its token count is 3% compared to your expected training tokens.
Niklas et al. 2023. Scaling Data-Constrained Language Models, https://arxiv.org/abs/2305.16264
RegMix is already applied in 14B model
๐ฌ While our current paper was conducted to models under 1B parameters due to computational limitations, we successfully applied the same data mixture approach in our Sailor paper (Dou et al. 2024).
๐ Notably, we discovered that the optimal data mixing strategy identified using 0.5B proxy model demonstrated impressive scalability, performing effectively across models up to 14B parameters! ๐ช
Dou et al. 2024. Sailor: Open Language Models for South-East Asia, https://arxiv.org/abs/2404.03608 You can also find the paper at https://huggingface.co./papers/2404.03608
Try RegMix on your dataset
We also provide an instruction on how to apply the RegMix method to your dataset, and please try it and leave comments here!