Add wall-clock checkpoints and full eval defaults

This commit is contained in:
Tobi Lütke 2026-02-22 15:02:02 -05:00
parent 5233e676d9
commit cbeeb1f89b
No known key found for this signature in database
7 changed files with 113 additions and 32 deletions

View File

@ -30,6 +30,10 @@ training:
learning_rate: 0.0000005 learning_rate: 0.0000005
max_grad_norm: 0.5 max_grad_norm: 0.5
max_steps: 200 max_steps: 200
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 50
grpo: grpo:
num_generations: 4 num_generations: 4

View File

@ -23,6 +23,11 @@ training:
max_length: 512 max_length: 512
warmup_ratio: 0.03 warmup_ratio: 0.03
lr_scheduler: "cosine" lr_scheduler: "cosine"
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 200
save_total_limit: 3
lora: lora:
rank: 16 rank: 16

View File

@ -21,6 +21,10 @@ training:
warmup_ratio: 0.03 warmup_ratio: 0.03
lr_scheduler: "cosine" lr_scheduler: "cosine"
ddp_find_unused_parameters: false ddp_find_unused_parameters: false
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 200
lora: lora:
rank: 16 rank: 16

View File

@ -1,11 +1,9 @@
{ {
"dataset_name": "qmd-query-expansion", "dataset_name": "qmd-query-expansion",
"train_samples": 5440, "train_samples": 2806,
"val_samples": 605, "val_samples": 312,
"short_query_pct": 11.1, "short_query_pct": 15.5,
"columns": [ "columns": [
"prompt",
"completion",
"text", "text",
"messages" "messages"
] ]

View File

@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict:
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
output_text = output_items_to_text(ex.output) output_text = output_items_to_text(ex.output)
user_prompt = f"/no_think Expand this search query: {ex.query}"
if ex.intent:
user_prompt = (
f"/no_think Expand this search query: {ex.query}\n"
f"Query intent: {ex.intent.strip()}"
)
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": f"/no_think Expand this search query: {ex.query}", "content": user_prompt,
}, },
{"role": "assistant", "content": output_text}, {"role": "assistant", "content": output_text},
] ]
@ -165,6 +172,7 @@ def main():
"train_samples": len(train_data), "train_samples": len(train_data),
"val_samples": len(val_data), "val_samples": len(val_data),
"short_query_pct": round(100 * short_final / len(all_examples), 1), "short_query_pct": round(100 * short_final / len(all_examples), 1),
"columns": ["text", "messages"],
} }
with open(output_dir / "dataset_info.json", "w") as f: with open(output_dir / "dataset_info.json", "w") as f:
json.dump(dataset_info, f, indent=2) json.dump(dataset_info, f, indent=2)

View File

