diff --git a/finetune/configs/grpo.yaml b/finetune/configs/grpo.yaml index ca717b4..db99207 100644 --- a/finetune/configs/grpo.yaml +++ b/finetune/configs/grpo.yaml @@ -30,6 +30,10 @@ training: learning_rate: 0.0000005 max_grad_norm: 0.5 max_steps: 200 + # Save checkpoints every 30 minutes + save_interval_minutes: 30 + # Fallback time-step save cadence if needed (not used for wall-clock mode) + save_steps: 50 grpo: num_generations: 4 diff --git a/finetune/configs/sft.yaml b/finetune/configs/sft.yaml index 830b593..b7d132e 100644 --- a/finetune/configs/sft.yaml +++ b/finetune/configs/sft.yaml @@ -23,6 +23,11 @@ training: max_length: 512 warmup_ratio: 0.03 lr_scheduler: "cosine" + # Save checkpoints every 30 minutes + save_interval_minutes: 30 + # Fallback time-step save cadence if needed (not used for wall-clock mode) + save_steps: 200 + save_total_limit: 3 lora: rank: 16 diff --git a/finetune/configs/sft_local.yaml b/finetune/configs/sft_local.yaml index 4d70a5c..43941ff 100644 --- a/finetune/configs/sft_local.yaml +++ b/finetune/configs/sft_local.yaml @@ -21,6 +21,10 @@ training: warmup_ratio: 0.03 lr_scheduler: "cosine" ddp_find_unused_parameters: false + # Save checkpoints every 30 minutes + save_interval_minutes: 30 + # Fallback time-step save cadence if needed (not used for wall-clock mode) + save_steps: 200 lora: rank: 16 diff --git a/finetune/data/train/dataset_info.json b/finetune/data/train/dataset_info.json index 5034ec5..34e381b 100644 --- a/finetune/data/train/dataset_info.json +++ b/finetune/data/train/dataset_info.json @@ -1,11 +1,9 @@ { "dataset_name": "qmd-query-expansion", - "train_samples": 5440, - "val_samples": 605, - "short_query_pct": 11.1, + "train_samples": 2806, + "val_samples": 312, + "short_query_pct": 15.5, "columns": [ - "prompt", - "completion", "text", "messages" ] diff --git a/finetune/dataset/prepare_data.py b/finetune/dataset/prepare_data.py index 2431a2f..3006be6 100644 --- a/finetune/dataset/prepare_data.py +++ b/finetune/dataset/prepare_data.py @@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict: tokenizer = get_tokenizer() output_text = output_items_to_text(ex.output) + user_prompt = f"/no_think Expand this search query: {ex.query}" + if ex.intent: + user_prompt = ( + f"/no_think Expand this search query: {ex.query}\n" + f"Query intent: {ex.intent.strip()}" + ) + messages = [ { "role": "user", - "content": f"/no_think Expand this search query: {ex.query}", + "content": user_prompt, }, {"role": "assistant", "content": output_text}, ] @@ -165,6 +172,7 @@ def main(): "train_samples": len(train_data), "val_samples": len(val_data), "short_query_pct": round(100 * short_final / len(all_examples), 1), + "columns": ["text", "messages"], } with open(output_dir / "dataset_info.json", "w") as f: json.dump(dataset_info, f, indent=2) diff --git a/finetune/eval.py b/finetune/eval.py index cfe20ac..cc91093 100644 --- a/finetune/eval.py +++ b/finetune/eval.py @@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator. Usage: uv run eval.py ./outputs/sft - uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt + uv run eval.py ./outputs/sft --queries evals/queries.txt + +By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set. """ import argparse import json import re import sys -import os from pathlib import Path # Import reward scoring @@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent)) from reward import score_expansion_detailed -QUERIES = [ - "how to configure authentication", - "docker compose networking", - "auth", - "who is TDS motorsports", - "React hooks tutorial", - "recent news about Shopify", - "how to implement caching with redis in nodejs", - "auth /only:lex", - "kubernetes pod deployment /only:vec", - "AWS Lambda cold start /only:hyde", -] + +DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt" def load_model(model_path: str): @@ -127,7 +118,11 @@ def generate_batch( def main(): parser = argparse.ArgumentParser(description="Evaluate QMD model") parser.add_argument("model", help="Model path (local or HF)") - parser.add_argument("--queries", help="Queries file (one per line)") + parser.add_argument( + "--queries", + default=str(DEFAULT_QUERY_FILE), + help="Queries file (one per line) [default: evals/queries.txt]", + ) parser.add_argument( "--max-new-tokens", type=int, @@ -154,11 +149,14 @@ def main(): ) args = parser.parse_args() - # Load queries - queries = QUERIES - if args.queries: - with open(args.queries) as f: - queries = [l.strip() for l in f if l.strip() and not l.startswith("#")] + # Load queries (default to full evals/queries.txt) + query_file = Path(args.queries) + if not query_file.exists(): + raise FileNotFoundError(f"Queries file not found: {query_file}") + with query_file.open(encoding="utf-8") as f: + queries = [ + l.strip() for l in f if l.strip() and not l.strip().startswith("#") + ] if args.max_queries and args.max_queries > 0: queries = queries[: args.max_queries] diff --git a/finetune/train.py b/finetune/train.py index a0a49a5..dc77ffb 100644 --- a/finetune/train.py +++ b/finetune/train.py @@ -32,9 +32,11 @@ import argparse import os import subprocess import sys +import time from pathlib import Path import yaml +from transformers import TrainerCallback def export_gguf(model, tokenizer, output_dir: str, model_name: str): @@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str): print(f"GGUF files saved to: {gguf_dir}") +class TimedSaveCallback(TrainerCallback): + """Trigger periodic checkpoint saves based on elapsed wall-clock time.""" + + def __init__(self, interval_minutes: float): + self.interval_seconds = float(interval_minutes) * 60.0 + self.last_save_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if not getattr(state, "is_world_process_zero", False): + return control + + now = time.time() + if now - self.last_save_time >= self.interval_seconds: + control.should_save = True + self.last_save_time = now + return control + + def run_eval(model_path: str) -> float | None: """Run eval.py on the trained model and return average score.""" print("\n" + "=" * 60) @@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None: def cmd_sft(args): """Run supervised fine-tuning.""" import torch - import os from datasets import load_dataset - import torch import torch.distributed as dist from peft import LoraConfig from transformers import AutoTokenizer, AutoModelForCausalLM @@ -276,6 +294,22 @@ def cmd_sft(args): "{time}", now.strftime("%H:%M") ) + save_interval_minutes = cfg["training"].get("save_interval_minutes") + save_steps = cfg["training"].get("save_steps", 200) + save_total_limit = cfg["training"].get("save_total_limit", 2) + if save_interval_minutes: + # Prefer wall-clock checkpointing (for long jobs / preemption safety) + save_steps = max(save_steps, 10_000_000) + + callbacks = [] + if save_interval_minutes: + try: + interval_value = float(save_interval_minutes) + except (TypeError, ValueError): + interval_value = None + if interval_value and interval_value > 0: + callbacks.append(TimedSaveCallback(interval_value)) + config = SFTConfig( output_dir=output_dir, push_to_hub=push_to_hub, @@ -288,10 +322,10 @@ def cmd_sft(args): max_length=cfg["training"]["max_length"], logging_steps=10, save_strategy="steps", - save_steps=200, - save_total_limit=2, + save_steps=save_steps, + save_total_limit=save_total_limit, eval_strategy="steps", - eval_steps=200, + eval_steps=cfg["training"].get("eval_steps", 200), warmup_ratio=cfg["training"]["warmup_ratio"], lr_scheduler_type=cfg["training"]["lr_scheduler"], ddp_find_unused_parameters=cfg["training"].get( @@ -329,6 +363,7 @@ def cmd_sft(args): args=config, peft_config=peft_config, processing_class=tokenizer, + callbacks=callbacks, ) print("Starting SFT training...") @@ -378,6 +413,7 @@ def cmd_sft(args): def cmd_grpo(args): """Run GRPO reinforcement learning on top of merged SFT weights.""" import torch + import torch.distributed as dist import os from datasets import load_dataset from peft import LoraConfig, PeftModel, get_peft_model @@ -494,6 +530,7 @@ def cmd_grpo(args): task_type="CAUSAL_LM", target_modules=cfg["lora"]["target_modules"], modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens + ensure_weight_tying=True, ) model = get_peft_model(model, grpo_lora_config) model.print_trainable_parameters() @@ -510,6 +547,24 @@ def cmd_grpo(args): if isinstance(learning_rate, str): learning_rate = float(learning_rate) + save_interval_minutes = cfg["training"].get("save_interval_minutes") + save_steps = cfg["training"].get("save_steps", 200) + save_total_limit = cfg["training"].get("save_total_limit", 2) + save_strategy = cfg["training"].get("save_strategy", "epoch") + if save_interval_minutes: + # Prefer wall-clock checkpointing (for long jobs / preemption safety) + save_steps = max(save_steps, 10_000_000) + save_strategy = "steps" + + callbacks = [] + if save_interval_minutes: + try: + interval_value = float(save_interval_minutes) + except (TypeError, ValueError): + interval_value = None + if interval_value and interval_value > 0: + callbacks.append(TimedSaveCallback(interval_value)) + config = GRPOConfig( output_dir=output_dir, push_to_hub=push_to_hub, @@ -524,7 +579,9 @@ def cmd_grpo(args): max_grad_norm=cfg["training"]["max_grad_norm"], max_steps=cfg["training"].get("max_steps", -1), logging_steps=10, - save_strategy="epoch", + save_strategy=save_strategy, + save_steps=save_steps, + save_total_limit=save_total_limit, bf16=True, skip_memory_metrics=True, report_to=report_to, @@ -539,11 +596,18 @@ def cmd_grpo(args): args=config, train_dataset=dataset, reward_funcs=[QMDRewardFunction()], + callbacks=callbacks, ) print("Starting GRPO training...") trainer.train() + is_main = os.environ.get("RANK", "0") == "0" + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if not is_main: + return + if push_to_hub: print("Pushing to Hub...") trainer.push_to_hub()