Table of Contents
- The Core Problem: Why Standard Attention Breaks
- The Intuition: A Three-Part Detective Strategy
- The Underlying Math: Building the Formalism
- Dataset Preparation
- The Architecture: Macro View
- Step 5: The Rigorous Manual Walkthrough (SWA + CSA + HCA)
- Step 6: Full PyTorch Implementation (Hybrid Attention)
- DeepSeek-V4's Secret Weapon: Multi-head Latent Attention
The Core Problem: Why Standard Attention Breaks
Every senior engineer who has deployed large language models has hit this wall. The standard Multi-Head Attention (MHA) computation is:
Attention(Q, K, V) = Softmax(Q * K^T / sqrt(d)) * V
The shape of Q * K^T is [seq_len, seq_len]. For a sequence of n tokens, you compute n² dot products — and you store n² attention scores. At 1 million tokens, that's 10¹² operations and terabytes of intermediate activations.
The KV cache (storing the Key and Value tensors from all previous tokens so they don't need recomputation) makes this worse during inference. For standard MHA with h heads, d_head dimensions, and sequence length n:
KV Cache size = 2 * h * d_head * n * precision_bytes
At n = 1,000,000 tokens with 128 heads, 128-dim heads, and FP16: that's ~34 GB for the KV cache alone, on a single layer, for a single sequence.
The three-mechanism hybrid in DeepSeek-V4 is the engineering answer to this impossibility.
The Intuition: A Three-Part Detective Strategy
Imagine you are a detective solving a mystery by reading a 1,000-page book. You are currently on page 700.
In standard attention, every time you read a new word, you re-read every single word on every previous page to find connections. At 1 million words, this becomes impossibly slow. You run out of mental energy (compute) and desk space (GPU memory / KV cache).
DeepSeek-V4 solves this with a three-part reading strategy:
1. Sliding Window Attention (SWA) — The Short-Term Memory
You remember the exact, word-for-word text of the last 3 pages you just read. This is crucial because the immediate context typically contains the most relevant grammatical and logical flow. You don't need to summarize it — you have it verbatim.
2. Compressed Sparse Attention (CSA) — The Page Summaries
For the rest of the book, you don't keep every word. Instead, for every single page, you write a one-sentence summary (Compression). When you need a clue, you don't read all the summaries. You do a quick keyword scan (Sparse Indexing via the Lightning Indexer) and only read the top-K most relevant page summaries.
3. Heavily Compressed Attention (HCA) — The Chapter Summaries
For vast amounts of text, even page summaries are too many to scan. So you write a single summary for every entire chapter (Heavy Compression). Because there are so few chapters, you can read all the chapter summaries directly without filtering — the chapter level is small enough by definition.
Why all three? Each mechanism covers a different region of the token timeline at a different resolution:
| Mechanism | Coverage | Resolution | Indexing |
|---|---|---|---|
| SWA | Last n_win tokens |
Exact / lossless | None (always included) |
| CSA | All past tokens | Low-resolution summaries (m=4) | Lightning Indexer, Top-K |
| HCA | All past tokens | Very-low-resolution summaries (m=128) | None (too few blocks to filter) |
The Underlying Math: Building the Formalism
Phase 1: The Concrete Base Case
Let's compress $m = 2$ tokens into $1$ summary token. Suppose the two tokens have Key values:
- Token 1 Key: $K_1 = 4$
- Token 2 Key: $K_2 = 8$
The model assigns importance weights using a learned linear layer followed by a Softmax, so the weights sum to 1.0:
- $W_1 = 0.25$, $W_2 = 0.75$
The compressed Key $K_{comp}$ is the weighted sum:
$$K_{comp} = (W_1 \times K_1) + (W_2 \times K2)$$ $$K{comp} = (0.25 \times 4) + (0.75 \times 8) = 1 + 6 = 7$$
Two tokens (values 4 and 8) have been collapsed into a single summary token (value 7).
Phase 2: The Show-Your-Work Evolution
Now let's see how a new Query $Q = 2$ interacts with this summary versus with the original tokens.
Standard Attention Logic — individual keys:
- Score 1 = $Q \times K_1 = 2 \times 4 = 8$
- Score 2 = $Q \times K_2 = 2 \times 8 = 16$
- Total: 2 multiplications
Compressed Attention Logic — single summary key:
- Compressed Score = $Q \times K_{comp} = 2 \times 7 = 14$
- Total: 1 multiplication
By substituting $K_{comp}$ in place of the individual keys, we halved the computational work. Scaling this to $m = 128$ gives a 128× reduction in both memory (KV cache entries) and compute (dot products).
Phase 3: Generalization and Formalism
For a sequence of hidden states $H$, the complete pipeline formalizes as follows:
Step 1 — Compression (The Summary):
$$C^{Comp}i = \sum{j=mi}^{m(i+1)-1} S_j \odot C_j$$
This takes the $i$-th block of $m$ consecutive tokens, multiplies each token's raw Key-Value data $C_j$ by its learned importance weight $S_j$, and sums them into a single compressed block $C^{Comp}_i$.
Step 2 — Lightning Indexer (The Quick Scan for CSA):
$$I_{t,s} = Q^{indexer}_t \cdot K^{indexer}_s$$
The Query and the compressed block keys are first projected into a smaller indexer dimension $d{indexer}$. The dot product in that smaller space produces a scalar relevance score $I{t,s}$ between query $t$ and compressed block $s$.
Step 3 — Sparse Selection (Picking the Top-K):
$$C^{SprsComp}_t = { C^{Comp}s \mid I{t,s} \in \text{Top-k}(I_{t,:}) }$$
Only the $k$ compressed blocks with the highest indexer scores are retrieved from the full $C^{Comp}$ set. The rest are discarded for this query step.
Step 4 — Final Hybrid Attention:
$$\text{Output}_t = \text{Attention}(Q_t, \text{Concat}(C^{SprsComp}_t, K^{SWA}_t))$$
The query $Q_t$ attends over a concatenation of: (a) the top-K selected CSA blocks, (b) the direct SWA window keys, and optionally (c) all HCA blocks.
Symbol glossary:
| Symbol | Meaning |
|---|---|
| $H$ | Sequence of input hidden states |
| $m$ | Compression rate (e.g., 4 for CSA, 128 for HCA) |
| $S_j$ | Softmax-normalized importance weight for token $j$ |
| $C_j$ | Raw Key-Value data for token $j$ |
| $C^{Comp}_i$ | Single compressed Key-Value entry for block $i$ |
| $I_{t,s}$ | Index score between query at step $t$ and compressed block $s$ |
| $C^{SprsComp}_t$ | Selected top-$k$ compressed blocks for step $t$ |
| $K^{SWA}_t$ | Uncompressed sliding-window Keys |
| $\text{Concat}$ | Concatenation of selected compressed summaries with recent exact memory |
Key Takeaway: HCA is mathematically identical to CSA (same weighted-sum compression formula), but uses a much larger $m$ value and skips the Lightning Indexer step entirely — because compressing 128 tokens to 1 produces so few blocks that attending to all of them is cheap.
Phase 4: The Loss Function
All the compression intelligence in the world is useless if the model can't learn it. The training signal is the standard Cross-Entropy Loss for next-token prediction:
$$\mathcal{L} = -\sum_{i=1}^{V} y_i \log(\hat{y}_i)$$
- $V$: vocabulary size
- $y_i$: ground-truth label (1 for the correct next token, 0 for everything else)
- $\hat{y}_i$: the model's predicted probability for token $i$
The penalty mechanism: If the model predicts a very low probability for the correct word — say $\hat{y}_{correct} = 0.01$ — then $\log(0.01) \approx -4.6$, and the loss becomes a large positive penalty. This gradient propagates backward all the way through the attention mechanism, forcing the compression weights $S_j$ to learn which tokens actually matter for downstream prediction. A compressor that consistently discards important information will accumulate a larger penalty and converge to a better strategy over training.
Cross-Entropy is ideal here because language modeling is fundamentally a probability distribution problem: we want the model's predicted distribution $\hat{y}$ to perfectly match the one-hot true distribution $y$. Cross-Entropy directly measures the KL divergence between two probability distributions, making it the natural and optimal choice.
Phase 5: Where Do the Compression Weights Come From?
This is the most common point of confusion. The Softmax that generates compression weights $S_j$ is completely different from the Softmax inside core attention. They are independent mechanisms operating at different stages of the pipeline.
The confusion arises because standard Transformer papers use "Softmax" to refer exclusively to the attention score normalization step. DeepSeek-V4's compression introduces a second, earlier Softmax.
Here is exactly where the compression weights originate:
Step A — Project to "Importance" Scores ($Z$):
$$Z = H \cdot W^Z$$
$H$ is the raw hidden state (token embeddings). $W^Z$ is a dedicated, learnable weight matrix — not the Q, K, or V projection matrices. It is a "highlighter" that the network trains to recognize structurally important tokens (nouns, verbs, entities) versus filler (articles, punctuation, common conjunctions).
Step B — The Compression Softmax ($S$):
$$S = \text{Softmax}(Z)$$
This is applied within each block of $m$ tokens, forcing the weights across those $m$ tokens to sum to 1.0. The result is a normalized importance distribution per block.
Step C — The Weighted Sum ($C^{Comp}$):
$$C^{Comp} = \sum_{j=1}^{m} S_j \cdot C_j$$
The actual Key-Value data $C_j$ for each token is multiplied by its normalized importance $S_j$ and summed. High-importance tokens dominate the resulting summary; low-importance tokens fade.
The architectural pipeline, mapped to the attention diagram:
[Input H]
|
+----> [Calculate Z Scores via W^Z]
| |
| (1) [Compression Softmax] <-- SEPARATE from core attention
| |
| (Weighted Sum -> C^Comp)
| |
+-------------> | (Compressed K / V)
|
(2) [Core Attention Softmax] <-- THE familiar Softmax
^
[Q * C^T / sqrt(d)]
Why this is powerful: By using a separate learnable weight matrix $W^Z$ for compression, DeepSeek-V4 allows the network to learn its own compression policy. If adjectives consistently prove irrelevant for long-range dependencies, $W^Z$ will learn to output low $Z$ scores for adjectives, causing them to contribute near-zero weight to summaries. This emergent behavior isn't hand-coded — it's learned entirely through the cross-entropy loss signal.
In the PyTorch implementation, this separation is explicit:
def compress_kv(self, kv_tensor, m, weight_layer):
# weight_layer is a dedicated nn.Linear — NOT q_proj or k_proj
raw_scores = weight_layer(grouped_kv) # Step A: compute Z
weights = F.softmax(raw_scores, dim=2) # Step B: Compression Softmax
compressed_kv = torch.sum(weights * grouped_kv, dim=2) # Step C: weighted sum
return compressed_kv
Dataset Preparation
To train this architecture, we use standard autoregressive language modeling data.
- Features (X): A sequence of integer token IDs representing the text seen so far.
- Labels (Y): The same sequence, shifted one position to the future. The task is next-token prediction.
Concrete example:
Sentence: "The cat sat on the mat"
Vocabulary: {"The": 1, "cat": 2, "sat": 3, "on": 4, "the": 5, "mat": 6}
- Input Feature Sequence:
[1, 2, 3, 4, 5]— "The cat sat on the" - Target Label Sequence:
[2, 3, 4, 5, 6]— "cat sat on the mat"
When the model processes [1, 2, 3], it compresses them using CSA and HCA, attends via the hybrid KV cache, and must predict token 4.
The Architecture: Macro View
+-------------------------------------------------------------+
| 1. INPUT SEQUENCE |
+-------------------------------------------------------------+
|
(Hidden States)
v
+-------------------------------------------------------------+
| 2. QUERY / KEY / VALUE PROJECTION |
+-------------------------------------------------------------+
| | |
(Recent Tokens) (All Past Tokens) (All Past Tokens)
| | |
v v v
+----------------+ +-----------------+ +------------------+
| 3. SLIDING | | 4A. CSA | | 4B. HCA |
| WINDOW (SWA)| | COMPRESSOR | | COMPRESSOR |
| (Keep last | | (Compress m=4) | | (Compress m=128) |
| n_win tokens)| +-----------------+ +------------------+
+----------------+ | |
| (Compressed Blocks) (Heavily Compressed)
| v |
| +-----------------+ |
| | 5. LIGHTNING | |
| | INDEXER | (No indexer! All
| | (Score & Top-K) | blocks kept.)
| +-----------------+ |
| | |
| (Top-K Blocks) |
v v v
+-------------------------------------------------------------+
| 6. CONCATENATION |
| (Combine HCA Blocks + CSA Top-K + SWA Window Keys) |
+-------------------------------------------------------------+
|
(Combined KV Cache)
v
+-------------------------------------------------------------+
| 7. CORE ATTENTION |
| (Query attends to combined KV cache) |
+-------------------------------------------------------------+
|
(Output Vectors)
v
[ Next Token Prediction ]
Step 5: The Rigorous Manual Walkthrough (SWA + CSA + HCA)
We trace a sequence of exactly 8 tokens ($T_1$ through $T_8$). We are currently at $T_8$ and want to predict the next word.
Setup:
| Parameter | Value |
|---|---|
| Sequence Keys | $T_1=10, T_2=20, T_3=30, T_4=40, T_5=50, T_6=60, T_7=70, T_8=80$ |
| Query $Q_8$ | $5$ |
| SWA Window ($n_{win}$) | 2 tokens |
| CSA Compression ($m$) | 2 tokens → 1 |
| HCA Compression ($m'$) | 4 tokens → 1 |
| CSA Top-K Selection | 1 block |
Stage 1 — Sliding Window Attention (SWA)
Input: All 8 Keys.
Operation: Strictly slice the last $n_{win} = 2$ tokens.
Result ($K_{swa}$): $[70, 80]$
These are the exact, lossless keys that always enter the final KV cache regardless of what CSA or HCA select.
Stage 2 — CSA Compression
Input: All 8 Keys.
Operation: Split into blocks of $m = 2$. For each block, a learnable weight matrix computes importance scores, Softmax normalizes them within the block, and we compute the weighted sum.
- Block 1 ($T_1, T_2$): Softmax weights $[0.2, 0.8]$
$$K_{csa1} = (0.2 \times 10) + (0.8 \times 20) = 2 + 16 = \mathbf{18}$$
- Block 2 ($T_3, T_4$): Softmax weights $[0.5, 0.5]$
$$K_{csa2} = (0.5 \times 30) + (0.5 \times 40) = 15 + 20 = \mathbf{35}$$
- Block 3 ($T_5, T_6$): Softmax weights $[0.9, 0.1]$
$$K_{csa3} = (0.9 \times 50) + (0.1 \times 60) = 45 + 6 = \mathbf{51}$$
- Block 4 ($T_7, T_8$): Softmax weights $[0.0, 1.0]$
$$K_{csa4} = (0.0 \times 70) + (1.0 \times 80) = 0 + 80 = \mathbf{80}$$
Result ($K_{csa}$): $[18, 35, 51, 80]$
Stage 3 — The Lightning Indexer (No Simplifications)
In the real model, the Query does not directly multiply the compressed keys for scoring. Both are first projected into a smaller "indexer dimension" $d_{indexer}$ using dedicated learnable projection matrices. This saves compute during the scanning phase.
Input: Query $Q8 = 5$, Compressed Keys $K{csa} = [18, 35, 51, 80]$
Operation 3a — Projection:
- Project Query: $Q^{idx} = Q_8 \times 0.4 = \mathbf{2}$
- Project CSA Keys: $K^{idx} = K_{csa} \times 0.5 = [\mathbf{9, 17.5, 25.5, 40}]$
Operation 3b — Scoring:
$$I_{t,s} = Q^{idx}_t \cdot K^{idx}_s$$
- Score 1: $2 \times 9 = 18$
- Score 2: $2 \times 17.5 = 35$
- Score 3: $2 \times 25.5 = \mathbf{51}$ (Highest!)
- (Block 4 is typically excluded from indexing to preserve causality, as it overlaps with the SWA window.)
Operation 3c — Top-K Selection:
Top-$k=1$ winner is Block 3 (score 51).
Result ($K_{csa_selected}$): We retrieve the original (non-projected) compressed key for Block 3: $[\mathbf{51}]$.
Key Takeaway: The indexer projections are used only for the scoring step. The final selected key that enters the KV cache is the full-dimensional original compressed block — not the projected version. This separation ensures the scoring is cheap while the attended value retains full information.
Stage 4 — HCA Compression
Input: All 8 Keys.
Operation: Split into blocks of $m' = 4$. Apply Softmax weights and compute the weighted sum — identical math to CSA but with a 4× larger compression ratio.
- Block 1 ($T_1, T_2, T_3, T_4$): Weights $[0.1, 0.2, 0.3, 0.4]$
$$K_{hca1} = (0.1 \times 10) + (0.2 \times 20) + (0.3 \times 30) + (0.4 \times 40)$$ $$= 1 + 4 + 9 + 16 = \mathbf{30}$$
- Block 2 ($T_5, T_6, T_7, T_8$): Weights $[0.25, 0.25, 0.25, 0.25]$
$$K_{hca2} = (0.25 \times 50) + (0.25 \times 60) + (0.25 \times 70) + (0.25 \times 80)$$ $$= 12.5 + 15 + 17.5 + 20 = \mathbf{65}$$
Result ($K_{hca}$): $[\mathbf{30, 65}]$
Crucial difference from CSA: These 2 blocks go directly to the concatenation step. There is no Lightning Indexer for HCA. Because we compressed 128 (or in this case 4) tokens into 1, the number of resulting blocks is small enough that we attend to all of them without filtering. This guarantees the model always has a dense, zoomed-out summary of the entire context history.
Stage 5 — Final Concatenation and Core Attention
Input:
- HCA: $[\mathbf{30, 65}]$
- CSA Selected (Top-K=1): $[\mathbf{51}]$
- SWA (exact last 2 tokens): $[\mathbf{70, 80}]$
Operation: Concatenate all three into a single hybrid KV cache.
Final KV Cache: $[\mathbf{30, 65, 51, 70, 80}]$
Core Attention:
Query $Q_8 = 5$ now attends over 5 keys instead of the original 8, achieving a 37.5% compute reduction in this small example. At production scale ($n = 1M$ tokens), the savings are several orders of magnitude larger.
The final output is computed as:
$$\text{Output}_8 = \text{Softmax}!\left(\frac{Q_8 \cdot [30, 65, 51, 70, 80]^T}{\sqrt{d}}\right) \cdot [30, 65, 51, 70, 80]$$
The query attends to: the zoomed-out history chapters (HCA), the most relevant compressed page summary (CSA), and the exact recent tokens (SWA).
Step 6: Full PyTorch Implementation (Hybrid Attention)
Here is the complete, unsimplified PyTorch implementation. It includes explicit separate weight matrices for CSA and HCA compression, dedicated low-rank projections for the Lightning Indexer, and the full three-way concatenation (HCA + CSA + SWA).
import torch
import torch.nn as nn
import torch.nn.functional as F
class FullHybridAttention(nn.Module):
def __init__(self, d_model, d_indexer, swa_window, csa_m, hca_m, top_k):
super().__init__()
self.swa_window = swa_window
self.csa_m = csa_m
self.hca_m = hca_m
self.top_k = top_k
self.d_indexer = d_indexer
# [ARCHITECTURE: 2. QUERY / KEY / VALUE PROJECTION]
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
# [ARCHITECTURE: 4A. CSA COMPRESSION WEIGHTS]
# A dedicated linear layer — completely separate from q/k/v projections.
# This is the W^Z matrix that learns to score token importance for compression.
self.csa_compress_weight = nn.Linear(d_model, 1)
# [ARCHITECTURE: 4B. HCA COMPRESSION WEIGHTS]
# Separate from CSA weights — HCA learns a different compression policy
# because it covers much longer spans.
self.hca_compress_weight = nn.Linear(d_model, 1)
# [ARCHITECTURE: 5. LIGHTNING INDEXER PROJECTIONS]
# The indexer does NOT use the raw Q and K tensors.
# It projects them to a smaller d_indexer dimension to reduce scanning cost.
self.q_indexer_proj = nn.Linear(d_model, d_indexer)
self.k_indexer_proj = nn.Linear(d_model, d_indexer)
def compress_kv(self, kv_tensor, m, weight_layer):
"""
[ARCHITECTURE: Corresponds to 4A/4B. CSA/HCA COMPRESSOR]
Groups the sequence into blocks of 'm' tokens, applies a learnable
Softmax weighting (the Compression Softmax, NOT core attention Softmax),
and collapses each block into a single summary token.
This implements the formal equation:
C^Comp_i = sum_{j=mi}^{m(i+1)-1} S_j * C_j
where S_j = Softmax(Z) and Z = H * W^Z (weight_layer)
"""
# [SHAPE: (Batch, Seq_Len, d_model)]
B, S, D = kv_tensor.shape
num_blocks = S // m # Integer division — caller must ensure divisibility
# Reshape to group 'm' tokens per block
# [SHAPE: (Batch, num_blocks, m, d_model)]
grouped_kv = kv_tensor.view(B, num_blocks, m, D)
# Step A: Compute Z scores via the dedicated weight matrix W^Z
# This is NOT Q*K. It's a per-token importance score based only on content.
# [SHAPE: (Batch, num_blocks, m, 1)]
raw_scores = weight_layer(grouped_kv)
# Step B: Compression Softmax — normalize over the 'm' dimension
# Forces the importance of 'm' tokens within a block to sum to 1.0
# [SHAPE: (Batch, num_blocks, m, 1)]
weights = F.softmax(raw_scores, dim=2)
# Step C: Weighted sum — the actual mathematical compression
# Collapses 'm' tokens into 1 single compressed summary token
# [SHAPE: (Batch, num_blocks, d_model)]
compressed_kv = torch.sum(weights * grouped_kv, dim=2)
return compressed_kv
def lightning_indexer(self, query, compressed_k, compressed_v):
"""
[ARCHITECTURE: Corresponds to 5. LIGHTNING INDEXER]
Projects Q and compressed_K to a lower-rank indexer dimension,
scores the blocks cheaply, then retrieves the ORIGINAL (full-dimension)
top-K compressed Keys and Values.
Implements: I_{t,s} = Q^{indexer}_t · K^{indexer}_s
Then: C^SprsComp_t = { C^Comp_s | I_{t,s} in Top-k(I_{t,:}) }
"""
B, num_blocks, D = compressed_k.shape
Q_len = query.shape[1] # Typically 1 during autoregressive generation
# Step 1: Project to indexer dimension (cheap scoring space)
# [SHAPE: (Batch, Query_Len, d_indexer)]
q_idx = self.q_indexer_proj(query)
# [SHAPE: (Batch, num_blocks, d_indexer)]
k_idx = self.k_indexer_proj(compressed_k)
# Step 2: Dot-product scoring in the low-rank indexer space
# [SHAPE: (Batch, Query_Len, num_blocks)]
scores = torch.bmm(q_idx, k_idx.transpose(1, 2))
# Step 3: Select Top-K highest-scoring block indices
# [STATE: topk_indices contains indices of the most relevant compressed blocks]
_, topk_indices = torch.topk(scores, self.top_k, dim=-1)
# Step 4: Gather the ORIGINAL full-dimension compressed K and V
# (We score in low-rank space but retrieve full-rank tensors)
# [SHAPE: (Batch, Query_Len, top_k, d_model)]
gathered_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)
# Expand compressed K/V to match query length dimension for gathering
expanded_k = compressed_k.unsqueeze(1).expand(-1, Q_len, -1, -1)
expanded_v = compressed_v.unsqueeze(1).expand(-1, Q_len, -1, -1)
# [SHAPE: (Batch, Query_Len, top_k, d_model)]
selected_k = torch.gather(expanded_k, 2, gathered_indices)
selected_v = torch.gather(expanded_v, 2, gathered_indices)
return selected_k, selected_v
def forward(self, hidden_states):
"""
Full hybrid attention forward pass.
Combines HCA (dense global) + CSA (sparse page) + SWA (exact local).
"""
# [SHAPE: (Batch, Seq_Len, d_model)]
B, S, D = hidden_states.shape
# ======================================================
# 1. GENERATE BASE Q, K, V
# ======================================================
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# For autoregressive generation, the Query is only the last token
# [SHAPE: (Batch, 1, d_model)]
q_final = q[:, -1:, :]
# ======================================================
# 2. SLIDING WINDOW ATTENTION (SWA)
# [ARCHITECTURE: Corresponds to "3. SLIDING WINDOW (SWA)"]
# ======================================================
# Take the raw, exact, lossless final 'n_win' tokens
# [SHAPE: (Batch, swa_window, d_model)]
k_swa = k[:, -self.swa_window:, :]
v_swa = v[:, -self.swa_window:, :]
# ======================================================
# 3. CSA COMPRESSION + LIGHTNING INDEXER
# [ARCHITECTURE: Corresponds to "4A. CSA COMPRESSOR" + "5. LIGHTNING INDEXER"]
# ======================================================
# Compress using csa_m and its dedicated weight matrix
# [SHAPE: (Batch, Seq_Len // csa_m, d_model)]
csa_k = self.compress_kv(k, self.csa_m, self.csa_compress_weight)
csa_v = self.compress_kv(v, self.csa_m, self.csa_compress_weight)
# Score in low-rank indexer space, retrieve full-rank Top-K
# [SHAPE: (Batch, 1, top_k, d_model)]
csa_k_selected, csa_v_selected = self.lightning_indexer(q_final, csa_k, csa_v)
# Strip the Query_Len=1 dimension for concatenation
# [SHAPE: (Batch, top_k, d_model)]
csa_k_selected = csa_k_selected.squeeze(1)
csa_v_selected = csa_v_selected.squeeze(1)
# ======================================================
# 4. HCA COMPRESSION (Strict — NO Lightning Indexer)
# [ARCHITECTURE: Corresponds to "4B. HCA COMPRESSOR"]
# ======================================================
# Compress using the much larger hca_m and its own weight matrix
# [SHAPE: (Batch, Seq_Len // hca_m, d_model)]
hca_k = self.compress_kv(k, self.hca_m, self.hca_compress_weight)
hca_v = self.compress_kv(v, self.hca_m, self.hca_compress_weight)
# IMPORTANT: No indexer here. All HCA blocks are kept.
# Because Seq_Len // hca_m is already small, attending to all of them is cheap.
# ======================================================
# 5. FULL CONCATENATION (The Three-Way Hybrid Memory)
# [ARCHITECTURE: Corresponds to "6. CONCATENATION"]
# ======================================================
# Order: Dense Global (HCA) + Sparse Page (CSA) + Local Exact (SWA)
# [SHAPE: (Batch, (Seq//hca_m) + top_k + swa_window, d_model)]
combined_k = torch.cat([hca_k, csa_k_selected, k_swa], dim=1)
combined_v = torch.cat([hca_v, csa_v_selected, v_swa], dim=1)
# ======================================================
# 6. CORE ATTENTION
# [ARCHITECTURE: Corresponds to "7. CORE ATTENTION"]
# ======================================================
# Standard scaled dot-product attention over the hybrid KV cache
attn_scores = torch.bmm(q_final, combined_k.transpose(1, 2)) / (D ** 0.5)
attn_probs = F.softmax(attn_scores, dim=-1)
# [SHAPE: (Batch, 1, d_model)]
output = torch.bmm(attn_probs, combined_v)
return output
# ======================================================
# EXECUTION TRACE
# ======================================================
if __name__ == "__main__":
batch_size = 1
seq_len = 16 # Must be divisible by both csa_m (2) and hca_m (4)
d_model = 64
d_indexer = 16 # Low-rank dimension for the Lightning Indexer
model = FullHybridAttention(
d_model=d_model,
d_indexer=d_indexer,
swa_window=2,
csa_m=2,
hca_m=4,
top_k=1
)
dummy_hidden_states = torch.randn(batch_size, seq_len, d_model)
final_output = model(dummy_hidden_states)
print(f"Original sequence length: {seq_len}")
print(f"Final Output shape: {final_output.shape}")
print("Hybrid KV cache size per step: "
f"{seq_len // 4} (HCA) + 1 (CSA top-k) + 2 (SWA) = "
f"{seq_len // 4 + 1 + 2} keys")
print(f"Original naive attention would process: {seq_len} keys")
print("Strict HCA + CSA + SWA processing complete.")
Summary of the three mechanisms in code:
| Mechanism | Code Path | Indexer? | KV Cache Contribution |
|---|---|---|---|
| SWA | k[:, -swa_window:, :] |
No | Always swa_window exact keys |
| CSA | compress_kv(m=2) → lightning_indexer() |
Yes (Top-K) | top_k selected summaries |
| HCA | compress_kv(m=4) |
No | All Seq // hca_m dense summaries |
DeepSeek-V4's Secret Weapon: Multi-head Latent Attention
Everything above describes the compression pipeline — how the KV data gets summarized. But DeepSeek-V4 goes further. It also redesigns how the Query and the KV cache are projected and combined, replacing standard Multi-Head Attention with Multi-head Latent Attention (MLA) and Shared Key-Value Multi-Query Attention.
What Makes V4 Different from V2/V3
DeepSeek-V2 and V3 introduced the core MLA idea (latent Query projections). DeepSeek-V4 evolves this with three specific changes:
1. Latent Queries via bottleneck projections
Instead of projecting the hidden state directly into $n_h$ large Query vectors, the model first compresses the hidden state into a tiny "latent" vector $c^Q$ via a down-projection matrix $W^{DQ}$, then "unpacks" it into multiple Query heads via an up-projection matrix $W^{UQ}$. This bottleneck is:
- Down:
d_model → d_c(where $dc \ll d{model}$) - Up:
d_c → n_heads * d_head
This dramatically reduces the compute cost of Query generation.
2. Key = Value (Shared KV)
This is V4's most aggressive optimization. Standard Transformers maintain two separate KV caches: one for Keys (shape [n, d_head]) and one for Values (same shape). DeepSeek-V4 creates a single compressed block $C^{Comp}$ that simultaneously serves as both the Key and the Value.
When attending, the model:
- Computes attention scores: $Q \times (C^{Comp})^T$ — treating $C^{Comp}$ as the Key
- Computes the output: $\text{attn\_probs} \times C^{Comp}$ — treating $C^{Comp}$ as the Value
This halves the KV cache size immediately, with no information loss beyond what the compression itself introduces.
3. Grouped Output Projection
With many attention heads, projecting the concatenated head outputs back to $d_{model}$ via a single large matrix becomes expensive. V4 splits the heads into $g$ groups, applies a down-projection per group (producing a vector of dimension $dg$), concatenates the group intermediate outputs, and applies a single final up-projection back to $d{model}$.
Shared Key-Value Multi-Query Attention — The Full Math
The formal equation for a single query head $i$ at token $t$ is:
$$o{t,i} = \text{Softmax}!\left(\frac{q{t,i} \cdot (C^{Comp})^T}{\sqrt{d}}\right) C^{Comp}$$
Notice $C^{Comp}$ appears twice:
- Inside the Softmax argument: acting as the Key to score relevance
- Outside the Softmax: acting as the Value to provide content
Symbol definitions:
| Symbol | Meaning |
|---|---|
| $o_{t,i}$ | Output vector for query head $i$ at token $t$ |
| $q_{t,i}$ | Query vector for head $i$ at token $t$ |
| $C^{Comp}$ | Shared Key-Value matrix (compressed blocks from CSA/HCA) |
| $(C^{Comp})^T$ | Transposed $C^{Comp}$ for matrix multiplication |
| $d$ | Head dimension; we divide scores by $\sqrt{d}$ for numerical stability |
| $\text{Softmax}$ | Normalizes scores to a probability distribution summing to 1 |
Why this is correct: The Cross-Entropy loss ensures the model learns $C^{Comp}$ vectors that are simultaneously good keys (distinguishable from each other, enabling precise retrieval) and good values (informationally rich, enabling accurate next-token prediction). The shared representation forces the model to encode both retrieval signals and content signals into a single vector — which turns out to produce more information-dense representations than the standard decoupled approach.
Step 5 (Latent): Rigorous Manual Walkthrough
We trace the generation of a single token through the full Latent Attention architecture.
Setup:
- Hidden State $h_t$: $[10, 20, 30, 40]$ (dim = 4)
- Number of Query Heads ($n_h$): 2
- Query Bottleneck Dim ($d_c$): 2
- Head Dim ($d_{head}$): 3
- Compressed history block $C^{Comp}$: pre-computed as $[4, 4, 4]$
Stage 1 — Latent Query ($c^Q_t$)
Instead of generating 2 large Query heads directly from $h_t$, we first compress $h_t$ into a tiny latent vector.
Input: $h_t = [10, 20, 30, 40]$
Operation: Multiply by down-projection matrix $W^{DQ}$ (shape $4 \times 2$) to reduce to 2 numbers.
Result ($c^Q_t$): $[\mathbf{5, 5}]$
This bottleneck forces the model to encode the most predictive information into just 2 numbers before generating any Query heads.
Stage 2 — Unpacking the Query Heads ($q_t$)
Input: Latent Query $c^Q_t = [5, 5]$
Operation: Multiply by up-projection matrix $W^{UQ}$ (shape $2 \times nh \cdot d{head} = 2 \times 6$) to expand back into full multi-head queries.
Result:
- Head 1 ($q_{t,1}$): $[\mathbf{2, 0, 0}]$
- Head 2 ($q_{t,2}$): $[\mathbf{0, 2, 0}]$
The two heads "unpack" different aspects of the latent context vector into their respective query specializations.
Stage 3 — The Shared Key-Value ($C^{Comp}$)
In V4's architecture, there is no separate $K$ and $V$ matrix. There is one vector $C^{Comp}$ — produced by the CSA or HCA compressor — that serves both roles.
Result ($C^{Comp}$): $[\mathbf{4, 4, 4}]$ (one compressed historical block, dim = 3)
Stage 4 — Shared Key-Value MQA (Core Attention)
We use $C^{Comp}$ as the Key for scoring and then again as the Value for content retrieval. (Here we have one block and one query, so Softmax yields probability 1.0.)
Head 1:
- Score: $q_{t,1} \cdot C^{Comp} = [2, 0, 0] \cdot [4, 4, 4] = \mathbf{8}$
- Softmax(8) = $\mathbf{1.0}$ (single block → score is trivially 1.0)
- Output: $1.0 \times C^{Comp} = 1.0 \times [4, 4, 4] = [\mathbf{4, 4, 4}]$
Head 2:
- Score: $q_{t,2} \cdot C^{Comp} = [0, 2, 0] \cdot [4, 4, 4] = \mathbf{8}$
- Softmax(8) = $\mathbf{1.0}$
- Output: $1.0 \times C^{Comp} = [\mathbf{4, 4, 4}]$
Both heads retrieved the same $C^{Comp}$ block. In a real sequence with many blocks, Head 1 and Head 2 would score those blocks differently (because $q{t,1} \neq q{t,2}$) and attend to different blocks — enabling multi-head diversity without the memory cost of maintaining separate K and V matrices.
Stage 5 — Grouped Output Projection
Instead of one massive matrix multiplication over the full concatenated head output, V4 groups the heads and applies the projection in stages.
Input: Concatenated heads = $[Head_1 | Head_2] = [\mathbf{4, 4, 4, 4, 4, 4}]$ (length 6)
With $g = 1$ group (both heads in one group) and $d_g = 2$:
Operation (Down-projection): Project dim 6 → 2
- Result: $[\mathbf{8, 8}]$
Operation (Final Up-projection): Project $g \times dg = 2$ → $d{model} = 4$
- Result: $[\mathbf{15, 25, 35, 45}]$
Summary: The Latent Attention pipeline uses a bottleneck to make Queries, uses the exact same vector for Keys and Values, and uses a bottleneck to project the output. Every step is designed to shave billions of floating-point operations at scale.
Step 6 (Latent): Full PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class DeepSeekV4LatentAttention(nn.Module):
"""
DeepSeek-V4 Multi-head Latent Attention with Shared Key-Value MQA.
Architecture:
1. Compress hidden_state into latent query c_Q via W_DQ (down-projection)
2. Expand c_Q into n_heads query vectors via W_UQ (up-projection)
3. Project the compressed history blocks into a shared KV vector C via W_KV
4. Use C as BOTH Key (for scoring) and Value (for content retrieval)
5. Group head outputs and apply a staged down/up projection to d_model
"""
def __init__(self, d_model, d_c, n_heads, d_head, d_g, groups):
"""
Args:
d_model : Full hidden dimension of the model
d_c : Query bottleneck dimension (d_c << d_model)
n_heads : Number of query attention heads
d_head : Dimension of each attention head
d_g : Intermediate dimension for grouped output projection
groups : Number of groups for output projection
"""
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.groups = groups
self.d_g = d_g
# ======================================================
# [ARCHITECTURE: 1. LATENT QUERY PROJECTIONS]
# Two-stage bottleneck: shrink the hidden state, then expand into heads.
# This is the core of MLA — queries are cheap to generate.
# ======================================================
self.W_DQ = nn.Linear(d_model, d_c) # Down-projection: d_model → d_c
self.W_UQ = nn.Linear(d_c, n_heads * d_head) # Up-projection: d_c → heads
# ======================================================
# [ARCHITECTURE: 2. SHARED KEY-VALUE PROJECTION]
# Critical: we only define ONE projection for the compressed history.
# There is no separate W_K and W_V. A single W_KV produces C,
# which DeepSeek-V4 uses simultaneously as Key and Value.
# ======================================================
self.W_KV = nn.Linear(d_model, d_head)
# ======================================================
# [ARCHITECTURE: 3. GROUPED OUTPUT PROJECTION]
# Instead of one large d_model projection, we:
# (a) Apply a per-group down-projection to d_g
# (b) Concatenate group outputs
# (c) Apply one final up-projection to d_model
# ======================================================
heads_per_group = n_heads // groups
self.group_down_proj = nn.Linear(heads_per_group * d_head, d_g)
self.final_out_proj = nn.Linear(groups * d_g, d_model)
def forward(self, hidden_states, compressed_history_blocks):
"""
Args:
hidden_states: (Batch, 1, d_model) — the current query token
compressed_history_blocks: (Batch, num_blocks, d_model)
Output from the CSA/HCA compressor.
This is C^Comp from the formal equations.
Returns:
final_output: (Batch, 1, d_model)
"""
B, _, D = hidden_states.shape
# ======================================================
# STAGE 1: LATENT QUERY GENERATION
# Implements: c^Q_t = W^DQ * h_t
# q_t = W^UQ * c^Q_t (reshaped into n_heads)
# ======================================================
# Step 1a: Down-projection into latent query bottleneck
# [SHAPE: (Batch, 1, d_c)]
c_Q = self.W_DQ(hidden_states)
# Step 1b: Up-projection into full multi-head queries
# [SHAPE: (Batch, 1, n_heads * d_head)]
q_heads_flat = self.W_UQ(c_Q)
# Step 1c: Reshape for multi-head attention computation
# [SHAPE: (Batch, n_heads, 1, d_head)]
q = q_heads_flat.view(B, 1, self.n_heads, self.d_head).transpose(1, 2)
# ======================================================
# STAGE 2: SHARED KEY-VALUE (C)
# Implements: C = W^KV * C^Comp
# This single tensor acts as BOTH Key AND Value.
# ======================================================
# Project compressed blocks into the head dimension
# [SHAPE: (Batch, num_blocks, d_head)]
C = self.W_KV(compressed_history_blocks)
# Unsqueeze to broadcast across all query heads
# [SHAPE: (Batch, 1, num_blocks, d_head)]
C = C.unsqueeze(1)
# ======================================================
# STAGE 3: CORE ATTENTION — C acts as both Key and Value
# Implements: o_{t,i} = Softmax(q_{t,i} * C^T / sqrt(d)) * C
# ======================================================
# Step 3a: Attention scores — C is acting as the KEY
# [SHAPE: (Batch, n_heads, 1, num_blocks)]
scores = torch.matmul(q, C.transpose(-1, -2)) / (self.d_head ** 0.5)
attn_probs = F.softmax(scores, dim=-1)
# Step 3b: Attention output — C is now acting as the VALUE
# [SHAPE: (Batch, n_heads, 1, d_head)]
attn_output = torch.matmul(attn_probs, C)
# ======================================================
# STAGE 4: GROUPED OUTPUT PROJECTION
# ======================================================
# Flatten all heads back together
# [SHAPE: (Batch, 1, n_heads * d_head)]
attn_output_flat = attn_output.transpose(1, 2).reshape(B, 1, self.n_heads * self.d_head)
# Process each group independently through a down-projection
heads_per_group = self.n_heads // self.groups
group_dim = heads_per_group * self.d_head
intermediate_outputs = []
for g in range(self.groups):
# Extract the slice for group g
start_idx = g * group_dim
end_idx = start_idx + group_dim
group_slice = attn_output_flat[:, :, start_idx:end_idx]
# Apply per-group down-projection to d_g
# [SHAPE: (Batch, 1, d_g)]
o_intermediate = self.group_down_proj(group_slice)
intermediate_outputs.append(o_intermediate)
# Concatenate all group intermediates
# [SHAPE: (Batch, 1, groups * d_g)]
concat_intermediate = torch.cat(intermediate_outputs, dim=-1)
# Final projection back to original model dimension
# [SHAPE: (Batch, 1, d_model)]
final_output = self.final_out_proj(concat_intermediate)
return final_output
# ======================================================
# EXECUTION TRACE — matches the Step 5 manual walkthrough
# ======================================================
if __name__ == "__main__":
batch_size = 1
d_model = 128 # Full model hidden dimension
d_c = 32 # Latent query bottleneck (4x smaller than d_model)
n_heads = 4 # Number of attention heads
d_head = 64 # Dimension per head
d_g = 16 # Grouped projection intermediate dimension
groups = 2 # Number of output projection groups
model = DeepSeekV4LatentAttention(d_model, d_c, n_heads, d_head, d_g, groups)
# The current query token (1 token being generated)
current_token = torch.randn(batch_size, 1, d_model)
# Pre-computed compressed history — result of CSA/HCA pipeline
# In production, this is the output of FullHybridAttention's compress_kv()
compressed_csa_blocks = torch.randn(batch_size, 10, d_model)
output = model(current_token, compressed_csa_blocks)
print(f"Input (current token) shape: {current_token.shape}")
print(f"Compressed history shape: {compressed_csa_blocks.shape}")
print(f"Final output shape: {output.shape}")
print()
print("Memory comparison:")
print(f" Standard MHA KV cache per step: 2 * {n_heads} * {d_head} = {2 * n_heads * d_head} floats")
print(f" V4 Shared KV cache per step: 1 * {d_head} = {d_head} floats (50% reduction at minimum)")
print()
print("V4 Latent Attention trace complete: Latent Query → Shared MQA → Grouped Projection")
The Shared KV Library Analogy: A Complete Self-Test
To verify your understanding, trace through this scenario mentally:
Standard Multi-Head Attention (100 students, 100 librarians):
Every student (Query) gets their own librarian (Key) who fetches a unique, personalized stack of books (Value). For 100 students: 100 Keys, 100 Values in memory. You need 100× the storage.
Multi-Query Attention (100 students, 1 librarian, 1 stack):
All students share one master librarian (Key) and one master stack (Value). Students ask 100 different questions, but the library maintains a single memory footprint.
Shared Key-Value MQA (the library's index card is the book):
The index card that tells you what a book is about is the book itself. Students use the same vector to both find relevant books (scoring phase) and read from them (retrieval phase). The equation $o_{t,i} = \text{Softmax}(q_{t,i} \cdot C^T / \sqrt{d}) \cdot C$ is exactly this: $C$ appears once as the index and once as the content.
On top of what we've discussed, DeepSeek-V4 combines Manifold-Constrained Hyper-Connections (mHC) to stabilize gradient flow through deep layers, the Muon optimizer to make trillion-parameter training converge faster, Hash-Routed Mixture-of-Experts to keep expert load balanced, Anticipatory Routing to prevent the feedback loops that crash MoE training, and Full-Vocabulary On-Policy Distillation to absorb the reasoning of a dozen specialist teacher models into one unified network.
However, I noticed that the technical paper does not include Engram. In my previous article, I mentioned a nuance regarding that algorithm: if you try to implement Engram naively, it can be slower due to overhead.
That could be the reason. Hopefully, they can add it in the next release.
