54 lines
1.4 KiB
YAML
54 lines
1.4 KiB
YAML
# GRPO Training Config for QMD Query Expansion
|
|
# Target: Qwen3-1.7B, trained on top of merged SFT weights
|
|
#
|
|
# Usage: uv run train.py grpo --config configs/grpo.yaml
|
|
#
|
|
# The reward function (reward.py) scores expansions on format compliance,
|
|
# diversity, hyde quality, content quality, and named entity preservation.
|
|
# beta > 0 is critical to prevent drift from the SFT checkpoint.
|
|
|
|
model:
|
|
base: "Qwen/Qwen3-1.7B"
|
|
sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
|
|
output: "outputs/grpo" # Local training output (push to HF manually after eval)
|
|
push_to_hub: false
|
|
torch_dtype: "bfloat16"
|
|
load_in_4bit: false
|
|
load_in_8bit: false
|
|
|
|
dataset:
|
|
# Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
|
|
# HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
|
|
name: "data/train/"
|
|
prompt_field: "messages"
|
|
max_samples: 1000
|
|
|
|
training:
|
|
epochs: 1
|
|
batch_size: 2
|
|
gradient_accumulation_steps: 8
|
|
learning_rate: 0.0000005
|
|
max_grad_norm: 0.5
|
|
max_steps: 200
|
|
# Save checkpoints every 30 minutes
|
|
save_interval_minutes: 30
|
|
# Fallback time-step save cadence if needed (not used for wall-clock mode)
|
|
save_steps: 50
|
|
|
|
grpo:
|
|
num_generations: 4
|
|
max_completion_length: 200
|
|
beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
|
|
|
|
lora:
|
|
rank: 4
|
|
alpha: 8
|
|
dropout: 0.05
|
|
target_modules:
|
|
- "q_proj"
|
|
- "v_proj"
|
|
|
|
tracking:
|
|
project: "qmd-query-expansion"
|
|
run_name: "grpo-1.7B"
|