Skip to content

Joint Selection for Large-Scale Pre-Training Data via Policy Gradient-based Mask Learning

License

Notifications You must be signed in to change notification settings

ByteDance-Seed/DATAMASK

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DATAMASK

English | 中文README

Joint Selection for Large-Scale Pre-Training Data via Policy Gradient-based Mask Learning

Version License

Ziqing Fan1,2 , Yuqiao Xian1,* , Yan Sun3, Li Shen4

1 ByteDance Seed, 2 Shanghai Jiao Tong University, 3 University of Sydney, 4 Sun Yat-sen University Shenzhen Campus.

📖 Overview

DATAMASK PIPELINE

Motivation In this study, we revisit metric-based selection and observe that selecting samples based on quality metrics (FineWeb-Edu, Ultra-FineWeb, and FineWeb-DCLM) shows severe diminishing returns during long-term pre-training, while selecting based on diversity metrics (FineWeb-Semdedup) removes too many valuable high-quality samples, both of which limit the capabilities of pre-trained LLMs.

Method: DATAMASK To solve the problem, as pipeline shown above, we propose a novel and efficient optimization framework for large-scale pre-training data selection that can simultaneously optimize multiple types of metrics in a unified process. It approaches the selection process as a mask learning problem, involving iterative sampling of data masks, computation of policy gradients based on predefined objectives with sampled masks, and updating of mask sampling logits.

Results: FineWeb-Mask Through policy gradient-based optimization and various acceleration enhancements, DATAMASK significantly reduces selection time by 98.9% compared to greedy algorithm (estimated on DiSF algorithm), enabling our study to explore joint learning within trillion-scale tokens. With DATAMASK, we select a subset of about 10% from the 15 trillion-token FineWeb dataset, termed FineWeb-Mask and achieves significant improvements after pre-training with hundreds of billions of tokens, demonstrating its effectiveness.

🚀 Quick Start

  1. Clone the repository:
git clone https://github.com/ByteDance-Seed/DATAMASK.git
cd DATAMASK
  1. Prepare your data Based on our original code, your data should be .parquet format. In the parquet file, you should provide idx, scores and text features. Idx like chunk_id is used to identify selected samples in your dataset during large scale distributed selecting. In our paper, the data looks like:
chunk_id quality_score feature_arr
0 8 size(768)
1 4 size(768)
... ... ...
  1. Define your optimization metrics In the code, we provide the implementation of three types of diversity scores, combined with quality score to perform the optimization. You can define your own number of metrics and the combination of metrics by modifying xx functions defined in utils/utils.py.

  2. Hyper-parameters In DATAMASK optimization, we introduce multiple hyper-parameters. We provided extensive ablations studies on them in the paper. Please refer to the paper for more details.

  • n_epochs: updating steps
  • partial: ratio of a batch of total samples for batch training.
  • algorithm: diversity score tyeps, options from ["DiSF","Facility","Pair_simi"]
  • max_lr and min_lr : initial learning rate, and final learning rate when reached n_epochs. We use linear scheduler.
  • select_ratio : selecting ratio
  • lamb: lambda for balance between quality and diversity
  • n_rollout: number of rollouts for each epoch
  • init : logit initialization strategy, options from same and quality
  1. Quick test After settle down data, optimization metrics, and hyper-parameters, you can try following codes as a quick test:
python3 train_mask.py \
--input your_input_path.parquet \
--output your_output_path.parquet \
--device cuda:0 \
--n_epochs 5000 \
--algorithm DiSF \
--partial 0.1 \
--max_lr 10 \
--min_lr 1 \
--n_rollout 128 \
--select_ratio 0.3 \
--init quality \
--lamb 0.5 ;

🔎 Text Feature Visualizations

