finetune: strict Pydantic schema, one canonical data format

Replace ad-hoc JSON parsing with a strict Pydantic model
(TrainingExample with typed OutputPair). All data loading goes
through load_examples() which fails loudly on invalid data.

- Convert v3_structured.jsonl from "searches" to "output" format
- Rewrite all consumer scripts (prepare, validate, score, analyze)
  to load through the Pydantic schema
- Prepared train/val files are ephemeral build artifacts
- Restore LFM2 and GEPA experiments under experiments/
- Add pydantic>=2.0 to dependencies
This commit is contained in:
Tobi Lutke 2026-02-22 13:38:55 -04:00
parent 3950055708
commit 1d7d167b29
No known key found for this signature in database
23 changed files with 5005 additions and 5026 deletions

View File

@ -18,6 +18,22 @@ vec: another semantic variation
- `lex:` lines for BM25 keyword search (1-3 lines, short keywords)
- `vec:` lines for vector similarity search (1-3 lines, natural language)
## Training Data Format
**There is exactly one JSONL format.** Every file in `data/*.jsonl` must match the strict Pydantic schema in `dataset/schema.py`:
```json
{"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
```
- `query`: non-empty string
- `output`: list of `[type, text]` pairs where type is `"lex"`, `"vec"`, or `"hyde"`
- Extra metadata fields (`category`, `intent`, `is_short`) are allowed but ignored
The schema is enforced by `dataset/schema.py:TrainingExample` (Pydantic model). All data loading goes through `load_examples()` which fails loudly on invalid data. No format alternatives, no legacy fallbacks.
**All `.jsonl` files in `data/` are concatenated and deduplicated for training runs.** The prepared train/val files in `data/train/` are ephemeral build artifacts.
## HuggingFace Repositories
| Repository | Purpose |
@ -33,72 +49,34 @@ vec: another semantic variation
- Only push when eval scores improve over current deployed model
- Always include eval results in model card when pushing
## Training Data
All JSONL files in `data/` are training data:
```
data/
├── qmd_expansion_v2.jsonl
├── qmd_expansion_handcrafted_only.jsonl
├── qmd_only_sampled.jsonl
├── qmd_only_variants.jsonl
└── ... any additional .jsonl files
```
**All `.jsonl` files in `data/` should be concatenated for training runs.**
Each JSONL line: `{"input": "query", "output": "hyde:...\nlex:...\nvec:..."}`
## Data Generation Tools
## Dataset Tools
| Script | Purpose |
|--------|---------|
| `dataset/generate_data.py` | Generate via Claude API (high quality) |
| `dataset/generate_data_offline.py` | Transform from HuggingFace datasets |
| `dataset/prepare_data.py` | Format for Qwen3 chat template |
| `dataset/clean_data.py` | Detect and fix technical term issues |
| `generate_only_variants.py` | Generate `/only:lex` and `/only:vec` variants |
## Local Training Output
All training outputs go to `outputs/` (gitignored):
```
outputs/
├── sft/ # SFT checkpoint
└── grpo/ # GRPO checkpoint
```
| `dataset/schema.py` | Pydantic `TrainingExample` model + `load_examples()` |
| `dataset/prepare_data.py` | Load via schema, apply Qwen3 chat template, dedup, split |
| `dataset/validate_schema.py` | Validate all JSONL files against schema |
| `dataset/score_data.py` | Score all examples using reward.py |
| `dataset/analyze_data.py` | Analyze distribution and quality |
## Training Pipeline
Always use **Qwen3-1.7B** as the base model unless explicitly stated otherwise.
Training can run **locally** (requires CUDA GPU) or via **HuggingFace Jobs** (cloud GPU, no local hardware needed).
### Stage 0: Prepare Data
Raw data in `data/*.jsonl` must be converted to Qwen3 chat format before training:
```bash
# Process all JSONL files in data/
uv run dataset/prepare_data.py
# Creates: data/train/train.jsonl, data/train/val.jsonl
# Or process a specific file
uv run dataset/prepare_data.py --input data/qmd_expansion_v2.jsonl
# Creates: data/train/train.jsonl, data/train/val.jsonl (ephemeral)
```
This applies the Qwen3 chat template, deduplicates, and splits into train/val sets.
### Stage 1: SFT
```bash
# Local (requires CUDA)
uv run train.py sft --config configs/sft.yaml
# Output: outputs/sft/
# Cloud (HuggingFace Jobs - no local GPU needed)
# Cloud (HuggingFace Jobs)
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft.py
```
@ -107,16 +85,13 @@ hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft.py
```bash
# Local (requires CUDA)
uv run train.py grpo --config configs/grpo.yaml
# Output: outputs/grpo/
# Cloud (HuggingFace Jobs - no local GPU needed)
# Cloud (HuggingFace Jobs)
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
```
### HuggingFace Jobs
If no local CUDA device is available, use `hf jobs` to run training in the cloud:
```bash
hf jobs ps # List running jobs
hf jobs logs <job-id> # Stream logs
@ -124,18 +99,11 @@ hf jobs inspect <job-id> # Check status
hf jobs cancel <job-id> # Cancel a job
```
The `jobs/` directory contains self-contained scripts that include all dependencies inline.
### Evaluation
```bash
# Eval local model
uv run eval.py --model ./outputs/grpo
# Eval HuggingFace model
uv run eval.py --model tobil/qmd-query-expansion-1.7B
# Save eval results to file
uv run eval.py --model ./outputs/grpo -o eval_results.json
```
@ -144,27 +112,26 @@ uv run eval.py --model ./outputs/grpo -o eval_results.json
`reward.py` is the single source of truth for scoring:
```bash
# Self-test the reward function
uv run reward.py
uv run reward.py # Self-test
```
See `SCORING.md` for the full rubric.
## Deployment Rules
## Experiments
**Never upload without eval.** Every model push must include eval results.
Experimental training configurations live in `experiments/`:
### Checklist
```
experiments/
├── lfm2/ # LiquidAI LFM2-1.2B (hybrid architecture, faster inference)
│ ├── sft_lfm2.yaml
│ └── sft_lfm2.py
└── gepa/ # DSPy-based prompt optimization (GEPA)
├── dspy_gepa.py
└── ...
```
1. Train SFT on all `data/*.jsonl``outputs/sft/`
2. Train GRPO on top of SFT → `outputs/grpo/`
3. **Run eval on local model**: `uv run eval.py --model ./outputs/grpo -o eval_results.json`
4. Compare against current deployed model's eval
5. If eval improves:
- Push to `tobil/qmd-query-expansion-1.7B`
- **Include eval output in the model card / commit message**
6. Convert to GGUF and update `tobil/qmd-query-expansion-1.7B-gguf`
7. Update `src/llm.ts` DEFAULT_GENERATE_MODEL if repo name changed
These are not part of the main training pipeline.
## Key Files
@ -176,10 +143,12 @@ finetune/
├── convert_gguf.py # GGUF conversion
├── SCORING.md # Detailed scoring rubric
├── CLAUDE.md # This file
├── data/ # All training JSONL files
├── outputs/ # Local training outputs (gitignored)
├── dataset/ # Data generation scripts
├── Justfile # Common commands
├── data/ # All training JSONL files (strict schema)
├── dataset/ # Schema + data tools (Pydantic-based)
├── jobs/ # Self-contained HuggingFace Jobs scripts
├── configs/ # Training configs (sft.yaml, grpo.yaml)
└── evals/ # Test queries and results
├── evals/ # Test queries
├── experiments/ # Experimental configs (LFM2, GEPA)
└── outputs/ # Local training outputs (gitignored)
```

View File

@ -92,20 +92,19 @@ finetune/
│ ├── sft.py # Self-contained SFT for HuggingFace Jobs
│ ├── grpo.py # Self-contained GRPO for HuggingFace Jobs
│ ├── eval.py # Self-contained eval for HuggingFace Jobs
│ ├── eval_common.py # Shared eval utilities
│ └── quantize.py # GGUF quantization for HuggingFace Jobs
│ └── eval_common.py # Shared eval utilities
├── configs/
│ ├── sft.yaml # SFT hyperparameters for Qwen3-1.7B
│ └── grpo.yaml # GRPO hyperparameters for Qwen3-1.7B
├── evals/
│ └── queries.txt # 31 test queries across 8 categories
├── data/
│ └── qmd_expansion_v2.jsonl # Source training data (1,000 high-quality examples)
├── data/ # Training JSONL files (all concatenated for training)
├── dataset/
│ ├── generate_data.py # Generate data via Claude API
│ ├── generate_data_offline.py # Generate from existing HF dataset
│ ├── prepare_data.py # Format for Qwen3 chat template
│ └── clean_data.py # Detect technical term misinterpretations
│ ├── prepare_data.py # Format for Qwen3 chat template, dedup, split
│ ├── schema.py # Parse/normalize output format
│ ├── validate_schema.py # Validate JSONL against schema
│ ├── score_data.py # Score all examples using reward.py
│ └── analyze_data.py # Analyze distribution and quality
├── SCORING.md # Detailed scoring rubric reference
└── README.md # This file
```
@ -122,7 +121,7 @@ Teaches the model the `lex:/vec:/hyde:` output format from labeled examples.
| Method | LoRA (rank 16, alpha 32) |
| Target modules | All projection layers (q/k/v/o/gate/up/down) |
| Dataset | ~2,290 examples (train split) |
| Effective batch size | 16 (4 × 4 gradient accumulation) |
| Effective batch size | 16 (4 x 4 gradient accumulation) |
| Epochs | 5 |
| Learning rate | 2e-4 (cosine schedule) |
@ -173,9 +172,6 @@ uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft -v
# Save detailed scores to JSON
uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft -o scores.json
# Score an existing JSONL file (backwards compat with old run.py output)
uv run eval.py --score-only evals/results_old.jsonl
```
## Reward Function
@ -212,9 +208,6 @@ quantized GGUF files for deployment:
# Use preset for 1.7B
uv run convert_gguf.py --size 1.7B
# Use preset for 4B
uv run convert_gguf.py --size 4B
# Custom models
uv run convert_gguf.py --base Qwen/Qwen3-1.7B \
--sft tobil/qmd-query-expansion-1.7B-sft \
@ -235,26 +228,19 @@ ollama run qmd-expand
## Data Pipeline
The training data (1,000 examples in `data/qmd_expansion_v2.jsonl`) was generated
from two sources and cleaned for quality. To regenerate:
All JSONL files in `data/` are concatenated for training. To prepare for training:
```bash
# Generate from existing HuggingFace dataset (bulk, no API needed)
uv run dataset/generate_data_offline.py
# Generate via Claude API (higher quality, needs ANTHROPIC_API_KEY)
uv run dataset/generate_data.py --count 100
# Detect and fix technical term misinterpretations
uv run dataset/clean_data.py
# Format for Qwen3 chat template, add short-query augmentation, split train/val
# Format for Qwen3 chat template, deduplicate, split train/val
uv run dataset/prepare_data.py
# Validate data quality
just validate
```
## Architecture Notes
The two-stage training approach (SFT GRPO) is standard for structured-output models:
The two-stage training approach (SFT -> GRPO) is standard for structured-output models:
1. **SFT** establishes format compliance and basic query understanding. It uses
a large LoRA (rank 16, all projection layers) because it needs to learn a
@ -297,42 +283,3 @@ deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubri
|-------|--------------|-----------------|
| SFT | 92.0% | 30/30 |
| GRPO | 91.7% | 30/30 |
## Alternative Base Models
### LiquidAI LFM2 (Experimental)
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
is a hybrid architecture from Liquid AI optimized for on-device inference. It uses
a novel combination of convolutions and attention that achieves 2x faster decode
and prefill speed compared to standard transformers.
**Why LFM2 for query expansion:**
- **Faster inference**: Lower latency for real-time search applications
- **Memory efficient**: Smaller memory footprint than equivalent transformers
- **Edge-optimized**: Can run on mobile devices and embedded systems
- **Good at agentic tasks**: LiquidAI recommends LFM2 for RAG and data extraction
**Training with LFM2:**
```bash
# SFT with LFM2-1.2B base model
uv run train.py sft --config configs/sft_lfm2.yaml
# Evaluate the trained model
uv run eval.py --model outputs/sft-lfm2
# Convert to GGUF for deployment
uv run convert_gguf.py --base LiquidAI/LFM2-1.2B \
--sft outputs/sft-lfm2 \
--output tobil/qmd-query-expansion-lfm2-gguf
```
**Key differences from Qwen3:**
- Different LoRA target modules: `q_proj, k_proj, v_proj, out_proj, in_proj, w1, w2, w3`
- Recommended generation parameters: `temp=0.3, min_p=0.15, repetition_penalty=1.05`
- Requires transformers >= 4.55.0 for architecture support
**Pre-trained GGUF models:**
- Base: `hf:LiquidAI/LFM2-1.2B-GGUF/LFM2-1.2B-Q4_K_M.gguf` (~731 MB)
- Instruct: `hf:LiquidAI/LFM2.5-1.2B-Instruct-GGUF/LFM2.5-1.2B-Instruct-Q4_K_M.gguf` (~731 MB)

File diff suppressed because it is too large Load Diff

View File

@ -1,34 +1,35 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["pydantic>=2.0"]
# ///
"""
Dataset Analysis and Quality Report Generator
Analyzes the training data for:
Analyzes training data loaded through the strict Pydantic schema for:
1. Query length distribution
2. Category diversity
3. Named entity coverage
4. Temporal query coverage
5. Short query coverage (important for ambiguous queries)
6. Duplicate detection
7. Quality issues (long hyde, missing fields, etc.)
4. Output format coverage
5. Duplicate detection
"""
import json
import re
import argparse
import sys
from pathlib import Path
from collections import Counter, defaultdict
from dataclasses import dataclass
sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import normalize_output_items, parse_output_text
from dataset.schema import TrainingExample, OutputType, load_examples
@dataclass
class DatasetStats:
total_examples: int = 0
short_queries: int = 0 # 1-2 words
medium_queries: int = 0 # 3-5 words
long_queries: int = 0 # 6+ words
short_queries: int = 0
medium_queries: int = 0
long_queries: int = 0
has_lex: int = 0
has_vec: int = 0
has_hyde: int = 0
@ -40,288 +41,164 @@ class DatasetStats:
def categorize_query(query: str) -> str:
"""Categorize a query by type."""
query_lower = query.lower()
words = query_lower.split()
word_count = len(words)
# Short keyword queries
if word_count <= 2:
return "short_keyword"
# Named entity queries (capitalized words or tech terms)
if any(w[0].isupper() for w in words if w):
if any(w[0].isupper() for w in query.split() if w):
return "named_entity"
# Temporal/recency queries
temporal_keywords = [
"latest",
"recent",
"new",
"update",
"changelog",
"changed",
"version",
"release",
"news",
"2024",
"2025",
"latest", "recent", "new", "update", "changelog",
"changed", "version", "release", "news", "2024", "2025",
]
if any(kw in query_lower for kw in temporal_keywords):
return "temporal"
# How-to queries
if query_lower.startswith("how "):
return "how_to"
# What is queries
if query_lower.startswith("what "):
return "what_is"
# Difference/comparison queries
if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
return "comparison"
# Personal/journal style
if any(
kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]
):
if any(kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]):
return "personal"
return "other"
def extract_named_entities(query: str) -> list:
"""Extract potential named entities from query."""
entities = []
words = query.split()
for word in words:
# Skip stopwords
if word.lower() in {
"the",
"a",
"an",
"is",
"are",
"to",
"for",
"of",
"in",
"and",
"or",
}:
stopwords = {"the", "a", "an", "is", "are", "to", "for", "of", "in", "and", "or"}
for word in query.split():
if word.lower() in stopwords:
continue
# Capitalized words (potential named entities)
if word and word[0].isupper() and len(word) > 1:
entities.append(word)
# Technology terms with version numbers or special chars
if any(c in word for c in ".+-0123456789") and len(word) > 1:
entities.append(word)
return entities
def analyze_dataset(filepath: Path) -> tuple[DatasetStats, dict, dict]:
"""Analyze the dataset and return statistics."""
def analyze_examples(examples: list[TrainingExample]) -> tuple[DatasetStats, dict, dict]:
stats = DatasetStats()
categories = Counter()
seen_queries = set()
duplicate_count = 0
category_examples = defaultdict(list)
categories: Counter = Counter()
seen_queries: set[str] = set()
category_examples: dict[str, list[str]] = defaultdict(list)
with open(filepath, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
for ex in examples:
stats.total_examples += 1
try:
example = json.loads(line)
query = example.get("query", "") or example.get("input", "")
output = example.get("output", [])
if isinstance(output, str):
output = parse_output_text(output)
output = normalize_output_items(output)
query_lower = ex.query.lower()
if query_lower in seen_queries:
stats.duplicate_queries += 1
else:
seen_queries.add(query_lower)
stats.total_examples += 1
word_count = len(ex.query.split())
if word_count <= 2:
stats.short_queries += 1
elif word_count <= 5:
stats.medium_queries += 1
else:
stats.long_queries += 1
# Check for duplicates
query_lower = query.lower()
if query_lower in seen_queries:
duplicate_count += 1
else:
seen_queries.add(query_lower)
category = categorize_query(ex.query)
categories[category] += 1
category_examples[category].append(ex.query)
# Query length categorization
word_count = len(query.split())
if word_count <= 2:
stats.short_queries += 1
elif word_count <= 5:
stats.medium_queries += 1
else:
stats.long_queries += 1
if extract_named_entities(ex.query):
stats.named_entity_queries += 1
# Category detection
category = categorize_query(query)
categories[category] += 1
category_examples[category].append(query)
# Use the typed OutputPair model
types_present = {p.type for p in ex.output}
if OutputType.lex in types_present:
stats.has_lex += 1
if OutputType.vec in types_present:
stats.has_vec += 1
if OutputType.hyde in types_present:
stats.has_hyde += 1
for p in ex.output:
if p.type == OutputType.hyde and len(p.text) > 200:
stats.long_hyde_count += 1
# Named entity detection
if extract_named_entities(query):
stats.named_entity_queries += 1
# Output analysis
has_lex = any(o[0] == "lex" for o in output)
has_vec = any(o[0] == "vec" for o in output)
has_hyde = any(o[0] == "hyde" for o in output)
if has_lex:
stats.has_lex += 1
if has_vec:
stats.has_vec += 1
if has_hyde:
stats.has_hyde += 1
# Check hyde length
for kind, text in output:
if kind == "hyde" and len(text) > 200:
stats.long_hyde_count += 1
except json.JSONDecodeError:
print(f"Warning: Could not parse line {line_num}")
stats.duplicate_queries = duplicate_count
stats.temporal_queries = categories.get("temporal", 0)
stats.short_keyword_queries = categories.get("short_keyword", 0)
return stats, dict(categories), dict(category_examples)
def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
"""Print a comprehensive analysis report."""
print("=" * 70)
print("QMD TRAINING DATA ANALYSIS REPORT")
print("=" * 70)
print()
# Basic statistics
print("📊 BASIC STATISTICS")
total = stats.total_examples
print("BASIC STATISTICS")
print("-" * 40)
print(f"Total examples: {stats.total_examples:>6}")
print(f"Total examples: {total:>6}")
print(f"Duplicates found: {stats.duplicate_queries:>6}")
print()
# Query length distribution
print("📝 QUERY LENGTH DISTRIBUTION")
print("QUERY LENGTH DISTRIBUTION")
print("-" * 40)
total = stats.total_examples
print(
f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)"
)
print(
f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)"
)
print(
f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)"
)
print(f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)")
print(f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)")
print(f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)")
print()
# Category distribution
print("🏷️ CATEGORY DISTRIBUTION")
print("CATEGORY DISTRIBUTION")
print("-" * 40)
for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
pct = 100 * count / total
bar = "" * int(pct / 2)
bar = "#" * int(pct / 2)
print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
print()
# Output format coverage
print("✅ OUTPUT FORMAT COVERAGE")
print("OUTPUT FORMAT COVERAGE")
print("-" * 40)
print(
f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)"
)
print(
f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)"
)
print(
f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)"
)
print(f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)")
print(f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)")
print(f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)")
print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
print()
# Critical metrics for evals
print("🎯 EVALUATION ALIGNMENT")
print("EVALUATION ALIGNMENT")
print("-" * 40)
print(
f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)"
)
print(
f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)"
)
print(
f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)"
)
print(f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)")
print(f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)")
print(f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)")
print()
# Recommendations
print("💡 RECOMMENDATIONS")
print("RECOMMENDATIONS")
print("-" * 40)
recommendations = []
if stats.short_queries / total < 0.15:
recommendations.append(
"⚠️ Short queries below 15% - add more 1-2 word keyword queries"
)
recommendations.append("Short queries below 15% - add more 1-2 word keyword queries")
if stats.named_entity_queries / total < 0.10:
recommendations.append(
"⚠️ Named entity queries below 10% - add more capitalized tech term queries"
)
recommendations.append("Named entity queries below 10% - add more capitalized tech term queries")
if stats.temporal_queries / total < 0.05:
recommendations.append(
"⚠️ Temporal queries below 5% - add more 'latest', 'recent' queries"
)
recommendations.append("Temporal queries below 5% - add more 'latest', 'recent' queries")
if stats.long_hyde_count > 50:
recommendations.append(
f"⚠️ {stats.long_hyde_count} long hyde sections - consider truncating"
)
recommendations.append(f"{stats.long_hyde_count} long hyde sections - consider truncating")
if stats.duplicate_queries > 0:
recommendations.append(
f"⚠️ {stats.duplicate_queries} duplicate queries - consider deduplication"
)
if categories.get("short_keyword", 0) < 100:
recommendations.append(
"⚠️ Need more short keyword examples for ambiguous query training"
)
recommendations.append(f"{stats.duplicate_queries} duplicate queries - consider deduplication")
if not recommendations:
print("Dataset looks good! No major issues detected.")
print("Dataset looks good! No major issues detected.")
else:
for rec in recommendations:
print(rec)
print(f" - {rec}")
print()
print("=" * 70)
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
parser.add_argument(
"--input",
type=str,
default="data/qmd_expansion_v2.jsonl",
default="data/qmd_expansion_v3_structured.jsonl",
help="Path to training data JSONL file",
)
parser.add_argument(
@ -333,33 +210,30 @@ def main():
args = parser.parse_args()
input_path = Path(args.input)
if not input_path.exists():
# Try relative to script directory
script_dir = Path(__file__).parent.parent
input_path = script_dir / args.input
if not input_path.exists():
print(f"Error: Could not find dataset at {input_path}")
print("Please run from finetune directory or specify correct path")
return 1
print(f"Analyzing: {input_path}")
print()
stats, categories, category_examples = analyze_dataset(input_path)
examples = load_examples(input_path)
stats, categories, category_examples = analyze_examples(examples)
print_report(stats, categories, category_examples)
# Show examples if requested
if args.show_examples > 0:
print("📋 SAMPLE QUERIES BY CATEGORY")
print("SAMPLE QUERIES BY CATEGORY")
print("-" * 40)
for cat in sorted(categories.keys()):
examples = category_examples.get(cat, [])
if examples:
exs = category_examples.get(cat, [])
if exs:
print(f"\n{cat.upper()}:")
for ex in examples[: args.show_examples]:
print(f" {ex}")
for ex in exs[:args.show_examples]:
print(f" - {ex}")
print()
return 0

View File

@ -3,27 +3,29 @@
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "pydantic>=2.0",
# "jinja2",
# ]
# ///
"""Prepare QMD query expansion data for training.
See PROMPT_FORMAT.md for format specification.
Loads all data/*.jsonl via the strict Pydantic schema, applies the Qwen3
chat template, deduplicates by query, and writes train/val splits.
The prepared train files are ephemeral build artifacts the canonical
data lives in data/*.jsonl and is always loaded through the schema.
"""
import argparse
import json
import random
import sys
import os
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import (
normalize_output_items,
TrainingExample,
load_examples,
output_items_to_text,
parse_output_text,
has_type,
)
from transformers import AutoTokenizer
@ -41,30 +43,26 @@ def get_tokenizer():
return _tokenizer
def format_for_training(query_text: str, output_items: list[list[str]]) -> dict:
"""Format a single example for SFT training using Qwen chat format."""
def format_for_training(ex: TrainingExample) -> dict:
"""Format a validated TrainingExample for SFT training."""
tokenizer = get_tokenizer()
output_text = output_items_to_text(output_items)
output_text = output_items_to_text(ex.output)
# Use /no_think to disable thinking mode - we want direct output
messages = [
{
"role": "user",
"content": f"/no_think Expand this search query: {query_text}",
"content": f"/no_think Expand this search query: {ex.query}",
},
{"role": "assistant", "content": output_text},
]
# Use tokenizer to generate proper chat format with special tokens
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
# Strip empty <think> tags - we don't want thinking mode
# The template adds "<think>\n\n</think>\n\n" which we remove
# Strip empty <think> tags — /no_think should suppress them
text = text.replace("<think>\n\n</think>\n\n", "")
return {
@ -88,27 +86,22 @@ def main():
"--split", type=float, default=0.1, help="Validation split ratio"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Shuffle seed (default: 42)",
"--seed", type=int, default=42, help="Shuffle seed",
)
args = parser.parse_args()
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Support glob patterns for input
import glob
# Resolve input files
import glob as globmod
if "*" in args.input:
input_files = sorted(glob.glob(args.input))
input_files = sorted(globmod.glob(args.input))
if not input_files:
print(f"Error: No files found matching: {args.input}")
exit(1)
print(
f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}"
)
print(f"Found {len(input_files)} input files")
else:
input_path = Path(args.input)
if not input_path.exists():
@ -116,78 +109,62 @@ def main():
exit(1)
input_files = [str(input_path)]
# Load all examples from all input files
examples = []
# Load all examples through strict Pydantic schema
all_examples: list[TrainingExample] = []
for input_file in input_files:
file_count = 0
with open(input_file) as f:
for line in f:
if line.strip():
ex = json.loads(line)
examples = load_examples(input_file)
print(f" {Path(input_file).name}: {len(examples)} examples")
all_examples.extend(examples)
# Normalize legacy format
if "query" not in ex and "input" in ex:
ex["query"] = ex.pop("input")
if isinstance(ex.get("output"), str):
ex["output"] = parse_output_text(ex["output"])
ex["output"] = normalize_output_items(ex.get("output", []))
print(f"Loaded {len(all_examples)} examples total")
examples.append(ex)
file_count += 1
print(f" {Path(input_file).name}: {file_count} examples")
# Deduplicate by query (case-insensitive)
seen: set[str] = set()
deduped: list[TrainingExample] = []
for ex in all_examples:
key = ex.query.lower().strip()
if key not in seen:
seen.add(key)
deduped.append(ex)
if len(deduped) < len(all_examples):
print(f"Deduplicated: {len(all_examples)} -> {len(deduped)}")
all_examples = deduped
print(f"Loaded {len(examples)} examples total")
# Combine and shuffle
all_examples = examples
# Shuffle
random.seed(args.seed)
random.shuffle(all_examples)
# Format for training
formatted = [format_for_training(ex["query"], ex["output"]) for ex in all_examples]
# Format each example using the Pydantic model
formatted = [format_for_training(ex) for ex in all_examples]
# Split into train/val
# Split
split_idx = int(len(formatted) * (1 - args.split))
train_data = formatted[:split_idx]
val_data = formatted[split_idx:]
# Write train set
train_path = output_dir / "train.jsonl"
with open(train_path, "w") as f:
for item in train_data:
f.write(json.dumps(item) + "\n")
# Write (these are ephemeral build artifacts)
for name, data in [("train.jsonl", train_data), ("val.jsonl", val_data)]:
with open(output_dir / name, "w") as f:
for item in data:
f.write(json.dumps(item) + "\n")
# Write validation set
val_path = output_dir / "val.jsonl"
with open(val_path, "w") as f:
for item in val_data:
f.write(json.dumps(item) + "\n")
# Write chat format (for TRL)
chat_path = output_dir / "train_chat.jsonl"
with open(chat_path, "w") as f:
with open(output_dir / "train_chat.jsonl", "w") as f:
for item in train_data:
f.write(json.dumps({"messages": item["messages"]}) + "\n")
# Stats
short_final = sum(1 for ex in all_examples if len(ex["query"].split()) <= 2)
short_final = sum(1 for ex in all_examples if len(ex.query.split()) <= 2)
print(f"\n=== Summary ===")
print(f"Total examples: {len(all_examples)}")
print(
f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
)
print(f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)")
print(f"Train: {len(train_data)}, Val: {len(val_data)}")
print(f"Output: {output_dir}")
# Dataset info
dataset_info = {
"dataset_name": "qmd-query-expansion",
"train_samples": len(train_data),
"val_samples": len(val_data),
"short_query_pct": round(100 * short_final / len(all_examples), 1),
"columns": ["prompt", "completion", "text", "messages"],
}
with open(output_dir / "dataset_info.json", "w") as f:
json.dump(dataset_info, f, indent=2)

View File

@ -1,17 +1,149 @@
#!/usr/bin/env python3
"""Schema helpers for QMD training JSONL data."""
"""
Strict schema for QMD training data.
Every JSONL file in data/ MUST conform to this format:
{"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
- query: non-empty string
- output: list of [type, text] pairs where type is "lex", "vec", or "hyde"
- Extra fields (category, intent, is_short, etc.) are allowed but ignored
There is exactly ONE format. No alternatives, no legacy fallbacks.
"""
from __future__ import annotations
from typing import Iterable
import json
from enum import Enum
from pathlib import Path
from typing import Annotated, Iterable
VALID_OUTPUT_TYPES = {"hyde", "lex", "vec"}
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
field_validator,
)
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class OutputType(str, Enum):
lex = "lex"
vec = "vec"
hyde = "hyde"
VALID_OUTPUT_TYPES = {t.value for t in OutputType}
class OutputPair(BaseModel):
"""A single expansion line: [type, text]."""
type: OutputType
text: str
model_config = ConfigDict(frozen=True)
@field_validator("text")
@classmethod
def text_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("text must not be empty")
return v
def to_list(self) -> list[str]:
return [self.type.value, self.text]
def _coerce_output_pairs(v: list) -> list[OutputPair]:
"""Accept [["lex", "..."], ...] from JSON and coerce to OutputPair list."""
pairs = []
for i, item in enumerate(v):
if isinstance(item, OutputPair):
pairs.append(item)
elif isinstance(item, (list, tuple)) and len(item) == 2:
pairs.append(OutputPair(type=item[0], text=item[1]))
else:
raise ValueError(
f"output[{i}] must be [type, text], got {item!r}"
)
return pairs
# ---------------------------------------------------------------------------
# Pydantic model — single source of truth for the JSONL schema
# ---------------------------------------------------------------------------
class TrainingExample(BaseModel):
"""One training example in the canonical JSONL format."""
query: str
output: Annotated[list[OutputPair], BeforeValidator(_coerce_output_pairs)]
# Optional metadata — present in some files, ignored during training.
category: str | None = None
intent: str | None = None
is_short: bool | None = None
model_config = ConfigDict(extra="ignore")
@field_validator("query")
@classmethod
def query_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("query must not be empty")
return v
@field_validator("output")
@classmethod
def output_not_empty(cls, v: list[OutputPair]) -> list[OutputPair]:
if not v:
raise ValueError("output must not be empty")
return v
def output_as_lists(self) -> list[list[str]]:
"""Return output as list-of-lists for JSON serialization."""
return [p.to_list() for p in self.output]
# ---------------------------------------------------------------------------
# Loading
# ---------------------------------------------------------------------------
def load_examples(path: str | Path) -> list[TrainingExample]:
"""Load and validate a JSONL file. Fails loudly on any bad line."""
path = Path(path)
examples: list[TrainingExample] = []
with path.open("r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(f"{path}:{line_num}: invalid JSON: {e}") from e
try:
examples.append(TrainingExample.model_validate(obj))
except Exception as e:
raise ValueError(f"{path}:{line_num}: {e}") from e
return examples
# ---------------------------------------------------------------------------
# Helpers (used by prepare_data.py, reward.py, and other tools)
# ---------------------------------------------------------------------------
def parse_output_text(text: str) -> list[list[str]]:
"""Parse prefixed output text into list pairs.
Returns: [["hyde", "..."], ["lex", "..."], ...]
>>> parse_output_text("lex: foo\\nvec: bar")
[["lex", "foo"], ["vec", "bar"]]
"""
items: list[list[str]] = []
for raw_line in text.strip().split("\n"):
@ -35,16 +167,18 @@ def reorder_hyde_first(items: list[list[str]]) -> list[list[str]]:
return hyde_items + lex_items + vec_items
def output_items_to_text(items: Iterable[Iterable[str]], hyde_first: bool = True) -> str:
"""Render output list pairs to prefixed text lines.
Args:
items: Iterable of [type, text] pairs
hyde_first: If True, reorder to put hyde first (default)
def output_items_to_text(
items: Iterable, hyde_first: bool = True
) -> str:
"""Render output pairs to prefixed text lines.
Accepts list[OutputPair] or list[list[str]].
"""
# First normalize to list
normalized = []
for item in items:
if isinstance(item, OutputPair):
normalized.append([item.type.value, item.text.strip()])
continue
if not item:
continue
try:
@ -59,24 +193,26 @@ def output_items_to_text(items: Iterable[Iterable[str]], hyde_first: bool = True
if not text:
continue
normalized.append([kind, text])
# Apply hyde-first ordering if requested
if hyde_first:
normalized = reorder_hyde_first(normalized)
lines = [f"{kind}: {text}" for kind, text in normalized]
return "\n".join(lines)
def normalize_output_items(items: Iterable[Iterable[str]], hyde_first: bool = True) -> list[list[str]]:
"""Normalize output list pairs (filter invalid, trim whitespace, reorder).
Args:
items: Iterable of [type, text] pairs
hyde_first: If True, reorder to put hyde first (default)
def normalize_output_items(
items: Iterable, hyde_first: bool = True
) -> list[list[str]]:
"""Normalize output pairs (filter invalid, trim whitespace, reorder).
Accepts list[OutputPair] or list[list[str]].
"""
normalized: list[list[str]] = []
for item in items:
if isinstance(item, OutputPair):
normalized.append([item.type.value, item.text.strip()])
continue
if not item:
continue
try:
@ -91,13 +227,18 @@ def normalize_output_items(items: Iterable[Iterable[str]], hyde_first: bool = Tr
if not text:
continue
normalized.append([kind, text])
# Apply hyde-first ordering if requested
if hyde_first:
normalized = reorder_hyde_first(normalized)
return normalized
def has_type(items: Iterable[Iterable[str]], kind: str) -> bool:
return any(item and item[0] == kind for item in items)
def has_type(items: Iterable, kind: str) -> bool:
for item in items:
if isinstance(item, OutputPair):
if item.type.value == kind:
return True
elif item and item[0] == kind:
return True
return False

View File

@ -1,20 +1,19 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["pydantic>=2.0"]
# ///
"""Score JSONL datasets with the reward function."""
from __future__ import annotations
import argparse
import json
import statistics
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import (
normalize_output_items,
output_items_to_text,
parse_output_text,
)
from dataset.schema import load_examples, output_items_to_text
from reward import score_expansion_detailed
@ -24,42 +23,24 @@ def score_file(path: Path) -> tuple[int, int, list[float], dict]:
scores: list[float] = []
ratings: dict[str, int] = {}
with path.open("r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
total += 1
try:
obj = json.loads(line)
except json.JSONDecodeError:
errors += 1
continue
try:
examples = load_examples(path)
except ValueError as e:
print(f" Error loading {path}: {e}")
return 0, 1, [], {}
query = obj.get("query") or obj.get("input")
output = obj.get("output")
if not isinstance(query, str) or not query.strip():
errors += 1
continue
if output is None:
errors += 1
continue
for ex in examples:
total += 1
output_text = output_items_to_text(ex.output)
if not output_text:
errors += 1
continue
if isinstance(output, str):
output_items = normalize_output_items(parse_output_text(output))
else:
output_items = normalize_output_items(output)
output_text = output_items_to_text(output_items)
if not output_text:
errors += 1
continue
detail = score_expansion_detailed(query, output_text)
score = detail["percentage"]
scores.append(score)
rating = detail["rating"]
ratings[rating] = ratings.get(rating, 0) + 1
detail = score_expansion_detailed(ex.query, output_text)
score = detail["percentage"]
scores.append(score)
rating = detail["rating"]
ratings[rating] = ratings.get(rating, 0) + 1
return total, errors, scores, ratings

View File

@ -1,5 +1,9 @@
#!/usr/bin/env python3
"""Validate JSONL files against the QMD training schema."""
# /// script
# requires-python = ">=3.10"
# dependencies = ["pydantic>=2.0"]
# ///
"""Validate JSONL files against the strict QMD training schema."""
from __future__ import annotations
@ -9,7 +13,7 @@ import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import VALID_OUTPUT_TYPES
from dataset.schema import TrainingExample
def validate_file(path: Path) -> tuple[int, int]:
@ -29,31 +33,11 @@ def validate_file(path: Path) -> tuple[int, int]:
errors += 1
continue
query = obj.get("query")
output = obj.get("output")
if not isinstance(query, str) or not query.strip():
print(f"{path}:{line_num}: missing/invalid query")
try:
TrainingExample.model_validate(obj)
except Exception as e:
print(f"{path}:{line_num}: {e}")
errors += 1
continue
if not isinstance(output, list):
print(f"{path}:{line_num}: output must be a list")
errors += 1
continue
for idx, item in enumerate(output):
if not isinstance(item, list) or len(item) != 2:
print(f"{path}:{line_num}: output[{idx}] must be [type, text]")
errors += 1
continue
kind, text = item
if kind not in VALID_OUTPUT_TYPES:
print(f"{path}:{line_num}: invalid output type '{kind}'")
errors += 1
if not isinstance(text, str) or not text.strip():
print(f"{path}:{line_num}: empty output text")
errors += 1
return total, errors

View File

@ -0,0 +1 @@
"""GEPA helpers."""

View File

@ -0,0 +1,31 @@
You are an assistant that expands a given search query into lexical (lex), vector (vec), and HYDE expansions for improved search retrieval.
## Input Format
You will receive input in this exact format:
```
## Inputs
### query
[the search query]
```
## Output Format
Respond ONLY with this exact format, nothing else:
```
## Generated Outputs
### expansion
lex: [short keyword phrase 1]
lex: [short keyword phrase 2]
lex: [short keyword phrase 3]
vec: [medium phrasal expansion 1]
vec: [medium phrasal expansion 2]
vec: [medium phrasal expansion 3]
hyde: [concise hypothetical document snippet, SINGLE LINE, under 150 characters total]
```
## Generation Rules
- **Exactly 3 lex lines**: Short (2-5 words), keyword-like expansions. MUST include core query terms or direct synonyms/variants (e.g., for "web mail", include "webmail"). Focus on key entities, actions, or concepts.
- **Exactly 3 vec lines**: Medium-length (4-8 words) natural language phrases capturing query intent, aspects, or related searches.
- **Exactly 1 hyde line**: A single, fluent sentence acting as a hypothetical relevant document passage. Keep STRICTLY under 150 characters (aim for 100-140). Be descriptive but concise—no lists, no examples unless essential.
- Strategy: Break down the query into synonyms (lex), semantic rephrasings (vec), and a compact informative summary (hyde) to cover lexical, embedding, and dense retrieval signals.
- Match query intent precisely; expand to related high-relevance terms without hallucinating unrelated content.
```

