Attention Mechanism: GQA, FlashAttention, and Paged Attention

This note keeps the core attention math, then separates three ideas that are easy to mix up in LLM systems: architecture, kernel / execution schedule, and KV-cache layout. The goal is to make the current file easier to study, not just longer.

Scaled dot-product attention Multi-head attention Grouped-query attention FlashAttention Paged attention Autoregressive inference

Table of contents

1. Study map big picture first

Attention is the math. GQA changes how heads are parameterized, FlashAttention changes how the same attention is executed, and paged attention changes how KV cache memory is organized.
Architecture. GQA changes the model itself by using fewer key/value heads than query heads.
Kernel. FlashAttention keeps exact attention but changes the compute schedule to reduce memory traffic.
Serving layout. Paged attention keeps the formula and stores cached K/V in fixed-size pages.
Diagram separating architecture changes, FlashAttention kernel tiling, and paged KV-cache layout
This is the most important distinction in the whole note. When a paper or system mentions one of these methods, first ask which layer it changes: model weights, compute schedule, or memory layout.
Common confusion 1. FlashAttention does not mean a different attention formula. It is still exact softmax attention.
Common confusion 2. Paged attention is not a new attention score. It is a way to store and fetch cached keys and values.
Common confusion 3. GQA is not only a systems trick. It changes the learned parameterization because K/V heads are shared.

2. Core attention math foundation

Let sequence length be n. Queries, keys, and values are

Main formula
\[ Q \in \mathbb{R}^{n \times d_k}, \qquad K \in \mathbb{R}^{n \times d_k}, \qquad V \in \mathbb{R}^{n \times d_v} \] \[ \mathrm{Attention}(Q,K,V) = \mathrm{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

Token-wise view

For query token i, attention first scores all key tokens:

\[ s_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} \]

Then it normalizes those scores row-wise:

\[ \alpha_{ij} = \frac{\exp(s_{ij})}{\sum_m \exp(s_{im})} \]

And finally computes a weighted sum of values:

\[ o_i = \sum_j \alpha_{ij} v_j \]

What the matrix form is doing

1
QK^T compares every query with every key.
2
Softmax converts raw scores into attention weights that sum to 1 across each row.
3
Multiplying by V mixes information from tokens that received high weight.

Why divide by \(\sqrt{d_k}\)

The dot product grows with dimension. If scores become too large, softmax becomes too sharp and gradients become less useful.

Variance intuition
\[ q \cdot k = \sum_{r=1}^{d_k} q_r k_r \] \[ \mathrm{Var}(q \cdot k) \approx d_k \qquad \Longrightarrow \qquad \mathrm{Std}(q \cdot k) \approx \sqrt{d_k} \] \[ s_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} \qquad \Longrightarrow \qquad \mathrm{Var}(s_{ij}) \approx 1 \]
Study takeaway. The scaling factor is mainly about keeping score magnitudes stable as head dimension grows, so softmax stays trainable.

Multi-head attention

Instead of using one attention map, the model uses multiple heads that learn different projections of the same input.

Multi-head attention
\[ Q_h = QW_h^Q, \qquad K_h = KW_h^K, \qquad V_h = VW_h^V \] \[ \mathrm{head}_h = \mathrm{Attention}(Q_h, K_h, V_h) \] \[ \mathrm{MultiHead}(Q,K,V) = \mathrm{Concat}(\mathrm{head}_1, \dots, \mathrm{head}_H)W^O \]

3. Multi-head and grouped-query attention architecture change

Standard multi-head attention gives each query head its own key head and value head. GQA keeps many query heads but uses fewer K/V heads, so several query heads share the same K/V stream.

Comparison of standard multi-head attention, grouped-query attention, and multi-query attention
GQA sits between full multi-head attention and multi-query attention. It preserves multiple query views while shrinking the number of K/V streams that must be cached during inference.

Definition and shapes

Let \(H_q\) be the number of query heads and \(H_{kv}\) the number of key/value heads. Define the group size

\[ g = \frac{H_q}{H_{kv}}. \]

Then the projected tensors are

\[ Q = XW^Q \in \mathbb{R}^{n \times H_q \times d_h} \] \[ K = XW^K \in \mathbb{R}^{n \times H_{kv} \times d_h}, \qquad V = XW^V \in \mathbb{R}^{n \times H_{kv} \times d_h} \]

For query head \(h\), the shared K/V head index is

\[ r(h) = \left\lfloor \frac{h}{g} \right\rfloor. \]

Per-head output

