Machine Learning

Notes on Attention Mechanisms

January 2025

These are running notes — updated as I read more. Not a polished explainer. More like thinking out loud.


What "attention" actually means

The word "attention" in transformers is a bit of a metaphor that got reified. The original insight in Bahdanau et al. (2014) was simple: instead of squishing the entire input sequence into a single fixed-size vector before decoding, let the decoder look back at all the encoder states at each step and weight them.

The weights are computed from a compatibility function between the decoder's current state and each encoder state. High compatibility → high weight → more "attention" to that position.

What we call self-attention extends this idea: every position attends to every other position in the same sequence. The sequence attends to itself.

The Q, K, V framing

The queries, keys, and values framing comes from information retrieval. You have a query, you compare it to keys in a database, and you retrieve the associated values proportionally.

In practice:

  • Q (query): "What am I looking for?"
  • K (key): "What do I have to offer?"
  • V (value): "What do I actually contribute if selected?"

The attention weights are softmax(QKᵀ / √d_k), and the output is those weights multiplied by V.

The √d_k scaling factor is there to prevent the dot products from growing too large in magnitude (which would push softmax into regions with tiny gradients). It's easy to overlook in the formula but important in practice.

Multi-head attention

Running one attention operation gives you one "perspective" on how tokens relate to each other. Multi-head attention runs several attention operations in parallel (each with its own Q, K, V projections), then concatenates and projects the results.

Different heads learn to attend to different things. Some heads track syntactic structure; others track coreference; others track positional relationships. This isn't explicitly programmed — it emerges from training.

Observations from reading papers

Attention is not explanation. A common mistake is treating attention weights as a causal explanation of model behavior. High attention weight on a token doesn't mean that token caused the prediction — it means the model found it relevant during that forward pass. These are different things. Jain & Wallace (2019) make this point rigorously.

Sparse attention scales better. Full self-attention is O(n²) in sequence length. For long sequences (documents, code), this becomes the bottleneck. Sparse variants (Longformer, BigBird, Sliding Window) restrict which positions can attend to each other, trading some expressiveness for better scaling.

Flash Attention is algorithmic, not architectural. It computes the same thing as regular attention but reorders operations to avoid materializing the full attention matrix in HBM. This is a systems optimization, not a model change. Worth understanding if you care about training efficiency.

Position encodings are underrated. Transformers have no built-in notion of sequence order — that has to be injected. Sinusoidal encodings (original Transformer), learned absolute positions, RoPE (rotary position embedding), ALiBi — each has trade-offs. RoPE's ability to generalize beyond training context length is particularly interesting.

Things I don't fully understand yet

  • Why does attention head pruning work so well? Why do models train so many redundant heads?
  • The precise relationship between attention patterns and induction heads in in-context learning
  • How cross-attention in diffusion models (like cross-attending to text embeddings) really works at the mechanistic level

Last updated: January 2025. Will add sections on mixture of experts and state space models when I get there.