View File

@ -0,0 +1 @@
Expand a search query into lex/vec/hyde lines.

View File

@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""Run DSPy GEPA using reward.py as the metric."""
from __future__ import annotations
import argparse
import importlib
import json
import sys
from pathlib import Path
def _import_dspy():
script_dir = Path(__file__).parent
repo_root = script_dir.parent
original_sys_path = list(sys.path)
try:
sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
return importlib.import_module("dspy")
finally:
sys.path = original_sys_path
dspy = _import_dspy()
repo_root = Path(__file__).parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
from reward import score_expansion_detailed
class ExpandSignature(dspy.Signature):
"""Expand a search query into lex/vec/hyde lines."""
query = dspy.InputField(desc="User search query")
output = dspy.OutputField(
desc=(
"JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
"Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
"Lex items are short keywords and must not echo the query. "
"Vec items are natural language search phrases. "
"Hyde is 50-200 chars, single line."
)
)
class Expander(dspy.Module):
def __init__(self):
super().__init__()
self.predict = dspy.Predict(ExpandSignature)
def forward(self, query: str):
return self.predict(query=query)
def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
expansion = output_items_to_text(_coerce_output_items(pred))
detail = score_expansion_detailed(gold.query, expansion)
score = detail["percentage"] / 100.0
feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
return dspy.Prediction(score=score, feedback=feedback)
def load_queries(path: Path) -> list[str]:
queries: list[str] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
query = obj.get("query") or obj.get("input")
if isinstance(query, str) and query.strip():
queries.append(query.strip())
return queries
def to_examples(queries: list[str]) -> list[dspy.Example]:
return [dspy.Example(query=q).with_inputs("query") for q in queries]
def _coerce_output_items(pred) -> list[list[str]]:
raw_output = getattr(pred, "output", None)
if isinstance(raw_output, (list, tuple)):
return normalize_output_items(raw_output)
raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
if not raw_text:
return []
if raw_text[0] in ("[", "{"):
try:
obj = json.loads(raw_text)
if isinstance(obj, dict) and "output" in obj:
obj = obj["output"]
if isinstance(obj, (list, tuple)):
return normalize_output_items(obj)
except Exception:
pass
return parse_output_text(raw_text)
def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
with path.open("w", encoding="utf-8") as f:
for query, output in zip(queries, outputs, strict=True):
f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
def main() -> int:
parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
parser.add_argument(
"--model",
type=str,
default="grok-4-1-fast-reasoning",
help="LM string in provider/model format (e.g., openai/gpt-4o)",
)
parser.add_argument(
"--reflection-model",
type=str,
default="grok-4-1-fast-reasoning",
help="LM string in provider/model format (e.g., openai/gpt-4o)",
)
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
parser.add_argument("--max-full-evals", type=int, default=None)
parser.add_argument("--max-metric-calls", type=int, default=None)
parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
args = parser.parse_args()
if "/" not in args.model or "/" not in args.reflection_model:
print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
return 1
if args.max_full_evals is not None and args.max_metric_calls is not None:
print("Provide only one of --max-full-evals or --max-metric-calls")
return 1
if args.max_full_evals is not None or args.max_metric_calls is not None:
args.auto = None
train_path = Path(args.input)
queries = load_queries(train_path)
if args.limit is not None:
queries = queries[: args.limit]
trainset = to_examples(queries)
valset = None
if args.valset:
val_queries = load_queries(Path(args.valset))
if args.val_limit is not None:
val_queries = val_queries[: args.val_limit]
valset = to_examples(val_queries)
lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
student = Expander()
student.set_lm(lm)
compiler = dspy.GEPA(
metric=reward_metric,
reflection_lm=reflection_lm,
auto=None if args.auto is None else args.auto,
max_full_evals=args.max_full_evals,
max_metric_calls=args.max_metric_calls,
track_stats=True,
track_best_outputs=True,
failure_score=0.0,
perfect_score=1.0,
)
optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
if args.save_prompt:
prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
print(f"Wrote {args.save_prompt}")
if args.emit:
outputs = []
for q in queries:
pred = optimized(query=q)
items = _coerce_output_items(pred)
outputs.append(items)
write_jsonl(Path(args.emit), queries, outputs)
print(f"Wrote {args.emit}")
if hasattr(optimized, "detailed_results"):
best = getattr(optimized.detailed_results, "best_outputs_valset", None)
if best:
print(f"Best outputs tracked: {len(best)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""GEPA example schema for QMD training JSONL lines."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Iterable
class SearchType(str, Enum):
LexSearch = "LexSearch"
VecSearch = "VecSearch"
HydeSearch = "HydeSearch"
SEARCH_TYPE_TO_PREFIX = {
SearchType.LexSearch: "lex",
SearchType.VecSearch: "vec",
SearchType.HydeSearch: "hyde",
}
@dataclass
class OutputItem:
"""Single expansion line with validation hints."""
kind: SearchType
text: str
# Validation hints (not strict rules).
min_chars: int = 3
max_chars: int | None = None
def __post_init__(self) -> None:
self.text = str(self.text).strip()
if not self.text:
raise ValueError("OutputItem.text must be non-empty")
if "\n" in self.text:
raise ValueError("OutputItem.text must be single-line")
if len(self.text) < self.min_chars:
raise ValueError("OutputItem.text is too short")
if self.max_chars is not None and len(self.text) > self.max_chars:
raise ValueError("OutputItem.text is too long")
def to_pair(self) -> list[str]:
return [SEARCH_TYPE_TO_PREFIX[self.kind], self.text]
@dataclass
class Example:
"""JSONL line schema for QMD training data."""
query: str
output: list[OutputItem] = field(default_factory=list)
def __post_init__(self) -> None:
self.query = str(self.query).strip()
if not self.query:
raise ValueError("Example.query must be non-empty")
if not self.output:
raise ValueError("Example.output must not be empty")
def to_json(self) -> dict:
return {
"query": self.query,
"output": [item.to_pair() for item in self.output],
}
def to_jsonl(self) -> str:
return json.dumps(self.to_json(), ensure_ascii=False)
def parse_output_items(raw_output: Iterable[Iterable[str]]) -> list[OutputItem]:
items: list[OutputItem] = []
for item in raw_output:
if not item or len(item) < 2:
continue
kind_raw, text = item[0], item[1]
kind_map = {
"lex": SearchType.LexSearch,
"vec": SearchType.VecSearch,
"hyde": SearchType.HydeSearch,
}
kind = kind_map.get(str(kind_raw).strip().lower())
if kind is None:
continue
max_chars = 200 if kind is SearchType.HydeSearch else None
items.append(OutputItem(kind=kind, text=str(text), max_chars=max_chars))
return items
def example_from_json(obj: dict) -> Example:
query = obj.get("query") or obj.get("input") or ""
output = obj.get("output") or []
if isinstance(output, str):
raise ValueError("String outputs are not supported in GEPA example schema")
items = parse_output_items(output)
return Example(query=query, output=items)
def load_jsonl(path: str | Path) -> list[Example]:
examples: list[Example] = []
with Path(path).open("r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
examples.append(example_from_json(obj))
except Exception as exc:
raise ValueError(f"Invalid line {line_num}: {exc}") from exc
return examples

View File

@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""Generate expansions using a saved GEPA prompt."""
from __future__ import annotations
import argparse
import importlib
import json
import sys
from pathlib import Path
def _import_dspy():
script_dir = Path(__file__).parent
original_sys_path = list(sys.path)
try:
sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
return importlib.import_module("dspy")
finally:
sys.path = original_sys_path
dspy = _import_dspy()
repo_root = Path(__file__).parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from dataset.schema import parse_output_text
def load_topics(path: Path) -> list[str]:
topics: list[str] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# Allow JSONL {"topic": "..."} or plain lines.
if line.startswith("{") and line.endswith("}"):
try:
obj = json.loads(line)
topic = obj.get("topic") or obj.get("query") or obj.get("input")
if isinstance(topic, str) and topic.strip():
topics.append(topic.strip())
continue
except json.JSONDecodeError:
pass
topics.append(line)
return topics
def write_jsonl_line(handle, query: str, output_text: str) -> None:
output = parse_output_text(output_text)
handle.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
def parse_queries(text: str) -> list[str]:
lines = []
for raw in text.splitlines():
line = raw.strip().lstrip("-").strip()
if not line:
continue
lines.append(line)
return lines
def main() -> int:
parser = argparse.ArgumentParser(description="Generate with saved GEPA prompt")
parser.add_argument("--prompt", type=str, required=True, help="Path to saved prompt text")
parser.add_argument("--topics", type=str, required=True, help="Topics file (one per line or JSONL)")
parser.add_argument("--output", type=str, required=True, help="Output JSONL path")
parser.add_argument("--model", type=str, required=True, help="LM string in provider/model format")
parser.add_argument("--per-topic", type=int, default=3, help="Queries to generate per topic")
args = parser.parse_args()
prompt_text = Path(args.prompt).read_text(encoding="utf-8").strip()
expansion_sig = dspy.Signature("query -> expansion", prompt_text)
query_sig = dspy.Signature(
"topic, count -> queries",
(
"Generate distinct user search queries for the given topic. "
"Return exactly `count` queries, one per line, no numbering or extra text."
),
)
class Generator(dspy.Module):
def __init__(self):
super().__init__()
self.predict = dspy.Predict(expansion_sig)
def forward(self, query: str):
return self.predict(query=query)
class QueryGenerator(dspy.Module):
def __init__(self):
super().__init__()
self.predict = dspy.Predict(query_sig)
def forward(self, topic: str, count: int):
return self.predict(topic=topic, count=str(count))
lm = dspy.LM(model=args.model)
gen = Generator()
gen.set_lm(lm)
qgen = QueryGenerator()
qgen.set_lm(lm)
topics = load_topics(Path(args.topics))
with Path(args.output).open("w", encoding="utf-8") as f_out:
for topic in topics:
qpred = qgen(topic=topic, count=args.per_topic)
qtext = getattr(qpred, "queries", "") or ""
generated = parse_queries(qtext)
if not generated:
generated = [topic]
generated = generated[: args.per_topic]
for query in generated:
pred = gen(query=query)
output_text = getattr(pred, "expansion", "") or ""
write_jsonl_line(f_out, query, output_text)
print(json.dumps({"query": query, "output": parse_output_text(output_text)}, ensure_ascii=False))
print(f"Wrote {args.output}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,10 @@
{"query": "how tourism affects local cultures", "output": [["lex", "tourism cultural impact"], ["lex", "local culture tourism effects"], ["lex", "overtourism traditions change"], ["vec", "effects of tourism on indigenous customs"], ["vec", "tourism influence on native practices"], ["vec", "cultural shifts from mass tourism"], ["hyde", "Tourism reshapes local cultures by commercializing traditions, introducing global influences, and sparking both preservation efforts and cultural erosion in communities."]]}
{"query": "how to ferment foods at home", "output": [["lex", "home fermentation"], ["lex", "DIY food fermenting"], ["lex", "lacto fermentation"], ["vec", "beginner guide to home food fermentation"], ["vec", "steps for safe fermenting vegetables"], ["vec", "homemade probiotic food recipes"], ["hyde", "Fermenting foods at home involves chopping produce, mixing with salt brine, packing into jars, and waiting 3-14 days in a cool spot for tangy flavors and probiotics."]]}
{"query": "how to mix modern and vintage decor", "output": [["lex", "modern vintage blend"], ["lex", "contemporary retro decor"], ["lex", "eclectic style fusion"], ["vec", "tips blending modern vintage furniture"], ["vec", "combining contemporary antique accents"], ["vec", "balancing old new interior elements"], ["hyde", "Seamlessly mix sleek modern furniture with charming vintage pieces using neutral palettes and strategic layering for a cohesive eclectic home."]]}
{"query": "how to perform a scientific experiment", "output": [["lex", "scientific method steps"], ["lex", "conduct lab experiment"], ["lex", "experiment procedure guide"], ["vec", "steps to design scientific experiment"], ["vec", "guide for running science experiments"], ["vec", "how to execute controlled experiment"], ["hyde", "To perform a scientific experiment, form a hypothesis, design a method, collect data via tests, analyze results, and draw evidence-based conclusions."]]}
{"query": "web mail", "output": [["lex", "webmail"], ["lex", "online email"], ["lex", "browser mail"], ["vec", "web-based email services"], ["vec", "access mail via browser"], ["vec", "free webmail providers"], ["hyde", "Webmail enables users to read, send, and manage emails directly through a web browser without installing software."]]}
{"query": "what does the quran cover", "output": [["lex", "Quran topics"], ["lex", "Quran contents"], ["lex", "Quran themes"], ["vec", "main topics in the Quran"], ["vec", "subjects covered by Quran"], ["vec", "themes and teachings Quran"], ["hyde", "The Quran addresses theology, prophethood, morality, Islamic laws, stories of ancient prophets, afterlife, and guidance for personal and social life."]]}
{"query": "web config", "output": [["lex", "web.config file"], ["lex", "ASP.NET config"], ["lex", "IIS configuration"], ["vec", "editing web.config settings"], ["vec", "web.config appSettings section"], ["vec", "configuring ASP.NET web app"], ["hyde", "The web.config file in ASP.NET defines application settings, authentication, modules, and connection strings for IIS-hosted web applications."]]}
{"query": "how to choose farm equipment", "output": [["lex", "farm machinery selection"], ["lex", "agricultural equipment buying"], ["lex", "tractor harvester choice"], ["vec", "guide to selecting farm tools"], ["vec", "factors in choosing farm gear"], ["vec", "tips for buying ag machinery"], ["hyde", "To choose farm equipment effectively, assess farm size, soil type, budget, durability, and brand reviews for long-term productivity and value."]]}
{"query": "how do thought experiments aid philosophical reasoning", "output": [["lex", "thought experiments philosophy"], ["lex", "hypothetical reasoning aids"], ["lex", "gedankenexperiment benefits"], ["vec", "role of thought experiments philosophy"], ["vec", "hypotheticals improve philosophical logic"], ["vec", "mental scenarios aid argumentation"], ["hyde", "Thought experiments bolster philosophical reasoning by simulating scenarios to test ideas, expose flaws, and clarify abstract concepts without real-world limits."]]}
{"query": "what is the significance of logic in philosophy", "output": [["lex", "logic philosophy importance"], ["lex", "philosophical logic role"], ["lex", "logic significance reasoning"], ["vec", "importance of logic in philosophy"], ["vec", "role of logic philosophical thought"], ["vec", "why logic fundamental to philosophy"], ["hyde", "Logic underpins philosophy by furnishing tools for valid inference, critical analysis, and structured argumentation across metaphysics, epistemology, and ethics."]]}

View File

@ -0,0 +1,20 @@
{"query": "how tourism affects local cultures", "output": []}
{"query": "how to ferment foods at home", "output": []}
{"query": "how to mix modern and vintage decor", "output": []}
{"query": "how to perform a scientific experiment", "output": []}
{"query": "web mail", "output": []}
{"query": "what does the quran cover", "output": []}
{"query": "web config", "output": []}
{"query": "how to choose farm equipment", "output": []}
{"query": "how do thought experiments aid philosophical reasoning", "output": []}
{"query": "what is the significance of logic in philosophy", "output": []}
{"query": "how to train for a 5k run", "output": []}
{"query": "how to engage with political dialogues", "output": []}
{"query": "what is competitive analysis", "output": []}
{"query": "how does the united nations operate", "output": []}
{"query": "what are the crusades?", "output": []}
{"query": "what is a literary theme?", "output": []}
{"query": "what is the ethical significance of consent", "output": []}
{"query": "paint mix", "output": []}
{"query": "how to conserve energy in the office?", "output": []}
{"query": "how to test soil ph?", "output": []}

View File

@ -0,0 +1,19 @@
{
"name": "qmd-gepa-example-generator",
"model": "grok-4-1-fast-reasoning",
"schema_version": 1,
"prompt": "You are a query expansion expert. Given a user query, output a single JSON object that matches the training JSONL schema:\n{\"query\": \"...\", \"output\": [[\"lex\", \"...\"], [\"vec\", \"...\"], [\"hyde\", \"...\"]]}\nRules:\n- output is a list of pairs, where the first element is one of: \"lex\", \"vec\", \"hyde\".\n- Include 2-3 lex lines, 2-3 vec lines, and 0-1 hyde line.\n- lex lines are short keyword phrases; never equal or near-echo the query.\n- vec lines are natural language search phrases.\n- hyde is a concise hypothetical passage (50-200 chars), single line.\n- Preserve key terms and named entities in lex lines.\n- No extra text outside the JSON object.\n",
"output_schema": {
"query": "string",
"output": [
[
"lex|vec|hyde",
"string"
]
]
},
"notes": [
"LexSearch/VecSearch/HydeSearch are represented as lex/vec/hyde in output.",
"Do not echo the query in lex lines."
]
}

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
"""Write model.json prompt config for generating high-quality examples."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from example import SearchType, SEARCH_TYPE_TO_PREFIX
def build_prompt() -> str:
lex = SEARCH_TYPE_TO_PREFIX[SearchType.LexSearch]
vec = SEARCH_TYPE_TO_PREFIX[SearchType.VecSearch]
hyde = SEARCH_TYPE_TO_PREFIX[SearchType.HydeSearch]
return (
"You are a query expansion expert. Given a user query, output a single JSON object "
"that matches the training JSONL schema:\n"
'{"query": "...", "output": [["lex", "..."], ["vec", "..."], ["hyde", "..."]]}\n'
"Rules:\n"
f"- output is a list of pairs, where the first element is one of: "
f"\"{lex}\", \"{vec}\", \"{hyde}\".\n"
"- Include 2-3 lex lines, 2-3 vec lines, and 0-1 hyde line.\n"
"- lex lines are short keyword phrases; never equal or near-echo the query.\n"
"- vec lines are natural language search phrases.\n"
"- hyde is a concise hypothetical passage (50-200 chars), single line.\n"
"- Preserve key terms and named entities in lex lines.\n"
"- No extra text outside the JSON object.\n"
)
def write_model_json(path: Path) -> None:
payload = {
"name": "qmd-gepa-example-generator",
"model": "grok-4-1-fast-reasoning",
"schema_version": 1,
"prompt": build_prompt(),
"output_schema": {
"query": "string",
"output": [["lex|vec|hyde", "string"]],
},
"notes": [
"LexSearch/VecSearch/HydeSearch are represented as lex/vec/hyde in output.",
"Do not echo the query in lex lines.",
],
}
path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
def main() -> int:
parser = argparse.ArgumentParser(description="Write model.json for GEPA generation")
parser.add_argument(
"--output",
type=str,
default="gepa/model.json",
help="Path to write model.json",
)
args = parser.parse_args()
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
write_model_json(output_path)
print(f"Wrote {output_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python3
"""Score GEPA JSONL outputs using reward.py."""
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
from example import example_from_json
from reward import score_expansion_detailed
from dataset.schema import output_items_to_text
def score_file(path: Path) -> tuple[int, int, list[float], dict]:
total = 0
errors = 0
scores: list[float] = []
ratings: dict[str, int] = {}
with path.open("r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
total += 1
try:
obj = json.loads(line)
example = example_from_json(obj)
except Exception:
errors += 1
continue
output_text = output_items_to_text([item.to_pair() for item in example.output])
if not output_text:
errors += 1
continue
detail = score_expansion_detailed(example.query, output_text)
score = detail["percentage"]
scores.append(score)
rating = detail["rating"]
ratings[rating] = ratings.get(rating, 0) + 1
return total, errors, scores, ratings
def main() -> int:
parser = argparse.ArgumentParser(description="Score GEPA JSONL outputs")
parser.add_argument("--input", type=str, required=True, help="Input JSONL file")
args = parser.parse_args()
path = Path(args.input)
if not path.exists():
print(f"Input not found: {path}")
return 1
total, errors, scores, ratings = score_file(path)
if scores:
avg = statistics.mean(scores)
median = statistics.median(scores)
min_score = min(scores)
max_score = max(scores)
above_70 = sum(1 for s in scores if s >= 70.0)
pct_70 = above_70 / len(scores) * 100
print(
f"{path}: {len(scores)} scored, {errors} errors, "
f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
)
else:
print(f"{path}: 0 scored, {errors} errors")
if ratings:
rating_parts = [f\"{k}:{v}\" for k, v in sorted(ratings.items())]
print(f\" ratings: {', '.join(rating_parts)}\")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,106 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.55.0",
# "accelerate>=0.24.0",
# "huggingface_hub>=0.20.0",
# "datasets",
# "bitsandbytes",
# "torch",
# ]
# ///
"""
SFT training for QMD query expansion with LiquidAI LFM2-1.2B.
LFM2 is a hybrid architecture optimized for edge/on-device inference.
Uses different LoRA target modules than standard transformers.
Self-contained script for HuggingFace Jobs:
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft_lfm2.py
"""
import os
from huggingface_hub import login
# --- Config (inlined from configs/sft_lfm2.yaml) ---
BASE_MODEL = "LiquidAI/LFM2-1.2B"
OUTPUT_MODEL = "tobil/qmd-query-expansion-lfm2-sft"
DATASET = "tobil/qmd-query-expansion-train"
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig
# Load and split dataset
print(f"Loading dataset: {DATASET}...")
dataset = load_dataset(DATASET, split="train")
print(f"Dataset loaded: {len(dataset)} examples")
split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# SFT config
config = SFTConfig(
output_dir="qmd-query-expansion-lfm2-sft",
push_to_hub=True,
hub_model_id=OUTPUT_MODEL,
hub_strategy="every_save",
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_length=512,
logging_steps=10,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
eval_strategy="steps",
eval_steps=200,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
report_to="none",
)
# LoRA config for LFM2 architecture
# LFM2 uses different layer names than standard transformers:
# - Attention: q_proj, k_proj, v_proj, out_proj
# - Input projection: in_proj
# - FFN/MLP gates (SwiGLU): w1, w2, w3
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"],
)
print("Initializing SFT trainer...")
trainer = SFTTrainer(
model=BASE_MODEL,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
)
print("Starting SFT training (LFM2-1.2B)...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")

View File

@ -0,0 +1,60 @@
# SFT Training Config for QMD Query Expansion with LiquidAI LFM2
# Target: LFM2-1.2B with LoRA (hybrid architecture: convolutions + attention)
#
# LFM2 is optimized for on-device inference with fast decode/prefill.
# Recommended for: agentic tasks, data extraction, RAG, creative writing.
#
# Usage: uv run train.py sft --config configs/sft_lfm2.yaml
#
# Requirements:
# - transformers >= 4.55.0 (LFM2 architecture support)
# - May need: pip install -U transformers
model:
base: "LiquidAI/LFM2-1.2B"
output: "outputs/sft-lfm2" # Local training output (push to HF manually after eval)
dataset:
# Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
# HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
name: "data/train/"
text_field: "text"
split: "train"
eval_split: 0.1
training:
epochs: 5
batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 2e-4
max_length: 512
warmup_ratio: 0.03
lr_scheduler: "cosine"
lora:
rank: 16
alpha: 32
dropout: 0.0
# LFM2 uses different architecture than standard transformers:
# - Attention layers: q_proj, k_proj, v_proj, out_proj
# - Input projection: in_proj
# - FFN/MLP gates: w1, w2, w3 (SwiGLU activation)
target_modules:
- "q_proj"
- "k_proj"
- "v_proj"
- "out_proj"
- "in_proj"
- "w1"
- "w2"
- "w3"
tracking:
project: "qmd-query-expansion"
run_name: "sft-lfm2-1.2B"
# LFM2-specific generation settings (recommended by LiquidAI)
generation:
temperature: 0.3
min_p: 0.15
repetition_penalty: 1.05

View File

@ -16,7 +16,7 @@ dependencies = [
"gguf",
"sentencepiece",
"nvidia-ml-py",
"dspy-ai>=3.1.2",
"pydantic>=2.0",
]
[dependency-groups]

5278
finetune/uv.lock generated

File diff suppressed because it is too large Load Diff