GQA head computation
\[ O_h = \mathrm{softmax}\!\left(\frac{Q_h K_{r(h)}^T}{\sqrt{d_h}}\right)V_{r(h)} \] \[ O = \mathrm{Concat}(O_0, O_1, \dots, O_{H_q-1})W^O \]
Main idea. The model still emits one output per query head. What shrinks is the number of distinct K/V streams that must exist and be cached.

Special cases

Choice Meaning
\(H_{kv} = H_q\) Standard multi-head attention
\(1 < H_{kv} < H_q\) Grouped-query attention
\(H_{kv} = 1\) Multi-query attention

Why GQA matters at inference time

During autoregressive generation, cached keys and values dominate memory cost. Per token, per layer, the number of cached elements is

\[ \text{KV elements per token} = 2 H_{kv} d_h. \]

For standard multi-head attention, this would be \(2 H_q d_h\). So

\[ \frac{\text{KV cache in GQA}}{\text{KV cache in MHA}} = \frac{H_{kv}}{H_q} = \frac{1}{g}. \]
Example. If \(H_q = 32\) and \(H_{kv} = 8\), then \(g = 4\). The cache is 4x smaller than standard multi-head attention.

4. FlashAttention exact but memory-aware

FlashAttention does not change the attention formula. It changes how exact attention is executed so large intermediate matrices do not have to be materialized in slow memory.

Important distinction. If two implementations both use FlashAttention correctly, they compute the same exact attention result as the standard formula. The gain comes from lower memory traffic, not approximate math.
Simplified GPU memory hierarchy showing SRAM, HBM, and CPU DRAM with bandwidth and capacity differences
FlashAttention is motivated by this hierarchy: on-chip SRAM is very fast but small, HBM on the GPU is much larger but slower, and CPU DRAM is slower still. Good kernels try to keep active blocks in SRAM and avoid repeatedly writing large attention matrices back to HBM.

Why memory hierarchy matters

Modern GPUs can do a lot of arithmetic, but moving data is often the real bottleneck. In practice:

  • SRAM / shared memory is closest to the compute units and has the highest bandwidth, but it is tiny.
  • HBM is much larger and still on the GPU, but access is slower than on-chip memory.
  • CPU DRAM is larger again, but far too slow for repeatedly materializing attention intermediates during a kernel.
FlashAttention in one sentence. It reorganizes exact attention so score blocks are loaded, normalized, and consumed while they are still in fast on-chip memory instead of bouncing large matrices through HBM.

Standard computation

Same formula
\[ S = \frac{QK^T}{\sqrt{d_h}}, \qquad A = \mathrm{softmax}(S), \qquad O = AV. \]

The problem is that many naive implementations store \(S \in \mathbb{R}^{n \times n}\) and sometimes \(A \in \mathbb{R}^{n \times n}\) in global GPU memory. Those tensors are large, and moving them around is expensive.

Tiling idea

Split \(K\) and \(V\) into blocks of size \(B\):

\[ K = \begin{bmatrix} K^{(1)} \\ K^{(2)} \\ \vdots \\ K^{(T)} \end{bmatrix}, \qquad V = \begin{bmatrix} V^{(1)} \\ V^{(2)} \\ \vdots \\ V^{(T)} \end{bmatrix}. \]

For one query row \(q\), process those blocks one at a time instead of forming all scores at once.

Online softmax update

For block \(t\), compute local scores

\[ s^{(t)} = \frac{q K^{(t)T}}{\sqrt{d_h}} \in \mathbb{R}^{B}. \]

Maintain a running row maximum \(m\), denominator \(\ell\), and numerator \(z\):

\[ m^{(0)} = -\infty, \qquad \ell^{(0)} = 0, \qquad z^{(0)} = \mathbf{0} \in \mathbb{R}^{d_h}. \]
Block merge rule
\[ m^{(t)} = \max\!\left(m^{(t-1)}, \max_j s_j^{(t)}\right) \] \[ p_j^{(t)} = \exp\!\left(s_j^{(t)} - m^{(t)}\right) \] \[ \ell^{(t)} = e^{m^{(t-1)} - m^{(t)}}\ell^{(t-1)} + \sum_j p_j^{(t)} \] \[ z^{(t)} = e^{m^{(t-1)} - m^{(t)}} z^{(t-1)} + \sum_j p_j^{(t)} v_j^{(t)} \] \[ o = \frac{z^{(T)}}{\ell^{(T)}}. \]
Notation check. Here \(p_j^{(t)}\) is not yet a normalized attention probability. It is the block-local exponential weight \(\exp(s_j^{(t)} - m^{(t)})\). Also, \(v_j^{(t)} \in \mathbb{R}^{d_h}\) is the j-th value vector inside block \(t\). So \(z^{(t)}\) is a running weighted sum of value vectors, while \(\ell^{(t)}\) is the running scalar denominator.
Why it is exact. The running max and denominator let the kernel merge blocks without losing the softmax normalization over the whole row.
Why it is fast. Arithmetic complexity stays roughly \(O(n^2 d_h)\), but intermediate memory use becomes much more IO-friendly because blocks are consumed before they are written out in full.

