Skip to main content

Characterizing Datasets and Building Better Models with Continued Pre-Training


Share this post
Header graphic with interlinked squares and circles

While large language models (LLMs) are increasingly adept at solving general tasks, they can often fall short on specific domains that are dissimilar to the data they were trained on. In such cases,  how do you effectively and efficiently adapt an open-source LLM to your needs? This can be challenging due to the many decisions involved, such as training methods and data selection. This blog post will explore one method for customizing LLMs — Continued Pre-Training (CPT) — and provide guidance on executing this process effectively. Furthermore, we consider how CPT can be used as a tool to efficiently characterize datasets, i.e. better understand which evaluation metrics are helped, hurt, or unaffected by the data.

Effective CPT requires attention to three key hyperparameters: (1) learning rate, (2) training duration, and (3) data mixture. In addition, simple weight averaging is an easy way to mitigate forgetting caused by CPT. This blog outlines these processes from start to finish to help you unlock the most value from your CPT runs.

Continued Pre-Training vs. Fine-Tuning

What is Continued Pre-Training (CPT), and how does it differ from fine-tuning?

When working with a new, specific domain (eg. a medical domain) that was not well represented in a model’s pre-training corpus, the model might lack factual knowledge critical to performing well on that domain. While one could pre-train a new model from scratch, this is not a cost-effective strategy as pre-trained models already possess many core language and reasoning capabilities that we want to leverage on the new domain. Continued Pre-Training refers to the cost effective alternative to pre-training. In this process, we further train a base pre-trained LLM on a large corpus of domain-specific text documents. This augments the model’s general knowledge with specific information from the particular domain. The data typically consists of large amounts of raw text, such as medical journals or mathematical texts.

Fine-Tuning, on the other hand, involves training a language model on a much smaller, task-specific dataset. This dataset often contains labeled input-output pairs, such as questions and answers, to align the model’s behavior to perform a specific, well-defined task. While a CPT dataset might contain billions of tokens of raw, unstructured text, a fine-tuning dataset will contain millions of tokens of structured input-output pairs. This is often not a sufficient amount of data to teach a model factual information from a completely new domain. In this case, it would be more effective to fine-tune for style and alignment after CPT. 

In this post, we focus on the case of continued pre-training. We demonstrate how CPT can enhance a small LLM’s factual knowledge performance to match that of a much larger LLM. We will outline the entire process for:

  1. Showing how to optimize hyperparameters.
  2. Measuring the effect of different datasets.
  3. Developing heuristics for mixing datasets.
  4. Mitigating forgetting.

Finally we consider how the performance gains from continued pre-training scale with training FLOPS, a measure of the amount of compute used to train the model.

How to do Continued Pre-training

Task and Evaluation

For our experiments, we will evaluate our model on the MMLU benchmark, which tests the model’s ability to recall a wide range of facts. This benchmark provides a stand in for the general process of factual knowledge acquisition in LLMs. 

In addition to the MMLU benchmark, we will monitor the Gauntlet Core Average[1], which averages a large set of language modeling benchmarks. This allows us to track the core language and reasoning capabilities of the model, ensuring it doesn’t lose skills in reading comprehension and language understanding, which are essential for other downstream tasks. Monitoring Core Average is a great way to keep track of forgetting in LLMs.

Models

We aim to see if we can start with a Llama-2-7B base model and elevate its performance to match that of a Llama-2-13B base model using CPT. To study CPT across model scales, we also demonstrate its efficacy at improving  Llama-2-13B and Llama-2-70B.

Example Datasets

For this experiment, we considered 5 datasets that we intuited could potentially help MMLU: OpenWebMath, FLAN, Wikipedia, Stack Exchange, and arXiv . These datasets, ranging from 8B to 60B tokens, were chosen for their high-quality sources and dense information to maximize general knowledge exposure.

Hyperparameters: The Key to Performance

When further training open-source base models, two critical hyperparameters are the learning rate (LR) and the training duration. The optimal values for these hyperparameters can vary based on the model, dataset size, dataset composition, and benchmark. Therefore, it is essential to sweep both hyperparameters while iterating.

