qmd/finetune/train.py

291 lines
9.6 KiB
Python

# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.45.0",
# "accelerate>=0.24.0",
# "huggingface_hub>=0.20.0",
# "trackio",
# "datasets",
# "bitsandbytes",
# "pyyaml",
# ]
# ///
"""
Unified training script for QMD query expansion models.
Supports two stages:
sft - Supervised fine-tuning on labeled examples
grpo - Group Relative Policy Optimization (RL) on top of merged SFT weights
Usage:
uv run train.py sft --config configs/sft.yaml
uv run train.py grpo --config configs/grpo.yaml
uv run train.py grpo --config configs/grpo.yaml --dry-run
"""
import argparse
import os
import sys
import yaml
def cmd_sft(args):
"""Run supervised fine-tuning."""
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
with open(args.config) as f:
cfg = yaml.safe_load(f)
if args.dry_run:
print("SFT Training Configuration:")
print(yaml.dump(cfg, default_flow_style=False))
return
dataset_name = cfg["dataset"]["name"]
print(f"Loading dataset: {dataset_name}...")
# Support local JSONL files
if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
from pathlib import Path
data_path = Path(dataset_name)
if data_path.is_dir():
train_file = data_path / "train.jsonl"
dataset = load_dataset("json", data_files=str(train_file), split="train")
else:
dataset = load_dataset("json", data_files=dataset_name, split="train")
else:
dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
print(f"Dataset loaded: {len(dataset)} examples")
split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# Check if output looks like a HF Hub path (contains /)
output_name = cfg["model"]["output"]
push_to_hub = "/" in output_name
config = SFTConfig(
output_dir=output_name.split("/")[-1] if push_to_hub else output_name,
push_to_hub=push_to_hub,
hub_model_id=output_name if push_to_hub else None,
hub_strategy="every_save" if push_to_hub else "end",
num_train_epochs=cfg["training"]["epochs"],
per_device_train_batch_size=cfg["training"]["batch_size"],
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
learning_rate=cfg["training"]["learning_rate"],
max_length=cfg["training"]["max_length"],
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
eval_strategy="steps",
eval_steps=200,
warmup_ratio=cfg["training"]["warmup_ratio"],
lr_scheduler_type=cfg["training"]["lr_scheduler"],
report_to="none", # Disable tracking for local training
)
peft_config = LoraConfig(
r=cfg["lora"]["rank"],
lora_alpha=cfg["lora"]["alpha"],
lora_dropout=cfg["lora"]["dropout"],
bias="none",
task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"],
)
print("Initializing SFT trainer...")
trainer = SFTTrainer(
model=cfg["model"]["base"],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
)
print("Starting SFT training...")
trainer.train()
if push_to_hub:
print("Pushing to Hub...")
trainer.push_to_hub()
print(f"Done! Model: https://huggingface.co/{output_name}")
else:
trainer.save_model()
print(f"Done! Model saved to: {output_name}")
def cmd_grpo(args):
"""Run GRPO reinforcement learning on top of merged SFT weights."""
import torch
import trackio
from datasets import load_dataset
from huggingface_hub import login
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
# Import reward from the shared module
sys.path.insert(0, os.path.dirname(__file__))
from reward import QMDRewardFunction, score_expansion, extract_named_entities
with open(args.config) as f:
cfg = yaml.safe_load(f)
if args.dry_run:
print("GRPO Training Configuration:")
print(yaml.dump(cfg, default_flow_style=False))
print("\nTesting reward function...")
tests = [
("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
("auth", "auth is important for security"),
("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
]
for query, expansion in tests:
score = score_expansion(query, expansion)
print(f" '{query}' -> {score:.2f}")
return
# Login
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
print("Logging in to HuggingFace Hub...")
login(token=hf_token)
# Load tokenizer
base_model_name = cfg["model"]["base"]
print(f"Loading tokenizer from {base_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load and format dataset
print("Loading dataset...")
dataset = load_dataset(cfg["dataset"]["name"], split="train")
def extract_prompt(example):
content = example[cfg["dataset"]["prompt_field"]][0]["content"]
messages = [{"role": "user", "content": content}]
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return {"prompt": formatted}
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
max_samples = cfg["dataset"].get("max_samples", len(dataset))
dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
print(f"Using {len(dataset)} prompts for GRPO")
# Load base model, merge SFT adapter
sft_model_name = cfg["model"]["sft"]
print(f"Loading SFT model from {sft_model_name}...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(base_model, sft_model_name)
model = model.merge_and_unload()
print("SFT adapter merged.")
# Add fresh LoRA for GRPO
grpo_lora_config = LoraConfig(
r=cfg["lora"]["rank"],
lora_alpha=cfg["lora"]["alpha"],
lora_dropout=cfg["lora"]["dropout"],
bias="none",
task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"],
)
model = get_peft_model(model, grpo_lora_config)
model.print_trainable_parameters()
# Build GRPO config, including beta and temperature
grpo_cfg = cfg.get("grpo", {})
config = GRPOConfig(
output_dir=cfg["model"]["output"].split("/")[-1],
push_to_hub=True,
hub_model_id=cfg["model"]["output"],
num_generations=grpo_cfg.get("num_generations", 4),
max_completion_length=grpo_cfg.get("max_completion_length", 200),
beta=grpo_cfg.get("beta", 0.04),
num_train_epochs=cfg["training"]["epochs"],
per_device_train_batch_size=cfg["training"]["batch_size"],
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
learning_rate=cfg["training"]["learning_rate"],
max_grad_norm=cfg["training"]["max_grad_norm"],
max_steps=cfg["training"].get("max_steps", -1),
logging_steps=10,
save_strategy="epoch",
report_to="trackio",
project=cfg["tracking"]["project"],
run_name=cfg["tracking"]["run_name"],
)
# Train
print("Initializing GRPO trainer...")
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=config,
train_dataset=dataset,
reward_funcs=[QMDRewardFunction()],
)
print("Starting GRPO training...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
trackio.finish()
print(f"Done! Model: https://huggingface.co/{cfg['model']['output']}")
def main():
parser = argparse.ArgumentParser(
description="QMD Query Expansion Training",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
uv run train.py sft --config configs/sft.yaml
uv run train.py grpo --config configs/grpo.yaml
uv run train.py grpo --config configs/grpo.yaml --dry-run
""",
)
sub = parser.add_subparsers(dest="stage", required=True)
sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
sft_parser.add_argument("--dry-run", action="store_true", help="Print config and exit")
grpo_parser = sub.add_parser("grpo", help="GRPO reinforcement learning")
grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
grpo_parser.add_argument("--dry-run", action="store_true", help="Print config, test reward, and exit")
args = parser.parse_args()
if args.stage == "sft":
cmd_sft(args)
elif args.stage == "grpo":
cmd_grpo(args)
if __name__ == "__main__":
main()