144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "trl>=0.12.0",
|
|
# "peft>=0.7.0",
|
|
# "transformers>=4.45.0",
|
|
# "accelerate>=0.24.0",
|
|
# "huggingface_hub>=0.20.0",
|
|
# "datasets",
|
|
# "bitsandbytes",
|
|
# "torch",
|
|
# ]
|
|
# ///
|
|
"""
|
|
GRPO training for QMD query expansion (Qwen3-1.7B).
|
|
|
|
Experimental recipe run on top of merged SFT weights. Self-contained runner:
|
|
uv run experiments/grpo/grpo.py
|
|
|
|
(If using HF Jobs, run this script as the job entrypoint.)
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from huggingface_hub import login
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
# Download eval_common.py if running as a standalone script (e.g. HF Jobs)
|
|
_eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
|
|
if not os.path.exists(_eval_common_path):
|
|
import urllib.request
|
|
_url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
|
|
_opener = urllib.request.build_opener()
|
|
_token = os.environ.get("HF_TOKEN", "")
|
|
if _token:
|
|
_opener.addheaders = [("Authorization", f"Bearer {_token}")]
|
|
with open(_eval_common_path, "wb") as _f:
|
|
_f.write(_opener.open(_url).read())
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from eval_common import QMDRewardFunction, run_eval
|
|
|
|
# --- Config (inlined from experiments/grpo/grpo.yaml) ---
|
|
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
|
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
|
OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
|
DATASET = "tobil/qmd-query-expansion-train"
|
|
|
|
|
|
def main():
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
if hf_token:
|
|
login(token=hf_token)
|
|
|
|
print(f"Loading tokenizer from {BASE_MODEL}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Load and format dataset
|
|
print(f"Loading dataset: {DATASET}...")
|
|
dataset = load_dataset(DATASET, split="train")
|
|
|
|
def extract_prompt(example):
|
|
content = example["messages"][0]["content"]
|
|
messages = [{"role": "user", "content": content}]
|
|
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
return {"prompt": formatted}
|
|
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
|
|
# Load base model, merge SFT adapter
|
|
print(f"Loading base model {BASE_MODEL}...")
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
|
|
)
|
|
print(f"Merging SFT adapter {SFT_MODEL}...")
|
|
model = PeftModel.from_pretrained(base_model, SFT_MODEL)
|
|
model = model.merge_and_unload()
|
|
print("SFT adapter merged.")
|
|
|
|
# Fresh LoRA for GRPO (small: rank 4, q/v only)
|
|
grpo_lora = LoraConfig(
|
|
r=4, lora_alpha=8, lora_dropout=0.05,
|
|
bias="none", task_type="CAUSAL_LM",
|
|
target_modules=["q_proj", "v_proj"],
|
|
)
|
|
model = get_peft_model(model, grpo_lora)
|
|
model.print_trainable_parameters()
|
|
|
|
config = GRPOConfig(
|
|
output_dir="qmd-query-expansion-1.7B-grpo",
|
|
push_to_hub=True,
|
|
hub_model_id=OUTPUT_MODEL,
|
|
|
|
num_generations=4,
|
|
max_completion_length=200,
|
|
beta=0.04, # KL regularization — prevents drift from SFT checkpoint
|
|
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2,
|
|
gradient_accumulation_steps=8,
|
|
learning_rate=5e-7,
|
|
max_grad_norm=0.5,
|
|
max_steps=200,
|
|
|
|
logging_steps=10,
|
|
save_strategy="epoch",
|
|
bf16=True,
|
|
|
|
report_to="none",
|
|
)
|
|
|
|
print("Initializing GRPO trainer...")
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
args=config,
|
|
train_dataset=dataset,
|
|
reward_funcs=[QMDRewardFunction()],
|
|
)
|
|
|
|
print("Starting GRPO training...")
|
|
trainer.train()
|
|
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub()
|
|
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
|
|
|
|
# --- Automatic evaluation ---
|
|
print("\nStarting automatic evaluation...")
|
|
trainer.model.eval()
|
|
run_eval(trainer.model, tokenizer, "grpo")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|