291 lines
9.6 KiB
Python
291 lines
9.6 KiB
Python
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "torch",
|
|
# "trl>=0.12.0",
|
|
# "peft>=0.7.0",
|
|
# "transformers>=4.45.0",
|
|
# "accelerate>=0.24.0",
|
|
# "huggingface_hub>=0.20.0",
|
|
# "trackio",
|
|
# "datasets",
|
|
# "bitsandbytes",
|
|
# "pyyaml",
|
|
# ]
|
|
# ///
|
|
"""
|
|
Unified training script for QMD query expansion models.
|
|
|
|
Supports two stages:
|
|
sft - Supervised fine-tuning on labeled examples
|
|
grpo - Group Relative Policy Optimization (RL) on top of merged SFT weights
|
|
|
|
Usage:
|
|
uv run train.py sft --config configs/sft.yaml
|
|
uv run train.py grpo --config configs/grpo.yaml
|
|
uv run train.py grpo --config configs/grpo.yaml --dry-run
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import yaml
|
|
|
|
|
|
def cmd_sft(args):
|
|
"""Run supervised fine-tuning."""
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig
|
|
from trl import SFTTrainer, SFTConfig
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
if args.dry_run:
|
|
print("SFT Training Configuration:")
|
|
print(yaml.dump(cfg, default_flow_style=False))
|
|
return
|
|
|
|
dataset_name = cfg["dataset"]["name"]
|
|
print(f"Loading dataset: {dataset_name}...")
|
|
|
|
# Support local JSONL files
|
|
if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
|
|
from pathlib import Path
|
|
data_path = Path(dataset_name)
|
|
if data_path.is_dir():
|
|
train_file = data_path / "train.jsonl"
|
|
dataset = load_dataset("json", data_files=str(train_file), split="train")
|
|
else:
|
|
dataset = load_dataset("json", data_files=dataset_name, split="train")
|
|
else:
|
|
dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
|
|
print(f"Dataset loaded: {len(dataset)} examples")
|
|
|
|
split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
|
|
train_dataset = split["train"]
|
|
eval_dataset = split["test"]
|
|
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
|
|
|
# Check if output looks like a HF Hub path (contains /)
|
|
output_name = cfg["model"]["output"]
|
|
push_to_hub = "/" in output_name
|
|
|
|
config = SFTConfig(
|
|
output_dir=output_name.split("/")[-1] if push_to_hub else output_name,
|
|
push_to_hub=push_to_hub,
|
|
hub_model_id=output_name if push_to_hub else None,
|
|
hub_strategy="every_save" if push_to_hub else "end",
|
|
|
|
num_train_epochs=cfg["training"]["epochs"],
|
|
per_device_train_batch_size=cfg["training"]["batch_size"],
|
|
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
|
|
learning_rate=cfg["training"]["learning_rate"],
|
|
max_length=cfg["training"]["max_length"],
|
|
|
|
logging_steps=10,
|
|
save_strategy="steps",
|
|
save_steps=200,
|
|
save_total_limit=2,
|
|
eval_strategy="steps",
|
|
eval_steps=200,
|
|
|
|
warmup_ratio=cfg["training"]["warmup_ratio"],
|
|
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
|
|
|
report_to="none", # Disable tracking for local training
|
|
)
|
|
|
|
peft_config = LoraConfig(
|
|
r=cfg["lora"]["rank"],
|
|
lora_alpha=cfg["lora"]["alpha"],
|
|
lora_dropout=cfg["lora"]["dropout"],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
)
|
|
|
|
print("Initializing SFT trainer...")
|
|
trainer = SFTTrainer(
|
|
model=cfg["model"]["base"],
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
args=config,
|
|
peft_config=peft_config,
|
|
)
|
|
|
|
print("Starting SFT training...")
|
|
trainer.train()
|
|
|
|
if push_to_hub:
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub()
|
|
print(f"Done! Model: https://huggingface.co/{output_name}")
|
|
else:
|
|
trainer.save_model()
|
|
print(f"Done! Model saved to: {output_name}")
|
|
|
|
|
|
def cmd_grpo(args):
|
|
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
|
import torch
|
|
import trackio
|
|
from datasets import load_dataset
|
|
from huggingface_hub import login
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
# Import reward from the shared module
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from reward import QMDRewardFunction, score_expansion, extract_named_entities
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
if args.dry_run:
|
|
print("GRPO Training Configuration:")
|
|
print(yaml.dump(cfg, default_flow_style=False))
|
|
print("\nTesting reward function...")
|
|
|
|
tests = [
|
|
("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
|
|
("auth", "auth is important for security"),
|
|
("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
|
|
("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
|
|
]
|
|
for query, expansion in tests:
|
|
score = score_expansion(query, expansion)
|
|
print(f" '{query}' -> {score:.2f}")
|
|
return
|
|
|
|
# Login
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
if hf_token:
|
|
print("Logging in to HuggingFace Hub...")
|
|
login(token=hf_token)
|
|
|
|
# Load tokenizer
|
|
base_model_name = cfg["model"]["base"]
|
|
print(f"Loading tokenizer from {base_model_name}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Load and format dataset
|
|
print("Loading dataset...")
|
|
dataset = load_dataset(cfg["dataset"]["name"], split="train")
|
|
|
|
def extract_prompt(example):
|
|
content = example[cfg["dataset"]["prompt_field"]][0]["content"]
|
|
messages = [{"role": "user", "content": content}]
|
|
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
return {"prompt": formatted}
|
|
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
max_samples = cfg["dataset"].get("max_samples", len(dataset))
|
|
dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
|
|
# Load base model, merge SFT adapter
|
|
sft_model_name = cfg["model"]["sft"]
|
|
print(f"Loading SFT model from {sft_model_name}...")
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_name,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
model = PeftModel.from_pretrained(base_model, sft_model_name)
|
|
model = model.merge_and_unload()
|
|
print("SFT adapter merged.")
|
|
|
|
# Add fresh LoRA for GRPO
|
|
grpo_lora_config = LoraConfig(
|
|
r=cfg["lora"]["rank"],
|
|
lora_alpha=cfg["lora"]["alpha"],
|
|
lora_dropout=cfg["lora"]["dropout"],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
)
|
|
model = get_peft_model(model, grpo_lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
# Build GRPO config, including beta and temperature
|
|
grpo_cfg = cfg.get("grpo", {})
|
|
config = GRPOConfig(
|
|
output_dir=cfg["model"]["output"].split("/")[-1],
|
|
push_to_hub=True,
|
|
hub_model_id=cfg["model"]["output"],
|
|
|
|
num_generations=grpo_cfg.get("num_generations", 4),
|
|
max_completion_length=grpo_cfg.get("max_completion_length", 200),
|
|
beta=grpo_cfg.get("beta", 0.04),
|
|
|
|
num_train_epochs=cfg["training"]["epochs"],
|
|
per_device_train_batch_size=cfg["training"]["batch_size"],
|
|
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
|
|
learning_rate=cfg["training"]["learning_rate"],
|
|
max_grad_norm=cfg["training"]["max_grad_norm"],
|
|
max_steps=cfg["training"].get("max_steps", -1),
|
|
|
|
logging_steps=10,
|
|
save_strategy="epoch",
|
|
|
|
report_to="trackio",
|
|
project=cfg["tracking"]["project"],
|
|
run_name=cfg["tracking"]["run_name"],
|
|
)
|
|
|
|
# Train
|
|
print("Initializing GRPO trainer...")
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
args=config,
|
|
train_dataset=dataset,
|
|
reward_funcs=[QMDRewardFunction()],
|
|
)
|
|
|
|
print("Starting GRPO training...")
|
|
trainer.train()
|
|
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub()
|
|
trackio.finish()
|
|
print(f"Done! Model: https://huggingface.co/{cfg['model']['output']}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="QMD Query Expansion Training",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
uv run train.py sft --config configs/sft.yaml
|
|
uv run train.py grpo --config configs/grpo.yaml
|
|
uv run train.py grpo --config configs/grpo.yaml --dry-run
|
|
""",
|
|
)
|
|
sub = parser.add_subparsers(dest="stage", required=True)
|
|
|
|
sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
|
|
sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
|
|
sft_parser.add_argument("--dry-run", action="store_true", help="Print config and exit")
|
|
|
|
grpo_parser = sub.add_parser("grpo", help="GRPO reinforcement learning")
|
|
grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
|
|
grpo_parser.add_argument("--dry-run", action="store_true", help="Print config, test reward, and exit")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.stage == "sft":
|
|
cmd_sft(args)
|
|
elif args.stage == "grpo":
|
|
cmd_grpo(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|