Mastering Tensor Dimensions in Transformers
Hugging Face blog post by Not Lain, published 2025-01-12. Step-by-step tensor shape walkthrough for a decoder-only transformer.
Abstract
This article traces the exact tensor shape at every stage of a decoder-only transformer, from raw token IDs through embedding, masked multi-head attention, feed-forward layers, and the language model head. Using a worked example (input [batch=1, seq_len=4]) with concrete numbers (embed_dim=768, 8 heads, head_size=96, vocab_size=9735), it shows why shapes must stay consistent across layers, how the embedding dimension propagates through the whole stack, and exactly where the K and V tensors arise — making it a direct prerequisite for understanding KV Caching. The article also covers cross-attention in encoder-decoder models, showing how differing source and target sequence lengths are reconciled.
Key Concepts
- Decoder-only architecture: used in generative models (GPT family); each token attends only to itself and preceding tokens (causal masking)
- Embedding layer: maps token IDs → dense vectors; output shape
[batch, seq_len, embed_dim] - Positional encoding: adds position information without changing shape; injected before the attention stack
- Multi-head attention (MHA): splits
embed_dimacrossn_headsindependent attention heads; each head operates onhead_size = embed_dim / n_heads - Causal mask: lower-triangular binary matrix applied to attention logits; prevents each token from attending to future positions by setting future positions to
-infbefore softmax - Residual (skip) connection + layer norm: tensor before attention and tensor after attention are added, then normalised — preserves shape and prevents gradient vanishing
- Feed-forward network (FFN): two linear layers with expansion-contraction:
embed_dim → 3×embed_dim → embed_dim; introduces non-linearity and enables richer pattern capture - Language model head: final linear layer
embed_dim → vocab_size; produces logits over vocabulary for next-token prediction - Cross-attention: in encoder-decoder models, Q comes from the decoder state while K and V come from the encoder output; handles differing source/target sequence lengths
Key Equations
Scaled dot-product attention (from Attention Is All You Need):
Scaling factor: . Without this, dot products grow with dimension and push softmax into saturated regions.
Complete Shape Walkthrough (embed_dim=768, heads=8, head_size=96)
| Stage | Operation | Output shape |
|---|---|---|
| Input | Token IDs | [1, 4] |
| Embedding | nn.Embedding(vocab, 768) | [1, 4, 768] |
| Positional encoding | Add position signal | [1, 4, 768] (unchanged) |
| Q, K, V projection | nn.Linear(768, 768) × 3 | [1, 4, 768] each |
| Head split | Reshape: 8 heads × 96 head_size | [1, 4, 8, 96] |
| Transpose | Swap seq_len and head dims | [1, 8, 4, 96] (Q, K, V) |
| K transposed | For matrix multiply | [1, 8, 96, 4] |
| QKᵀ | [1,8,4,96] × [1,8,96,4] | [1, 8, 4, 4] |
| Causal mask + softmax | Apply -inf to future, normalise | [1, 8, 4, 4] (values changed) |
| Attention × V | [1,8,4,4] × [1,8,4,96] | [1, 8, 4, 96] |
| Concat heads | Transpose + reshape | [1, 4, 768] |
| Projection | nn.Linear(768, 768) | [1, 4, 768] |
| Residual + norm | Skip connection | [1, 4, 768] |
| FFN expand | nn.Linear(768, 3×768) | [1, 4, 2304] |
| FFN contract | nn.Linear(2304, 768) | [1, 4, 768] |
| Residual + norm | Skip connection | [1, 4, 768] |
| LM head | nn.Linear(768, vocab_size) | [1, 4, 9735] |
Key invariant: every sub-layer in the decoder preserves [batch, seq_len, embed_dim], allowing arbitrary stacking of decoder blocks.
Cross-Attention Shape Example (encoder-decoder)
Source: I am at home → encoder output [1, 4, 768]
Target: <bos> je suis à la maison → decoder masked-MHA output [1, 6, 768]
In cross-attention: Q from decoder [1, 8, 6, 96], K and V from encoder [1, 8, 4, 96]:
Output after concat: [1, 6, 768] — matches the decoder sequence length, as expected.
Connections to Existing Wiki Pages
- Why K and V are cached: shapes here show that K
[1, 8, 4, 96]and V[1, 8, 4, 96]for past tokens are fixed once computed; only Q changes at each new decode step → see KV Caching Explained - Original architecture reference: Attention Is All You Need — source of the attention formula and transformer architecture this article annotates