Mamba-2 y state-space models (Checkpoint C5·b)
entender la principal alternativa al Transformer que sobrevivió la oleada de 2023–2026, por qué importa para contexto largo, y entrenar un Mamba-2 que compita con un Transformer equivalente. Cierra C5(b).
4.1El problema que ataca: la cuadrática de la atención
La atención (N5·L1) compara cada token con todos los demás → coste O(L²) en la longitud de secuencia L y un KV-cache que crece linealmente con L (N1·L1). A 1M de tokens, eso es prohibitivo. ¿Y si pudiéramos procesar secuencias en O(L) con estado de tamaño constante?
Los State-Space Models (SSM) lo hacen. Vienen de la teoría de control: un estado oculto h que evoluciona en el tiempo según la entrada:
h_t = A·h_{t-1} + B·x_t (recurrencia: estado se actualiza con cada token)
y_t = C·h_t (salida desde el estado)
Como una RNN, pero con estructura que permite paralelizar el entrenamiento (vía una convolución/scan) y inferencia recurrente O(1) por token (estado de tamaño fijo, sin KV-cache que crezca). Eso es lo atractivo para contexto larguísimo.
4.2Mamba y Mamba-2: las claves
- Mamba (Gu & Dao, 2023): hace los parámetros del SSM dependientes de la entrada (selective SSM / S6) → el modelo "decide" qué recordar y qué olvidar según el token, recuperando gran parte de la expresividad que las SSM lineales perdían frente a la atención. Implementado con un scan paralelo eficiente en hardware.
- Mamba-2 (Dao & Gu, 2024): introduce la SSD (Structured State Space Duality), que demuestra que los SSM selectivos y la atención son dos caras de la misma moneda matemática. Esto permite implementar Mamba-2 con operaciones de matmul (que las GPUs aman) → 2–8× más rápido que Mamba-1, manteniendo la calidad.
Conexión que cierra el curso: los modelos de frontera 2026 (Nemotron Nano, Qwen3.5 con kernels Mamba) son híbridos Mamba+atención — usan capas SSM para la eficiencia en contexto largo y capas de atención para el recall preciso. Entender Mamba-2 es entender por qué.
4.3Laboratorio L4.1 — Entrenar un Mamba-2 pequeño
1# Instala el paquete oficial (kernels CUDA; en Blackwell asegúrate de torch cu128, N0·L2)
2uv pip install mamba-ssm causal-conv1d # de state-spaces/mamba1# lab_n5l4_mamba.py — un modelo de lenguaje Mamba-2 ~200M entrenable en la 5090
2import torch, torch.nn as nn
3from mamba_ssm import Mamba2
4
5class MambaLM(nn.Module):
6 def __init__(self, vocab, d=1024, n_layers=24, d_state=128):
7 super().__init__()
8 self.emb = nn.Embedding(vocab, d)
9 self.layers = nn.ModuleList([
10 nn.ModuleDict({
11 "norm": nn.RMSNorm(d),
12 "mixer": Mamba2(d_model=d, d_state=d_state, headdim=64), # bloque Mamba-2
13 }) for _ in range(n_layers)])
14 self.norm_f = nn.RMSNorm(d)
15 self.head = nn.Linear(d, vocab, bias=False)
16 self.head.weight = self.emb.weight # weight tying (N5·L2)
17
18 def forward(self, idx, targets=None):
19 x = self.emb(idx)
20 for layer in self.layers:
21 x = x + layer["mixer"](layer["norm"](x)) # residual + pre-norm, igual que un GPT
22 logits = self.head(self.norm_f(x))
23 if targets is None: return logits, None
24 loss = torch.nn.functional.cross_entropy(
25 logits.view(-1, logits.size(-1)), targets.view(-1))
26 return logits, loss
27
28# Entrena con el MISMO pipeline de datos que tu GPT de N5·L2/L3 (FineWeb-Edu),
29# mismo nº de parámetros aproximado, para una comparación justa.Líneas no triviales explicadas:
- El bloque Mamba-2 sustituye a la atención, pero la estructura del modelo (embedding → capas con residual+pre-norm → head con weight tying) es idéntica a tu GPT. Eso es lo elegante: solo cambias el "mixer" (cómo se mezcla la información entre posiciones). Atención y SSM son intercambiables a nivel de bloque.
d_state: el tamaño del estado oculto del SSM (su "memoria"). Mayor = más capacidad de recall, más coste.headdim: Mamba-2 organiza el estado en cabezas (paralelo a multi-head); es lo que la SSD permite mapear a matmuls.- No hay KV-cache: en inferencia, Mamba mantiene un estado de tamaño fijo por capa → memoria constante en
L, a diferencia del Transformer.
4.4Laboratorio L4.2 — La comparación justa (el experimento de C5·b)
Entrena, con el mismo dataset, mismo nº de parámetros (~200M) y mismo presupuesto de tokens: tu GPT de N5·L2/L3 y este Mamba-2. Compara:
- Perplexity / val loss en un holdout (¿iguala Mamba-2 al Transformer?).
- Velocidad de entrenamiento (tok/s) y de inferencia.
- Comportamiento en contexto largo: evalúa ambos a longitudes crecientes (usa el needle-in-a-haystack de N3·D). Aquí Mamba debería brillar en memoria/velocidad, aunque la atención suele ganar en recall preciso a larga distancia — ver ese tradeoff con tus números es la lección.
4.5CHECKPOINT C5(b) — criterio de aprobado
- Un Mamba-2 ~200M entrenado por ti que iguala o supera la perplexity de un Transformer equivalente (mismos params/datos/tokens) en tu dataset.
- Análisis del comportamiento en contexto largo: velocidad, memoria (estado fijo vs KV-cache) y recall, con números.
- Sabes explicar la recurrencia SSM, por qué es O(L), y qué aporta la SSD de Mamba-2.
Rúbrica: Nivel 3 si reproduces la comparación; Nivel 4 si construyes un híbrido (algunas capas Mamba, algunas de atención) y muestras que combina lo mejor de ambos.
Combinado con C5(a), cierra el Checkpoint C5 y el último nivel del curso.
4.6Ejercicios
E1. Mide la memoria de inferencia de tu Mamba-2 vs tu GPT a 1K, 8K y 32K tokens. Verifica que la de Mamba es ~constante y la del GPT crece (KV-cache).
E2. Varía d_state (64, 128, 256). ¿Cómo afecta a la perplexity y a la velocidad? ¿Dónde está el punto de rendimientos decrecientes?
E3. Construye un híbrido: reemplaza 1 de cada 4 bloques Mamba por un bloque de atención. ¿Mejora el recall a larga distancia respecto al Mamba puro?
4.7Trampas comunes
- Comparar Mamba vs Transformer con tamaños/datos distintos → la comparación no dice nada.
- Esperar que Mamba gane en todo: suele perder en recall preciso a larga distancia (por eso los híbridos).
mamba-ssmsin los kernels correctos para Blackwell → recompila con torch cu128.
4.8Referencias
- Mamba (Gu & Dao, 2023). Mamba-2 / SSD (Dao & Gu, 2024). Repo state-spaces/mamba. Jamba (AI21, 2024) como híbrido. CS336 (arquitecturas). "The Annotated S4" (Sasha Rush) para la teoría de SSM.