Add wall-clock checkpoints and full eval defaults

This commit is contained in:
Tobi Lütke 2026-02-22 15:02:02 -05:00
parent 5233e676d9
commit cbeeb1f89b
No known key found for this signature in database
7 changed files with 113 additions and 32 deletions

View File

@ -30,6 +30,10 @@ training:
learning_rate: 0.0000005
max_grad_norm: 0.5
max_steps: 200
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 50
grpo:
num_generations: 4

View File

@ -23,6 +23,11 @@ training:
max_length: 512
warmup_ratio: 0.03
lr_scheduler: "cosine"
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 200
save_total_limit: 3
lora:
rank: 16

View File

@ -21,6 +21,10 @@ training:
warmup_ratio: 0.03
lr_scheduler: "cosine"
ddp_find_unused_parameters: false
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 200
lora:
rank: 16

View File

@ -1,11 +1,9 @@
{
"dataset_name": "qmd-query-expansion",
"train_samples": 5440,
"val_samples": 605,
"short_query_pct": 11.1,
"train_samples": 2806,
"val_samples": 312,
"short_query_pct": 15.5,
"columns": [
"prompt",
"completion",
"text",
"messages"
]

View File

@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict:
tokenizer = get_tokenizer()
output_text = output_items_to_text(ex.output)
user_prompt = f"/no_think Expand this search query: {ex.query}"
if ex.intent:
user_prompt = (
f"/no_think Expand this search query: {ex.query}\n"
f"Query intent: {ex.intent.strip()}"
)
messages = [
{
"role": "user",
"content": f"/no_think Expand this search query: {ex.query}",
"content": user_prompt,
},
{"role": "assistant", "content": output_text},
]
@ -165,6 +172,7 @@ def main():
"train_samples": len(train_data),
"val_samples": len(val_data),
"short_query_pct": round(100 * short_final / len(all_examples), 1),
"columns": ["text", "messages"],
}
with open(output_dir / "dataset_info.json", "w") as f:
json.dump(dataset_info, f, indent=2)

View File

@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator.
Usage:
uv run eval.py ./outputs/sft
uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt
uv run eval.py ./outputs/sft --queries evals/queries.txt
By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
"""
import argparse
import json
import re
import sys
import os
from pathlib import Path
# Import reward scoring
@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent))
from reward import score_expansion_detailed
QUERIES = [
"how to configure authentication",
"docker compose networking",
"auth",
"who is TDS motorsports",
"React hooks tutorial",
"recent news about Shopify",
"how to implement caching with redis in nodejs",
"auth /only:lex",
"kubernetes pod deployment /only:vec",
"AWS Lambda cold start /only:hyde",
]
DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
def load_model(model_path: str):
@ -127,7 +118,11 @@ def generate_batch(
def main():
parser = argparse.ArgumentParser(description="Evaluate QMD model")
parser.add_argument("model", help="Model path (local or HF)")
parser.add_argument("--queries", help="Queries file (one per line)")
parser.add_argument(
"--queries",
default=str(DEFAULT_QUERY_FILE),
help="Queries file (one per line) [default: evals/queries.txt]",
)
parser.add_argument(
"--max-new-tokens",
type=int,
@ -154,11 +149,14 @@ def main():
)
args = parser.parse_args()
# Load queries
queries = QUERIES
if args.queries:
with open(args.queries) as f:
queries = [l.strip() for l in f if l.strip() and not l.startswith("#")]
# Load queries (default to full evals/queries.txt)
query_file = Path(args.queries)
if not query_file.exists():
raise FileNotFoundError(f"Queries file not found: {query_file}")
with query_file.open(encoding="utf-8") as f:
queries = [
l.strip() for l in f if l.strip() and not l.strip().startswith("#")
]
if args.max_queries and args.max_queries > 0:
queries = queries[: args.max_queries]

View File

@ -32,9 +32,11 @@ import argparse
import os
import subprocess
import sys
import time
from pathlib import Path
import yaml
from transformers import TrainerCallback
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
print(f"GGUF files saved to: {gguf_dir}")
class TimedSaveCallback(TrainerCallback):
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
def __init__(self, interval_minutes: float):
self.interval_seconds = float(interval_minutes) * 60.0
self.last_save_time = time.time()
def on_step_end(self, args, state, control, **kwargs):
if not getattr(state, "is_world_process_zero", False):
return control
now = time.time()
if now - self.last_save_time >= self.interval_seconds:
control.should_save = True
self.last_save_time = now
return control
def run_eval(model_path: str) -> float | None:
"""Run eval.py on the trained model and return average score."""
print("\n" + "=" * 60)
@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
def cmd_sft(args):
"""Run supervised fine-tuning."""
import torch
import os
from datasets import load_dataset
import torch
import torch.distributed as dist
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
@ -276,6 +294,22 @@ def cmd_sft(args):
"{time}", now.strftime("%H:%M")
)
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = SFTConfig(
output_dir=output_dir,
push_to_hub=push_to_hub,
@ -288,10 +322,10 @@ def cmd_sft(args):
max_length=cfg["training"]["max_length"],
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
save_steps=save_steps,
save_total_limit=save_total_limit,
eval_strategy="steps",
eval_steps=200,
eval_steps=cfg["training"].get("eval_steps", 200),
warmup_ratio=cfg["training"]["warmup_ratio"],
lr_scheduler_type=cfg["training"]["lr_scheduler"],
ddp_find_unused_parameters=cfg["training"].get(
@ -329,6 +363,7 @@ def cmd_sft(args):
args=config,
peft_config=peft_config,
processing_class=tokenizer,
callbacks=callbacks,
)
print("Starting SFT training...")
@ -378,6 +413,7 @@ def cmd_sft(args):
def cmd_grpo(args):
"""Run GRPO reinforcement learning on top of merged SFT weights."""
import torch
import torch.distributed as dist
import os
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model
@ -494,6 +530,7 @@ def cmd_grpo(args):
task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"],
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
ensure_weight_tying=True,
)
model = get_peft_model(model, grpo_lora_config)
model.print_trainable_parameters()
@ -510,6 +547,24 @@ def cmd_grpo(args):
if isinstance(learning_rate, str):
learning_rate = float(learning_rate)
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
save_strategy = cfg["training"].get("save_strategy", "epoch")
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
save_strategy = "steps"
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = GRPOConfig(
output_dir=output_dir,
push_to_hub=push_to_hub,
@ -524,7 +579,9 @@ def cmd_grpo(args):
max_grad_norm=cfg["training"]["max_grad_norm"],
max_steps=cfg["training"].get("max_steps", -1),
logging_steps=10,
save_strategy="epoch",
save_strategy=save_strategy,
save_steps=save_steps,
save_total_limit=save_total_limit,
bf16=True,
skip_memory_metrics=True,
report_to=report_to,
@ -539,11 +596,18 @@ def cmd_grpo(args):
args=config,
train_dataset=dataset,
reward_funcs=[QMDRewardFunction()],
callbacks=callbacks,
)
print("Starting GRPO training...")
trainer.train()
is_main = os.environ.get("RANK", "0") == "0"
if dist.is_available() and dist.is_initialized():
dist.barrier()
if not is_main:
return
if push_to_hub:
print("Pushing to Hub...")
trainer.push_to_hub()