The SSM Mamba Architecture: The Complete Guide

The SSM Mamba Architecture: The Complete Guide


You know that all Transformer architecture has quadratic attention bottleneck. Every new token must attend to every previous token. A sequence of length $N$ requires $O(N^2)$ operations and memory. Double the sequence length, quadruple the cost. This is a fundamental architectural tax that no engineering trick fully eliminates.

Now flip the scenario. You switch to an RNN-based model. Memory is fine. But by the time you reach the end of Document 1, the model has forgotten the contract's key clause from page 2. Standard RNNs compress history into a fixed-size state—a "finite notepad"—that fills up and overwrites critical information with every new token.

Mamba breaks this tradeoff entirely.

It processes arbitrarily long sequences in linear time (like an RNN) while achieving Transformer-level performance on long-range dependencies (by fixing the forgetting problem). The mechanism it uses to do this is called the Selective State Space Model (S6)—and it's genuinely elegant.

This guide is the complete technical walkthrough. We cover the intuition, the math, the architecture, a step-by-step numerical simulation, and a fully annotated PyTorch implementation. By the end, you'll understand exactly why Mamba works—and when it's the wrong tool.


Table of Contents


1. The Mental Model: Three Types of Secretaries

Before touching a single equation, let's internalize the core insight with an analogy. You are an executive (the model) reading a very long transcript (the input sequence). You need a secretary (the hidden state) to track context for you.

The RNN: The Finite Notepad

Your secretary has one small index card. For every new sentence spoken, they must erase something old to write down the new information. By the end of a 3-hour meeting, the critical decision from the opening remarks has been overwritten by the coffee order. Every sentence is treated with equal importance—there's no filtering.

The architectural consequence: Standard RNNs maintain a fixed-size hidden state. No matter how long the sequence, the state vector has the same number of dimensions. Information from early in the sequence is progressively overwritten through a process known as the vanishing gradient problem—the model literally loses the ability to back-propagate the importance of early tokens.

The Transformer: Photographic Memory That Slows to a Crawl

You don't have a secretary. Every time you want to understand the new sentence, you re-read the entire transcript from the beginning. This gives you perfect recall, but your reading speed degrades with every new page. Reading sentence 1,000 requires re-reading 999 sentences first.

The architectural consequence: The Attention mechanism computes a score between every pair of tokens. For a sequence of length $N$, the Attention Matrix has $N^2$ entries. Memory and compute scale quadratically with sequence length.

Mamba: The Selective Secretary

Your secretary has a Smart Notepad with a gating mechanism. When a sentence is filler ("um, let me think"), they clamp the gate closed: the notepad barely changes. When a sentence contains critical information ("the authorization code is 7743"), they open the gate wide: the information is written in permanent marker, and irrelevant old context is actively discarded.

The crucial insight: the gating decision is made by looking at the content of the current input. This is called the Selection Mechanism. It's what separates Mamba from all prior SSMs.


2. The Problem It Solves

Two failure modes defined the pre-Mamba landscape:

Failure Mode 1 — Quadratic Scaling. Transformers are $O(N^2)$ in both time and memory for sequence length $N$. A sequence of 32K tokens requires roughly 1,000x more attention computation than a sequence of 1K tokens. For genomics, legal documents, audio, or code repositories, this is a hard wall.

Failure Mode 2 — Context-Insensitive State Updates. Earlier SSMs (like S4) were Linear Time Invariant (LTI): the system parameters ($A$, $B$, $C$) were fixed regardless of the current token. This means the model processed "fox" and "the" with the same memory gate settings—it couldn't learn to prioritize one over the other. It was structurally incapable of true content-based filtering.

Mamba solves both simultaneously:

  • Linear time by retaining the RNN-style recurrence (no $N^2$ matrix).
  • Content-aware memory by making the key parameters functions of the current input.

3. The Mathematical Foundation: State Space Models

Mamba is built on State Space Models (SSMs), a classical framework from control theory that describes how a system evolves over time. Understanding the math is essential—the architecture's behavior falls directly out of these equations.

