- Training scripts for Qwen3-0.6B and 1.7B models - Dataset generation from s-emanuilov/query-expansion - Evaluation scripts comparing finetuned vs baseline models - GRPO RL training script (optional improvement) - Export script for GGUF conversion Results: - 0.6B finetuned: 95% format compliance (lex/vec/hyde) - Baseline: 0% format compliance - Dataset: 5,157 examples on HuggingFace Hub Models available at: - tobil/qmd-query-expansion-0.6B (recommended) - tobil/qmd-query-expansion-train (dataset) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Prepare QMD query expansion data for training."""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Prompt template matching QMD's llm.ts format (simplified for training)
|
|
PROMPT_TEMPLATE = """You are a search query optimization expert. Transform the query into retrieval-optimized outputs.
|
|
|
|
Query: {query}
|
|
|
|
Output format:
|
|
lex: {{keyword variation}}
|
|
vec: {{semantic reformulation}}
|
|
hyde: {{hypothetical document passage}}
|
|
|
|
Output:"""
|
|
|
|
|
|
def format_for_training(input_text: str, output_text: str) -> dict:
|
|
"""Format a single example for SFT training."""
|
|
prompt = PROMPT_TEMPLATE.format(query=input_text)
|
|
return {
|
|
"prompt": prompt,
|
|
"completion": output_text,
|
|
# Alternative format for some trainers
|
|
"text": f"{prompt}\n{output_text}",
|
|
# Chat format
|
|
"messages": [
|
|
{"role": "user", "content": f"Expand this search query:\n\n{input_text}"},
|
|
{"role": "assistant", "content": output_text}
|
|
]
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Prepare data for training")
|
|
parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
|
|
parser.add_argument("--output", type=str, default="data/train", help="Output directory")
|
|
parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
|
|
args = parser.parse_args()
|
|
|
|
input_path = Path(args.input)
|
|
output_dir = Path(args.output)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not input_path.exists():
|
|
print(f"Error: Input file not found: {input_path}")
|
|
exit(1)
|
|
|
|
# Load examples
|
|
examples = []
|
|
with open(input_path) as f:
|
|
for line in f:
|
|
if line.strip():
|
|
examples.append(json.loads(line))
|
|
|
|
print(f"Loaded {len(examples)} examples from {input_path}")
|
|
|
|
# Format for training
|
|
formatted = [format_for_training(ex["input"], ex["output"]) for ex in examples]
|
|
|
|
# Split into train/val
|
|
split_idx = int(len(formatted) * (1 - args.split))
|
|
train_data = formatted[:split_idx]
|
|
val_data = formatted[split_idx:]
|
|
|
|
# Write train set
|
|
train_path = output_dir / "train.jsonl"
|
|
with open(train_path, "w") as f:
|
|
for item in train_data:
|
|
f.write(json.dumps(item) + "\n")
|
|
|
|
# Write validation set
|
|
val_path = output_dir / "val.jsonl"
|
|
with open(val_path, "w") as f:
|
|
for item in val_data:
|
|
f.write(json.dumps(item) + "\n")
|
|
|
|
# Write chat format (for TRL/Unsloth)
|
|
chat_path = output_dir / "train_chat.jsonl"
|
|
with open(chat_path, "w") as f:
|
|
for item in train_data:
|
|
f.write(json.dumps({"messages": item["messages"]}) + "\n")
|
|
|
|
print(f"Written {len(train_data)} train examples to {train_path}")
|
|
print(f"Written {len(val_data)} validation examples to {val_path}")
|
|
print(f"Written chat format to {chat_path}")
|
|
|
|
# Also save as HuggingFace datasets format info
|
|
dataset_info = {
|
|
"dataset_name": "qmd-query-expansion",
|
|
"train_samples": len(train_data),
|
|
"val_samples": len(val_data),
|
|
"columns": ["prompt", "completion", "text", "messages"],
|
|
}
|
|
with open(output_dir / "dataset_info.json", "w") as f:
|
|
json.dump(dataset_info, f, indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|