We use the following procedure to set these hyperparameters for OpenWebMath. We swept the LR for 15B tokens with values of 10e-6, 3e-6, 10e-5, and 3e-5. In Figure 1a, we can see that the accuracy on MMLU can vary by as much as 5 percentage points based on the LR, indicating the importance of this hyperparameter. Typically, 1B to 10B tokens are sufficient for identifying the optimal LR.

Charts comparing MMLU scores achieved through varying the learning rate and number of training tokens
Figure 1. Performance of continued pre-training with OpenWebMath on Llama-2-7B. (a) To determine which learning rate to use for CPT, we sweep the learning rate for CPT on OpenWebMath (~14.5B tokens) for 1 epoch. For each learning rate, a constant learning rate schedule with linear warmup (500 batches) and cooldown (700 batches) was used. There is a substantial difference in improving MMLU and preventing forgetting (as measured by Gauntlet Core Average) across learning rates, which emphasizes the importance of hyperparameter tuning when performing CPT. (b) A sweep over the number of tokens with the optimal learning rate. For each duration in the sweep, we again use a constant learning rate with linear warmup and decay. At a duration of 30B tokens (approximately 2 epochs of OpenWebMath), CPT produced a 6 percentage point improvement in MMLU with no degradation in Gauntlet Core Average.

After identifying the optimal LR, we trained on OpenWebMath for longer durations to determine the optimal training period (Figure 1B). In addition to measuring the performance on our target MMLU metric, we also measure the Core Average to monitor forgetting.

Capturing the impact of the dataset

We repeated the LR sweep (like the one shown in Figure 1A for OpenWebMath) for each of our five datasets, training for between 1B and 10B tokens each time. Surprisingly, only two of these high-quality datasets improved our model’s performance: OpenWebMath and FLAN. The other datasets reduced accuracy across all LRs. Notably, the optimal learning rates for the different datasets were not the same.

Figure 3 shows the duration sweep of the optimal learning rate for OpenWebMath and FLAN, the two datasets that resulted in MMLU improvement. The red horizontal dashed line is the performance of Llama-2-7B base, the model before training. The black horizontal dashed line is the performance of Llama-2-13B base, a model that is twice as big. Both datasets led to substantial improvements over Llama-2-7B base but had very different optimal durations. While one of our datasets led to improved performance at 8B tokens but worse performance with more training (pink line in Figure 2), the other dataset showed consistent performance improvements up to 40B tokens (blue line Figure 2). Additionally, monitoring our Core Average metric revealed that over-training on certain datasets could lead to forgetting. 

Thus, we see running LR sweeps at the 1B to 10B token regime is a fast and effective way to identify which datasets enhance model performance. This allows us to remove ineffectual datasets and eventually mix the beneficial datasets and train them for longer periods, making CPT an efficient tool for identifying useful datasets.

Charts comparing model performance impacts of different datasets
Figure 2. Continued pre-training of different datasets and data mixes. Each line shows a sweep of training duration with the optimal learning rate–the markers correspond to individual training runs for the duration on the x-axis with a full learning rate schedule. Of the 5 datasets considered, only the two that improved MMLU are included in this sweep: OpenWebmath (blue) and FLAN (pink). We then experimented with mixes of these two datasets. Mix 1 (orange) is 66% OpenWebMath and 34% FLAN and Mix 2 (green) is 84% OpenWebMath and 16% FLAN; overall, Mix 2 for 40B tokens achieved the highest performance on MMLU out of the datasets considered. Finally, the best performing model was linearly merged with the base model; the best performing merge is indicated by the red star and achieves comparable performance to Llama-2-13B.

Mixing Datasets for Better Performance

After identifying individual datasets that improve performance and their optimal training durations, we mixed them to achieve further improvements. We recommend a simple heuristic: mix them in the ratio of the number of tokens required for optimal performance for each dataset. For example, we found success mixing them in the ratio of 8:40, or 16% of the data comes from FLAN (pink) and 84% comes from OpenWebMath (blue).