3.1 Variable Dictionary

Symbol Meaning Role
$x_t$ Input at time $t$ The current token/signal
$h_t$ Hidden state at time $t$ The model's "working memory"
$y_t$ Output at time $t$ The prediction
$A$ State transition matrix How the state evolves/decays
$B$ Input matrix How much new input enters the state
$C$ Output matrix How the state projects into a prediction
$\Delta$ Time-step / step size The "gate size"—how much time the system simulates per input

3.2 The Continuous System

SSMs begin in continuous time, modeled as a pair of differential equations. Consider a cup of coffee cooling:

  • State ($h$): Current temperature.
  • Input ($x$): Heat from a microwave.
  • Dynamics ($A$): Natural cooling rate (negative, so the coffee loses heat).

The rate of change of the state is: $$h'(t) = Ah(t) + Bx(t)$$

If $A = -0.5$ (natural cooling), the coffee cools. If the microwave is on ($x > 0$), the temperature rises proportionally to $B$. The output (is this drinkable?) reads from the state: $$y(t) = Ch(t)$$

This is elegant but useless for computers, which can't process continuous time.

3.3 Discretization and the Zero-Order Hold Rule

Computers process tokens one at a time—discrete steps $t = 0, 1, 2, \ldots$ We need to convert continuous parameters $(A, B)$ into discrete equivalents $(\bar{A}, \bar{B})$. The method Mamba uses is the Zero-Order Hold (ZOH) rule.

The ZOH assumption: Between two discrete samples, the input $x$ is assumed to hold its value constant. Like a thermostat that reads temperature once per hour—it assumes the room held at 70°F for the full hour until the next reading. Mathematically, a constant is a polynomial of degree zero, hence "Zero-Order Hold."

Integrating the continuous equation over a time step $\Delta$ under this assumption yields:

Discrete State Transition ($\bar{A}$): $$\bar{A} = \exp(\Delta \cdot A)$$

Intuition: If $A$ is the instantaneous decay rate and $\Delta$ is the step size, then $e^{\Delta A}$ is the total decay over that step. A large $\Delta$ means more time passes, so more decay.

Discrete Input ($\bar{B}$): $$\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B$$

Intuition: This accumulates how much of the input $x$ enters the state over the entire duration $\Delta$, accounting for the fact that the input begins decaying according to $A$ the instant it enters.

3.4 The Taylor Approximation for B̄

The exact ZOH formula for $\bar{B}$ involves a matrix inverse—expensive to compute. Mamba uses the first-order Taylor approximation.

Recall the full Taylor series: $$e^x = 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \frac{x^4}{4!} + \cdots$$

Substituting $x = \Delta A$: $$\exp(\Delta A) - I = \Delta A + \frac{(\Delta A)^2}{2!} + \frac{(\Delta A)^3}{3!} + \cdots$$

Plugging into the ZOH formula and distributing the $(\Delta A)^{-1}$: $$\bar{B} = \left(I + \frac{\Delta A}{2!} + \frac{(\Delta A)^2}{3!} + \cdots\right) \cdot \Delta B$$

When $\Delta$ is small (as initialized in training), higher-order terms ($\Delta^2$, $\Delta^3$) become negligibly small. Truncating after the identity: $$\bar{B} \approx I \cdot \Delta B = \Delta B$$

This is the approximation in the code: discrete_B = delta B. Fast, GPU-friendly, and empirically indistinguishable from the exact formula during training.

3.5 The Selective (Input-Dependent) SSM

This is the core innovation of Mamba. In a standard Linear Time Invariant (LTI) SSM, the parameters $\Delta$, $B$, and $C$ are fixed constants regardless of input—every token is processed with the same gate settings.

Mamba makes them functions of the current input $x_t$:

$$\Delta_t = \text{Softplus}(\text{Linear}(x_t))$$ $$B_t = \text{Linear}(x_t)$$ $$C_t = \text{Linear}(x_t)$$

The recurrence then becomes: $$h_t = \bar{A}t , h{t-1} + \bar{B}_t , x_t$$ $$y_t = C_t , h_t$$