@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator.
Usage: Usage:
uv run eval.py ./outputs/sft uv run eval.py ./outputs/sft
uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt uv run eval.py ./outputs/sft --queries evals/queries.txt
By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
""" """
import argparse import argparse
import json import json
import re import re
import sys import sys
import os
from pathlib import Path from pathlib import Path
# Import reward scoring # Import reward scoring
@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent))
from reward import score_expansion_detailed from reward import score_expansion_detailed
QUERIES = [
"how to configure authentication", DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
"docker compose networking",
"auth",
"who is TDS motorsports",
"React hooks tutorial",
"recent news about Shopify",
"how to implement caching with redis in nodejs",
"auth /only:lex",
"kubernetes pod deployment /only:vec",
"AWS Lambda cold start /only:hyde",
]
def load_model(model_path: str): def load_model(model_path: str):
@ -127,7 +118,11 @@ def generate_batch(
def main(): def main():
parser = argparse.ArgumentParser(description="Evaluate QMD model") parser = argparse.ArgumentParser(description="Evaluate QMD model")
parser.add_argument("model", help="Model path (local or HF)") parser.add_argument("model", help="Model path (local or HF)")
parser.add_argument("--queries", help="Queries file (one per line)") parser.add_argument(
"--queries",
default=str(DEFAULT_QUERY_FILE),
help="Queries file (one per line) [default: evals/queries.txt]",
)
parser.add_argument( parser.add_argument(
"--max-new-tokens", "--max-new-tokens",
type=int, type=int,
@ -154,11 +149,14 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
# Load queries # Load queries (default to full evals/queries.txt)
queries = QUERIES query_file = Path(args.queries)
if args.queries: if not query_file.exists():
with open(args.queries) as f: raise FileNotFoundError(f"Queries file not found: {query_file}")
queries = [l.strip() for l in f if l.strip() and not l.startswith("#")] with query_file.open(encoding="utf-8") as f:
queries = [
l.strip() for l in f if l.strip() and not l.strip().startswith("#")
]
if args.max_queries and args.max_queries > 0: if args.max_queries and args.max_queries > 0:
queries = queries[: args.max_queries] queries = queries[: args.max_queries]

View File

@ -32,9 +32,11 @@ import argparse
import os import os
import subprocess import subprocess
import sys import sys
import time
from pathlib import Path from pathlib import Path
import yaml import yaml
from transformers import TrainerCallback
def export_gguf(model, tokenizer, output_dir: str, model_name: str): def export_gguf(model, tokenizer, output_dir: str, model_name: str):
@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
print(f"GGUF files saved to: {gguf_dir}") print(f"GGUF files saved to: {gguf_dir}")
class TimedSaveCallback(TrainerCallback):
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
def __init__(self, interval_minutes: float):
self.interval_seconds = float(interval_minutes) * 60.0
self.last_save_time = time.time()
def on_step_end(self, args, state, control, **kwargs):
if not getattr(state, "is_world_process_zero", False):
return control
now = time.time()
if now - self.last_save_time >= self.interval_seconds:
control.should_save = True
self.last_save_time = now
return control
def run_eval(model_path: str) -> float | None: def run_eval(model_path: str) -> float | None:
"""Run eval.py on the trained model and return average score.""" """Run eval.py on the trained model and return average score."""
print("\n" + "=" * 60) print("\n" + "=" * 60)
@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
def cmd_sft(args): def cmd_sft(args):
"""Run supervised fine-tuning.""" """Run supervised fine-tuning."""
import torch import torch
import os
from datasets import load_dataset from datasets import load_dataset
import torch
import torch.distributed as dist import torch.distributed as dist
from peft import LoraConfig from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
@ -276,6 +294,22 @@ def cmd_sft(args):
"{time}", now.strftime("%H:%M") "{time}", now.strftime("%H:%M")
) )
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = SFTConfig( config = SFTConfig(
output_dir=output_dir, output_dir=output_dir,
push_to_hub=push_to_hub, push_to_hub=push_to_hub,
@ -288,10 +322,10 @@ def cmd_sft(args):
max_length=cfg["training"]["max_length"], max_length=cfg["training"]["max_length"],
logging_steps=10, logging_steps=10,
save_strategy="steps", save_strategy="steps",
save_steps=200, save_steps=save_steps,
save_total_limit=2, save_total_limit=save_total_limit,
eval_strategy="steps", eval_strategy="steps",
eval_steps=200, eval_steps=cfg["training"].get("eval_steps", 200),
warmup_ratio=cfg["training"]["warmup_ratio"], warmup_ratio=cfg["training"]["warmup_ratio"],
lr_scheduler_type=cfg["training"]["lr_scheduler"], lr_scheduler_type=cfg["training"]["lr_scheduler"],
ddp_find_unused_parameters=cfg["training"].get( ddp_find_unused_parameters=cfg["training"].get(
@ -329,6 +363,7 @@ def cmd_sft(args):
args=config, args=config,
peft_config=peft_config, peft_config=peft_config,
processing_class=tokenizer, processing_class=tokenizer,
callbacks=callbacks,
) )
print("Starting SFT training...") print("Starting SFT training...")
@ -378,6 +413,7 @@ def cmd_sft(args):
def cmd_grpo(args): def cmd_grpo(args):
"""Run GRPO reinforcement learning on top of merged SFT weights.""" """Run GRPO reinforcement learning on top of merged SFT weights."""
import torch import torch
import torch.distributed as dist
import os import os
from datasets import load_dataset from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
@ -494,6 +530,7 @@ def cmd_grpo(args):
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"], target_modules=cfg["lora"]["target_modules"],
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
ensure_weight_tying=True,
) )
model = get_peft_model(model, grpo_lora_config) model = get_peft_model(model, grpo_lora_config)
model.print_trainable_parameters() model.print_trainable_parameters()
@ -510,6 +547,24 @@ def cmd_grpo(args):
if isinstance(learning_rate, str): if isinstance(learning_rate, str):
learning_rate = float(learning_rate) learning_rate = float(learning_rate)
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
save_strategy = cfg["training"].get("save_strategy", "epoch")
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
save_strategy = "steps"
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = GRPOConfig( config = GRPOConfig(
output_dir=output_dir, output_dir=output_dir,
push_to_hub=push_to_hub, push_to_hub=push_to_hub,
@ -524,7 +579,9 @@ def cmd_grpo(args):
max_grad_norm=cfg["training"]["max_grad_norm"], max_grad_norm=cfg["training"]["max_grad_norm"],
max_steps=cfg["training"].get("max_steps", -1), max_steps=cfg["training"].get("max_steps", -1),
logging_steps=10, logging_steps=10,
save_strategy="epoch", save_strategy=save_strategy,
save_steps=save_steps,
save_total_limit=save_total_limit,
bf16=True, bf16=True,
skip_memory_metrics=True, skip_memory_metrics=True,
report_to=report_to, report_to=report_to,
@ -539,11 +596,18 @@ def cmd_grpo(args):
args=config, args=config,
train_dataset=dataset, train_dataset=dataset,
reward_funcs=[QMDRewardFunction()], reward_funcs=[QMDRewardFunction()],
callbacks=callbacks,
) )
print("Starting GRPO training...") print("Starting GRPO training...")
trainer.train() trainer.train()
is_main = os.environ.get("RANK", "0") == "0"
if dist.is_available() and dist.is_initialized():
dist.barrier()
if not is_main:
return
if push_to_hub: if push_to_hub:
print("Pushing to Hub...") print("Pushing to Hub...")
trainer.push_to_hub() trainer.push_to_hub()