Mastering Tensor Dimensions in Transformers

Hugging Face blog post by Not Lain, published 2025-01-12. Step-by-step tensor shape walkthrough for a decoder-only transformer.

Abstract

This article traces the exact tensor shape at every stage of a decoder-only transformer, from raw token IDs through embedding, masked multi-head attention, feed-forward layers, and the language model head. Using a worked example (input [batch=1, seq_len=4]) with concrete numbers (embed_dim=768, 8 heads, head_size=96, vocab_size=9735), it shows why shapes must stay consistent across layers, how the embedding dimension propagates through the whole stack, and exactly where the K and V tensors arise — making it a direct prerequisite for understanding KV Caching. The article also covers cross-attention in encoder-decoder models, showing how differing source and target sequence lengths are reconciled.

Key Concepts

  • Decoder-only architecture: used in generative models (GPT family); each token attends only to itself and preceding tokens (causal masking)
  • Embedding layer: maps token IDs → dense vectors; output shape [batch, seq_len, embed_dim]
  • Positional encoding: adds position information without changing shape; injected before the attention stack
  • Multi-head attention (MHA): splits embed_dim across n_heads independent attention heads; each head operates on head_size = embed_dim / n_heads
  • Causal mask: lower-triangular binary matrix applied to attention logits; prevents each token from attending to future positions by setting future positions to -inf before softmax
  • Residual (skip) connection + layer norm: tensor before attention and tensor after attention are added, then normalised — preserves shape and prevents gradient vanishing
  • Feed-forward network (FFN): two linear layers with expansion-contraction: embed_dim → 3×embed_dim → embed_dim; introduces non-linearity and enables richer pattern capture
  • Language model head: final linear layer embed_dim → vocab_size; produces logits over vocabulary for next-token prediction
  • Cross-attention: in encoder-decoder models, Q comes from the decoder state while K and V come from the encoder output; handles differing source/target sequence lengths

Key Equations

Scaled dot-product attention (from Attention Is All You Need):

Scaling factor: . Without this, dot products grow with dimension and push softmax into saturated regions.

Complete Shape Walkthrough (embed_dim=768, heads=8, head_size=96)

StageOperationOutput shape
InputToken IDs[1, 4]
Embeddingnn.Embedding(vocab, 768)[1, 4, 768]
Positional encodingAdd position signal[1, 4, 768] (unchanged)
Q, K, V projectionnn.Linear(768, 768) × 3[1, 4, 768] each
Head splitReshape: 8 heads × 96 head_size[1, 4, 8, 96]
TransposeSwap seq_len and head dims[1, 8, 4, 96] (Q, K, V)
K transposedFor matrix multiply[1, 8, 96, 4]
QKᵀ[1,8,4,96] × [1,8,96,4][1, 8, 4, 4]
Causal mask + softmaxApply -inf to future, normalise[1, 8, 4, 4] (values changed)
Attention × V[1,8,4,4] × [1,8,4,96][1, 8, 4, 96]
Concat headsTranspose + reshape[1, 4, 768]
Projectionnn.Linear(768, 768)[1, 4, 768]
Residual + normSkip connection[1, 4, 768]
FFN expandnn.Linear(768, 3×768)[1, 4, 2304]
FFN contractnn.Linear(2304, 768)[1, 4, 768]
Residual + normSkip connection[1, 4, 768]
LM headnn.Linear(768, vocab_size)[1, 4, 9735]

Key invariant: every sub-layer in the decoder preserves [batch, seq_len, embed_dim], allowing arbitrary stacking of decoder blocks.

Cross-Attention Shape Example (encoder-decoder)

Source: I am at home → encoder output [1, 4, 768]
Target: <bos> je suis à la maison → decoder masked-MHA output [1, 6, 768]

In cross-attention: Q from decoder [1, 8, 6, 96], K and V from encoder [1, 8, 4, 96]:

Output after concat: [1, 6, 768] — matches the decoder sequence length, as expected.

Connections to Existing Wiki Pages

  • Why K and V are cached: shapes here show that K [1, 8, 4, 96] and V [1, 8, 4, 96] for past tokens are fixed once computed; only Q changes at each new decode step → see KV Caching Explained
  • Original architecture reference: Attention Is All You Need — source of the attention formula and transformer architecture this article annotates