The behavioral consequences are direct:

  • When $x_t$ is noise: The model learns to predict a small $\Delta_t$. From the discretization formula, a small $\Delta$ gives $\bar{A} = e^{\Delta A} \approx 1$ (retain history) and $\bar{B} \approx 0$ (ignore input). The gate is clamped shut.

  • When $x_t$ is signal: The model predicts a large $\Delta_t$. This gives $\bar{A} \approx 0$ (aggressively forget old context) and $\bar{B}$ large (absorb the new information). The gate is thrown open.

This is what replaces Attention. Instead of computing explicit pairwise scores like Transformers, Mamba learns to selectively admit or block information based on its own assessment of relevance.


4. The Architecture: A Full Layer Walkthrough

A single Mamba layer wraps the SSM in a gated MLP structure. Here is the complete data flow:



Each component has a specific job:

  • Projection & Expansion: Multiplies feature dimension by 2 and creates two parallel branches. More channels = more independent "heads" (see Section 8).
  • Conv1d: A lightweight depthwise 1D convolution over the time dimension. It captures purely local context (adjacent tokens) before the SSM handles long-range dependencies.
  • S6Block (The Core): The selective scan. This is Mamba's Attention equivalent. Full breakdown in Section 6.
  • Gating: The output of the SSM is element-wise multiplied by the gating branch. This acts as a learned filter: the gate branch learns to suppress irrelevant outputs.
  • Output Projection: Mixes all channels back to DModel dimension. This is where independent channel insights are combined—the "Multi-Head Concat + Linear" equivalent.

5. HIPPO: Why Initialization Is Not Just a Detail

5.1 The Problem With Random Initialization

In the basic SSM code, matrix $A$ is initialized randomly. This is a trap. A randomly initialized $A$ will almost certainly cause one of two failure modes during training:

  • Exploding memory: If eigenvalues of $A$ are positive, the state grows exponentially. The system becomes numerically unstable.
  • Vanishing memory: If eigenvalues of $A$ are too negative, the state decays to near-zero within a few steps. The model loses all long-range memory before training can teach it to recover.

Neither scenario is recoverable by the gradient descent alone. The model wastes thousands of training steps just learning how to remember—before it can learn what to remember.

5.2 The HIPPO Matrix and S4D Diagonal Approximation

HIPPO (Hierarchical Polynomial Projection Operators) is a mathematically derived initialization that gives the model an optimal memory structure from step one.

The core idea: Instead of storing raw token values, the hidden state $h$ stores the coefficients of polynomials that approximate the entire history of inputs. Think of it like a Fourier transform—you don't store the raw waveform, you store the frequency coefficients that let you reconstruct it.

The HIPPO-LegS (Legendre Polynomials, Scaled) matrix has a closed-form definition. For state size $N$, the entries are:

$$A_{nk} = -(2n+1)^{1/2}(2k+1)^{1/2} \quad \text{if } n > k$$

$$A_{nn} = -(n+1)$$

$$A_{nk} = 0 \quad \text{if } n < k$$

This creates a structured lower-triangular matrix where each dimension captures a different time scale—from long-range trends to immediate fluctuations.

The practical problem: A full $N \times N$ matrix multiplication is $O(N^2)$—defeating Mamba's linear-time goal.

The S4D solution: Research showed that the HIPPO matrix can be approximated effectively using only its diagonal entries, called the S4D-Real initialization:

$$A_n = -\frac{1}{2} \cdot n \quad \text{for } n = 1, 2, \ldots, N$$

The resulting diagonal entries are $[-0.5, -1.0, -1.5, -2.0, \ldots, -N/2]$.

What this means structurally:

  • State dimension 0 has $A = -0.5$: decays slowly. Responsible for long-term memory.
  • State dimension $N-1$ has $A = -N/2$: decays rapidly. Responsible for fine-grained recent detail.

The model enters training with a built-in memory hierarchy—some neurons remembering centuries, others remembering microseconds (relatively speaking). It now uses training to learn what to store, not how to store it.


