SEXTANTEcursos técnicos de IA
BancoRTX 5090 · GB202
Rev2026.06
Entrar
N2 · Post-training + RL/L7

GRPO hands-on: reproducir el "aha moment" (Checkpoint C2b·a)

Objetivo de maestría

entrenar GRPO en tu 5090 y ver con tus ojos la emergencia de razonamiento: la longitud de respuesta y la recompensa subiendo juntas. Cierra C2b(a). Aplicas la teoría de L6 con la API real de TRL/Unsloth.


7.1El plan

Entrenaremos un modelo pequeño con GRPO + RLVR sobre GSM8K (problemas de matemáticas con respuesta numérica verificable), porque es el banco de pruebas canónico del "aha moment" y la recompensa es trivialmente verificable (¿el número final es correcto?).

Dos caminos, ambos en la 5090:

  • TRL GRPOTrainer + vLLM colocate — el "happy path" estándar, máxima transparencia.
  • Unsloth GRPO — máxima eficiencia en VRAM (Qwen3-4B cómodo; 30B-A3B QLoRA en 17.5 GB).

7.2Las funciones de recompensa (el diseño que decide todo)

En RLVR tú escribes la recompensa. Para GSM8K, una combinación típica:

python
1# lab_n2l7_rewards.py — funciones de recompensa para GSM8K (estilo willccbb, muy usado)
2import re
3
4def extract_answer(text: str):
5    # GSM8K marca la respuesta final tras "####"; aquí esperamos <answer>N</answer>
6    m = re.search(r"<answer>\s*(-?\d[\d,]*)\s*</answer>", text)
7    return m.group(1).replace(",", "") if m else None
8
9def correctness_reward(completions, answer, **kwargs):
10    """+2.0 si la respuesta numérica final es exacta. La señal principal."""
11    out = []
12    for comp, gold in zip(completions, answer):
13        pred = extract_answer(comp)
14        out.append(2.0 if (pred is not None and pred == gold) else 0.0)
15    return out
16
17def format_reward(completions, **kwargs):
18    """+0.5 si respeta la estructura <think>...</think><answer>...</answer>."""
19    pat = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
20    return [0.5 if pat.search(c) else 0.0 for c in completions]
21
22def soft_format_reward(completions, **kwargs):
23    """Recompensa parcial por tener al menos las etiquetas (evita el colapso de formato)."""
24    return [0.1*(("<think>" in c) + ("<answer>" in c)) for c in completions]

Líneas no triviales:

  • correctness_reward compara el número, no el texto: una respuesta correcta escrita distinto sigue siendo correcta.
  • Separar correctness (la señal real, peso alto) de format (peso bajo) evita que el modelo aprenda solo a formatear sin resolver — un reward hacking clásico (L5.3). Si solo premias formato, el modelo "razona" bonito y falla los números.
  • La firma (completions, answer, **kwargs): TRL pasa completions y las columnas extra del dataset (aquí answer) como kwargs. Mantén esa firma.

7.3Laboratorio L7.1 — GRPO con TRL + vLLM colocate

python
1# lab_n2l7_grpo_trl.py — GRPO sobre GSM8K con TRL + vLLM colocate en una 5090
2# entorno: trl[vllm]>=0.28, vllm>=0.12, transformers>=4.57
3from datasets import load_dataset
4from trl import GRPOConfig, GRPOTrainer
5from lab_n2l7_rewards import correctness_reward, format_reward, soft_format_reward
6
7SYSTEM = ("Resuelve el problema. Razona dentro de <think>...</think> "
8          "y da el número final dentro de <answer>...</answer>.")
9
10ds = load_dataset("openai/gsm8k", "main", split="train")
11def prep(ex):
12    gold = ex["answer"].split("####")[-1].strip().replace(",", "")
13    return {"prompt": [{"role":"system","content":SYSTEM},
14                       {"role":"user","content":ex["question"]}],
15            "answer": gold}
16ds = ds.map(prep)
17
18args = GRPOConfig(
19    output_dir="grpo_gsm8k",
20    use_vllm=True, vllm_mode="colocate",      # vLLM dentro del proceso, comparte la GPU (N6: por qué)
21    vllm_gpu_memory_utilization=0.3,          # deja VRAM para el training; sube/baja si OOM/infrautiliza
22    num_generations=8,                        # G de GRPO (L6): 8 respuestas por prompt
23    max_prompt_length=512, max_completion_length=1024,
24    per_device_train_batch_size=8,            # debe ser múltiplo de num_generations
25    gradient_accumulation_steps=4,
26    learning_rate=1e-6,                        # RL usa LR pequeñísimo
27    beta=0.04,                                 # peso de la KL hacia la referencia (L6)
28    logging_steps=1, max_steps=500,
29    report_to="wandb",
30)
31
32trainer = GRPOTrainer(
33    model="Qwen/Qwen3-4B-Instruct",
34    reward_funcs=[correctness_reward, format_reward, soft_format_reward],  # se suman
35    args=args, train_dataset=ds,
36)
37trainer.train()