Feature
To visualize the dilemma, we visualize the text embeddings via t-SNE on random subsets of FineWeb. White, light blue, and dark blue points correspond to samples that are top diverse,= top high-quality, and samples selected by algorithms that exhibit both high diversity and quality. Light blue points show tighter clustering. Dark blue points are sparse in algorithms except for ours. It means that, selecting samples based on quality scores (dark blue and purple data points) leads to tighter clustering compared to the raw data distribution, indicating higher semantic redundancy and reduced information diversity.

🌟 Optimization Curves and One Optimization Ablation

Here we show the optimization curves and one optimization ablation in terms of rollout number G. Ablation on choosing G. As shown in the figure, we record the optimized Facility Location values while tuning G in terms of computational time. Results show that a value of G that is too small causes the training to diverge, while a larger G incurs excessive computational costs. After tuning based on the three diversity metrics, we recommend G = 128 or 256, which is the smallest value that remains stable and yields near-optimal values across all cases. As for more ablations, please refer to the paper.

Optimization curves and one optimization ablation in terms of rollout number G

📈 Detailed Performance

In the following, we show detailed performance of each task during pre-training on the 1.5B dense model and 7B MoE model. The upper part is based on dense model, while the lower part is based on MoE model.

1.5B Dense RACE-H RACE-M HellaSwag NQ OBQA KQAPro MMLU TrivalQA ARC-Challenge SIQA PIQA WinoGrande
FineWeb 40.8 51.2 58.9 11.1 48.6 45.1 33.9 35.2 31.9 48.9 75.1 58.5
FineWeb-Semdedup 40.5 51.7 57.2 11.0 45.0 43.5 33.4 30.2 31.6 48.7 74.9 57.9
FineWeb-Edu 42.3 51.4 57.6 12.1 53.6 42.5 37.8 37.8 44.5 49.2 74.2 59.0
UltraFineWeb-en 41.8 53.0 57.8 10.9 50.6 42.2 37.2 30.6 44.2 48.9 75.9 57.1
FineWebPro 43.1 52.2 61.3 12.4 51.5 42.0 36.2 38.6 43.0 50.1 75.2 61.0
FineWeb-DCLM 43.6 52.9 61.4 11.1 48.2 43.4 34.8 37.8 40.7 50.4 76.2 61.4
FineWeb-Mask (Ours) 43.8 53.7 56.4 14.1 51.4 47.0 36.5 47.3 40.9 51.4 74.4 59.8
7B MoE RACE-H RACE-M HellaSwag NQ OBQA KQAPro MMLU TrivalQA ARC-C SIQA PIQA Wino
FineWeb 42.1 54.1 69.9 17.9 53.6 49.8 35.8 53.8 40.7 52.8 78.6 63.7
FineWeb-Semdedup 41.4 55.3 67.7 17.0 53.0 49.3 35.6 49.1 38.4 49.4 77.4 65.1
FineWeb-Edu 42.8 56.5 65.8 16.6 53.8 48.5 41.3 50.4 50.0 50.4 77.1 64.5
UltraFineWeb-en 42.4 54.0 64.9 13.9 55.4 42.0 41.4 38.6 49.3 49.7 78.3 60.9
FineWebPro 42.8 55.8 68.9 15.5 55.2 46.3 40.0 49.8 47.9 52.8 77.9 64.3
FineWeb-DCLM 43.2 54.8 69.3 18.7 54.0 49.8 40.4 54.6 47.0 52.1 78.5 64.0
FineWeb-Mask (Ours) 42.8 55.4 66.1 19.5 55.0 51.4 39.5 61.7 45.9 51.2 77.5 65.1

📚 Citation

@misc{fan2025jointselectionlargescalepretraining,
      title={Joint Selection for Large-Scale Pre-Training Data via Policy Gradient-based Mask Learning}, 
      author={Ziqing Fan and Yuqiao Xian and Yan Sun and Li Shen},
      year={2025},
      eprint={2512.24265},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2512.24265}, 
}

📧 Contact

📄 License

Apache 2.0

About

Joint Selection for Large-Scale Pre-Training Data via Policy Gradient-based Mask Learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published