6. Manual Walkthrough: The Selective Scan Step-by-Step

Let's trace the exact computations inside the S6 core. We use a minimal example to make the math concrete.

Setup:

  • Single channel, state size $N = 1$.
  • Fixed learned parameter: $A = -1.0$.
  • Input sequence: $x = [0.1, 0.5]$ (two tokens).
  • Initial state: $h_0 = 0.0$.

Token 1: $x_1 = 0.1$ ("Noise")

① Selection: The model's Linear layers see $x_1 = 0.1$ (a small, unremarkable value) and output: $$\Delta_1 = 0.1 \quad B_1 = 0.5 \quad C_1 = 1.0$$

The small $\Delta_1$ is the model signaling: "This isn't interesting. Don't update the state much."

② Discretization: $$\bar{A}_1 = \exp(0.1 \times -1.0) \approx 0.90 \quad (\text{90% history retention})$$ $$\bar{B}_1 \approx \Delta_1 \times B_1 = 0.1 \times 0.5 = 0.05$$

③ Recurrence: $$h_1 = \bar{A}_1 \cdot h_0 + \bar{B}_1 \cdot x_1 = (0.90 \times 0.0) + (0.05 \times 0.1) = 0.005$$

④ Output: $$y_1 = C_1 \cdot h_1 = 1.0 \times 0.005 = 0.005$$

The state barely moved. "The" has almost no effect on memory.


Token 2: $x_2 = 0.5$ ("Signal")

① Selection: The model sees $x_2 = 0.5$ and recognizes it as important. It outputs: $$\Delta_2 = 2.0 \quad B_2 = 1.0 \quad C_2 = 1.0$$

The large $\Delta_2$ signals: "Open the gate. Write this in permanent marker."

② Discretization: $$\bar{A}_2 = \exp(2.0 \times -1.0) \approx 0.13 \quad (\text{only 13% retention—aggressively forget})$$ $$\bar{B}_2 \approx 2.0 \times 1.0 = 2.0$$

③ Recurrence: $$h_2 = (0.13 \times 0.005) + (2.0 \times 0.5) = 0.00065 + 1.0 = 1.00065$$

④ Output: $$y_2 = 1.0 \times 1.00065 \approx 1.0$$

The signal dominates. The prior noise ($0.005$) has been multiplied by $0.13$ and is nearly gone. The output reflects the important token almost exclusively.

This is Selection. Not filtering by position, not filtering by learned static weights—filtering by the content of the input itself, computed on the fly.


The Full Scan Loop: Matrix Form

For the full multi-dimensional scan (the actual code), dimensions expand:

Shape conventions (D_Model=3, D_State=3):

Initial hidden state $h_0$ is a zero matrix: $$h = \begin{bmatrix} 0 & 0 & 0 \ 0 & 0 & 0 \ 0 & 0 & 0 \end{bmatrix} \quad \text{(rows = channels, cols = state dimensions)}$$

For input $x_t = [10, 20, 30]$ (one value per channel), with example discrete parameters:

Step A — State Update: $h_t = \bar{A}t \cdot h{t-1} + \bar{B}_t \cdot xt$

With $\bar{A} = 0.5$ everywhere and $\bar{B}$ as: $$ \begin{bmatrix} 1.0 & 0.0 & 0.0 \\ 0.0 & 1.0 & 0.0 \\ 0.5 & 0.5 & 0.5 \end{bmatrix}$$

Input $[10, 20, 30]$ writes into the state as: $$ \begin{bmatrix} 10 & 0 & 0 \\ 0 & 20 & 0 \\ 15 & 15 & 15 \end{bmatrix}$$

Channel 1 is "laser focused" on State 0. Channel 3 distributes evenly.

Step B — Output Read: $y_t = \sum(h_t \odot C_t)$ along the state dimension

With output gate $C$: $$ \begin{bmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix}$$

Element-wise product then row-sum:

  • Channel 1: $(10 \cdot 1) + (0 \cdot 1) + (0 \cdot 1) = 10$
  • Channel 2: $(0 \cdot 1) + (20 \cdot 0) + (0 \cdot 0) = 0$
  • Channel 3: $(15 \cdot 0) + (15 \cdot 0) + (15 \cdot 1) = 15$

