GRPO hands-on: reproducir el "aha moment" (Checkpoint C2b·a)
entrenar GRPO en tu 5090 y ver con tus ojos la emergencia de razonamiento: la longitud de respuesta y la recompensa subiendo juntas. Cierra C2b(a). Aplicas la teoría de L6 con la API real de TRL/Unsloth.
7.1El plan
Entrenaremos un modelo pequeño con GRPO + RLVR sobre GSM8K (problemas de matemáticas con respuesta numérica verificable), porque es el banco de pruebas canónico del "aha moment" y la recompensa es trivialmente verificable (¿el número final es correcto?).
Dos caminos, ambos en la 5090:
- TRL
GRPOTrainer+ vLLM colocate — el "happy path" estándar, máxima transparencia. - Unsloth GRPO — máxima eficiencia en VRAM (Qwen3-4B cómodo; 30B-A3B QLoRA en 17.5 GB).
7.2Las funciones de recompensa (el diseño que decide todo)
En RLVR tú escribes la recompensa. Para GSM8K, una combinación típica:
1# lab_n2l7_rewards.py — funciones de recompensa para GSM8K (estilo willccbb, muy usado)
2import re
3
4def extract_answer(text: str):
5 # GSM8K marca la respuesta final tras "####"; aquí esperamos <answer>N</answer>
6 m = re.search(r"<answer>\s*(-?\d[\d,]*)\s*</answer>", text)
7 return m.group(1).replace(",", "") if m else None
8
9def correctness_reward(completions, answer, **kwargs):
10 """+2.0 si la respuesta numérica final es exacta. La señal principal."""
11 out = []
12 for comp, gold in zip(completions, answer):
13 pred = extract_answer(comp)
14 out.append(2.0 if (pred is not None and pred == gold) else 0.0)
15 return out
16
17def format_reward(completions, **kwargs):
18 """+0.5 si respeta la estructura <think>...</think><answer>...</answer>."""
19 pat = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
20 return [0.5 if pat.search(c) else 0.0 for c in completions]
21
22def soft_format_reward(completions, **kwargs):
23 """Recompensa parcial por tener al menos las etiquetas (evita el colapso de formato)."""
24 return [0.1*(("<think>" in c) + ("<answer>" in c)) for c in completions]Líneas no triviales:
correctness_rewardcompara el número, no el texto: una respuesta correcta escrita distinto sigue siendo correcta.- Separar
correctness(la señal real, peso alto) deformat(peso bajo) evita que el modelo aprenda solo a formatear sin resolver — un reward hacking clásico (L5.3). Si solo premias formato, el modelo "razona" bonito y falla los números. - La firma
(completions, answer, **kwargs): TRL pasacompletionsy las columnas extra del dataset (aquíanswer) como kwargs. Mantén esa firma.
7.3Laboratorio L7.1 — GRPO con TRL + vLLM colocate
1# lab_n2l7_grpo_trl.py — GRPO sobre GSM8K con TRL + vLLM colocate en una 5090
2# entorno: trl[vllm]>=0.28, vllm>=0.12, transformers>=4.57
3from datasets import load_dataset
4from trl import GRPOConfig, GRPOTrainer
5from lab_n2l7_rewards import correctness_reward, format_reward, soft_format_reward
6
7SYSTEM = ("Resuelve el problema. Razona dentro de <think>...</think> "
8 "y da el número final dentro de <answer>...</answer>.")
9
10ds = load_dataset("openai/gsm8k", "main", split="train")
11def prep(ex):
12 gold = ex["answer"].split("####")[-1].strip().replace(",", "")
13 return {"prompt": [{"role":"system","content":SYSTEM},
14 {"role":"user","content":ex["question"]}],
15 "answer": gold}
16ds = ds.map(prep)
17
18args = GRPOConfig(
19 output_dir="grpo_gsm8k",
20 use_vllm=True, vllm_mode="colocate", # vLLM dentro del proceso, comparte la GPU (N6: por qué)
21 vllm_gpu_memory_utilization=0.3, # deja VRAM para el training; sube/baja si OOM/infrautiliza
22 num_generations=8, # G de GRPO (L6): 8 respuestas por prompt
23 max_prompt_length=512, max_completion_length=1024,
24 per_device_train_batch_size=8, # debe ser múltiplo de num_generations
25 gradient_accumulation_steps=4,
26 learning_rate=1e-6, # RL usa LR pequeñísimo
27 beta=0.04, # peso de la KL hacia la referencia (L6)
28 logging_steps=1, max_steps=500,
29 report_to="wandb",
30)
31
32trainer = GRPOTrainer(
33 model="Qwen/Qwen3-4B-Instruct",
34 reward_funcs=[correctness_reward, format_reward, soft_format_reward], # se suman
35 args=args, train_dataset=ds,
36)
37trainer.train()Líneas no triviales:
vllm_mode="colocate"+vllm_gpu_memory_utilization=0.3: GRPO necesita generar (rollouts) en cada paso; vLLM hace esa generación rápida dentro del mismo proceso, compartiendo la GPU con el entrenamiento. El 0.3 reparte VRAM: ~30% para la generación vLLM, el resto para el training. Si ves OOM, baja este número; si vLLM va lento, súbelo. Este equilibrio es la clave de hacer GRPO en una sola GPU.per_device_train_batch_sizemúltiplo denum_generations: cada prompt produce G completions que forman un grupo; el batch debe contener grupos enteros.learning_rate=1e-6ybeta=0.04: RL es frágil; LR alto o KL baja → colapso o reward hacking.reward_funcs=[...]: TRL suma las recompensas de la lista. La de correctness domina (peso 2.0).
7.4Laboratorio L7.2 — La misma receta con Unsloth (máxima eficiencia)
Si quieres modelos mayores o menos VRAM, Unsloth es el camino. Su notebook "Advanced GRPO" trae trucos (pre-finetuning para no aprender solo formato, reward de proximidad, Standby).
1# lab_n2l7_grpo_unsloth.py — GRPO con Unsloth (VRAM mínima; permite Qwen3-4B holgado)
2import os
3os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # libera VRAM de vLLM cuando no genera -> clave en 5090
4from unsloth import FastLanguageModel
5from trl import GRPOConfig, GRPOTrainer
6from datasets import load_dataset
7from lab_n2l7_rewards import correctness_reward, format_reward, soft_format_reward
8
9model, tok = FastLanguageModel.from_pretrained(
10 "unsloth/Qwen3-4B-Instruct", max_seq_length=2048, load_in_4bit=True, fast_inference=True)
11model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16,
12 target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
13 use_gradient_checkpointing="unsloth")
14
15# (mismo dataset GSM8K preparado que en L7.1)
16trainer = GRPOTrainer(model=model, processing_class=tok,
17 reward_funcs=[correctness_reward, format_reward, soft_format_reward],
18 args=GRPOConfig(num_generations=8, max_prompt_length=512, max_completion_length=1024,
19 per_device_train_batch_size=8, gradient_accumulation_steps=4,
20 learning_rate=1e-6, beta=0.04, logging_steps=1, max_steps=500,
21 output_dir="grpo_unsloth", report_to="wandb"),
22 train_dataset=ds)
23trainer.train()
24model.save_pretrained("adapters/qwen3-4b-gsm8k-grpo")Línea clave: UNSLOTH_VLLM_STANDBY=1 — el "Standby" de Unsloth descarga la VRAM de vLLM mientras el paso de entrenamiento ocupa la GPU y la recarga para generar. Es lo que hace que GRPO + generación quepan cómodos en 32 GB sin pelearte con el reparto manual.
7.5Qué observar (la prueba del "aha moment")
Mira en wandb estas curvas a lo largo de los pasos:
reward(total) → debe subir.completions/mean_length→ debe crecer sola. Esta es la firma del aha moment (L6.6): el modelo descubre que razonar más largo resuelve más.rewards/correctness_reward→ la accuracy en GSM8K subiendo.kl→ debe mantenerse acotada (si se dispara, baja LR o sube beta).
Si la longitud crece y la correctness sube → has reproducido el fenómeno de R1-zero en tu 5090. Si solo sube el format reward y la correctness no → reward hacking de formato (baja el peso de format, sube el pre-finetuning).
7.6CHECKPOINT C2b(a) — criterio de aprobado
- Entrenamiento GRPO que corre estable en tu 5090 (sin OOM, KL acotada).
- Curvas que muestran reward ↑ y mean_length ↑ de forma sostenida.
- Mejora medible de correctness en GSM8K (holdout) del modelo GRPO vs el base/SFT.
- Sabes explicar cada función de recompensa y por qué separaste correctness de format (anti reward-hacking).
Rúbrica: Nivel 3 si reproduces el fenómeno y lo explicas; Nivel 4 si diseñaste tus propias reward functions y mostraste cómo evitan un reward hacking concreto.
7.7Ejercicios
E1. Entrena con num_generations=4 vs 16. ¿Cómo cambia la estabilidad y la velocidad de convergencia? (Conecta con L6: G mejor estima la baseline.)
E2. Quita correctness_reward y deja solo format_reward. Observa el reward hacking: el modelo formatea perfecto y falla los números. Documéntalo.
E3. Sube beta (KL) a 0.2 y bájalo a 0.0. ¿Qué le pasa a la longitud de respuesta y a la fluidez? Relaciónalo con L6.8.
7.8Trampas comunes
batch_sizeno múltiplo denum_generations→ error de grupos.vllm_gpu_memory_utilizationmal ajustado → OOM o generación lentísima.- Premiar solo formato → reward hacking.
- LR de SFT (2e-4) en GRPO → colapso inmediato. RL quiere ~1e-6.
7.9Referencias
- Docs Unsloth (RL Guide; Advanced GRPO notebook; Standby). Docs TRL (GRPOTrainer, vLLM integration colocate). DeepSeek-R1 (Guo et al. 2025). willccbb GSM8K reward functions.