Add wall-clock checkpoints and full eval defaults
This commit is contained in:
parent
5233e676d9
commit
cbeeb1f89b
@ -30,6 +30,10 @@ training:
|
||||
learning_rate: 0.0000005
|
||||
max_grad_norm: 0.5
|
||||
max_steps: 200
|
||||
# Save checkpoints every 30 minutes
|
||||
save_interval_minutes: 30
|
||||
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||
save_steps: 50
|
||||
|
||||
grpo:
|
||||
num_generations: 4
|
||||
|
||||
@ -23,6 +23,11 @@ training:
|
||||
max_length: 512
|
||||
warmup_ratio: 0.03
|
||||
lr_scheduler: "cosine"
|
||||
# Save checkpoints every 30 minutes
|
||||
save_interval_minutes: 30
|
||||
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||
save_steps: 200
|
||||
save_total_limit: 3
|
||||
|
||||
lora:
|
||||
rank: 16
|
||||
|
||||
@ -21,6 +21,10 @@ training:
|
||||
warmup_ratio: 0.03
|
||||
lr_scheduler: "cosine"
|
||||
ddp_find_unused_parameters: false
|
||||
# Save checkpoints every 30 minutes
|
||||
save_interval_minutes: 30
|
||||
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||
save_steps: 200
|
||||
|
||||
lora:
|
||||
rank: 16
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
{
|
||||
"dataset_name": "qmd-query-expansion",
|
||||
"train_samples": 5440,
|
||||
"val_samples": 605,
|
||||
"short_query_pct": 11.1,
|
||||
"train_samples": 2806,
|
||||
"val_samples": 312,
|
||||
"short_query_pct": 15.5,
|
||||
"columns": [
|
||||
"prompt",
|
||||
"completion",
|
||||
"text",
|
||||
"messages"
|
||||
]
|
||||
|
||||
@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict:
|
||||
tokenizer = get_tokenizer()
|
||||
output_text = output_items_to_text(ex.output)
|
||||
|
||||
user_prompt = f"/no_think Expand this search query: {ex.query}"
|
||||
if ex.intent:
|
||||
user_prompt = (
|
||||
f"/no_think Expand this search query: {ex.query}\n"
|
||||
f"Query intent: {ex.intent.strip()}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"/no_think Expand this search query: {ex.query}",
|
||||
"content": user_prompt,
|
||||
},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
@ -165,6 +172,7 @@ def main():
|
||||
"train_samples": len(train_data),
|
||||
"val_samples": len(val_data),
|
||||
"short_query_pct": round(100 * short_final / len(all_examples), 1),
|
||||
"columns": ["text", "messages"],
|
||||
}
|
||||
with open(output_dir / "dataset_info.json", "w") as f:
|
||||
json.dump(dataset_info, f, indent=2)
|
||||
|
||||
@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator.
|
||||
|
||||
Usage:
|
||||
uv run eval.py ./outputs/sft
|
||||
uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt
|
||||
uv run eval.py ./outputs/sft --queries evals/queries.txt
|
||||
|
||||
By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Import reward scoring
|
||||
@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent))
|
||||
from reward import score_expansion_detailed
|
||||
|
||||
|
||||
QUERIES = [
|
||||
"how to configure authentication",
|
||||
"docker compose networking",
|
||||
"auth",
|
||||
"who is TDS motorsports",
|
||||
"React hooks tutorial",
|
||||
"recent news about Shopify",
|
||||
"how to implement caching with redis in nodejs",
|
||||
"auth /only:lex",
|
||||
"kubernetes pod deployment /only:vec",
|
||||
"AWS Lambda cold start /only:hyde",
|
||||
]
|
||||
|
||||
DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
|
||||
|
||||
|
||||
def load_model(model_path: str):
|
||||
@ -127,7 +118,11 @@ def generate_batch(
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate QMD model")
|
||||
parser.add_argument("model", help="Model path (local or HF)")
|
||||
parser.add_argument("--queries", help="Queries file (one per line)")
|
||||
parser.add_argument(
|
||||
"--queries",
|
||||
default=str(DEFAULT_QUERY_FILE),
|
||||
help="Queries file (one per line) [default: evals/queries.txt]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
@ -154,11 +149,14 @@ def main():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load queries
|
||||
queries = QUERIES
|
||||
if args.queries:
|
||||
with open(args.queries) as f:
|
||||
queries = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
||||
# Load queries (default to full evals/queries.txt)
|
||||
query_file = Path(args.queries)
|
||||
if not query_file.exists():
|
||||
raise FileNotFoundError(f"Queries file not found: {query_file}")
|
||||
with query_file.open(encoding="utf-8") as f:
|
||||
queries = [
|
||||
l.strip() for l in f if l.strip() and not l.strip().startswith("#")
|
||||
]
|
||||
|
||||
if args.max_queries and args.max_queries > 0:
|
||||
queries = queries[: args.max_queries]
|
||||
|
||||
@ -32,9 +32,11 @@ import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
||||
@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
||||
print(f"GGUF files saved to: {gguf_dir}")
|
||||
|
||||
|
||||
class TimedSaveCallback(TrainerCallback):
|
||||
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
|
||||
|
||||
def __init__(self, interval_minutes: float):
|
||||
self.interval_seconds = float(interval_minutes) * 60.0
|
||||
self.last_save_time = time.time()
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if not getattr(state, "is_world_process_zero", False):
|
||||
return control
|
||||
|
||||
now = time.time()
|
||||
if now - self.last_save_time >= self.interval_seconds:
|
||||
control.should_save = True
|
||||
self.last_save_time = now
|
||||
return control
|
||||
|
||||
|
||||
def run_eval(model_path: str) -> float | None:
|
||||
"""Run eval.py on the trained model and return average score."""
|
||||
print("\n" + "=" * 60)
|
||||
@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
|
||||
def cmd_sft(args):
|
||||
"""Run supervised fine-tuning."""
|
||||
import torch
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
@ -276,6 +294,22 @@ def cmd_sft(args):
|
||||
"{time}", now.strftime("%H:%M")
|
||||
)
|
||||
|
||||
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
||||
save_steps = cfg["training"].get("save_steps", 200)
|
||||
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
||||
if save_interval_minutes:
|
||||
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
||||
save_steps = max(save_steps, 10_000_000)
|
||||
|
||||
callbacks = []
|
||||
if save_interval_minutes:
|
||||
try:
|
||||
interval_value = float(save_interval_minutes)
|
||||
except (TypeError, ValueError):
|
||||
interval_value = None
|
||||
if interval_value and interval_value > 0:
|
||||
callbacks.append(TimedSaveCallback(interval_value))
|
||||
|
||||
config = SFTConfig(
|
||||
output_dir=output_dir,
|
||||
push_to_hub=push_to_hub,
|
||||
@ -288,10 +322,10 @@ def cmd_sft(args):
|
||||
max_length=cfg["training"]["max_length"],
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=200,
|
||||
save_total_limit=2,
|
||||
save_steps=save_steps,
|
||||
save_total_limit=save_total_limit,
|
||||
eval_strategy="steps",
|
||||
eval_steps=200,
|
||||
eval_steps=cfg["training"].get("eval_steps", 200),
|
||||
warmup_ratio=cfg["training"]["warmup_ratio"],
|
||||
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
||||
ddp_find_unused_parameters=cfg["training"].get(
|
||||
@ -329,6 +363,7 @@ def cmd_sft(args):
|
||||
args=config,
|
||||
peft_config=peft_config,
|
||||
processing_class=tokenizer,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
print("Starting SFT training...")
|
||||
@ -378,6 +413,7 @@ def cmd_sft(args):
|
||||
def cmd_grpo(args):
|
||||
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
@ -494,6 +530,7 @@ def cmd_grpo(args):
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=cfg["lora"]["target_modules"],
|
||||
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
||||
ensure_weight_tying=True,
|
||||
)
|
||||
model = get_peft_model(model, grpo_lora_config)
|
||||
model.print_trainable_parameters()
|
||||
@ -510,6 +547,24 @@ def cmd_grpo(args):
|
||||
if isinstance(learning_rate, str):
|
||||
learning_rate = float(learning_rate)
|
||||
|
||||
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
||||
save_steps = cfg["training"].get("save_steps", 200)
|
||||
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
||||
save_strategy = cfg["training"].get("save_strategy", "epoch")
|
||||
if save_interval_minutes:
|
||||
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
||||
save_steps = max(save_steps, 10_000_000)
|
||||
save_strategy = "steps"
|
||||
|
||||
callbacks = []
|
||||
if save_interval_minutes:
|
||||
try:
|
||||
interval_value = float(save_interval_minutes)
|
||||
except (TypeError, ValueError):
|
||||
interval_value = None
|
||||
if interval_value and interval_value > 0:
|
||||
callbacks.append(TimedSaveCallback(interval_value))
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir=output_dir,
|
||||
push_to_hub=push_to_hub,
|
||||
@ -524,7 +579,9 @@ def cmd_grpo(args):
|
||||
max_grad_norm=cfg["training"]["max_grad_norm"],
|
||||
max_steps=cfg["training"].get("max_steps", -1),
|
||||
logging_steps=10,
|
||||
save_strategy="epoch",
|
||||
save_strategy=save_strategy,
|
||||
save_steps=save_steps,
|
||||
save_total_limit=save_total_limit,
|
||||
bf16=True,
|
||||
skip_memory_metrics=True,
|
||||
report_to=report_to,
|
||||
@ -539,11 +596,18 @@ def cmd_grpo(args):
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
reward_funcs=[QMDRewardFunction()],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
print("Starting GRPO training...")
|
||||
trainer.train()
|
||||
|
||||
is_main = os.environ.get("RANK", "0") == "0"
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.barrier()
|
||||
if not is_main:
|
||||
return
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user