Output vector: $yt = [10, 0, 15]$

Step C — Skip Connection: $y = y + x \cdot D$

The parameter $D$ (a learned passthrough) adds the raw input directly to the SSM output, acting as a residual connection inside the recurrence: $$y_{\text{final}} = [10, 0, 15] + [10, 20, 30] \cdot [1, 1, 1] = [20, 20, 45]$$


7. Concrete Sentence Example: Processing "the quick brown fox"

Let's trace Mamba predicting the word "jumps" from the input "the quick brown fox," using simplified but numerically grounded values.

Setup:

  • Single channel. State size $N = 2$. Fixed $A = -1.0$ everywhere. Fixed $B = [1, 1]$, $C = [1, 0]$.
  • Initial state: $h_0 = [0, 0]$.

Processing "the" (Stop Word → Small Gate)

The model sees the word "the." High frequency, low information. Selection outputs: $$\Delta_1 = 0.1$$ $$\bar{A}_1 = e^{0.1 \times -1} \approx 0.90 \quad \bar{B}_1 = 0.1 \times 1.0 = 0.1$$

Let's say word embedding value $= 0.5$: $$h_1 = (0.90 \times [0, 0]) + (0.1 \times 0.5) = [0.05, 0.05]$$

State is barely non-zero. Memory is essentially empty.

Processing "fox" (Subject → Open Gate)

Jumping forward after "quick" and "brown" have added moderate contributions, we arrive at "fox"—the grammatical subject. The model's Selection network recognizes this high-information token: $$\Delta_4 = 2.0$$ $$\bar{A}_4 = e^{2.0 \times -1} \approx 0.13 \quad \bar{B}_4 = 2.0 \times 1.0 = 2.0$$

Word embedding value $= 10.0$ (high semantic weight): $$h_4 = (0.13 \times h_3) + (2.0 \times 10.0)$$

Assuming $h_3 \approx [0.5, 0.5]$ from prior adjectives: $$h_4 = [0.065, 0.065] + [20.0, 20.0] = [20.065, 20.065]$$

Output: $$y_4 = C \cdot h_4 = [1.0, 0.0] \cdot [20.065, 20.065] = 20.065$$

The word "the" contributed $0.005$ (from $0.05 \times 0.13$, further decayed). The word "fox" contributed $20.0$. The output is dominated by "fox"—the correct subject for predicting "jumps."

The model has selectively compressed the sentence into a state that encodes "agile brown animal" and predicts "jumps" as the next token.


8. Mamba vs. Multi-Head Attention: The Real Comparison

In a Transformer, Multi-Head Attention allows different heads to specialize in different relationship types (grammatical, semantic, positional). Mamba has no Attention matrix—but it replicates this behavior through independent channels.

Mamba's Multi-Channel Architecture

After the in_proj expansion (factor $E = 2$), the model has D_Inner = DModel * 2 channels—all running their own independent SSMs. Each channel has its own $\Delta$, $B$, and $C$, and its own hidden state trajectory.

Concrete example with $D

{Model} = 2$:

For input "jumps" with embedding $x = [5.0, 10.0]$:

Channel 1 (learns grammar — "is this a verb?"):

  • Input: $5.0$
  • Prior state: $h^{(1)} = [0.5, 0.5]$ (accumulated verb patterns)
  • Output $y^{(1)} = 3.0$ → Encoding: "This is a verb."

Channel 2 (learns semantics — "action relates to animal?"):

  • Input: $10.0$
  • Prior state: $h^{(2)} = [20.0, 20.0]$ (accumulated "fox" information)
  • Output $y^{(2)} = 50.0$ → Encoding: "High-energy action following animal subject."

The Mixing Step (The Output Projection)

Now $Y_{combined} = [3.0, 50.0]$. The outproj Linear layer mixes these independent channel insights:

$$W

{out} = \begin{bmatrix} 0.5 & 0.1 \ 0.1 & 0.8 \end{bmatrix}$$