5. Paged attention KV-cache layout

Paged attention also keeps the same attention formula. It changes how the KV cache is stored during autoregressive decoding, especially when a serving system must handle many sequences with different lengths.

Standard cache view

At generation step \(t\), one layer stores keys and values for all previous tokens:

\[ K_{\mathrm{cache}}, V_{\mathrm{cache}} \in \mathbb{R}^{t \times H_{kv} \times d_h}. \]

If this cache is one long contiguous tensor, growing it for many sequences is awkward and can fragment memory.

Split the cache into pages

Choose a page size \(B\) tokens and split the sequence dimension into pages:

\[ K^{(p)}, V^{(p)} \in \mathbb{R}^{B \times H_{kv} \times d_h}, \qquad p = 0, 1, \dots, P-1, \qquad P = \left\lceil \frac{t}{B} \right\rceil. \]

Logical pages and physical pages

For each sequence \(s\), maintain a page table

\[ \beta_s = [b_s^{(0)}, b_s^{(1)}, \dots, b_s^{(P-1)}]. \]

Token index \(\tau\) maps to page and offset by

\[ p = \left\lfloor \frac{\tau}{B} \right\rfloor, \qquad o = \tau \bmod B. \]

The actual cached entries read by head \(h\) are

\[ k_{s,\tau,h} = K^{\mathrm{phys}}_{b_s^{(p)}}[o,h,:], \qquad v_{s,\tau,h} = V^{\mathrm{phys}}_{b_s^{(p)}}[o,h,:]. \]
Interpretation. Paged attention adds one level of indirection: logical token index to page id and offset to physical cache block.

The attention math is unchanged

Same decoding rule
\[ \alpha_{s,t,\tau,h} = \frac{\exp\!\left(q_{s,t,h} \cdot k_{s,\tau,h} / \sqrt{d_h}\right)} {\sum_{u=0}^{t} \exp\!\left(q_{s,t,h} \cdot k_{s,u,h} / \sqrt{d_h}\right)} \] \[ o_{s,t,h} = \sum_{\tau=0}^{t} \alpha_{s,t,\tau,h} v_{s,\tau,h}. \]

Why serving systems like it

  • Less memory fragmentation.
  • No huge reallocation when a sequence grows.
  • Easier continuous batching across requests with different lengths.
  • Possible prefix sharing when multiple requests reuse the same prompt pages.

How it composes with FlashAttention

A kernel can read one page or one block of K/V at a time and still apply the same online-softmax update used by FlashAttention.

Short version. Paged attention decides where cached K/V live. FlashAttention decides how exact attention is computed over them efficiently.

6. Position and causality sequence order

Self-attention by itself is permutation-equivariant, so the model needs extra structure to know token order and to prevent future-token leakage during autoregressive decoding.

Sinusoidal positional encoding

\[ \mathrm{PE}(pos, 2i) = \sin\!\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right) \] \[ \mathrm{PE}(pos, 2i+1) = \cos\!\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right) \] \[ x_{pos} = \mathrm{Embed}(token_{pos}) + \mathrm{PE}(pos). \]

The base 10000 spreads wavelengths across dimensions so some coordinates capture short-range order while others vary slowly over long distances.

Causal masking

\[ M_{ij} = \begin{cases} 0 & j \le i \\ -\infty & j > i \end{cases} \] \[ S = \frac{QK^T + M}{\sqrt{d_k}}. \]

That mask enforces the autoregressive rule that token \(t\) may attend only to positions \(<= t\).

7. Summary and self-check revision

Method What changes? Does the formula change? Main gain
GQA Head sharing pattern: fewer K/V heads than query heads Yes, because the architecture uses shared K/V projections Smaller KV cache while keeping many query heads
FlashAttention Execution schedule and tiling No Lower memory traffic for exact attention
Paged attention KV-cache allocation and lookup No Efficient long-context serving across many requests