Add wall-clock checkpoints and full eval defaults
This commit is contained in:
parent
5233e676d9
commit
cbeeb1f89b
@ -30,6 +30,10 @@ training:
|
|||||||
learning_rate: 0.0000005
|
learning_rate: 0.0000005
|
||||||
max_grad_norm: 0.5
|
max_grad_norm: 0.5
|
||||||
max_steps: 200
|
max_steps: 200
|
||||||
|
# Save checkpoints every 30 minutes
|
||||||
|
save_interval_minutes: 30
|
||||||
|
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||||
|
save_steps: 50
|
||||||
|
|
||||||
grpo:
|
grpo:
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
|
|||||||
@ -23,6 +23,11 @@ training:
|
|||||||
max_length: 512
|
max_length: 512
|
||||||
warmup_ratio: 0.03
|
warmup_ratio: 0.03
|
||||||
lr_scheduler: "cosine"
|
lr_scheduler: "cosine"
|
||||||
|
# Save checkpoints every 30 minutes
|
||||||
|
save_interval_minutes: 30
|
||||||
|
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||||
|
save_steps: 200
|
||||||
|
save_total_limit: 3
|
||||||
|
|
||||||
lora:
|
lora:
|
||||||
rank: 16
|
rank: 16
|
||||||
|
|||||||
@ -21,6 +21,10 @@ training:
|
|||||||
warmup_ratio: 0.03
|
warmup_ratio: 0.03
|
||||||
lr_scheduler: "cosine"
|
lr_scheduler: "cosine"
|
||||||
ddp_find_unused_parameters: false
|
ddp_find_unused_parameters: false
|
||||||
|
# Save checkpoints every 30 minutes
|
||||||
|
save_interval_minutes: 30
|
||||||
|
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
||||||
|
save_steps: 200
|
||||||
|
|
||||||
lora:
|
lora:
|
||||||
rank: 16
|
rank: 16
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
{
|
{
|
||||||
"dataset_name": "qmd-query-expansion",
|
"dataset_name": "qmd-query-expansion",
|
||||||
"train_samples": 5440,
|
"train_samples": 2806,
|
||||||
"val_samples": 605,
|
"val_samples": 312,
|
||||||
"short_query_pct": 11.1,
|
"short_query_pct": 15.5,
|
||||||
"columns": [
|
"columns": [
|
||||||
"prompt",
|
|
||||||
"completion",
|
|
||||||
"text",
|
"text",
|
||||||
"messages"
|
"messages"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict:
|
|||||||
tokenizer = get_tokenizer()
|
tokenizer = get_tokenizer()
|
||||||
output_text = output_items_to_text(ex.output)
|
output_text = output_items_to_text(ex.output)
|
||||||
|
|
||||||
|
user_prompt = f"/no_think Expand this search query: {ex.query}"
|
||||||
|
if ex.intent:
|
||||||
|
user_prompt = (
|
||||||
|
f"/no_think Expand this search query: {ex.query}\n"
|
||||||
|
f"Query intent: {ex.intent.strip()}"
|
||||||
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"/no_think Expand this search query: {ex.query}",
|
"content": user_prompt,
|
||||||
},
|
},
|
||||||
{"role": "assistant", "content": output_text},
|
{"role": "assistant", "content": output_text},
|
||||||
]
|
]
|
||||||
@ -165,6 +172,7 @@ def main():
|
|||||||
"train_samples": len(train_data),
|
"train_samples": len(train_data),
|
||||||
"val_samples": len(val_data),
|
"val_samples": len(val_data),
|
||||||
"short_query_pct": round(100 * short_final / len(all_examples), 1),
|
"short_query_pct": round(100 * short_final / len(all_examples), 1),
|
||||||
|
"columns": ["text", "messages"],
|
||||||
}
|
}
|
||||||
with open(output_dir / "dataset_info.json", "w") as f:
|
with open(output_dir / "dataset_info.json", "w") as f:
|
||||||
json.dump(dataset_info, f, indent=2)
|
json.dump(dataset_info, f, indent=2)
|
||||||
|
|||||||
@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator.
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
uv run eval.py ./outputs/sft
|
uv run eval.py ./outputs/sft
|
||||||
uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt
|
uv run eval.py ./outputs/sft --queries evals/queries.txt
|
||||||
|
|
||||||
|
By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Import reward scoring
|
# Import reward scoring
|
||||||
@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent))
|
|||||||
from reward import score_expansion_detailed
|
from reward import score_expansion_detailed
|
||||||
|
|
||||||
|
|
||||||
QUERIES = [
|
|
||||||
"how to configure authentication",
|
DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
|
||||||
"docker compose networking",
|
|
||||||
"auth",
|
|
||||||
"who is TDS motorsports",
|
|
||||||
"React hooks tutorial",
|
|
||||||
"recent news about Shopify",
|
|
||||||
"how to implement caching with redis in nodejs",
|
|
||||||
"auth /only:lex",
|
|
||||||
"kubernetes pod deployment /only:vec",
|
|
||||||
"AWS Lambda cold start /only:hyde",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: str):
|
def load_model(model_path: str):
|
||||||
@ -127,7 +118,11 @@ def generate_batch(
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Evaluate QMD model")
|
parser = argparse.ArgumentParser(description="Evaluate QMD model")
|
||||||
parser.add_argument("model", help="Model path (local or HF)")
|
parser.add_argument("model", help="Model path (local or HF)")
|
||||||
parser.add_argument("--queries", help="Queries file (one per line)")
|
parser.add_argument(
|
||||||
|
"--queries",
|
||||||
|
default=str(DEFAULT_QUERY_FILE),
|
||||||
|
help="Queries file (one per line) [default: evals/queries.txt]",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-new-tokens",
|
"--max-new-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
@ -154,11 +149,14 @@ def main():
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load queries
|
# Load queries (default to full evals/queries.txt)
|
||||||
queries = QUERIES
|
query_file = Path(args.queries)
|
||||||
if args.queries:
|
if not query_file.exists():
|
||||||
with open(args.queries) as f:
|
raise FileNotFoundError(f"Queries file not found: {query_file}")
|
||||||
queries = [l.strip() for l in f if l.strip() and not l.startswith("#")]
|
with query_file.open(encoding="utf-8") as f:
|
||||||
|
queries = [
|
||||||
|
l.strip() for l in f if l.strip() and not l.strip().startswith("#")
|
||||||
|
]
|
||||||
|
|
||||||
if args.max_queries and args.max_queries > 0:
|
if args.max_queries and args.max_queries > 0:
|
||||||
queries = queries[: args.max_queries]
|
queries = queries[: args.max_queries]
|
||||||
|
|||||||
@ -32,9 +32,11 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
||||||
@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
|||||||
print(f"GGUF files saved to: {gguf_dir}")
|
print(f"GGUF files saved to: {gguf_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
class TimedSaveCallback(TrainerCallback):
|
||||||
|
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
|
||||||
|
|
||||||
|
def __init__(self, interval_minutes: float):
|
||||||
|
self.interval_seconds = float(interval_minutes) * 60.0
|
||||||
|
self.last_save_time = time.time()
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
if not getattr(state, "is_world_process_zero", False):
|
||||||
|
return control
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
if now - self.last_save_time >= self.interval_seconds:
|
||||||
|
control.should_save = True
|
||||||
|
self.last_save_time = now
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def run_eval(model_path: str) -> float | None:
|
def run_eval(model_path: str) -> float | None:
|
||||||
"""Run eval.py on the trained model and return average score."""
|
"""Run eval.py on the trained model and return average score."""
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
|
|||||||
def cmd_sft(args):
|
def cmd_sft(args):
|
||||||
"""Run supervised fine-tuning."""
|
"""Run supervised fine-tuning."""
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
@ -276,6 +294,22 @@ def cmd_sft(args):
|
|||||||
"{time}", now.strftime("%H:%M")
|
"{time}", now.strftime("%H:%M")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
||||||
|
save_steps = cfg["training"].get("save_steps", 200)
|
||||||
|
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
||||||
|
if save_interval_minutes:
|
||||||
|
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
||||||
|
save_steps = max(save_steps, 10_000_000)
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if save_interval_minutes:
|
||||||
|
try:
|
||||||
|
interval_value = float(save_interval_minutes)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
interval_value = None
|
||||||
|
if interval_value and interval_value > 0:
|
||||||
|
callbacks.append(TimedSaveCallback(interval_value))
|
||||||
|
|
||||||
config = SFTConfig(
|
config = SFTConfig(
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
push_to_hub=push_to_hub,
|
push_to_hub=push_to_hub,
|
||||||
@ -288,10 +322,10 @@ def cmd_sft(args):
|
|||||||
max_length=cfg["training"]["max_length"],
|
max_length=cfg["training"]["max_length"],
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
save_steps=200,
|
save_steps=save_steps,
|
||||||
save_total_limit=2,
|
save_total_limit=save_total_limit,
|
||||||
eval_strategy="steps",
|
eval_strategy="steps",
|
||||||
eval_steps=200,
|
eval_steps=cfg["training"].get("eval_steps", 200),
|
||||||
warmup_ratio=cfg["training"]["warmup_ratio"],
|
warmup_ratio=cfg["training"]["warmup_ratio"],
|
||||||
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
||||||
ddp_find_unused_parameters=cfg["training"].get(
|
ddp_find_unused_parameters=cfg["training"].get(
|
||||||
@ -329,6 +363,7 @@ def cmd_sft(args):
|
|||||||
args=config,
|
args=config,
|
||||||
peft_config=peft_config,
|
peft_config=peft_config,
|
||||||
processing_class=tokenizer,
|
processing_class=tokenizer,
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Starting SFT training...")
|
print("Starting SFT training...")
|
||||||
@ -378,6 +413,7 @@ def cmd_sft(args):
|
|||||||
def cmd_grpo(args):
|
def cmd_grpo(args):
|
||||||
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import os
|
import os
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
@ -494,6 +530,7 @@ def cmd_grpo(args):
|
|||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
target_modules=cfg["lora"]["target_modules"],
|
target_modules=cfg["lora"]["target_modules"],
|
||||||
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
||||||
|
ensure_weight_tying=True,
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, grpo_lora_config)
|
model = get_peft_model(model, grpo_lora_config)
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
@ -510,6 +547,24 @@ def cmd_grpo(args):
|
|||||||
if isinstance(learning_rate, str):
|
if isinstance(learning_rate, str):
|
||||||
learning_rate = float(learning_rate)
|
learning_rate = float(learning_rate)
|
||||||
|
|
||||||
|
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
||||||
|
save_steps = cfg["training"].get("save_steps", 200)
|
||||||
|
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
||||||
|
save_strategy = cfg["training"].get("save_strategy", "epoch")
|
||||||
|
if save_interval_minutes:
|
||||||
|
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
||||||
|
save_steps = max(save_steps, 10_000_000)
|
||||||
|
save_strategy = "steps"
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if save_interval_minutes:
|
||||||
|
try:
|
||||||
|
interval_value = float(save_interval_minutes)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
interval_value = None
|
||||||
|
if interval_value and interval_value > 0:
|
||||||
|
callbacks.append(TimedSaveCallback(interval_value))
|
||||||
|
|
||||||
config = GRPOConfig(
|
config = GRPOConfig(
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
push_to_hub=push_to_hub,
|
push_to_hub=push_to_hub,
|
||||||
@ -524,7 +579,9 @@ def cmd_grpo(args):
|
|||||||
max_grad_norm=cfg["training"]["max_grad_norm"],
|
max_grad_norm=cfg["training"]["max_grad_norm"],
|
||||||
max_steps=cfg["training"].get("max_steps", -1),
|
max_steps=cfg["training"].get("max_steps", -1),
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
save_strategy="epoch",
|
save_strategy=save_strategy,
|
||||||
|
save_steps=save_steps,
|
||||||
|
save_total_limit=save_total_limit,
|
||||||
bf16=True,
|
bf16=True,
|
||||||
skip_memory_metrics=True,
|
skip_memory_metrics=True,
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
@ -539,11 +596,18 @@ def cmd_grpo(args):
|
|||||||
args=config,
|
args=config,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
reward_funcs=[QMDRewardFunction()],
|
reward_funcs=[QMDRewardFunction()],
|
||||||
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Starting GRPO training...")
|
print("Starting GRPO training...")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
is_main = os.environ.get("RANK", "0") == "0"
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
if not is_main:
|
||||||
|
return
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
print("Pushing to Hub...")
|
print("Pushing to Hub...")
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user