$$[3.0, 50.0] \times W_{out} = \left[(3.0 \times 0.5 + 50.0 \times 0.1),\ (3.0 \times 0.1 + 50.0 \times 0.8)\right]$$

$$= [1.5 + 5.0,\ 0.3 + 40.0] = [6.5,\ 40.3]$$

The output vector $[6.5, 40.3]$ combines both the grammatical insight and the semantic insight into a single representation used to predict the next word.

Side-by-Side Comparison

Feature Transformer (Multi-Head Attention) Mamba (Multi-Channel SSM)
Mechanism All tokens attend to all tokens ($N^2$) Each channel tracks its own state ($N$)
Specialization Fixed number of heads (e.g., 8 or 16) Thousands of independent channels
Head Mixing Concat outputs → Linear projection Output projection (out_proj)
Context window Exact KV cache (full recall) Compressed state $h$ (lossy summary)
Selection Softmax over $Q \cdot K^T$ scores $\Delta$, $B$, $C$ gates from input
Compute $O(N^2)$ $O(N)$
Memory $O(N)$ KV cache (grows with sequence) $O(1)$ constant state size

The analogy: Transformer heads are a committee of 8 experts discussing the document together. Mamba channels are 1024 workers in separate cubicles—they read the document independently, write their own notes, and hand summaries to a manager (out_proj) who combines them.

The key difference: with thousands of independent channels, Mamba can capture extremely granular pattern specializations that a small number of attention heads might miss.


9. Full PyTorch Implementation (Annotated)

The implementation has two components: S6Block (the SSM core) and MambaBlock (the outer gated architecture). We use the HIPPO initialization for $A$.