Líneas no triviales:

  • vllm_mode="colocate" + vllm_gpu_memory_utilization=0.3: GRPO necesita generar (rollouts) en cada paso; vLLM hace esa generación rápida dentro del mismo proceso, compartiendo la GPU con el entrenamiento. El 0.3 reparte VRAM: ~30% para la generación vLLM, el resto para el training. Si ves OOM, baja este número; si vLLM va lento, súbelo. Este equilibrio es la clave de hacer GRPO en una sola GPU.
  • per_device_train_batch_size múltiplo de num_generations: cada prompt produce G completions que forman un grupo; el batch debe contener grupos enteros.
  • learning_rate=1e-6 y beta=0.04: RL es frágil; LR alto o KL baja → colapso o reward hacking.
  • reward_funcs=[...]: TRL suma las recompensas de la lista. La de correctness domina (peso 2.0).

7.4Laboratorio L7.2 — La misma receta con Unsloth (máxima eficiencia)

Si quieres modelos mayores o menos VRAM, Unsloth es el camino. Su notebook "Advanced GRPO" trae trucos (pre-finetuning para no aprender solo formato, reward de proximidad, Standby).

python
1# lab_n2l7_grpo_unsloth.py — GRPO con Unsloth (VRAM mínima; permite Qwen3-4B holgado)
2import os
3os.environ["UNSLOTH_VLLM_STANDBY"] = "1"   # libera VRAM de vLLM cuando no genera -> clave en 5090
4from unsloth import FastLanguageModel
5from trl import GRPOConfig, GRPOTrainer
6from datasets import load_dataset
7from lab_n2l7_rewards import correctness_reward, format_reward, soft_format_reward
8
9model, tok = FastLanguageModel.from_pretrained(
10    "unsloth/Qwen3-4B-Instruct", max_seq_length=2048, load_in_4bit=True, fast_inference=True)
11model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16,
12    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
13    use_gradient_checkpointing="unsloth")
14
15# (mismo dataset GSM8K preparado que en L7.1)
16trainer = GRPOTrainer(model=model, processing_class=tok,
17    reward_funcs=[correctness_reward, format_reward, soft_format_reward],
18    args=GRPOConfig(num_generations=8, max_prompt_length=512, max_completion_length=1024,
19                    per_device_train_batch_size=8, gradient_accumulation_steps=4,
20                    learning_rate=1e-6, beta=0.04, logging_steps=1, max_steps=500,
21                    output_dir="grpo_unsloth", report_to="wandb"),
22    train_dataset=ds)
23trainer.train()
24model.save_pretrained("adapters/qwen3-4b-gsm8k-grpo")

Línea clave: UNSLOTH_VLLM_STANDBY=1 — el "Standby" de Unsloth descarga la VRAM de vLLM mientras el paso de entrenamiento ocupa la GPU y la recarga para generar. Es lo que hace que GRPO + generación quepan cómodos en 32 GB sin pelearte con el reparto manual.


7.5Qué observar (la prueba del "aha moment")

Mira en wandb estas curvas a lo largo de los pasos:

  • reward (total) → debe subir.
  • completions/mean_length → debe crecer sola. Esta es la firma del aha moment (L6.6): el modelo descubre que razonar más largo resuelve más.
  • rewards/correctness_reward → la accuracy en GSM8K subiendo.
  • kl → debe mantenerse acotada (si se dispara, baja LR o sube beta).

Si la longitud crece y la correctness sube → has reproducido el fenómeno de R1-zero en tu 5090. Si solo sube el format reward y la correctness no → reward hacking de formato (baja el peso de format, sube el pre-finetuning).


7.6CHECKPOINT C2b(a) — criterio de aprobado

  • Entrenamiento GRPO que corre estable en tu 5090 (sin OOM, KL acotada).
  • Curvas que muestran reward ↑ y mean_length ↑ de forma sostenida.
  • Mejora medible de correctness en GSM8K (holdout) del modelo GRPO vs el base/SFT.
  • Sabes explicar cada función de recompensa y por qué separaste correctness de format (anti reward-hacking).

Rúbrica: Nivel 3 si reproduces el fenómeno y lo explicas; Nivel 4 si diseñaste tus propias reward functions y mostraste cómo evitan un reward hacking concreto.


7.7Ejercicios

E1. Entrena con num_generations=4 vs 16. ¿Cómo cambia la estabilidad y la velocidad de convergencia? (Conecta con L6: G mejor estima la baseline.)

E2. Quita correctness_reward y deja solo format_reward. Observa el reward hacking: el modelo formatea perfecto y falla los números. Documéntalo.

E3. Sube beta (KL) a 0.2 y bájalo a 0.0. ¿Qué le pasa a la longitud de respuesta y a la fluidez? Relaciónalo con L6.8.

7.8Trampas comunes

  • batch_size no múltiplo de num_generations → error de grupos.
  • vllm_gpu_memory_utilization mal ajustado → OOM o generación lentísima.
  • Premiar solo formato → reward hacking.
  • LR de SFT (2e-4) en GRPO → colapso inmediato. RL quiere ~1e-6.

7.9Referencias

  • Docs Unsloth (RL Guide; Advanced GRPO notebook; Standby). Docs TRL (GRPOTrainer, vLLM integration colocate). DeepSeek-R1 (Guo et al. 2025). willccbb GSM8K reward functions.