qmd/finetune/prepare_data.py
Tobi Lutke 7cca164dd9
Add query expansion model finetuning infrastructure
- Training scripts for Qwen3-0.6B and 1.7B models
- Dataset generation from s-emanuilov/query-expansion
- Evaluation scripts comparing finetuned vs baseline models
- GRPO RL training script (optional improvement)
- Export script for GGUF conversion

Results:
- 0.6B finetuned: 95% format compliance (lex/vec/hyde)
- Baseline: 0% format compliance
- Dataset: 5,157 examples on HuggingFace Hub

Models available at:
- tobil/qmd-query-expansion-0.6B (recommended)
- tobil/qmd-query-expansion-train (dataset)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 19:47:06 -05:00

104 lines
3.3 KiB
Python

#!/usr/bin/env python3
"""Prepare QMD query expansion data for training."""
import argparse
import json
from pathlib import Path
# Prompt template matching QMD's llm.ts format (simplified for training)
PROMPT_TEMPLATE = """You are a search query optimization expert. Transform the query into retrieval-optimized outputs.
Query: {query}
Output format:
lex: {{keyword variation}}
vec: {{semantic reformulation}}
hyde: {{hypothetical document passage}}
Output:"""
def format_for_training(input_text: str, output_text: str) -> dict:
"""Format a single example for SFT training."""
prompt = PROMPT_TEMPLATE.format(query=input_text)
return {
"prompt": prompt,
"completion": output_text,
# Alternative format for some trainers
"text": f"{prompt}\n{output_text}",
# Chat format
"messages": [
{"role": "user", "content": f"Expand this search query:\n\n{input_text}"},
{"role": "assistant", "content": output_text}
]
}
def main():
parser = argparse.ArgumentParser(description="Prepare data for training")
parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
parser.add_argument("--output", type=str, default="data/train", help="Output directory")
parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
args = parser.parse_args()
input_path = Path(args.input)
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
if not input_path.exists():
print(f"Error: Input file not found: {input_path}")
exit(1)
# Load examples
examples = []
with open(input_path) as f:
for line in f:
if line.strip():
examples.append(json.loads(line))
print(f"Loaded {len(examples)} examples from {input_path}")
# Format for training
formatted = [format_for_training(ex["input"], ex["output"]) for ex in examples]
# Split into train/val
split_idx = int(len(formatted) * (1 - args.split))
train_data = formatted[:split_idx]
val_data = formatted[split_idx:]
# Write train set
train_path = output_dir / "train.jsonl"
with open(train_path, "w") as f:
for item in train_data:
f.write(json.dumps(item) + "\n")
# Write validation set
val_path = output_dir / "val.jsonl"
with open(val_path, "w") as f:
for item in val_data:
f.write(json.dumps(item) + "\n")
# Write chat format (for TRL/Unsloth)
chat_path = output_dir / "train_chat.jsonl"
with open(chat_path, "w") as f:
for item in train_data:
f.write(json.dumps({"messages": item["messages"]}) + "\n")
print(f"Written {len(train_data)} train examples to {train_path}")
print(f"Written {len(val_data)} validation examples to {val_path}")
print(f"Written chat format to {chat_path}")
# Also save as HuggingFace datasets format info
dataset_info = {
"dataset_name": "qmd-query-expansion",
"train_samples": len(train_data),
"val_samples": len(val_data),
"columns": ["prompt", "completion", "text", "messages"],
}
with open(output_dir / "dataset_info.json", "w") as f:
json.dump(dataset_info, f, indent=2)
if __name__ == "__main__":
main()