Weight tying does not imply embedding tying

Small LLMs

The first "L" in LLM stands for "Large" and that is meant quite literally: modern large language models commonly have several billion, or even up to a trillion parameters. [1] Those parameters are distributed over the different parts of the model, mainly in the embedding layer, attention calculation, MLPs and a final prediction layer.

To scale down a LLM to e.g., just a few billion parameters, there exist a few common strategies: - Reduce layers: The most obvious one is to simply reduce the number of Transformer blocks. While it empirically was not validated that scaling of LLMs is completly arichtecture-agnostic as proposed by Kaplan et. al. [2], there seems to be diminishing returns after just a couple of layers [3][4]. - Alternatives to Multi-Head-Attention (MHA): While MHA works great and allows the model to focus on a lot of different things [5], it is also pretty parameter hungry. The most radical solution is to just keep one set of key and value projection matrices for many different queries [6], with the currently preferred solution being a compromise between both: Grouped-Query Attention [7]. This keeps almost all of the expressiveness of MHA but significantly reduces the amount of parameters in a LLM. - Weight tying: One thing left is the size of embedding and next token prediction layers (commonly referred to as lm_head), i.e., mappings from vocabulary tokens to embeddings and back. Given multi-language models with sub-word tokens, the vocabulary size can easily exceed 100k or even 200k. Given an embedding dimension of 3072 or 4096, this leads to huge matrices, which especially in smaller models take up a huge portion of the full model. E.g., for a vocab size of 256k and an embedding dimensionality of 3072 the embedding layer has already 786M parameters. Since the a matrix of same size is required for the output embedding, this leaves us with already over 1.5B paremeters, just for those two mappings. A common trick for smaller models is therefore to combine these two layers. They share a similar task anyways (both map between embeddings and vocabulary) and there is a huge saving potential. And the best thing is, it is actually quite simple and seems to work pretty well: it is commonly known as weight-tying. [8] This technique, and the effects it has on the embedding right before the lm_head layer, will be the main topic of this blog post.

Weight tying

The idea is quite simple: When we use the same weights for both mappings we get these two behaviours: 1. For the embedding layer we basically have a look-up table, since the input tokens are represented using one-hot encoded vectors over the full vocabulary (in practice we use the index instead of the fully materialized one-hot vector, but the concept stays the same). So each row of the embedding matrix corresponds to one token. 2. The task of the lm_head layer is to map the final token embedding to logits for all possible tokens of the vocabulary. If we just multiply the final token embedding with the lm_head matrix (which is the transposed embedding matrix), we get a vector of correct shape. Each entry in this vector is the result of a dot product between the final embedding and the token embedding at this index. That sounds intuitively not too bad, since the dot product between (close to) parallel vectors is high, for orthogonal ones zero and for opposing vectors negative.

But what does that mean for the final token embedding, right before the lm_head? Does the space of these tokens have to be similar to the one that the input token embeddings populate? Should the last token embedding given the context The capital of France is be similar to the input embedding for Paris? What does similarity even mean in these high-dimensional spaces?

Let's have a look at some of these questions.

Token decomposition

We start by having a detailed look at the final hidden token embedding, right before the lm_head. Is it similar to the embedding of the predicted token? Or more precisely, how are these two related?

In a weight-tied architecture the final logit for each token is calculated as a dot product between the last token embedding $h$ and each embedding vector $w$:

$$ h \cdot w = ||h|| \cdot ||w|| \cdot \cos(\theta) $$

where $\theta$ is the angle between $h$ and $w$. To further investigate the content of $h$ we can decompose it relative to a specific token into a parallel $h_{\parallel}$ and an orthogonal part $h_{\perp}$:

$$ h = h_{\parallel} + h_{\perp} $$

where $h_{\parallel}$ is the part of $h$ aligned with $w$ and $h_{\perp}$ lies in the null-space of $w$. The purpose of this decomposition is shown in the next equation:

$$ h\cdot w=(h_{\parallel}+h_{\perp})\cdot w=(h_{\parallel}\cdot w)+(h_{\perp}\cdot w)=(h_{\parallel}\cdot w)+0 $$

As shown, the orthogonal component of $h$ has no effect on the outcome of the dot product with $w$, so it is basically ignored (for this $w$, but for other embedding vectors as we will see).

Before further investigating the properties of final embedding vectors I thought this "unrelated" part to be relatively small, compare to the aligned part $h_{\parallel}$. So let's check this intuition.

Analysis of final hidden tokens

I took three recent, small language models and took there their final hidden token embedding $h$ and decomposed it as described in the previous section into $h_{\parallel}$ and $h_{\perp}$.

The first surprising thing was that the magnitude of $h_{\perp}$ is usually $4-6$ times larger than the magnitude of the relevant part $h_{\parallel}$. After a short research on this topic, I have some ideas why this might be the case: - Eventhough $h_{\perp}$ has a large magnitude, it is not well aligned with any other embedding vector. While there are up to 200k embedding vectors, the high-dimensionality of the embedding space makes it difficult to have a good alignment. - Embedding spaces usually do not have a uniform density distribution. Instead, they are anisotropic and only occupy a narrow cone [9], and while this seems suboptimal, it does not significantly decrease performance and can even help the models [10]. In this case a vector $h_{\perp}$, perpendicular to one of the embedding vectors $w$, is also almost orthogonal to all other embedding vectors. - Final output embeddings are vastly different to the input embedding of the same token, since they contain additional context [11]. This is especially helpful

References

  1. Kimi K2: Open Agentic Intelligence
    2025
  2. Scaling laws for neural language models
    arXiv preprint arXiv:2001.08361, 2020
  3. The depth-to-width interplay in self-attention
    arXiv preprint arXiv:2006.12467, 2020
  4. The impact of depth on compositional generalization in transformer language models
    Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers), 2024
  5. Attention is all you need
    Advances in neural information processing systems, 2017
  6. Fast transformer decoding: One write-head is all you need
    arXiv preprint arXiv:1911.02150, 2019
  7. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
    Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, 2023
  8. Using the output embedding to improve language models
    Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 2, Short Papers, 2017
  9. Representation degeneration problem in training natural language generation models
    arXiv preprint arXiv:1907.12009, 2019
  10. Stable anisotropic regularization
    ICLR 2024, 2023
  11. How contextual are contextualized word representations? Comparing the geometry of BERT, ELMo, and GPT-2 embeddings
    arXiv preprint arXiv:1909.00512, 2019