This simple heuristic outperforms mixing datasets in a 1:1 ratio or merely concatenating them. With the new mixed dataset, we again swept the LR at 1B tokens. We then swept the training duration at the optimal learning rate. This resulted in our CPT model (orange line) at 40B tokens outperforming the Llama-2-7B base on both MMLU and Gauntlet Core Average, nearly matching the performance of the Llama-2-13B base.

Mitigating Forgetting with Model Soups

While the model trained for 40B tokens on our mix had the best performance on MMLU, it performed slightly worse on Core Average than models trained for a shorter duration. To mitigate forgetting, we used model souping: simply averaging the weights of two models that have the same architecture but are trained differently. We averaged the model trained for 40B tokens on the mixed dataset with Llama-2-7B base before CPT. This not only improved Core Average, reducing forgetting, but also enhanced performance on MMLU, resulting in our best model yet (red star in figure). In fact, this model matches or exceeds the performance of the Llama-2-13B base on both metrics.

Three charts that illustrate performance gains from model merging
Figure 3. Performance gains from model merging. (a) We seek to improve the model with the best MMLU score after continued pre-training (40B tokens of 85% OpenWebMath and 15% FLAN) by averaging with the base model. Here we plot the performance of the merged model for different values of the mixing coefficient alpha, and find that an alpha of 0.9 (0.9 times the CPT model + 0.1 times the base model) yields both the best MMLU and core average (red star). (b) Scaling plot illustrating the efficiency of continued pre-training. For only 40B additional tokens of training, we were able to push Llama-2-7B to MMLU and Core Average performance that matches or exceeds Llama-2-13B.

Does Continued Pre-training Scale Well?

Finally, we consider how well CPT with OpenWebMath scales to models at larger FLOP scales. We repeat the learning rate and duration sweep with OpenWebMath performed for Llama-2-7B above, but now with Llama-2-13B and Llama-2-70B. As shown in Figure 4, we continue to  see improvements at the 10^24 FLOP scale, and the scaling curve indicates that we could potentially see gains at even higher FLOPS. Each marker for the CPT runs represents the best MMLU performance following the learning rate and duration sweep.

Chart comparing CPT vs non-CPT model performance on MMLU as training FLOPs are increased
Figure 4. Scaling of performance gains for CPT with OpenWebMath. Note that we are now using error instead of accuracy as is customary in scaling law plots. Each CPT model is the result of a learning rate and duration sweep with training on just the OpenWebMath dataset. While there are diminishing returns we still see strong performance gains up to the 1024 FLOP scale. For Llama-2-13B we obtain a 3.3 percentage point improvement on MMLU without degradation on Core Average, and for Llama-2-70B we obtain a 1.8 percentage point improvement without degradation.

Conclusion

In this blog post, we explored the process of Continued Pre-Training (CPT) to enhance a small LLM’s general knowledge performance to that of a larger model. We demonstrated how to effectively sweep hyperparameters, identify beneficial datasets, and mix datasets for improved performance. Additionally, we discussed strategies to mitigate forgetting through model souping. By following these guidelines, you can leverage CPT to quickly measure if different datasets are effective at teaching models new information as well as customize and enhance your LLMs efficiently, achieving remarkable performance improvements.

An important consideration is the success of CPT is likely to be dependent on the original pre-training data mix. For example, because OpenWebMath was released after the Llama-2 family, our continued pre-training introduced the model to a novel mix of high-quality mathematical data, and the results could potentially be altered if OpenWebMath was included in the pre-training corpus. Regardless, the results demonstrate the ability of CPT to adapt a model to novel data in a FLOP efficient manner.


[1] In this blog, reported scores are Gauntlet v0.2 core average. In a recent blog post Calibrating the Mosaic Evaluation Gauntlet, we discussed our process of building the Gauntlet v0.3 core average in which we removed several evals based on poor scaling with training FLOPS. The v0.2 and v0.3 scores will be similar but should not be directly compared.