9.1 S6Block with HIPPO Initialization

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class S6Block(nn.Module):
    """
    The SSM Core. Equivalent to "Self-Attention" in a Transformer.

    Corresponds to:
      - Section 3: The Selection Mechanism (input-dependent parameters)
      - Section 4: The Discretization (ZOH approximation)
      - Section 6: The Selective Scan (the recurrence loop)
    """
    def init(self, d_model, d_state=16):
        super().init()
        self.d_model = d_model
        self.d_state = d_state

        # --- HIPPO INITIALIZATION (Section 5) ---
        # Instead of random init, use S4D-Real: A_n = -0.5  n
        # This gives the state a memory hierarchy from slow (long-term)
        # to fast (short-term) decay, right from step one of training.
        #
        # State 0: A = -0.5  → slow decay, long-term memory neuron
        # State 1: A = -1.0  → medium decay
        # State N: A = -N/2  → fast decay, short-term memory neuron

        # 1. Create [1, 2, ..., d_state]
        A_arange = torch.arange(1, d_state + 1).float()

        # 2. Apply S4D-Real formula
        A_init = -0.5  A_arange

        # 3. Expand to (d_model, d_state): every channel gets the same structure
        A_init = A_init.repeat(d_model, 1)

        # 4. Store as log(|A|) for numerical stability.
        # During forward pass: A = -exp(A_log) ensures A is always negative.
        # This mathematical trick guarantees the system never diverges.
        self.A_log = nn.Parameter(torch.log(A_init.abs()))

        # D: The Skip/Residual Connection.
        # y_final = y_ssm + D  x_raw
        # This lets the model directly pass the raw input to the output,
        # complementing what the SSM tracks over time.
        self.D = nn.Parameter(torch.ones(d_model))

        # --- SELECTION MECHANISM: Combined Projector (Section 3.5) ---
        # Architectural insight: instead of 3 separate Linear layers for
        # delta, B, and C, we use ONE large projection for GPU efficiency.
        # We project to a low rank (dt_rank) for delta to save parameters.
        dt_rank = math.ceil(d_model / 16)

        # Single projection outputs all three: delta_rank | B | C
        self.x_proj = nn.Linear(d_model, dt_rank + d_state  2, bias=False)

        # Expands low-rank delta back to full d_model size
        self.dt_proj = nn.Linear(dt_rank, d_model, bias=True)

        # Initialize dt_proj weights and bias for stable early training
        dt_initstd = 2 ** -4
        nn.init.uniform(self.dt_proj.weight, -dt_init_std, dt_init_std)

        # Bias initialized to spread step sizes across a log-uniform range
        dt = torch.exp(
            torch.rand(d_model)  (math.log(0.001) - math.log(0.1)) + math.log(0.1)
        )
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dtproj.bias.copy(inv_dt)

    def forward(self, x):
        """
        x: (Batch, Seq_Len, D_Model)
        """
        (b, l, d) = x.shape

        # ① SELECTION MECHANISM
        # Project x to a combined vector [delta_rank | B | C]
        # Shape: (Batch, Seq_Len, dt_rank + 2d_state)
        x_dbl = self.x_proj(x)

        dt_rank = self.dt_proj.in_features

        # Slice the combined projection into its three parts
        # delta_rank: (B, L, dt_rank)  — raw step size signal
        # B:          (B, L, d_state)  — input dynamics gate
        # C:          (B, L, d_state)  — output readout gate
        delta_rank, B, C = torch.split(x_dbl, [dt_rank, self.d_state, self.d_state], dim=-1)

        # Expand delta from low-rank to d_model, enforce positivity with Softplus.
        # Softplus = ln(1 + e^x): smooth, always-positive version of ReLU.
        # Positive delta is a physical requirement: time must move forward.
        delta = F.softplus(self.dt_proj(delta_rank))  # (B, L, D_Model)

        # ② DISCRETIZATION (Section 3.3 and 3.4)
        # Reconstruct A from stored log(|A|) and enforce negativity
        A = -torch.exp(self.A_log)  # (D_Model, D_State), strictly negative

        # Compute discrete A_bar = exp(delta  A)  [ZOH exact formula for A]
        # Broadcasting: delta (B,L,D,1)  A (D,N) → (B,L,D,N)
        discrete_A = torch.exp(delta.unsqueeze(-1)  A)  # (B, L, D_Model, D_State)

        # Compute discrete B_bar ≈ delta  B  [First-order Taylor approx, Section 3.4]
        # Outer product via broadcasting: delta (B,L,D,1)  B (B,L,1,N) → (B,L,D,N)
        discrete_B = delta.unsqueeze(-1)  B.unsqueeze(2)  # (B, L, D_Model, D_State)

        # ③ SELECTIVE SCAN — The Recurrence
        # This for-loop is the "Attention" equivalent (Section 8).
        # In production Mamba, this is replaced by a fused CUDA parallel scan
        # kernel (prefix-sum style), which achieves true O(N) parallelism.
        # We write the serial Python version here for clarity.
        h = torch.zeros(b, d, self.d_state, device=x.device)  # (B, D_Model, D_State)

        ys = []
        for t in range(l):
            # State update: h_t = A_bart * h{t-1} + B_bar_t  x_t
            # discrete_A[:, t]: (B, D, N) — how much of past to retain
            # discrete_B[:, t]: (B, D, N) — how much of new input to absorb
            # x[:, t, :].unsqueeze(-1): (B, D, 1) — current token value
            h = discrete_A[:, t]  h + discrete_B[:, t]  x[:, t, :].unsqueeze(-1)

            # Output: y_t = sum_over_states(h  C)
            # C[:, t, :].unsqueeze(1): (B, 1, N) broadcasts over D dimension
            # torch.sum(..., dim=-1) collapses state dim → (B, D_Model)
            y_t = torch.sum(h  C[:, t, :].unsqueeze(1), dim=-1)
            ys.append(y_t)

        y = torch.stack(ys, dim=1)  # (B, Seq_Len, D_Model)

        # Skip connection: pass raw input directly to output
        y = y + x  self.D

        return y

9.2 MambaBlock (The Outer Layer)

