Summary of "Scaling Laws for Neural Language Models"

Summary of "Scaling Laws for Neural Language Models"

After yesterday's look at BERT[1] and the benefits of bidirectional attention, we are covering a different direction today. While previous papers focused on how to build the model (architecture), this paper[2] focuses on how big to make it, how much data to use, and how much compute is required to improve the loss. It marks the moment the industry moved from "architecture engineering" to "scale engineering."

Introduction

  • The end of tweaking: The authors empirically show that details like network width, depth, or number of attention heads don't matter much. Performance depends mostly on three things: Parameters, Data, and Compute.

  • Predictable scaling: They found that performance scales as a precise power law. You can predict the loss of a massive model by training small ones.

  • The "Overtraining" Hypothesis: A key finding (at the time) was that models should be very large but not trained to convergence.

  • Legacy: This paper defined the "GPT-3 era" of training, convincing OpenAI to build massive models on relatively smaller datasets (a view that was later corrected by Chinchilla[3]).

Abstract & Introduction

  • empirically show that model performance scales predictably as a power law with parameters, data and compute
  • other details, like width, depth, and number of attention heads, do not matter (within a range spanning multiple orders of magnitude)
  • finding, that most current models are overtrained and very large models only require a modest data set
  • power-law scalings for performance as a function of training time, context length, dataset size, model size, and compute budget
  • increasing model size by 8x, requires data set size increase by 5x
  • larger models are more sample efficient, reaching lower loss with same data

Setup

  • mainly decoder-only Transformer[4], comparison to LSTM
  • dataset is WebText2, vocab size 50257, context length 1,024
    • $1.62 \cdot 10^{10}$ words, resulting in $2.29\cdot 10^{10}$ tokens
  • Heuristics:
    • number of non-embedding parameters $N= 12 n_\text{layer}d^2_\text{model}$
    • required floating point operations per token $C\approx 6\cdot N$
  • 3,000 step linear warm-up with cosine decay to zero
  • total training $2,5\cdot 10^5$ steps with batch size $512$

Results

  • width-depth does not matter too much within $20\leq d_\text{model}/n_\text{layer} \leq 300$
  • number of attention heads is also rather robust, especially for $d_\text{model}=1,024$ (largest reported)
  • feedforward ratio $d_\text{ff}/d_\text{model}$ should be kept in range $[1;4]$
  • Transformer asymptotically outperforms LSTM, and shows continuous improvements for larger context, while LSTMs plateau after 100 token
  • critical batch size is independent of model size, only depends on loss

Power Laws

  • $L(N) = (N_\text{c}/N)^{\alpha_\text{N}} ; \alpha_\text{N} \sim 0.076, N_\text{c} \sim 8.8 \times 10^{13}$
  • $L(D) = (D_\text{c}/D)^{\alpha_\text{D}} ; \alpha_\text{D} \sim 0.095, D_\text{c} \sim 5.4 \times 10^{13}$
  • $L(C_\text{min}) = (C^\text{min}_\text{c}/C_\text{min})^{\alpha^\text{min}_\text{C}} ; \alpha^\text{min}_\text{C} \sim 0.050, C_\text{c}^\text{min} \sim 3.1 \times 10^{8}$
  • $B_\text{crit}(L) = \frac{B_\ast}{L^{1/\alpha_\text{B}}} ; \alpha_\text{B} \sim 0.21, B_\ast \sim 2 \times 10^8$
  • $L(N, D) = \left[ \left( \frac{N_c}{N} \right)^{\frac{\alpha_N}{\alpha_D}} + \frac{D_c}{D} \right]^{\alpha_D}$
  • $L(N, S) = \left( \frac{N_c}{N} \right)^{\alpha_N} + \left( \frac{S_c}{S_\text{min}(S)} \right)^{\alpha_S} ; \alpha_\text{S} \approx 0.76, S_c \approx 2.1\times 10^3$

Interpretations

  • $L(N)$: The loss goes down very slowly with respect to increases in number of model parameters
    • $\alpha_N \sim 0.076$: To improve loss by 10% just with increasing number of parameters we need $$\begin{align}L&\propto N^{-0.076} \\ \frac{L_{new}}{L_{old}} &= \left( \frac{N_{new}}{N_{old}} \right)^{-0.076} = 0.9 \\ \frac{N_{new}}{N_{old}} &= 0.9^{-\frac{1}{0.076}} \approx 0.9^{-13.16} \approx \mathbf{4.0}\end{align}$$ a 4x bigger model. For 20% a $0.8^{-13.16}\approx 18.9$x increase would be required
  • $L(D)$: For data the relationship is slightly better:
    • A 10% loss decrease caused just by increasing the dataset would require a $0.9^{-1/0.095}\approx 3.0$x the data (for 20% improvement a $0.8^{-10.5}\approx 10.4$x bigger dataset would be required)
  • $L(C)$: Improving the loss with respect to compute has the worst exponent so far, resulting in massive increases of computation required for small improvements:
    • For a 10% lower loss, $0.9^{-1/0.05}\approx 8.2$x the compute would be required, for a 20% lower loss even a $0.8^{-20}\approx 86.7$x increase!
  • $L(N, D)$: The loss depends on both number of parameters and size of dataset. If one is too small, an increase of the other factor has almost no effect on the loss (effectively bottlenecking any improvement)
  • $L(N, S)$: The same is true for model size and training steps. Notably, the exponent $\alpha_S$ is pretty large compared to others, indicating that less dramatic increases are required for lower loss values. This formula effectively forms the basis for their claim that models should not be trained to convergence ($\left( \frac{S_c}{S_\text{min}(S)} \right)^{\alpha_S}$ shrinks pretty fast), but instead the compute should be invested in a larger model (bringing down the first factor) trained slightly shorter
  • Even though they have not observed it, there has to be an end to the trends, since the loss would otherwise decrease to zero

Note

  • While these scaling laws are interesting and this paper shows empirical proof for them, they are no longer considered valid today. Instead the findings of the Chinchilla paper[3] (tomorrow's summary) showed that most models are undertrained and require a lot more data. Therefore modern models, even small ones, are usually trained on massive datasets.

Anki

Today there are only two cards, since most of the details are no longer that relevant (at least in my opinion). - Question: Kaplan: What is the proposed heuristic for number of non-embedding parameters $N$, given $n_\text{layer}$ layers and a model dimension $d_\text{model}$? - $N= 12 n_\text{layer}d^2_\text{model}$ - Question: Kaplan Scaling Laws: According to Kaplan, which factor scales loss the slowest (has the worst exponent)? - Compute (C). (It requires massive increases in compute for small loss gains).

References

  1. Bert: Pre-training of deep bidirectional transformers for language understanding
    Proceedings of the 2019 conference of the North American chapter of the association for computational linguistics: human language technologies, volume 1 (long and short papers), 2019
  2. Scaling laws for neural language models
    arXiv preprint arXiv:2001.08361, 2020
  3. Training compute-optimal large language models
    Proceedings of the 36th International Conference on Neural Information Processing Systems, 2022
  4. Attention is all you need
    Advances in neural information processing systems, 2017