finetune: strict Pydantic schema, one canonical data format
Replace ad-hoc JSON parsing with a strict Pydantic model (TrainingExample with typed OutputPair). All data loading goes through load_examples() which fails loudly on invalid data. - Convert v3_structured.jsonl from "searches" to "output" format - Rewrite all consumer scripts (prepare, validate, score, analyze) to load through the Pydantic schema - Prepared train/val files are ephemeral build artifacts - Restore LFM2 and GEPA experiments under experiments/ - Add pydantic>=2.0 to dependencies
This commit is contained in:
parent
3950055708
commit
1d7d167b29
@ -18,6 +18,22 @@ vec: another semantic variation
|
||||
- `lex:` lines for BM25 keyword search (1-3 lines, short keywords)
|
||||
- `vec:` lines for vector similarity search (1-3 lines, natural language)
|
||||
|
||||
## Training Data Format
|
||||
|
||||
**There is exactly one JSONL format.** Every file in `data/*.jsonl` must match the strict Pydantic schema in `dataset/schema.py`:
|
||||
|
||||
```json
|
||||
{"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
|
||||
```
|
||||
|
||||
- `query`: non-empty string
|
||||
- `output`: list of `[type, text]` pairs where type is `"lex"`, `"vec"`, or `"hyde"`
|
||||
- Extra metadata fields (`category`, `intent`, `is_short`) are allowed but ignored
|
||||
|
||||
The schema is enforced by `dataset/schema.py:TrainingExample` (Pydantic model). All data loading goes through `load_examples()` which fails loudly on invalid data. No format alternatives, no legacy fallbacks.
|
||||
|
||||
**All `.jsonl` files in `data/` are concatenated and deduplicated for training runs.** The prepared train/val files in `data/train/` are ephemeral build artifacts.
|
||||
|
||||
## HuggingFace Repositories
|
||||
|
||||
| Repository | Purpose |
|
||||
@ -33,72 +49,34 @@ vec: another semantic variation
|
||||
- Only push when eval scores improve over current deployed model
|
||||
- Always include eval results in model card when pushing
|
||||
|
||||
## Training Data
|
||||
|
||||
All JSONL files in `data/` are training data:
|
||||
|
||||
```
|
||||
data/
|
||||
├── qmd_expansion_v2.jsonl
|
||||
├── qmd_expansion_handcrafted_only.jsonl
|
||||
├── qmd_only_sampled.jsonl
|
||||
├── qmd_only_variants.jsonl
|
||||
└── ... any additional .jsonl files
|
||||
```
|
||||
|
||||
**All `.jsonl` files in `data/` should be concatenated for training runs.**
|
||||
|
||||
Each JSONL line: `{"input": "query", "output": "hyde:...\nlex:...\nvec:..."}`
|
||||
|
||||
## Data Generation Tools
|
||||
## Dataset Tools
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `dataset/generate_data.py` | Generate via Claude API (high quality) |
|
||||
| `dataset/generate_data_offline.py` | Transform from HuggingFace datasets |
|
||||
| `dataset/prepare_data.py` | Format for Qwen3 chat template |
|
||||
| `dataset/clean_data.py` | Detect and fix technical term issues |
|
||||
| `generate_only_variants.py` | Generate `/only:lex` and `/only:vec` variants |
|
||||
|
||||
## Local Training Output
|
||||
|
||||
All training outputs go to `outputs/` (gitignored):
|
||||
|
||||
```
|
||||
outputs/
|
||||
├── sft/ # SFT checkpoint
|
||||
└── grpo/ # GRPO checkpoint
|
||||
```
|
||||
| `dataset/schema.py` | Pydantic `TrainingExample` model + `load_examples()` |
|
||||
| `dataset/prepare_data.py` | Load via schema, apply Qwen3 chat template, dedup, split |
|
||||
| `dataset/validate_schema.py` | Validate all JSONL files against schema |
|
||||
| `dataset/score_data.py` | Score all examples using reward.py |
|
||||
| `dataset/analyze_data.py` | Analyze distribution and quality |
|
||||
|
||||
## Training Pipeline
|
||||
|
||||
Always use **Qwen3-1.7B** as the base model unless explicitly stated otherwise.
|
||||
|
||||
Training can run **locally** (requires CUDA GPU) or via **HuggingFace Jobs** (cloud GPU, no local hardware needed).
|
||||
|
||||
### Stage 0: Prepare Data
|
||||
|
||||
Raw data in `data/*.jsonl` must be converted to Qwen3 chat format before training:
|
||||
|
||||
```bash
|
||||
# Process all JSONL files in data/
|
||||
uv run dataset/prepare_data.py
|
||||
# Creates: data/train/train.jsonl, data/train/val.jsonl
|
||||
|
||||
# Or process a specific file
|
||||
uv run dataset/prepare_data.py --input data/qmd_expansion_v2.jsonl
|
||||
# Creates: data/train/train.jsonl, data/train/val.jsonl (ephemeral)
|
||||
```
|
||||
|
||||
This applies the Qwen3 chat template, deduplicates, and splits into train/val sets.
|
||||
|
||||
### Stage 1: SFT
|
||||
|
||||
```bash
|
||||
# Local (requires CUDA)
|
||||
uv run train.py sft --config configs/sft.yaml
|
||||
# Output: outputs/sft/
|
||||
|
||||
# Cloud (HuggingFace Jobs - no local GPU needed)
|
||||
# Cloud (HuggingFace Jobs)
|
||||
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft.py
|
||||
```
|
||||
|
||||
@ -107,16 +85,13 @@ hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft.py
|
||||
```bash
|
||||
# Local (requires CUDA)
|
||||
uv run train.py grpo --config configs/grpo.yaml
|
||||
# Output: outputs/grpo/
|
||||
|
||||
# Cloud (HuggingFace Jobs - no local GPU needed)
|
||||
# Cloud (HuggingFace Jobs)
|
||||
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
|
||||
```
|
||||
|
||||
### HuggingFace Jobs
|
||||
|
||||
If no local CUDA device is available, use `hf jobs` to run training in the cloud:
|
||||
|
||||
```bash
|
||||
hf jobs ps # List running jobs
|
||||
hf jobs logs <job-id> # Stream logs
|
||||
@ -124,18 +99,11 @@ hf jobs inspect <job-id> # Check status
|
||||
hf jobs cancel <job-id> # Cancel a job
|
||||
```
|
||||
|
||||
The `jobs/` directory contains self-contained scripts that include all dependencies inline.
|
||||
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
# Eval local model
|
||||
uv run eval.py --model ./outputs/grpo
|
||||
|
||||
# Eval HuggingFace model
|
||||
uv run eval.py --model tobil/qmd-query-expansion-1.7B
|
||||
|
||||
# Save eval results to file
|
||||
uv run eval.py --model ./outputs/grpo -o eval_results.json
|
||||
```
|
||||
|
||||
@ -144,27 +112,26 @@ uv run eval.py --model ./outputs/grpo -o eval_results.json
|
||||
`reward.py` is the single source of truth for scoring:
|
||||
|
||||
```bash
|
||||
# Self-test the reward function
|
||||
uv run reward.py
|
||||
uv run reward.py # Self-test
|
||||
```
|
||||
|
||||
See `SCORING.md` for the full rubric.
|
||||
|
||||
## Deployment Rules
|
||||
## Experiments
|
||||
|
||||
**Never upload without eval.** Every model push must include eval results.
|
||||
Experimental training configurations live in `experiments/`:
|
||||
|
||||
### Checklist
|
||||
```
|
||||
experiments/
|
||||
├── lfm2/ # LiquidAI LFM2-1.2B (hybrid architecture, faster inference)
|
||||
│ ├── sft_lfm2.yaml
|
||||
│ └── sft_lfm2.py
|
||||
└── gepa/ # DSPy-based prompt optimization (GEPA)
|
||||
├── dspy_gepa.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
1. Train SFT on all `data/*.jsonl` → `outputs/sft/`
|
||||
2. Train GRPO on top of SFT → `outputs/grpo/`
|
||||
3. **Run eval on local model**: `uv run eval.py --model ./outputs/grpo -o eval_results.json`
|
||||
4. Compare against current deployed model's eval
|
||||
5. If eval improves:
|
||||
- Push to `tobil/qmd-query-expansion-1.7B`
|
||||
- **Include eval output in the model card / commit message**
|
||||
6. Convert to GGUF and update `tobil/qmd-query-expansion-1.7B-gguf`
|
||||
7. Update `src/llm.ts` DEFAULT_GENERATE_MODEL if repo name changed
|
||||
These are not part of the main training pipeline.
|
||||
|
||||
## Key Files
|
||||
|
||||
@ -176,10 +143,12 @@ finetune/
|
||||
├── convert_gguf.py # GGUF conversion
|
||||
├── SCORING.md # Detailed scoring rubric
|
||||
├── CLAUDE.md # This file
|
||||
├── data/ # All training JSONL files
|
||||
├── outputs/ # Local training outputs (gitignored)
|
||||
├── dataset/ # Data generation scripts
|
||||
├── Justfile # Common commands
|
||||
├── data/ # All training JSONL files (strict schema)
|
||||
├── dataset/ # Schema + data tools (Pydantic-based)
|
||||
├── jobs/ # Self-contained HuggingFace Jobs scripts
|
||||
├── configs/ # Training configs (sft.yaml, grpo.yaml)
|
||||
└── evals/ # Test queries and results
|
||||
├── evals/ # Test queries
|
||||
├── experiments/ # Experimental configs (LFM2, GEPA)
|
||||
└── outputs/ # Local training outputs (gitignored)
|
||||
```
|
||||
|
||||
@ -92,20 +92,19 @@ finetune/
|
||||
│ ├── sft.py # Self-contained SFT for HuggingFace Jobs
|
||||
│ ├── grpo.py # Self-contained GRPO for HuggingFace Jobs
|
||||
│ ├── eval.py # Self-contained eval for HuggingFace Jobs
|
||||
│ ├── eval_common.py # Shared eval utilities
|
||||
│ └── quantize.py # GGUF quantization for HuggingFace Jobs
|
||||
│ └── eval_common.py # Shared eval utilities
|
||||
├── configs/
|
||||
│ ├── sft.yaml # SFT hyperparameters for Qwen3-1.7B
|
||||
│ └── grpo.yaml # GRPO hyperparameters for Qwen3-1.7B
|
||||
├── evals/
|
||||
│ └── queries.txt # 31 test queries across 8 categories
|
||||
├── data/
|
||||
│ └── qmd_expansion_v2.jsonl # Source training data (1,000 high-quality examples)
|
||||
├── data/ # Training JSONL files (all concatenated for training)
|
||||
├── dataset/
|
||||
│ ├── generate_data.py # Generate data via Claude API
|
||||
│ ├── generate_data_offline.py # Generate from existing HF dataset
|
||||
│ ├── prepare_data.py # Format for Qwen3 chat template
|
||||
│ └── clean_data.py # Detect technical term misinterpretations
|
||||
│ ├── prepare_data.py # Format for Qwen3 chat template, dedup, split
|
||||
│ ├── schema.py # Parse/normalize output format
|
||||
│ ├── validate_schema.py # Validate JSONL against schema
|
||||
│ ├── score_data.py # Score all examples using reward.py
|
||||
│ └── analyze_data.py # Analyze distribution and quality
|
||||
├── SCORING.md # Detailed scoring rubric reference
|
||||
└── README.md # This file
|
||||
```
|
||||
@ -122,7 +121,7 @@ Teaches the model the `lex:/vec:/hyde:` output format from labeled examples.
|
||||
| Method | LoRA (rank 16, alpha 32) |
|
||||
| Target modules | All projection layers (q/k/v/o/gate/up/down) |
|
||||
| Dataset | ~2,290 examples (train split) |
|
||||
| Effective batch size | 16 (4 × 4 gradient accumulation) |
|
||||
| Effective batch size | 16 (4 x 4 gradient accumulation) |
|
||||
| Epochs | 5 |
|
||||
| Learning rate | 2e-4 (cosine schedule) |
|
||||
|
||||
@ -173,9 +172,6 @@ uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft -v
|
||||
|
||||
# Save detailed scores to JSON
|
||||
uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft -o scores.json
|
||||
|
||||
# Score an existing JSONL file (backwards compat with old run.py output)
|
||||
uv run eval.py --score-only evals/results_old.jsonl
|
||||
```
|
||||
|
||||
## Reward Function
|
||||
@ -212,9 +208,6 @@ quantized GGUF files for deployment:
|
||||
# Use preset for 1.7B
|
||||
uv run convert_gguf.py --size 1.7B
|
||||
|
||||
# Use preset for 4B
|
||||
uv run convert_gguf.py --size 4B
|
||||
|
||||
# Custom models
|
||||
uv run convert_gguf.py --base Qwen/Qwen3-1.7B \
|
||||
--sft tobil/qmd-query-expansion-1.7B-sft \
|
||||
@ -235,26 +228,19 @@ ollama run qmd-expand
|
||||
|
||||
## Data Pipeline
|
||||
|
||||
The training data (1,000 examples in `data/qmd_expansion_v2.jsonl`) was generated
|
||||
from two sources and cleaned for quality. To regenerate:
|
||||
All JSONL files in `data/` are concatenated for training. To prepare for training:
|
||||
|
||||
```bash
|
||||
# Generate from existing HuggingFace dataset (bulk, no API needed)
|
||||
uv run dataset/generate_data_offline.py
|
||||
|
||||
# Generate via Claude API (higher quality, needs ANTHROPIC_API_KEY)
|
||||
uv run dataset/generate_data.py --count 100
|
||||
|
||||
# Detect and fix technical term misinterpretations
|
||||
uv run dataset/clean_data.py
|
||||
|
||||
# Format for Qwen3 chat template, add short-query augmentation, split train/val
|
||||
# Format for Qwen3 chat template, deduplicate, split train/val
|
||||
uv run dataset/prepare_data.py
|
||||
|
||||
# Validate data quality
|
||||
just validate
|
||||
```
|
||||
|
||||
## Architecture Notes
|
||||
|
||||
The two-stage training approach (SFT → GRPO) is standard for structured-output models:
|
||||
The two-stage training approach (SFT -> GRPO) is standard for structured-output models:
|
||||
|
||||
1. **SFT** establishes format compliance and basic query understanding. It uses
|
||||
a large LoRA (rank 16, all projection layers) because it needs to learn a
|
||||
@ -297,42 +283,3 @@ deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubri
|
||||
|-------|--------------|-----------------|
|
||||
| SFT | 92.0% | 30/30 |
|
||||
| GRPO | 91.7% | 30/30 |
|
||||
|
||||
## Alternative Base Models
|
||||
|
||||
### LiquidAI LFM2 (Experimental)
|
||||
|
||||
[LFM2](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
is a hybrid architecture from Liquid AI optimized for on-device inference. It uses
|
||||
a novel combination of convolutions and attention that achieves 2x faster decode
|
||||
and prefill speed compared to standard transformers.
|
||||
|
||||
**Why LFM2 for query expansion:**
|
||||
- **Faster inference**: Lower latency for real-time search applications
|
||||
- **Memory efficient**: Smaller memory footprint than equivalent transformers
|
||||
- **Edge-optimized**: Can run on mobile devices and embedded systems
|
||||
- **Good at agentic tasks**: LiquidAI recommends LFM2 for RAG and data extraction
|
||||
|
||||
**Training with LFM2:**
|
||||
|
||||
```bash
|
||||
# SFT with LFM2-1.2B base model
|
||||
uv run train.py sft --config configs/sft_lfm2.yaml
|
||||
|
||||
# Evaluate the trained model
|
||||
uv run eval.py --model outputs/sft-lfm2
|
||||
|
||||
# Convert to GGUF for deployment
|
||||
uv run convert_gguf.py --base LiquidAI/LFM2-1.2B \
|
||||
--sft outputs/sft-lfm2 \
|
||||
--output tobil/qmd-query-expansion-lfm2-gguf
|
||||
```
|
||||
|
||||
**Key differences from Qwen3:**
|
||||
- Different LoRA target modules: `q_proj, k_proj, v_proj, out_proj, in_proj, w1, w2, w3`
|
||||
- Recommended generation parameters: `temp=0.3, min_p=0.15, repetition_penalty=1.05`
|
||||
- Requires transformers >= 4.55.0 for architecture support
|
||||
|
||||
**Pre-trained GGUF models:**
|
||||
- Base: `hf:LiquidAI/LFM2-1.2B-GGUF/LFM2-1.2B-Q4_K_M.gguf` (~731 MB)
|
||||
- Instruct: `hf:LiquidAI/LFM2.5-1.2B-Instruct-GGUF/LFM2.5-1.2B-Instruct-Q4_K_M.gguf` (~731 MB)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,34 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["pydantic>=2.0"]
|
||||
# ///
|
||||
"""
|
||||
Dataset Analysis and Quality Report Generator
|
||||
|
||||
Analyzes the training data for:
|
||||
Analyzes training data loaded through the strict Pydantic schema for:
|
||||
1. Query length distribution
|
||||
2. Category diversity
|
||||
3. Named entity coverage
|
||||
4. Temporal query coverage
|
||||
5. Short query coverage (important for ambiguous queries)
|
||||
6. Duplicate detection
|
||||
7. Quality issues (long hyde, missing fields, etc.)
|
||||
4. Output format coverage
|
||||
5. Duplicate detection
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dataset.schema import normalize_output_items, parse_output_text
|
||||
from dataset.schema import TrainingExample, OutputType, load_examples
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetStats:
|
||||
total_examples: int = 0
|
||||
short_queries: int = 0 # 1-2 words
|
||||
medium_queries: int = 0 # 3-5 words
|
||||
long_queries: int = 0 # 6+ words
|
||||
short_queries: int = 0
|
||||
medium_queries: int = 0
|
||||
long_queries: int = 0
|
||||
has_lex: int = 0
|
||||
has_vec: int = 0
|
||||
has_hyde: int = 0
|
||||
@ -40,288 +41,164 @@ class DatasetStats:
|
||||
|
||||
|
||||
def categorize_query(query: str) -> str:
|
||||
"""Categorize a query by type."""
|
||||
query_lower = query.lower()
|
||||
words = query_lower.split()
|
||||
word_count = len(words)
|
||||
|
||||
# Short keyword queries
|
||||
if word_count <= 2:
|
||||
return "short_keyword"
|
||||
|
||||
# Named entity queries (capitalized words or tech terms)
|
||||
if any(w[0].isupper() for w in words if w):
|
||||
if any(w[0].isupper() for w in query.split() if w):
|
||||
return "named_entity"
|
||||
|
||||
# Temporal/recency queries
|
||||
temporal_keywords = [
|
||||
"latest",
|
||||
"recent",
|
||||
"new",
|
||||
"update",
|
||||
"changelog",
|
||||
"changed",
|
||||
"version",
|
||||
"release",
|
||||
"news",
|
||||
"2024",
|
||||
"2025",
|
||||
"latest", "recent", "new", "update", "changelog",
|
||||
"changed", "version", "release", "news", "2024", "2025",
|
||||
]
|
||||
if any(kw in query_lower for kw in temporal_keywords):
|
||||
return "temporal"
|
||||
|
||||
# How-to queries
|
||||
if query_lower.startswith("how "):
|
||||
return "how_to"
|
||||
|
||||
# What is queries
|
||||
if query_lower.startswith("what "):
|
||||
return "what_is"
|
||||
|
||||
# Difference/comparison queries
|
||||
if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
|
||||
return "comparison"
|
||||
|
||||
# Personal/journal style
|
||||
if any(
|
||||
kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]
|
||||
):
|
||||
if any(kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]):
|
||||
return "personal"
|
||||
|
||||
return "other"
|
||||
|
||||
|
||||
def extract_named_entities(query: str) -> list:
|
||||
"""Extract potential named entities from query."""
|
||||
entities = []
|
||||
words = query.split()
|
||||
|
||||
for word in words:
|
||||
# Skip stopwords
|
||||
if word.lower() in {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"is",
|
||||
"are",
|
||||
"to",
|
||||
"for",
|
||||
"of",
|
||||
"in",
|
||||
"and",
|
||||
"or",
|
||||
}:
|
||||
stopwords = {"the", "a", "an", "is", "are", "to", "for", "of", "in", "and", "or"}
|
||||
for word in query.split():
|
||||
if word.lower() in stopwords:
|
||||
continue
|
||||
|
||||
# Capitalized words (potential named entities)
|
||||
if word and word[0].isupper() and len(word) > 1:
|
||||
entities.append(word)
|
||||
|
||||
# Technology terms with version numbers or special chars
|
||||
if any(c in word for c in ".+-0123456789") and len(word) > 1:
|
||||
entities.append(word)
|
||||
|
||||
return entities
|
||||
|
||||
|
||||
def analyze_dataset(filepath: Path) -> tuple[DatasetStats, dict, dict]:
|
||||
"""Analyze the dataset and return statistics."""
|
||||
def analyze_examples(examples: list[TrainingExample]) -> tuple[DatasetStats, dict, dict]:
|
||||
stats = DatasetStats()
|
||||
categories = Counter()
|
||||
seen_queries = set()
|
||||
duplicate_count = 0
|
||||
category_examples = defaultdict(list)
|
||||
categories: Counter = Counter()
|
||||
seen_queries: set[str] = set()
|
||||
category_examples: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
for ex in examples:
|
||||
stats.total_examples += 1
|
||||
|
||||
try:
|
||||
example = json.loads(line)
|
||||
query = example.get("query", "") or example.get("input", "")
|
||||
output = example.get("output", [])
|
||||
if isinstance(output, str):
|
||||
output = parse_output_text(output)
|
||||
output = normalize_output_items(output)
|
||||
query_lower = ex.query.lower()
|
||||
if query_lower in seen_queries:
|
||||
stats.duplicate_queries += 1
|
||||
else:
|
||||
seen_queries.add(query_lower)
|
||||
|
||||
stats.total_examples += 1
|
||||
word_count = len(ex.query.split())
|
||||
if word_count <= 2:
|
||||
stats.short_queries += 1
|
||||
elif word_count <= 5:
|
||||
stats.medium_queries += 1
|
||||
else:
|
||||
stats.long_queries += 1
|
||||
|
||||
# Check for duplicates
|
||||
query_lower = query.lower()
|
||||
if query_lower in seen_queries:
|
||||
duplicate_count += 1
|
||||
else:
|
||||
seen_queries.add(query_lower)
|
||||
category = categorize_query(ex.query)
|
||||
categories[category] += 1
|
||||
category_examples[category].append(ex.query)
|
||||
|
||||
# Query length categorization
|
||||
word_count = len(query.split())
|
||||
if word_count <= 2:
|
||||
stats.short_queries += 1
|
||||
elif word_count <= 5:
|
||||
stats.medium_queries += 1
|
||||
else:
|
||||
stats.long_queries += 1
|
||||
if extract_named_entities(ex.query):
|
||||
stats.named_entity_queries += 1
|
||||
|
||||
# Category detection
|
||||
category = categorize_query(query)
|
||||
categories[category] += 1
|
||||
category_examples[category].append(query)
|
||||
# Use the typed OutputPair model
|
||||
types_present = {p.type for p in ex.output}
|
||||
if OutputType.lex in types_present:
|
||||
stats.has_lex += 1
|
||||
if OutputType.vec in types_present:
|
||||
stats.has_vec += 1
|
||||
if OutputType.hyde in types_present:
|
||||
stats.has_hyde += 1
|
||||
for p in ex.output:
|
||||
if p.type == OutputType.hyde and len(p.text) > 200:
|
||||
stats.long_hyde_count += 1
|
||||
|
||||
# Named entity detection
|
||||
if extract_named_entities(query):
|
||||
stats.named_entity_queries += 1
|
||||
|
||||
# Output analysis
|
||||
has_lex = any(o[0] == "lex" for o in output)
|
||||
has_vec = any(o[0] == "vec" for o in output)
|
||||
has_hyde = any(o[0] == "hyde" for o in output)
|
||||
|
||||
if has_lex:
|
||||
stats.has_lex += 1
|
||||
if has_vec:
|
||||
stats.has_vec += 1
|
||||
if has_hyde:
|
||||
stats.has_hyde += 1
|
||||
|
||||
# Check hyde length
|
||||
for kind, text in output:
|
||||
if kind == "hyde" and len(text) > 200:
|
||||
stats.long_hyde_count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Could not parse line {line_num}")
|
||||
|
||||
stats.duplicate_queries = duplicate_count
|
||||
stats.temporal_queries = categories.get("temporal", 0)
|
||||
stats.short_keyword_queries = categories.get("short_keyword", 0)
|
||||
|
||||
return stats, dict(categories), dict(category_examples)
|
||||
|
||||
|
||||
def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
|
||||
"""Print a comprehensive analysis report."""
|
||||
print("=" * 70)
|
||||
print("QMD TRAINING DATA ANALYSIS REPORT")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
# Basic statistics
|
||||
print("📊 BASIC STATISTICS")
|
||||
total = stats.total_examples
|
||||
print("BASIC STATISTICS")
|
||||
print("-" * 40)
|
||||
print(f"Total examples: {stats.total_examples:>6}")
|
||||
print(f"Total examples: {total:>6}")
|
||||
print(f"Duplicates found: {stats.duplicate_queries:>6}")
|
||||
print()
|
||||
|
||||
# Query length distribution
|
||||
print("📝 QUERY LENGTH DISTRIBUTION")
|
||||
print("QUERY LENGTH DISTRIBUTION")
|
||||
print("-" * 40)
|
||||
total = stats.total_examples
|
||||
print(
|
||||
f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)"
|
||||
)
|
||||
print(f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)")
|
||||
print(f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)")
|
||||
print(f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)")
|
||||
print()
|
||||
|
||||
# Category distribution
|
||||
print("🏷️ CATEGORY DISTRIBUTION")
|
||||
print("CATEGORY DISTRIBUTION")
|
||||
print("-" * 40)
|
||||
for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
|
||||
pct = 100 * count / total
|
||||
bar = "█" * int(pct / 2)
|
||||
bar = "#" * int(pct / 2)
|
||||
print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
|
||||
print()
|
||||
|
||||
# Output format coverage
|
||||
print("✅ OUTPUT FORMAT COVERAGE")
|
||||
print("OUTPUT FORMAT COVERAGE")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)"
|
||||
)
|
||||
print(f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)")
|
||||
print(f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)")
|
||||
print(f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)")
|
||||
print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
|
||||
print()
|
||||
|
||||
# Critical metrics for evals
|
||||
print("🎯 EVALUATION ALIGNMENT")
|
||||
print("EVALUATION ALIGNMENT")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)"
|
||||
)
|
||||
print(
|
||||
f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)"
|
||||
)
|
||||
print(f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)")
|
||||
print(f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)")
|
||||
print(f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)")
|
||||
print()
|
||||
|
||||
# Recommendations
|
||||
print("💡 RECOMMENDATIONS")
|
||||
print("RECOMMENDATIONS")
|
||||
print("-" * 40)
|
||||
|
||||
recommendations = []
|
||||
|
||||
if stats.short_queries / total < 0.15:
|
||||
recommendations.append(
|
||||
"⚠️ Short queries below 15% - add more 1-2 word keyword queries"
|
||||
)
|
||||
|
||||
recommendations.append("Short queries below 15% - add more 1-2 word keyword queries")
|
||||
if stats.named_entity_queries / total < 0.10:
|
||||
recommendations.append(
|
||||
"⚠️ Named entity queries below 10% - add more capitalized tech term queries"
|
||||
)
|
||||
|
||||
recommendations.append("Named entity queries below 10% - add more capitalized tech term queries")
|
||||
if stats.temporal_queries / total < 0.05:
|
||||
recommendations.append(
|
||||
"⚠️ Temporal queries below 5% - add more 'latest', 'recent' queries"
|
||||
)
|
||||
|
||||
recommendations.append("Temporal queries below 5% - add more 'latest', 'recent' queries")
|
||||
if stats.long_hyde_count > 50:
|
||||
recommendations.append(
|
||||
f"⚠️ {stats.long_hyde_count} long hyde sections - consider truncating"
|
||||
)
|
||||
|
||||
recommendations.append(f"{stats.long_hyde_count} long hyde sections - consider truncating")
|
||||
if stats.duplicate_queries > 0:
|
||||
recommendations.append(
|
||||
f"⚠️ {stats.duplicate_queries} duplicate queries - consider deduplication"
|
||||
)
|
||||
|
||||
if categories.get("short_keyword", 0) < 100:
|
||||
recommendations.append(
|
||||
"⚠️ Need more short keyword examples for ambiguous query training"
|
||||
)
|
||||
|
||||
recommendations.append(f"{stats.duplicate_queries} duplicate queries - consider deduplication")
|
||||
if not recommendations:
|
||||
print("✅ Dataset looks good! No major issues detected.")
|
||||
print("Dataset looks good! No major issues detected.")
|
||||
else:
|
||||
for rec in recommendations:
|
||||
print(rec)
|
||||
|
||||
print(f" - {rec}")
|
||||
print()
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="data/qmd_expansion_v2.jsonl",
|
||||
default="data/qmd_expansion_v3_structured.jsonl",
|
||||
help="Path to training data JSONL file",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -333,33 +210,30 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
|
||||
if not input_path.exists():
|
||||
# Try relative to script directory
|
||||
script_dir = Path(__file__).parent.parent
|
||||
input_path = script_dir / args.input
|
||||
|
||||
if not input_path.exists():
|
||||
print(f"Error: Could not find dataset at {input_path}")
|
||||
print("Please run from finetune directory or specify correct path")
|
||||
return 1
|
||||
|
||||
print(f"Analyzing: {input_path}")
|
||||
print()
|
||||
|
||||
stats, categories, category_examples = analyze_dataset(input_path)
|
||||
examples = load_examples(input_path)
|
||||
stats, categories, category_examples = analyze_examples(examples)
|
||||
print_report(stats, categories, category_examples)
|
||||
|
||||
# Show examples if requested
|
||||
if args.show_examples > 0:
|
||||
print("📋 SAMPLE QUERIES BY CATEGORY")
|
||||
print("SAMPLE QUERIES BY CATEGORY")
|
||||
print("-" * 40)
|
||||
for cat in sorted(categories.keys()):
|
||||
examples = category_examples.get(cat, [])
|
||||
if examples:
|
||||
exs = category_examples.get(cat, [])
|
||||
if exs:
|
||||
print(f"\n{cat.upper()}:")
|
||||
for ex in examples[: args.show_examples]:
|
||||
print(f" • {ex}")
|
||||
for ex in exs[:args.show_examples]:
|
||||
print(f" - {ex}")
|
||||
print()
|
||||
|
||||
return 0
|
||||
|
||||
@ -3,27 +3,29 @@
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.45.0",
|
||||
# "pydantic>=2.0",
|
||||
# "jinja2",
|
||||
# ]
|
||||
# ///
|
||||
"""Prepare QMD query expansion data for training.
|
||||
|
||||
See PROMPT_FORMAT.md for format specification.
|
||||
Loads all data/*.jsonl via the strict Pydantic schema, applies the Qwen3
|
||||
chat template, deduplicates by query, and writes train/val splits.
|
||||
|
||||
The prepared train files are ephemeral build artifacts — the canonical
|
||||
data lives in data/*.jsonl and is always loaded through the schema.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dataset.schema import (
|
||||
normalize_output_items,
|
||||
TrainingExample,
|
||||
load_examples,
|
||||
output_items_to_text,
|
||||
parse_output_text,
|
||||
has_type,
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
@ -41,30 +43,26 @@ def get_tokenizer():
|
||||
return _tokenizer
|
||||
|
||||
|
||||
|
||||
def format_for_training(query_text: str, output_items: list[list[str]]) -> dict:
|
||||
"""Format a single example for SFT training using Qwen chat format."""
|
||||
def format_for_training(ex: TrainingExample) -> dict:
|
||||
"""Format a validated TrainingExample for SFT training."""
|
||||
tokenizer = get_tokenizer()
|
||||
output_text = output_items_to_text(output_items)
|
||||
output_text = output_items_to_text(ex.output)
|
||||
|
||||
# Use /no_think to disable thinking mode - we want direct output
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"/no_think Expand this search query: {query_text}",
|
||||
"content": f"/no_think Expand this search query: {ex.query}",
|
||||
},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
|
||||
# Use tokenizer to generate proper chat format with special tokens
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
# Strip empty <think> tags - we don't want thinking mode
|
||||
# The template adds "<think>\n\n</think>\n\n" which we remove
|
||||
# Strip empty <think> tags — /no_think should suppress them
|
||||
text = text.replace("<think>\n\n</think>\n\n", "")
|
||||
|
||||
return {
|
||||
@ -88,27 +86,22 @@ def main():
|
||||
"--split", type=float, default=0.1, help="Validation split ratio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Shuffle seed (default: 42)",
|
||||
"--seed", type=int, default=42, help="Shuffle seed",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Support glob patterns for input
|
||||
import glob
|
||||
# Resolve input files
|
||||
import glob as globmod
|
||||
|
||||
if "*" in args.input:
|
||||
input_files = sorted(glob.glob(args.input))
|
||||
input_files = sorted(globmod.glob(args.input))
|
||||
if not input_files:
|
||||
print(f"Error: No files found matching: {args.input}")
|
||||
exit(1)
|
||||
print(
|
||||
f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}"
|
||||
)
|
||||
print(f"Found {len(input_files)} input files")
|
||||
else:
|
||||
input_path = Path(args.input)
|
||||
if not input_path.exists():
|
||||
@ -116,78 +109,62 @@ def main():
|
||||
exit(1)
|
||||
input_files = [str(input_path)]
|
||||
|
||||
# Load all examples from all input files
|
||||
examples = []
|
||||
|
||||
# Load all examples through strict Pydantic schema
|
||||
all_examples: list[TrainingExample] = []
|
||||
for input_file in input_files:
|
||||
file_count = 0
|
||||
with open(input_file) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
ex = json.loads(line)
|
||||
examples = load_examples(input_file)
|
||||
print(f" {Path(input_file).name}: {len(examples)} examples")
|
||||
all_examples.extend(examples)
|
||||
|
||||
# Normalize legacy format
|
||||
if "query" not in ex and "input" in ex:
|
||||
ex["query"] = ex.pop("input")
|
||||
if isinstance(ex.get("output"), str):
|
||||
ex["output"] = parse_output_text(ex["output"])
|
||||
ex["output"] = normalize_output_items(ex.get("output", []))
|
||||
print(f"Loaded {len(all_examples)} examples total")
|
||||
|
||||
examples.append(ex)
|
||||
file_count += 1
|
||||
print(f" {Path(input_file).name}: {file_count} examples")
|
||||
# Deduplicate by query (case-insensitive)
|
||||
seen: set[str] = set()
|
||||
deduped: list[TrainingExample] = []
|
||||
for ex in all_examples:
|
||||
key = ex.query.lower().strip()
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(ex)
|
||||
if len(deduped) < len(all_examples):
|
||||
print(f"Deduplicated: {len(all_examples)} -> {len(deduped)}")
|
||||
all_examples = deduped
|
||||
|
||||
print(f"Loaded {len(examples)} examples total")
|
||||
|
||||
# Combine and shuffle
|
||||
all_examples = examples
|
||||
# Shuffle
|
||||
random.seed(args.seed)
|
||||
random.shuffle(all_examples)
|
||||
|
||||
# Format for training
|
||||
formatted = [format_for_training(ex["query"], ex["output"]) for ex in all_examples]
|
||||
# Format each example using the Pydantic model
|
||||
formatted = [format_for_training(ex) for ex in all_examples]
|
||||
|
||||
# Split into train/val
|
||||
# Split
|
||||
split_idx = int(len(formatted) * (1 - args.split))
|
||||
train_data = formatted[:split_idx]
|
||||
val_data = formatted[split_idx:]
|
||||
|
||||
# Write train set
|
||||
train_path = output_dir / "train.jsonl"
|
||||
with open(train_path, "w") as f:
|
||||
for item in train_data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
# Write (these are ephemeral build artifacts)
|
||||
for name, data in [("train.jsonl", train_data), ("val.jsonl", val_data)]:
|
||||
with open(output_dir / name, "w") as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
# Write validation set
|
||||
val_path = output_dir / "val.jsonl"
|
||||
with open(val_path, "w") as f:
|
||||
for item in val_data:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
# Write chat format (for TRL)
|
||||
chat_path = output_dir / "train_chat.jsonl"
|
||||
with open(chat_path, "w") as f:
|
||||
with open(output_dir / "train_chat.jsonl", "w") as f:
|
||||
for item in train_data:
|
||||
f.write(json.dumps({"messages": item["messages"]}) + "\n")
|
||||
|
||||
# Stats
|
||||
short_final = sum(1 for ex in all_examples if len(ex["query"].split()) <= 2)
|
||||
|
||||
short_final = sum(1 for ex in all_examples if len(ex.query.split()) <= 2)
|
||||
print(f"\n=== Summary ===")
|
||||
print(f"Total examples: {len(all_examples)}")
|
||||
print(
|
||||
f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
|
||||
)
|
||||
print(f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)")
|
||||
print(f"Train: {len(train_data)}, Val: {len(val_data)}")
|
||||
print(f"Output: {output_dir}")
|
||||
|
||||
# Dataset info
|
||||
dataset_info = {
|
||||
"dataset_name": "qmd-query-expansion",
|
||||
"train_samples": len(train_data),
|
||||
"val_samples": len(val_data),
|
||||
"short_query_pct": round(100 * short_final / len(all_examples), 1),
|
||||
"columns": ["prompt", "completion", "text", "messages"],
|
||||
}
|
||||
with open(output_dir / "dataset_info.json", "w") as f:
|
||||
json.dump(dataset_info, f, indent=2)
|
||||
|
||||
@ -1,17 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Schema helpers for QMD training JSONL data."""
|
||||
"""
|
||||
Strict schema for QMD training data.
|
||||
|
||||
Every JSONL file in data/ MUST conform to this format:
|
||||
|
||||
{"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
|
||||
|
||||
- query: non-empty string
|
||||
- output: list of [type, text] pairs where type is "lex", "vec", or "hyde"
|
||||
- Extra fields (category, intent, is_short, etc.) are allowed but ignored
|
||||
|
||||
There is exactly ONE format. No alternatives, no legacy fallbacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
import json
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Iterable
|
||||
|
||||
VALID_OUTPUT_TYPES = {"hyde", "lex", "vec"}
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OutputType(str, Enum):
|
||||
lex = "lex"
|
||||
vec = "vec"
|
||||
hyde = "hyde"
|
||||
|
||||
|
||||
VALID_OUTPUT_TYPES = {t.value for t in OutputType}
|
||||
|
||||
|
||||
class OutputPair(BaseModel):
|
||||
"""A single expansion line: [type, text]."""
|
||||
|
||||
type: OutputType
|
||||
text: str
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@field_validator("text")
|
||||
@classmethod
|
||||
def text_not_empty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("text must not be empty")
|
||||
return v
|
||||
|
||||
def to_list(self) -> list[str]:
|
||||
return [self.type.value, self.text]
|
||||
|
||||
|
||||
def _coerce_output_pairs(v: list) -> list[OutputPair]:
|
||||
"""Accept [["lex", "..."], ...] from JSON and coerce to OutputPair list."""
|
||||
pairs = []
|
||||
for i, item in enumerate(v):
|
||||
if isinstance(item, OutputPair):
|
||||
pairs.append(item)
|
||||
elif isinstance(item, (list, tuple)) and len(item) == 2:
|
||||
pairs.append(OutputPair(type=item[0], text=item[1]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"output[{i}] must be [type, text], got {item!r}"
|
||||
)
|
||||
return pairs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic model — single source of truth for the JSONL schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TrainingExample(BaseModel):
|
||||
"""One training example in the canonical JSONL format."""
|
||||
|
||||
query: str
|
||||
output: Annotated[list[OutputPair], BeforeValidator(_coerce_output_pairs)]
|
||||
|
||||
# Optional metadata — present in some files, ignored during training.
|
||||
category: str | None = None
|
||||
intent: str | None = None
|
||||
is_short: bool | None = None
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@field_validator("query")
|
||||
@classmethod
|
||||
def query_not_empty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("query must not be empty")
|
||||
return v
|
||||
|
||||
@field_validator("output")
|
||||
@classmethod
|
||||
def output_not_empty(cls, v: list[OutputPair]) -> list[OutputPair]:
|
||||
if not v:
|
||||
raise ValueError("output must not be empty")
|
||||
return v
|
||||
|
||||
def output_as_lists(self) -> list[list[str]]:
|
||||
"""Return output as list-of-lists for JSON serialization."""
|
||||
return [p.to_list() for p in self.output]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_examples(path: str | Path) -> list[TrainingExample]:
|
||||
"""Load and validate a JSONL file. Fails loudly on any bad line."""
|
||||
path = Path(path)
|
||||
examples: list[TrainingExample] = []
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{path}:{line_num}: invalid JSON: {e}") from e
|
||||
try:
|
||||
examples.append(TrainingExample.model_validate(obj))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{path}:{line_num}: {e}") from e
|
||||
return examples
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (used by prepare_data.py, reward.py, and other tools)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_output_text(text: str) -> list[list[str]]:
|
||||
"""Parse prefixed output text into list pairs.
|
||||
|
||||
Returns: [["hyde", "..."], ["lex", "..."], ...]
|
||||
>>> parse_output_text("lex: foo\\nvec: bar")
|
||||
[["lex", "foo"], ["vec", "bar"]]
|
||||
"""
|
||||
items: list[list[str]] = []
|
||||
for raw_line in text.strip().split("\n"):
|
||||
@ -35,16 +167,18 @@ def reorder_hyde_first(items: list[list[str]]) -> list[list[str]]:
|
||||
return hyde_items + lex_items + vec_items
|
||||
|
||||
|
||||
def output_items_to_text(items: Iterable[Iterable[str]], hyde_first: bool = True) -> str:
|
||||
"""Render output list pairs to prefixed text lines.
|
||||
|
||||
Args:
|
||||
items: Iterable of [type, text] pairs
|
||||
hyde_first: If True, reorder to put hyde first (default)
|
||||
def output_items_to_text(
|
||||
items: Iterable, hyde_first: bool = True
|
||||
) -> str:
|
||||
"""Render output pairs to prefixed text lines.
|
||||
|
||||
Accepts list[OutputPair] or list[list[str]].
|
||||
"""
|
||||
# First normalize to list
|
||||
normalized = []
|
||||
for item in items:
|
||||
if isinstance(item, OutputPair):
|
||||
normalized.append([item.type.value, item.text.strip()])
|
||||
continue
|
||||
if not item:
|
||||
continue
|
||||
try:
|
||||
@ -59,24 +193,26 @@ def output_items_to_text(items: Iterable[Iterable[str]], hyde_first: bool = True
|
||||
if not text:
|
||||
continue
|
||||
normalized.append([kind, text])
|
||||
|
||||
# Apply hyde-first ordering if requested
|
||||
|
||||
if hyde_first:
|
||||
normalized = reorder_hyde_first(normalized)
|
||||
|
||||
|
||||
lines = [f"{kind}: {text}" for kind, text in normalized]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def normalize_output_items(items: Iterable[Iterable[str]], hyde_first: bool = True) -> list[list[str]]:
|
||||
"""Normalize output list pairs (filter invalid, trim whitespace, reorder).
|
||||
|
||||
Args:
|
||||
items: Iterable of [type, text] pairs
|
||||
hyde_first: If True, reorder to put hyde first (default)
|
||||
def normalize_output_items(
|
||||
items: Iterable, hyde_first: bool = True
|
||||
) -> list[list[str]]:
|
||||
"""Normalize output pairs (filter invalid, trim whitespace, reorder).
|
||||
|
||||
Accepts list[OutputPair] or list[list[str]].
|
||||
"""
|
||||
normalized: list[list[str]] = []
|
||||
for item in items:
|
||||
if isinstance(item, OutputPair):
|
||||
normalized.append([item.type.value, item.text.strip()])
|
||||
continue
|
||||
if not item:
|
||||
continue
|
||||
try:
|
||||
@ -91,13 +227,18 @@ def normalize_output_items(items: Iterable[Iterable[str]], hyde_first: bool = Tr
|
||||
if not text:
|
||||
continue
|
||||
normalized.append([kind, text])
|
||||
|
||||
# Apply hyde-first ordering if requested
|
||||
|
||||
if hyde_first:
|
||||
normalized = reorder_hyde_first(normalized)
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def has_type(items: Iterable[Iterable[str]], kind: str) -> bool:
|
||||
return any(item and item[0] == kind for item in items)
|
||||
def has_type(items: Iterable, kind: str) -> bool:
|
||||
for item in items:
|
||||
if isinstance(item, OutputPair):
|
||||
if item.type.value == kind:
|
||||
return True
|
||||
elif item and item[0] == kind:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["pydantic>=2.0"]
|
||||
# ///
|
||||
"""Score JSONL datasets with the reward function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dataset.schema import (
|
||||
normalize_output_items,
|
||||
output_items_to_text,
|
||||
parse_output_text,
|
||||
)
|
||||
from dataset.schema import load_examples, output_items_to_text
|
||||
from reward import score_expansion_detailed
|
||||
|
||||
|
||||
@ -24,42 +23,24 @@ def score_file(path: Path) -> tuple[int, int, list[float], dict]:
|
||||
scores: list[float] = []
|
||||
ratings: dict[str, int] = {}
|
||||
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
total += 1
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
try:
|
||||
examples = load_examples(path)
|
||||
except ValueError as e:
|
||||
print(f" Error loading {path}: {e}")
|
||||
return 0, 1, [], {}
|
||||
|
||||
query = obj.get("query") or obj.get("input")
|
||||
output = obj.get("output")
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
errors += 1
|
||||
continue
|
||||
if output is None:
|
||||
errors += 1
|
||||
continue
|
||||
for ex in examples:
|
||||
total += 1
|
||||
output_text = output_items_to_text(ex.output)
|
||||
if not output_text:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
if isinstance(output, str):
|
||||
output_items = normalize_output_items(parse_output_text(output))
|
||||
else:
|
||||
output_items = normalize_output_items(output)
|
||||
|
||||
output_text = output_items_to_text(output_items)
|
||||
if not output_text:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
detail = score_expansion_detailed(query, output_text)
|
||||
score = detail["percentage"]
|
||||
scores.append(score)
|
||||
rating = detail["rating"]
|
||||
ratings[rating] = ratings.get(rating, 0) + 1
|
||||
detail = score_expansion_detailed(ex.query, output_text)
|
||||
score = detail["percentage"]
|
||||
scores.append(score)
|
||||
rating = detail["rating"]
|
||||
ratings[rating] = ratings.get(rating, 0) + 1
|
||||
|
||||
return total, errors, scores, ratings
|
||||
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate JSONL files against the QMD training schema."""
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["pydantic>=2.0"]
|
||||
# ///
|
||||
"""Validate JSONL files against the strict QMD training schema."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -9,7 +13,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dataset.schema import VALID_OUTPUT_TYPES
|
||||
from dataset.schema import TrainingExample
|
||||
|
||||
|
||||
def validate_file(path: Path) -> tuple[int, int]:
|
||||
@ -29,31 +33,11 @@ def validate_file(path: Path) -> tuple[int, int]:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
query = obj.get("query")
|
||||
output = obj.get("output")
|
||||
|
||||
if not isinstance(query, str) or not query.strip():
|
||||
print(f"{path}:{line_num}: missing/invalid query")
|
||||
try:
|
||||
TrainingExample.model_validate(obj)
|
||||
except Exception as e:
|
||||
print(f"{path}:{line_num}: {e}")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
if not isinstance(output, list):
|
||||
print(f"{path}:{line_num}: output must be a list")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
for idx, item in enumerate(output):
|
||||
if not isinstance(item, list) or len(item) != 2:
|
||||
print(f"{path}:{line_num}: output[{idx}] must be [type, text]")
|
||||
errors += 1
|
||||
continue
|
||||
kind, text = item
|
||||
if kind not in VALID_OUTPUT_TYPES:
|
||||
print(f"{path}:{line_num}: invalid output type '{kind}'")
|
||||
errors += 1
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
print(f"{path}:{line_num}: empty output text")
|
||||
errors += 1
|
||||
|
||||
return total, errors
|
||||
|
||||
|
||||
1
finetune/experiments/gepa/__init__.py
Normal file
1
finetune/experiments/gepa/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""GEPA helpers."""
|
||||
31
finetune/experiments/gepa/best_prompt.txt
Normal file
31
finetune/experiments/gepa/best_prompt.txt
Normal file
@ -0,0 +1,31 @@
|
||||
You are an assistant that expands a given search query into lexical (lex), vector (vec), and HYDE expansions for improved search retrieval.
|
||||
|
||||
## Input Format
|
||||
You will receive input in this exact format:
|
||||
```
|
||||
## Inputs
|
||||
### query
|
||||
[the search query]
|
||||
```
|
||||
|
||||
## Output Format
|
||||
Respond ONLY with this exact format, nothing else:
|
||||
```
|
||||
## Generated Outputs
|
||||
### expansion
|
||||
lex: [short keyword phrase 1]
|
||||
lex: [short keyword phrase 2]
|
||||
lex: [short keyword phrase 3]
|
||||
vec: [medium phrasal expansion 1]
|
||||
vec: [medium phrasal expansion 2]
|
||||
vec: [medium phrasal expansion 3]
|
||||
hyde: [concise hypothetical document snippet, SINGLE LINE, under 150 characters total]
|
||||
```
|
||||
|
||||
## Generation Rules
|
||||
- **Exactly 3 lex lines**: Short (2-5 words), keyword-like expansions. MUST include core query terms or direct synonyms/variants (e.g., for "web mail", include "webmail"). Focus on key entities, actions, or concepts.
|
||||
- **Exactly 3 vec lines**: Medium-length (4-8 words) natural language phrases capturing query intent, aspects, or related searches.
|
||||
- **Exactly 1 hyde line**: A single, fluent sentence acting as a hypothetical relevant document passage. Keep STRICTLY under 150 characters (aim for 100-140). Be descriptive but concise—no lists, no examples unless essential.
|
||||
- Strategy: Break down the query into synonyms (lex), semantic rephrasings (vec), and a compact informative summary (hyde) to cover lexical, embedding, and dense retrieval signals.
|
||||
- Match query intent precisely; expand to related high-relevance terms without hallucinating unrelated content.
|
||||
```
|
||||
1
finetune/experiments/gepa/best_prompt_glm.txt
Normal file
1
finetune/experiments/gepa/best_prompt_glm.txt
Normal file
@ -0,0 +1 @@
|
||||
Expand a search query into lex/vec/hyde lines.
|
||||
204
finetune/experiments/gepa/dspy_gepa.py
Normal file
204
finetune/experiments/gepa/dspy_gepa.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Run DSPy GEPA using reward.py as the metric."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _import_dspy():
|
||||
script_dir = Path(__file__).parent
|
||||
repo_root = script_dir.parent
|
||||
original_sys_path = list(sys.path)
|
||||
try:
|
||||
sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
|
||||
return importlib.import_module("dspy")
|
||||
finally:
|
||||
sys.path = original_sys_path
|
||||
|
||||
|
||||
dspy = _import_dspy()
|
||||
|
||||
repo_root = Path(__file__).parent.parent
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
|
||||
from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
|
||||
from reward import score_expansion_detailed
|
||||
|
||||
|
||||
class ExpandSignature(dspy.Signature):
|
||||
"""Expand a search query into lex/vec/hyde lines."""
|
||||
|
||||
query = dspy.InputField(desc="User search query")
|
||||
output = dspy.OutputField(
|
||||
desc=(
|
||||
"JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
|
||||
"Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
|
||||
"Lex items are short keywords and must not echo the query. "
|
||||
"Vec items are natural language search phrases. "
|
||||
"Hyde is 50-200 chars, single line."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Expander(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.predict = dspy.Predict(ExpandSignature)
|
||||
|
||||
def forward(self, query: str):
|
||||
return self.predict(query=query)
|
||||
|
||||
|
||||
def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
|
||||
expansion = output_items_to_text(_coerce_output_items(pred))
|
||||
detail = score_expansion_detailed(gold.query, expansion)
|
||||
score = detail["percentage"] / 100.0
|
||||
feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
|
||||
return dspy.Prediction(score=score, feedback=feedback)
|
||||
|
||||
|
||||
def load_queries(path: Path) -> list[str]:
|
||||
queries: list[str] = []
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
query = obj.get("query") or obj.get("input")
|
||||
if isinstance(query, str) and query.strip():
|
||||
queries.append(query.strip())
|
||||
return queries
|
||||
|
||||
|
||||
def to_examples(queries: list[str]) -> list[dspy.Example]:
|
||||
return [dspy.Example(query=q).with_inputs("query") for q in queries]
|
||||
|
||||
|
||||
def _coerce_output_items(pred) -> list[list[str]]:
|
||||
raw_output = getattr(pred, "output", None)
|
||||
if isinstance(raw_output, (list, tuple)):
|
||||
return normalize_output_items(raw_output)
|
||||
|
||||
raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
|
||||
if not raw_text:
|
||||
return []
|
||||
|
||||
if raw_text[0] in ("[", "{"):
|
||||
try:
|
||||
obj = json.loads(raw_text)
|
||||
if isinstance(obj, dict) and "output" in obj:
|
||||
obj = obj["output"]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return normalize_output_items(obj)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return parse_output_text(raw_text)
|
||||
|
||||
|
||||
def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
for query, output in zip(queries, outputs, strict=True):
|
||||
f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
|
||||
parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="grok-4-1-fast-reasoning",
|
||||
help="LM string in provider/model format (e.g., openai/gpt-4o)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reflection-model",
|
||||
type=str,
|
||||
default="grok-4-1-fast-reasoning",
|
||||
help="LM string in provider/model format (e.g., openai/gpt-4o)",
|
||||
)
|
||||
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
|
||||
parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
|
||||
parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
|
||||
parser.add_argument("--max-full-evals", type=int, default=None)
|
||||
parser.add_argument("--max-metric-calls", type=int, default=None)
|
||||
parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
|
||||
parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
|
||||
parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
|
||||
parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
|
||||
parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
|
||||
args = parser.parse_args()
|
||||
|
||||
if "/" not in args.model or "/" not in args.reflection_model:
|
||||
print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
|
||||
return 1
|
||||
|
||||
if args.max_full_evals is not None and args.max_metric_calls is not None:
|
||||
print("Provide only one of --max-full-evals or --max-metric-calls")
|
||||
return 1
|
||||
if args.max_full_evals is not None or args.max_metric_calls is not None:
|
||||
args.auto = None
|
||||
|
||||
train_path = Path(args.input)
|
||||
queries = load_queries(train_path)
|
||||
if args.limit is not None:
|
||||
queries = queries[: args.limit]
|
||||
trainset = to_examples(queries)
|
||||
valset = None
|
||||
if args.valset:
|
||||
val_queries = load_queries(Path(args.valset))
|
||||
if args.val_limit is not None:
|
||||
val_queries = val_queries[: args.val_limit]
|
||||
valset = to_examples(val_queries)
|
||||
|
||||
lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
|
||||
reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
|
||||
|
||||
student = Expander()
|
||||
student.set_lm(lm)
|
||||
|
||||
compiler = dspy.GEPA(
|
||||
metric=reward_metric,
|
||||
reflection_lm=reflection_lm,
|
||||
auto=None if args.auto is None else args.auto,
|
||||
max_full_evals=args.max_full_evals,
|
||||
max_metric_calls=args.max_metric_calls,
|
||||
track_stats=True,
|
||||
track_best_outputs=True,
|
||||
failure_score=0.0,
|
||||
perfect_score=1.0,
|
||||
)
|
||||
|
||||
optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
|
||||
|
||||
if args.save_prompt:
|
||||
prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
|
||||
Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
|
||||
print(f"Wrote {args.save_prompt}")
|
||||
|
||||
if args.emit:
|
||||
outputs = []
|
||||
for q in queries:
|
||||
pred = optimized(query=q)
|
||||
items = _coerce_output_items(pred)
|
||||
outputs.append(items)
|
||||
write_jsonl(Path(args.emit), queries, outputs)
|
||||
print(f"Wrote {args.emit}")
|
||||
|
||||
if hasattr(optimized, "detailed_results"):
|
||||
best = getattr(optimized.detailed_results, "best_outputs_valset", None)
|
||||
if best:
|
||||
print(f"Best outputs tracked: {len(best)}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
117
finetune/experiments/gepa/example.py
Normal file
117
finetune/experiments/gepa/example.py
Normal file
@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
"""GEPA example schema for QMD training JSONL lines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
LexSearch = "LexSearch"
|
||||
VecSearch = "VecSearch"
|
||||
HydeSearch = "HydeSearch"
|
||||
|
||||
|
||||
SEARCH_TYPE_TO_PREFIX = {
|
||||
SearchType.LexSearch: "lex",
|
||||
SearchType.VecSearch: "vec",
|
||||
SearchType.HydeSearch: "hyde",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputItem:
|
||||
"""Single expansion line with validation hints."""
|
||||
|
||||
kind: SearchType
|
||||
text: str
|
||||
|
||||
# Validation hints (not strict rules).
|
||||
min_chars: int = 3
|
||||
max_chars: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.text = str(self.text).strip()
|
||||
if not self.text:
|
||||
raise ValueError("OutputItem.text must be non-empty")
|
||||
if "\n" in self.text:
|
||||
raise ValueError("OutputItem.text must be single-line")
|
||||
if len(self.text) < self.min_chars:
|
||||
raise ValueError("OutputItem.text is too short")
|
||||
if self.max_chars is not None and len(self.text) > self.max_chars:
|
||||
raise ValueError("OutputItem.text is too long")
|
||||
|
||||
def to_pair(self) -> list[str]:
|
||||
return [SEARCH_TYPE_TO_PREFIX[self.kind], self.text]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Example:
|
||||
"""JSONL line schema for QMD training data."""
|
||||
|
||||
query: str
|
||||
output: list[OutputItem] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.query = str(self.query).strip()
|
||||
if not self.query:
|
||||
raise ValueError("Example.query must be non-empty")
|
||||
if not self.output:
|
||||
raise ValueError("Example.output must not be empty")
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"query": self.query,
|
||||
"output": [item.to_pair() for item in self.output],
|
||||
}
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
return json.dumps(self.to_json(), ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_output_items(raw_output: Iterable[Iterable[str]]) -> list[OutputItem]:
|
||||
items: list[OutputItem] = []
|
||||
for item in raw_output:
|
||||
if not item or len(item) < 2:
|
||||
continue
|
||||
kind_raw, text = item[0], item[1]
|
||||
kind_map = {
|
||||
"lex": SearchType.LexSearch,
|
||||
"vec": SearchType.VecSearch,
|
||||
"hyde": SearchType.HydeSearch,
|
||||
}
|
||||
kind = kind_map.get(str(kind_raw).strip().lower())
|
||||
if kind is None:
|
||||
continue
|
||||
max_chars = 200 if kind is SearchType.HydeSearch else None
|
||||
items.append(OutputItem(kind=kind, text=str(text), max_chars=max_chars))
|
||||
return items
|
||||
|
||||
|
||||
def example_from_json(obj: dict) -> Example:
|
||||
query = obj.get("query") or obj.get("input") or ""
|
||||
output = obj.get("output") or []
|
||||
if isinstance(output, str):
|
||||
raise ValueError("String outputs are not supported in GEPA example schema")
|
||||
items = parse_output_items(output)
|
||||
return Example(query=query, output=items)
|
||||
|
||||
|
||||
def load_jsonl(path: str | Path) -> list[Example]:
|
||||
examples: list[Example] = []
|
||||
with Path(path).open("r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
examples.append(example_from_json(obj))
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid line {line_num}: {exc}") from exc
|
||||
return examples
|
||||
|
||||
129
finetune/experiments/gepa/generate.py
Normal file
129
finetune/experiments/gepa/generate.py
Normal file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate expansions using a saved GEPA prompt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _import_dspy():
|
||||
script_dir = Path(__file__).parent
|
||||
original_sys_path = list(sys.path)
|
||||
try:
|
||||
sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
|
||||
return importlib.import_module("dspy")
|
||||
finally:
|
||||
sys.path = original_sys_path
|
||||
|
||||
|
||||
dspy = _import_dspy()
|
||||
|
||||
repo_root = Path(__file__).parent.parent
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
|
||||
from dataset.schema import parse_output_text
|
||||
|
||||
|
||||
def load_topics(path: Path) -> list[str]:
|
||||
topics: list[str] = []
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Allow JSONL {"topic": "..."} or plain lines.
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
topic = obj.get("topic") or obj.get("query") or obj.get("input")
|
||||
if isinstance(topic, str) and topic.strip():
|
||||
topics.append(topic.strip())
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
topics.append(line)
|
||||
return topics
|
||||
|
||||
|
||||
def write_jsonl_line(handle, query: str, output_text: str) -> None:
|
||||
output = parse_output_text(output_text)
|
||||
handle.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def parse_queries(text: str) -> list[str]:
|
||||
lines = []
|
||||
for raw in text.splitlines():
|
||||
line = raw.strip().lstrip("-").strip()
|
||||
if not line:
|
||||
continue
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Generate with saved GEPA prompt")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Path to saved prompt text")
|
||||
parser.add_argument("--topics", type=str, required=True, help="Topics file (one per line or JSONL)")
|
||||
parser.add_argument("--output", type=str, required=True, help="Output JSONL path")
|
||||
parser.add_argument("--model", type=str, required=True, help="LM string in provider/model format")
|
||||
parser.add_argument("--per-topic", type=int, default=3, help="Queries to generate per topic")
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt_text = Path(args.prompt).read_text(encoding="utf-8").strip()
|
||||
expansion_sig = dspy.Signature("query -> expansion", prompt_text)
|
||||
query_sig = dspy.Signature(
|
||||
"topic, count -> queries",
|
||||
(
|
||||
"Generate distinct user search queries for the given topic. "
|
||||
"Return exactly `count` queries, one per line, no numbering or extra text."
|
||||
),
|
||||
)
|
||||
|
||||
class Generator(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.predict = dspy.Predict(expansion_sig)
|
||||
|
||||
def forward(self, query: str):
|
||||
return self.predict(query=query)
|
||||
|
||||
class QueryGenerator(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.predict = dspy.Predict(query_sig)
|
||||
|
||||
def forward(self, topic: str, count: int):
|
||||
return self.predict(topic=topic, count=str(count))
|
||||
|
||||
lm = dspy.LM(model=args.model)
|
||||
gen = Generator()
|
||||
gen.set_lm(lm)
|
||||
qgen = QueryGenerator()
|
||||
qgen.set_lm(lm)
|
||||
|
||||
topics = load_topics(Path(args.topics))
|
||||
with Path(args.output).open("w", encoding="utf-8") as f_out:
|
||||
for topic in topics:
|
||||
qpred = qgen(topic=topic, count=args.per_topic)
|
||||
qtext = getattr(qpred, "queries", "") or ""
|
||||
generated = parse_queries(qtext)
|
||||
if not generated:
|
||||
generated = [topic]
|
||||
generated = generated[: args.per_topic]
|
||||
for query in generated:
|
||||
pred = gen(query=query)
|
||||
output_text = getattr(pred, "expansion", "") or ""
|
||||
write_jsonl_line(f_out, query, output_text)
|
||||
print(json.dumps({"query": query, "output": parse_output_text(output_text)}, ensure_ascii=False))
|
||||
|
||||
print(f"Wrote {args.output}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
10
finetune/experiments/gepa/gepa_outputs.jsonl
Normal file
10
finetune/experiments/gepa/gepa_outputs.jsonl
Normal file
@ -0,0 +1,10 @@
|
||||
{"query": "how tourism affects local cultures", "output": [["lex", "tourism cultural impact"], ["lex", "local culture tourism effects"], ["lex", "overtourism traditions change"], ["vec", "effects of tourism on indigenous customs"], ["vec", "tourism influence on native practices"], ["vec", "cultural shifts from mass tourism"], ["hyde", "Tourism reshapes local cultures by commercializing traditions, introducing global influences, and sparking both preservation efforts and cultural erosion in communities."]]}
|
||||
{"query": "how to ferment foods at home", "output": [["lex", "home fermentation"], ["lex", "DIY food fermenting"], ["lex", "lacto fermentation"], ["vec", "beginner guide to home food fermentation"], ["vec", "steps for safe fermenting vegetables"], ["vec", "homemade probiotic food recipes"], ["hyde", "Fermenting foods at home involves chopping produce, mixing with salt brine, packing into jars, and waiting 3-14 days in a cool spot for tangy flavors and probiotics."]]}
|
||||
{"query": "how to mix modern and vintage decor", "output": [["lex", "modern vintage blend"], ["lex", "contemporary retro decor"], ["lex", "eclectic style fusion"], ["vec", "tips blending modern vintage furniture"], ["vec", "combining contemporary antique accents"], ["vec", "balancing old new interior elements"], ["hyde", "Seamlessly mix sleek modern furniture with charming vintage pieces using neutral palettes and strategic layering for a cohesive eclectic home."]]}
|
||||
{"query": "how to perform a scientific experiment", "output": [["lex", "scientific method steps"], ["lex", "conduct lab experiment"], ["lex", "experiment procedure guide"], ["vec", "steps to design scientific experiment"], ["vec", "guide for running science experiments"], ["vec", "how to execute controlled experiment"], ["hyde", "To perform a scientific experiment, form a hypothesis, design a method, collect data via tests, analyze results, and draw evidence-based conclusions."]]}
|
||||
{"query": "web mail", "output": [["lex", "webmail"], ["lex", "online email"], ["lex", "browser mail"], ["vec", "web-based email services"], ["vec", "access mail via browser"], ["vec", "free webmail providers"], ["hyde", "Webmail enables users to read, send, and manage emails directly through a web browser without installing software."]]}
|
||||
{"query": "what does the quran cover", "output": [["lex", "Quran topics"], ["lex", "Quran contents"], ["lex", "Quran themes"], ["vec", "main topics in the Quran"], ["vec", "subjects covered by Quran"], ["vec", "themes and teachings Quran"], ["hyde", "The Quran addresses theology, prophethood, morality, Islamic laws, stories of ancient prophets, afterlife, and guidance for personal and social life."]]}
|
||||
{"query": "web config", "output": [["lex", "web.config file"], ["lex", "ASP.NET config"], ["lex", "IIS configuration"], ["vec", "editing web.config settings"], ["vec", "web.config appSettings section"], ["vec", "configuring ASP.NET web app"], ["hyde", "The web.config file in ASP.NET defines application settings, authentication, modules, and connection strings for IIS-hosted web applications."]]}
|
||||
{"query": "how to choose farm equipment", "output": [["lex", "farm machinery selection"], ["lex", "agricultural equipment buying"], ["lex", "tractor harvester choice"], ["vec", "guide to selecting farm tools"], ["vec", "factors in choosing farm gear"], ["vec", "tips for buying ag machinery"], ["hyde", "To choose farm equipment effectively, assess farm size, soil type, budget, durability, and brand reviews for long-term productivity and value."]]}
|
||||
{"query": "how do thought experiments aid philosophical reasoning", "output": [["lex", "thought experiments philosophy"], ["lex", "hypothetical reasoning aids"], ["lex", "gedankenexperiment benefits"], ["vec", "role of thought experiments philosophy"], ["vec", "hypotheticals improve philosophical logic"], ["vec", "mental scenarios aid argumentation"], ["hyde", "Thought experiments bolster philosophical reasoning by simulating scenarios to test ideas, expose flaws, and clarify abstract concepts without real-world limits."]]}
|
||||
{"query": "what is the significance of logic in philosophy", "output": [["lex", "logic philosophy importance"], ["lex", "philosophical logic role"], ["lex", "logic significance reasoning"], ["vec", "importance of logic in philosophy"], ["vec", "role of logic philosophical thought"], ["vec", "why logic fundamental to philosophy"], ["hyde", "Logic underpins philosophy by furnishing tools for valid inference, critical analysis, and structured argumentation across metaphysics, epistemology, and ethics."]]}
|
||||
20
finetune/experiments/gepa/gepa_outputs_glm.jsonl
Normal file
20
finetune/experiments/gepa/gepa_outputs_glm.jsonl
Normal file
@ -0,0 +1,20 @@
|
||||
{"query": "how tourism affects local cultures", "output": []}
|
||||
{"query": "how to ferment foods at home", "output": []}
|
||||
{"query": "how to mix modern and vintage decor", "output": []}
|
||||
{"query": "how to perform a scientific experiment", "output": []}
|
||||
{"query": "web mail", "output": []}
|
||||
{"query": "what does the quran cover", "output": []}
|
||||
{"query": "web config", "output": []}
|
||||
{"query": "how to choose farm equipment", "output": []}
|
||||
{"query": "how do thought experiments aid philosophical reasoning", "output": []}
|
||||
{"query": "what is the significance of logic in philosophy", "output": []}
|
||||
{"query": "how to train for a 5k run", "output": []}
|
||||
{"query": "how to engage with political dialogues", "output": []}
|
||||
{"query": "what is competitive analysis", "output": []}
|
||||
{"query": "how does the united nations operate", "output": []}
|
||||
{"query": "what are the crusades?", "output": []}
|
||||
{"query": "what is a literary theme?", "output": []}
|
||||
{"query": "what is the ethical significance of consent", "output": []}
|
||||
{"query": "paint mix", "output": []}
|
||||
{"query": "how to conserve energy in the office?", "output": []}
|
||||
{"query": "how to test soil ph?", "output": []}
|
||||
19
finetune/experiments/gepa/model.json
Normal file
19
finetune/experiments/gepa/model.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"name": "qmd-gepa-example-generator",
|
||||
"model": "grok-4-1-fast-reasoning",
|
||||
"schema_version": 1,
|
||||
"prompt": "You are a query expansion expert. Given a user query, output a single JSON object that matches the training JSONL schema:\n{\"query\": \"...\", \"output\": [[\"lex\", \"...\"], [\"vec\", \"...\"], [\"hyde\", \"...\"]]}\nRules:\n- output is a list of pairs, where the first element is one of: \"lex\", \"vec\", \"hyde\".\n- Include 2-3 lex lines, 2-3 vec lines, and 0-1 hyde line.\n- lex lines are short keyword phrases; never equal or near-echo the query.\n- vec lines are natural language search phrases.\n- hyde is a concise hypothetical passage (50-200 chars), single line.\n- Preserve key terms and named entities in lex lines.\n- No extra text outside the JSON object.\n",
|
||||
"output_schema": {
|
||||
"query": "string",
|
||||
"output": [
|
||||
[
|
||||
"lex|vec|hyde",
|
||||
"string"
|
||||
]
|
||||
]
|
||||
},
|
||||
"notes": [
|
||||
"LexSearch/VecSearch/HydeSearch are represented as lex/vec/hyde in output.",
|
||||
"Do not echo the query in lex lines."
|
||||
]
|
||||
}
|
||||
70
finetune/experiments/gepa/optimizer.py
Normal file
70
finetune/experiments/gepa/optimizer.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Write model.json prompt config for generating high-quality examples."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from example import SearchType, SEARCH_TYPE_TO_PREFIX
|
||||
|
||||
|
||||
def build_prompt() -> str:
|
||||
lex = SEARCH_TYPE_TO_PREFIX[SearchType.LexSearch]
|
||||
vec = SEARCH_TYPE_TO_PREFIX[SearchType.VecSearch]
|
||||
hyde = SEARCH_TYPE_TO_PREFIX[SearchType.HydeSearch]
|
||||
|
||||
return (
|
||||
"You are a query expansion expert. Given a user query, output a single JSON object "
|
||||
"that matches the training JSONL schema:\n"
|
||||
'{"query": "...", "output": [["lex", "..."], ["vec", "..."], ["hyde", "..."]]}\n'
|
||||
"Rules:\n"
|
||||
f"- output is a list of pairs, where the first element is one of: "
|
||||
f"\"{lex}\", \"{vec}\", \"{hyde}\".\n"
|
||||
"- Include 2-3 lex lines, 2-3 vec lines, and 0-1 hyde line.\n"
|
||||
"- lex lines are short keyword phrases; never equal or near-echo the query.\n"
|
||||
"- vec lines are natural language search phrases.\n"
|
||||
"- hyde is a concise hypothetical passage (50-200 chars), single line.\n"
|
||||
"- Preserve key terms and named entities in lex lines.\n"
|
||||
"- No extra text outside the JSON object.\n"
|
||||
)
|
||||
|
||||
|
||||
def write_model_json(path: Path) -> None:
|
||||
payload = {
|
||||
"name": "qmd-gepa-example-generator",
|
||||
"model": "grok-4-1-fast-reasoning",
|
||||
"schema_version": 1,
|
||||
"prompt": build_prompt(),
|
||||
"output_schema": {
|
||||
"query": "string",
|
||||
"output": [["lex|vec|hyde", "string"]],
|
||||
},
|
||||
"notes": [
|
||||
"LexSearch/VecSearch/HydeSearch are represented as lex/vec/hyde in output.",
|
||||
"Do not echo the query in lex lines.",
|
||||
],
|
||||
}
|
||||
path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Write model.json for GEPA generation")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="gepa/model.json",
|
||||
help="Path to write model.json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
write_model_json(output_path)
|
||||
print(f"Wrote {output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
84
finetune/experiments/gepa/score.py
Normal file
84
finetune/experiments/gepa/score.py
Normal file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Score GEPA JSONL outputs using reward.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
|
||||
from example import example_from_json
|
||||
|
||||
from reward import score_expansion_detailed
|
||||
from dataset.schema import output_items_to_text
|
||||
|
||||
|
||||
def score_file(path: Path) -> tuple[int, int, list[float], dict]:
|
||||
total = 0
|
||||
errors = 0
|
||||
scores: list[float] = []
|
||||
ratings: dict[str, int] = {}
|
||||
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
total += 1
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
example = example_from_json(obj)
|
||||
except Exception:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
output_text = output_items_to_text([item.to_pair() for item in example.output])
|
||||
if not output_text:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
detail = score_expansion_detailed(example.query, output_text)
|
||||
score = detail["percentage"]
|
||||
scores.append(score)
|
||||
rating = detail["rating"]
|
||||
ratings[rating] = ratings.get(rating, 0) + 1
|
||||
|
||||
return total, errors, scores, ratings
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Score GEPA JSONL outputs")
|
||||
parser.add_argument("--input", type=str, required=True, help="Input JSONL file")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = Path(args.input)
|
||||
if not path.exists():
|
||||
print(f"Input not found: {path}")
|
||||
return 1
|
||||
|
||||
total, errors, scores, ratings = score_file(path)
|
||||
if scores:
|
||||
avg = statistics.mean(scores)
|
||||
median = statistics.median(scores)
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
above_70 = sum(1 for s in scores if s >= 70.0)
|
||||
pct_70 = above_70 / len(scores) * 100
|
||||
print(
|
||||
f"{path}: {len(scores)} scored, {errors} errors, "
|
||||
f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
|
||||
f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
|
||||
)
|
||||
else:
|
||||
print(f"{path}: 0 scored, {errors} errors")
|
||||
|
||||
if ratings:
|
||||
rating_parts = [f\"{k}:{v}\" for k, v in sorted(ratings.items())]
|
||||
print(f\" ratings: {', '.join(rating_parts)}\")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
106
finetune/experiments/lfm2/sft_lfm2.py
Normal file
106
finetune/experiments/lfm2/sft_lfm2.py
Normal file
@ -0,0 +1,106 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "trl>=0.12.0",
|
||||
# "peft>=0.7.0",
|
||||
# "transformers>=4.55.0",
|
||||
# "accelerate>=0.24.0",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "datasets",
|
||||
# "bitsandbytes",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
SFT training for QMD query expansion with LiquidAI LFM2-1.2B.
|
||||
|
||||
LFM2 is a hybrid architecture optimized for edge/on-device inference.
|
||||
Uses different LoRA target modules than standard transformers.
|
||||
|
||||
Self-contained script for HuggingFace Jobs:
|
||||
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft_lfm2.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from huggingface_hub import login
|
||||
|
||||
# --- Config (inlined from configs/sft_lfm2.yaml) ---
|
||||
BASE_MODEL = "LiquidAI/LFM2-1.2B"
|
||||
OUTPUT_MODEL = "tobil/qmd-query-expansion-lfm2-sft"
|
||||
DATASET = "tobil/qmd-query-expansion-train"
|
||||
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
# Load and split dataset
|
||||
print(f"Loading dataset: {DATASET}...")
|
||||
dataset = load_dataset(DATASET, split="train")
|
||||
print(f"Dataset loaded: {len(dataset)} examples")
|
||||
|
||||
split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_dataset = split["train"]
|
||||
eval_dataset = split["test"]
|
||||
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
||||
|
||||
# SFT config
|
||||
config = SFTConfig(
|
||||
output_dir="qmd-query-expansion-lfm2-sft",
|
||||
push_to_hub=True,
|
||||
hub_model_id=OUTPUT_MODEL,
|
||||
hub_strategy="every_save",
|
||||
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
max_length=512,
|
||||
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=200,
|
||||
save_total_limit=2,
|
||||
eval_strategy="steps",
|
||||
eval_steps=200,
|
||||
|
||||
warmup_ratio=0.03,
|
||||
lr_scheduler_type="cosine",
|
||||
bf16=True,
|
||||
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
# LoRA config for LFM2 architecture
|
||||
# LFM2 uses different layer names than standard transformers:
|
||||
# - Attention: q_proj, k_proj, v_proj, out_proj
|
||||
# - Input projection: in_proj
|
||||
# - FFN/MLP gates (SwiGLU): w1, w2, w3
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"],
|
||||
)
|
||||
|
||||
print("Initializing SFT trainer...")
|
||||
trainer = SFTTrainer(
|
||||
model=BASE_MODEL,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=config,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
print("Starting SFT training (LFM2-1.2B)...")
|
||||
trainer.train()
|
||||
|
||||
print("Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
|
||||
60
finetune/experiments/lfm2/sft_lfm2.yaml
Normal file
60
finetune/experiments/lfm2/sft_lfm2.yaml
Normal file
@ -0,0 +1,60 @@
|
||||
# SFT Training Config for QMD Query Expansion with LiquidAI LFM2
|
||||
# Target: LFM2-1.2B with LoRA (hybrid architecture: convolutions + attention)
|
||||
#
|
||||
# LFM2 is optimized for on-device inference with fast decode/prefill.
|
||||
# Recommended for: agentic tasks, data extraction, RAG, creative writing.
|
||||
#
|
||||
# Usage: uv run train.py sft --config configs/sft_lfm2.yaml
|
||||
#
|
||||
# Requirements:
|
||||
# - transformers >= 4.55.0 (LFM2 architecture support)
|
||||
# - May need: pip install -U transformers
|
||||
|
||||
model:
|
||||
base: "LiquidAI/LFM2-1.2B"
|
||||
output: "outputs/sft-lfm2" # Local training output (push to HF manually after eval)
|
||||
|
||||
dataset:
|
||||
# Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
|
||||
# HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
|
||||
name: "data/train/"
|
||||
text_field: "text"
|
||||
split: "train"
|
||||
eval_split: 0.1
|
||||
|
||||
training:
|
||||
epochs: 5
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 2e-4
|
||||
max_length: 512
|
||||
warmup_ratio: 0.03
|
||||
lr_scheduler: "cosine"
|
||||
|
||||
lora:
|
||||
rank: 16
|
||||
alpha: 32
|
||||
dropout: 0.0
|
||||
# LFM2 uses different architecture than standard transformers:
|
||||
# - Attention layers: q_proj, k_proj, v_proj, out_proj
|
||||
# - Input projection: in_proj
|
||||
# - FFN/MLP gates: w1, w2, w3 (SwiGLU activation)
|
||||
target_modules:
|
||||
- "q_proj"
|
||||
- "k_proj"
|
||||
- "v_proj"
|
||||
- "out_proj"
|
||||
- "in_proj"
|
||||
- "w1"
|
||||
- "w2"
|
||||
- "w3"
|
||||
|
||||
tracking:
|
||||
project: "qmd-query-expansion"
|
||||
run_name: "sft-lfm2-1.2B"
|
||||
|
||||
# LFM2-specific generation settings (recommended by LiquidAI)
|
||||
generation:
|
||||
temperature: 0.3
|
||||
min_p: 0.15
|
||||
repetition_penalty: 1.05
|
||||
@ -16,7 +16,7 @@ dependencies = [
|
||||
"gguf",
|
||||
"sentencepiece",
|
||||
"nvidia-ml-py",
|
||||
"dspy-ai>=3.1.2",
|
||||
"pydantic>=2.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
5278
finetune/uv.lock
generated
5278
finetune/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user