class MambaBlock(nn.Module):
    """
    The full Mamba layer: Gated MLP wrapper around the S6 core.
    Handles projections, local convolution, and gating.
    """
    def init(self, d_model, expand=2):
        super().init()
        self.d_model = d_model
        self.d_inner = d_model  expand  # Feature expansion (default 2x)

        # ① Projection & Expansion
        # Creates both the Main Branch and Gating Branch at once.
        # Output is (Batch, Seq_Len, 2  D_Inner), then split into two halves.
        self.in_proj = nn.Linear(d_model, self.d_inner  2)

        # ② Local Context (Convolution)
        # Depthwise 1D conv over the time dimension.
        # Groups = d_inner means each channel has its own kernel (no cross-channel mixing here).
        # This captures purely local token relationships before the long-range SSM.
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=4,
            groups=self.d_inner,  # Depthwise: independent per channel
            padding=3              # Causal padding
        )

        self.activation = nn.SiLU()

        # ③ SSM Core
        self.ssm = S6Block(self.d_inner)

        # ⑤ Output Projection — the "Multi-Head Concat + Linear" equivalent
        # This is where all independent channel insights are mixed together.
        self.out_proj = nn.Linear(self.d_inner, d_model)

    def forward(self, x):
        # x: (Batch, Seq_Len, D_Model)

        # Create both branches in one projection
        x_and_res = self.in_proj(x)  # (B, L, 2  D_Inner)
        (x, res) = x_and_res.split([self.d_inner, self.d_inner], dim=-1)

        # --- Main Branch ---
        x = x.transpose(1, 2)            # (B, D_Inner, L) — Conv1d expects channels first
        x = self.conv1d(x)[:, :, :x.shape[-1]]  # Crop causal padding
        x = x.transpose(1, 2)            # (B, L, D_Inner)

        x = self.activation(x)
        x = self.ssm(x)                  # Run the selective SSM

        # --- Gating Branch ---
        res = self.activation(res)

        # ④ Gating: element-wise product merges SSM output with gate branch
        x = x * res

        return self.out_proj(x)


# --- Test ---
if name == "main":
    BATCH_SIZE = 2
    SEQ_LEN = 10
    D_MODEL = 32

    model = MambaBlock(D_MODEL)
    x = torch.randn(BATCH_SIZE, SEQ_LEN, D_MODEL)
    y = model(x)

    print(f"Input Shape:  {x.shape}")   # (2, 10, 32)
    print(f"Output Shape: {y.shape}")   # (2, 10, 32)
    print("Success! Input and output dimensions match.")

    # Verify HIPPO initialization
    A_reconstructed = -torch.exp(model.ssm.A_log)[0]
    print(f"\nHIPPO A (first row, first 4 states): {A_reconstructed[:4]}")
    # Expected: tensor([-0.5000, -1.0000, -1.5000, -2.0000])

Key Takeaway: The code structure maps one-to-one to the architecture diagram. Every numbered section in the MambaBlock corresponds to a specific architectural role. When debugging or extending Mamba, always trace back through this mapping—the code is the math, not an approximation of it.


10. Comparison Matrix: Transformers vs. RNNs vs. Mamba

Property Transformer Standard RNN Linear SSM (S4) Mamba (S6)
Time Complexity $O(N^2)$ $O(N)$ $O(N)$ $O(N)$
Memory (Inference) $O(N)$ KV cache $O(1)$ $O(1)$ $O(1)$
Parallelism (Training) Full None Full (conv) Full (parallel scan)
Long-range Memory Exact recall Degrades quickly Good (HIPPO) Excellent (HIPPO + Selection)
Content-based Filtering Softmax Attention None None Yes ($\Delta, B, C$ from input)
Causal by Default No (masking needed) Yes Yes Yes
State Parameters Fixed Fixed Fixed Input-dependent
Architecture Complexity High (QKV + FFN) Low Medium Medium

12. Conclusion: Decision Criteria

The most interesting frontier right now is hybrid architectures—interleaving Mamba layers with full Attention layers (used in models like Jamba and Zamba). These combine Mamba's efficiency on bulk sequence processing with Transformer Attention's precision for tasks requiring exact retrieval. That is likely the near-term path for production-grade long-context models.

Previous Post
Next Post

post written by: