Deploy fine-tuned GRPO model as default for query expansion

Switch from generic Qwen3-1.7B-Q8_0 (~2.2GB) to fine-tuned
qmd-query-expansion-1.7B-q4_k_m (~1.1GB). The fine-tuned Q4
scores 91.7% avg with 30/30 Excellent, outperforming the base Q8.

- Update default generate model in src/llm.ts
- Update README model table, architecture diagram, config block
- Add v2 training data, eval scripts, and quantize job
- Remove superseded v1 training data (5,742 → 1,000 examples)
- Update finetune README with v2 results and file structure

Co-Authored-By: Claude (claude-fudge-eap-cc) <noreply@anthropic.com>
This commit is contained in:
Tobi Lutke 2026-01-28 23:24:58 -08:00
parent 5ab78d00a2
commit 8572c2fd94
No known key found for this signature in database
14 changed files with 2267 additions and 6033 deletions

View File

@ -112,7 +112,7 @@ Although the tool works perfectly fine when you just tell your agent to use it o
▼ ▼
┌────────────────┐ ┌────────────────┐
│ Query Expansion│ │ Original Query│
(Qwen3-1.7B) │ │ (×2 weight) │
(fine-tuned) │ │ (×2 weight) │
└───────┬────────┘ └───────┬────────┘
│ │
│ 2 alternative queries │
@ -213,7 +213,7 @@ QMD uses three local GGUF models (auto-downloaded on first use):
|-------|---------|------|
| `embeddinggemma-300M-Q8_0` | Vector embeddings | ~300MB |
| `qwen3-reranker-0.6b-q8_0` | Re-ranking | ~640MB |
| `Qwen3-1.7B-Q8_0` | Query expansion | ~2.2GB |
| `qmd-query-expansion-1.7B-q4_k_m` | Query expansion (fine-tuned) | ~1.1GB |
Models are downloaded from HuggingFace and cached in `~/.cache/qmd/models/`.
@ -515,7 +515,7 @@ Models are configured in `src/llm.ts` as HuggingFace URIs:
```typescript
const DEFAULT_EMBED_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-1.7B-GGUF/Qwen3-1.7B-Q8_0.gguf";
const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
```
### EmbeddingGemma Prompt Format

12
finetune/.gitignore vendored
View File

@ -3,10 +3,11 @@ qmd-query-expansion-*/
*.pt
*.safetensors
# Large data files (stored on HuggingFace Hub)
data/train/train.jsonl
data/train/train_chat.jsonl
data/train/val.jsonl
# Processed data files (regenerated by prepare_data.py)
data/train/
data/train_v2/train.jsonl
data/train_v2/train_chat.jsonl
data/train_v2/val.jsonl
data/qmd_expansion_cleaned.jsonl
data/quality_report.txt
@ -16,6 +17,3 @@ evals/results_*.jsonl
# Python cache
__pycache__/
*.pyc
# Keep the generated source data
!data/qmd_expansion.jsonl

View File

@ -77,14 +77,17 @@ finetune/
├── convert_gguf.py # GGUF conversion for Ollama/llama.cpp
├── jobs/
│ ├── sft.py # Self-contained SFT for HuggingFace Jobs
│ └── grpo.py # Self-contained GRPO for HuggingFace Jobs
│ ├── grpo.py # Self-contained GRPO for HuggingFace Jobs
│ ├── eval.py # Self-contained eval for HuggingFace Jobs
│ ├── eval_common.py # Shared eval utilities
│ └── quantize.py # GGUF quantization for HuggingFace Jobs
├── configs/
│ ├── sft.yaml # SFT hyperparameters for Qwen3-1.7B
│ └── grpo.yaml # GRPO hyperparameters for Qwen3-1.7B
├── evals/
│ └── queries.txt # 31 test queries across 8 categories
├── data/
│ └── qmd_expansion.jsonl # Source training data (5,742 examples)
│ └── qmd_expansion_v2.jsonl # Source training data (1,000 high-quality examples)
├── dataset/
│ ├── generate_data.py # Generate data via Claude API
│ ├── generate_data_offline.py # Generate from existing HF dataset
@ -105,9 +108,9 @@ Teaches the model the `lex:/vec:/hyde:` output format from labeled examples.
| Base model | `Qwen/Qwen3-1.7B` |
| Method | LoRA (rank 16, alpha 32) |
| Target modules | All projection layers (q/k/v/o/gate/up/down) |
| Dataset | 11,124 examples (train split) |
| Dataset | ~2,290 examples (train split) |
| Effective batch size | 16 (4 × 4 gradient accumulation) |
| Epochs | 3 |
| Epochs | 5 |
| Learning rate | 2e-4 (cosine schedule) |
```bash
@ -219,7 +222,7 @@ ollama run qmd-expand
## Data Pipeline
The training data (5,730 examples in `data/qmd_expansion.jsonl`) was generated
The training data (1,000 examples in `data/qmd_expansion_v2.jsonl`) was generated
from two sources and cleaned for quality. To regenerate:
```bash
@ -251,16 +254,17 @@ The two-stage training approach (SFT → GRPO) is standard for structured-output
The reward function is entirely rule-based (no LLM judge) which makes it fast,
deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubric.
## Training Results (Qwen3-1.7B)
## Training Results (Qwen3-1.7B, v2)
### SFT
| Metric | Value |
|--------|-------|
| Final train loss | 0.223 |
| Final eval loss | 0.321 |
| Token accuracy (train) | 94.8% |
| Token accuracy (eval) | 92.4% |
| Final train loss | 0.472 |
| Final eval loss | 0.304 |
| Token accuracy (train) | 97.4% |
| Token accuracy (eval) | 93.8% |
| Epochs | 5 |
| Hardware | A10G (24 GB VRAM) |
### GRPO
@ -273,3 +277,10 @@ deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubri
| Mean completion length | ~58 tokens |
| Training time | ~19 min (200 steps) |
| Hardware | A10G (24 GB VRAM) |
### Evaluation Scores
| Model | Average Score | Excellent (30) |
|-------|--------------|-----------------|
| SFT | 92.0% | 30/30 |
| GRPO | 91.7% | 30/30 |

View File

@ -14,7 +14,7 @@ dataset:
eval_split: 0.1
training:
epochs: 3
epochs: 5
batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 2e-4

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,12 @@
{
"dataset_name": "qmd-query-expansion",
"train_samples": 1145,
"val_samples": 128,
"short_query_pct": 29.3,
"columns": [
"prompt",
"completion",
"text",
"messages"
]
}

490
finetune/jobs/eval.py Normal file
View File

@ -0,0 +1,490 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "peft>=0.7.0",
# "torch",
# "huggingface_hub>=0.20.0",
# "accelerate",
# ]
# ///
"""
Evaluate QMD query expansion models on HuggingFace Jobs.
Self-contained script inlines the reward function and test queries.
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
"""
import argparse
import csv
import io
import json
import os
import re
import sys
from collections import Counter
import torch
from huggingface_hub import HfApi, login
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Config ---
BASE_MODEL = "Qwen/Qwen3-1.7B"
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
# --- Test queries (inlined from evals/queries.txt) ---
QUERIES = [
# Technical documentation
"how to configure authentication",
"typescript async await",
"docker compose networking",
"git rebase vs merge",
"react useEffect cleanup",
# Short/ambiguous
"auth",
"config",
"setup",
"api",
# Named entities
"who is TDS motorsports",
"React hooks tutorial",
"Docker container networking",
"Kubernetes pod deployment",
"AWS Lambda functions",
# Personal notes / journals
"meeting notes project kickoff",
"ideas for new feature",
"todo list app architecture",
# Research / learning
"what is dependency injection",
"difference between sql and nosql",
"kubernetes vs docker swarm",
# Error/debugging
"connection timeout error",
"memory leak debugging",
"cors error fix",
# Temporal / recency
"recent news about Shopify",
"latest AI developments",
"best laptops right now",
"what changed in kubernetes latest version",
# Complex
"how to implement caching with redis in nodejs",
"best practices for api rate limiting",
"setting up ci cd pipeline with github actions",
]
# =============================================================================
# Reward function (inlined from reward.py)
# =============================================================================
STOPWORDS = frozenset({
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
})
KEY_TERM_STOPWORDS = frozenset({
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
})
GENERIC_LEX_PHRASES = frozenset({
'find information about', 'search for', 'look up', 'get information',
'learn about', 'information on', 'details about', 'find out about',
'what is', 'how to', 'guide to', 'help with',
})
CHAT_TEMPLATE_TOKENS = frozenset({
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
'\nassistant\n', '\nuser\n',
})
def parse_expansion(text):
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
if line.startswith("lex:"):
result["lex"].append(line[4:].strip())
elif line.startswith("vec:"):
result["vec"].append(line[4:].strip())
elif line.startswith("hyde:"):
result["hyde"].append(line[5:].strip())
else:
result["invalid"].append(line)
return result
def clean_model_output(text):
text = text.replace('<|im_end|>', '').strip()
used_thinking = '<think>' in text and '</think>' in text
if used_thinking:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
return text, used_thinking
def extract_named_entities(query):
entities = set()
words = query.split()
prev_was_entity = False
for i, word in enumerate(words):
clean = word.strip('.,!?:;()[]"\'')
if not clean:
prev_was_entity = False
continue
is_entity = False
if clean.isupper() and len(clean) >= 2:
entities.add(clean.lower()); is_entity = True
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower()); is_entity = True
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
entities.add(clean.lower()); is_entity = True
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
entities.add(clean.lower()); is_entity = True
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower()); is_entity = True
prev_was_entity = is_entity
return entities
def get_key_terms(query):
return set(query.lower().split()) - KEY_TERM_STOPWORDS
def lex_preserves_key_terms(lex_line, query):
key_terms = get_key_terms(query)
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
def lex_preserves_entities(line, entities):
if not entities: return True
return any(e in line.lower() for e in entities)
def lex_is_generic(lex_line):
lower = lex_line.lower().strip()
for phrase in GENERIC_LEX_PHRASES:
if phrase in lower or lower.startswith(phrase.split()[0]):
remaining = lower
for word in phrase.split():
remaining = remaining.replace(word, '', 1).strip()
if len(remaining) < 3:
return True
return False
def word_set_distance(a, b):
return len(set(a.lower().split()) ^ set(b.lower().split()))
def is_diverse(a, b, min_distance=2):
a, b = a.lower().strip(), b.lower().strip()
if a == b or a in b or b in a: return False
return word_set_distance(a, b) >= min_distance
def echoes_query(expansion, query):
exp, q = expansion.lower().strip(), query.lower().strip()
return exp == q or (q in exp and len(exp) < len(q) + 10)
def word_repetition_penalty(text):
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
return sum((c - 2) * 2 for w, c in counts.items()
if c >= 3 and w not in STOPWORDS and len(w) > 2)
def score_expansion_detailed(query, expansion):
text, used_thinking = clean_model_output(expansion.strip())
deductions = []
def _fail(reason):
return {
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
"think_bonus": 0, "total": 0, "max_possible": 100,
"percentage": 0.0, "rating": "Failed", "deductions": [reason],
}
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
return _fail("CHAT TEMPLATE LEAKAGE")
for line in text.split("\n"):
line = line.strip()
if line and not line.startswith(("lex:", "vec:", "hyde:")):
return _fail(f"INVALID LINE: {line[:50]}")
parsed = parse_expansion(text)
format_score = 10
if parsed["lex"]: format_score += 10
else: deductions.append("missing lex:")
if parsed["vec"]: format_score += 10
else: deductions.append("missing vec:")
diversity_score = 0
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
if types_present >= 2: diversity_score += 10
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
lex_div = 5
for i, a in enumerate(parsed["lex"]):
for b in parsed["lex"][i+1:]:
if not is_diverse(a, b, 2): lex_div -= 2
diversity_score += max(0, lex_div)
vec_div = 5
for i, a in enumerate(parsed["vec"]):
for b in parsed["vec"][i+1:]:
if not is_diverse(a, b, 3): vec_div -= 2
diversity_score += max(0, vec_div)
echo = 5
for exp in parsed["lex"] + parsed["vec"]:
if echoes_query(exp, query): echo -= 3
diversity_score += max(0, echo)
hyde_score = 0
if parsed["hyde"]:
hyde_text = parsed["hyde"][0]
hyde_score += 5
hyde_len = len(hyde_text)
if 50 <= hyde_len <= 200: hyde_score += 5
elif hyde_len < 50: hyde_score += 2
if "\n" not in hyde_text: hyde_score += 5
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
quality_score = 5
if parsed["lex"] and parsed["vec"]:
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
if avg_lex <= avg_vec: quality_score += 5
if parsed["vec"]:
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
quality_score += 5 if natural == len(parsed["vec"]) else 2
if parsed["lex"]:
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
if with_terms == len(parsed["lex"]): quality_score += 5
elif with_terms > 0: quality_score += 2
entity_score = 0
entities = extract_named_entities(query)
if entities and parsed["lex"]:
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
if with_entities == len(parsed["lex"]): entity_score += 15
elif with_entities > 0: entity_score += 5
else: entity_score -= 30
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
if generic_count: entity_score -= generic_count * 15
if parsed["vec"]:
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
if vec_with > 0: entity_score += 5
elif not entities:
entity_score = 10
think_bonus = 0 if used_thinking else 20
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
max_possible = 140 if parsed["hyde"] else 120
percentage = max(0.0, min(100.0, total / max_possible * 100))
if percentage >= 80: rating = "Excellent"
elif percentage >= 60: rating = "Good"
elif percentage >= 40: rating = "Acceptable"
elif percentage >= 20: rating = "Poor"
else: rating = "Failed"
return {
"format": format_score, "diversity": diversity_score, "hyde": hyde_score,
"quality": quality_score, "entity": max(0, entity_score),
"think_bonus": think_bonus, "total": max(0, total),
"max_possible": max_possible, "percentage": round(percentage, 1),
"rating": rating, "deductions": deductions,
"entities_detected": list(entities) if entities else [],
}
# =============================================================================
# Model loading and generation
# =============================================================================
def load_model(base, sft=None, grpo=None):
print(f"Loading tokenizer from {base}...")
tokenizer = AutoTokenizer.from_pretrained(base)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading base model {base}...")
model = AutoModelForCausalLM.from_pretrained(
base, torch_dtype=torch.bfloat16, device_map="auto",
)
if sft:
print(f"Loading and merging SFT adapter {sft}...")
model = PeftModel.from_pretrained(model, sft)
model = model.merge_and_unload()
if grpo:
print(f"Loading GRPO adapter {grpo}...")
model = PeftModel.from_pretrained(model, grpo)
model.eval()
return model, tokenizer
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=max_new_tokens,
temperature=0.7, do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "\nassistant\n" in full_output:
expansion = full_output.split("\nassistant\n")[-1].strip()
elif "assistant\n" in full_output:
expansion = full_output.split("assistant\n")[-1].strip()
else:
expansion = full_output[len(prompt):].strip()
if "<think>" in expansion:
expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
return expansion
# =============================================================================
# Main
# =============================================================================
def results_to_csv(results, label):
"""Convert eval results to CSV string."""
buf = io.StringIO()
writer = csv.writer(buf)
writer.writerow([
"model", "query", "expansion", "score_pct", "rating",
"format", "diversity", "hyde", "quality", "entity", "think_bonus",
"total", "max_possible", "deductions",
])
for r in results:
s = r["scores"]
writer.writerow([
label, r["query"], r["expansion"], s["percentage"], s["rating"],
s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
s["think_bonus"], s["total"], s["max_possible"],
"; ".join(s.get("deductions", [])),
])
return buf.getvalue()
def upload_csv(results, label, repo_id, api):
"""Upload eval results CSV to HuggingFace Hub."""
csv_data = results_to_csv(results, label)
tag = label.split("/")[-1].replace(" ", "_").lower()
filename = f"eval_{tag}.csv"
print(f" Uploading {filename} to {repo_id}...")
api.upload_file(
path_or_fileobj=csv_data.encode("utf-8"),
path_in_repo=filename,
repo_id=repo_id,
repo_type="model",
)
print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
def evaluate_model(model, tokenizer, label):
print(f"\n{'='*70}")
print(f" EVALUATING: {label}")
print(f"{'='*70}")
results = []
for i, query in enumerate(QUERIES, 1):
expansion = generate_expansion(model, tokenizer, query)
scores = score_expansion_detailed(query, expansion)
results.append({"query": query, "expansion": expansion, "scores": scores})
marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
avg = sum(r["scores"]["percentage"] for r in results) / len(results)
ratings = Counter(r["scores"]["rating"] for r in results)
print(f"\n {''*50}")
print(f" Average score: {avg:.1f}%")
print(f" Ratings:")
for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
count = ratings.get(rating, 0)
if count > 0:
print(f" {rating:10s}: {count:2d} {'' * count}")
# Show worst queries
worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
print(f"\n Bottom 5:")
for r in worst:
print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
if r["scores"]["deductions"]:
print(f" {', '.join(r['scores']['deductions'][:3])}")
return results, avg
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
help="HF repo to upload CSV results")
args = parser.parse_args()
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
api = HfApi()
api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
# Evaluate SFT
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
upload_csv(sft_results, "sft", args.upload_repo, api)
if not args.sft_only:
# For GRPO: reload base, merge SFT, then load GRPO adapter
del model
torch.cuda.empty_cache()
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
upload_csv(grpo_results, "grpo", args.upload_repo, api)
# Upload combined comparison CSV
combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
api.upload_file(
path_or_fileobj=combined.encode("utf-8"),
path_in_repo="eval_comparison.csv",
repo_id=args.upload_repo,
repo_type="model",
)
print(f" Uploaded: eval_comparison.csv")
# Comparison
print(f"\n{'='*70}")
print(f" COMPARISON")
print(f"{'='*70}")
print(f" SFT average: {sft_avg:.1f}%")
print(f" GRPO average: {grpo_avg:.1f}%")
print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
improved = sum(1 for s, g in zip(sft_results, grpo_results)
if g["scores"]["percentage"] > s["scores"]["percentage"])
regressed = sum(1 for s, g in zip(sft_results, grpo_results)
if g["scores"]["percentage"] < s["scores"]["percentage"])
print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,354 @@
"""
Common evaluation and reward scoring for QMD query expansion models.
Shared by sft.py and grpo.py for post-training evaluation.
"""
import csv
import io
import re
from collections import Counter
import torch
from huggingface_hub import HfApi
# =============================================================================
# Reward function (single source of truth)
# =============================================================================
STOPWORDS = frozenset({
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
})
KEY_TERM_STOPWORDS = frozenset({
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
})
GENERIC_LEX_PHRASES = frozenset({
'find information about', 'search for', 'look up', 'get information',
'learn about', 'information on', 'details about', 'find out about',
'what is', 'how to', 'guide to', 'help with',
})
CHAT_TEMPLATE_TOKENS = frozenset({
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
'\nassistant\n', '\nuser\n',
})
def parse_expansion(text):
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
if line.startswith("lex:"):
result["lex"].append(line[4:].strip())
elif line.startswith("vec:"):
result["vec"].append(line[4:].strip())
elif line.startswith("hyde:"):
result["hyde"].append(line[5:].strip())
else:
result["invalid"].append(line)
return result
def clean_model_output(text):
text = text.replace('<|im_end|>', '').strip()
used_thinking = '<think>' in text and '</think>' in text
if used_thinking:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
return text, used_thinking
def extract_named_entities(query):
entities = set()
words = query.split()
prev_was_entity = False
for i, word in enumerate(words):
clean = word.strip('.,!?:;()[]"\'')
if not clean:
prev_was_entity = False
continue
is_entity = False
if clean.isupper() and len(clean) >= 2:
entities.add(clean.lower()); is_entity = True
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower()); is_entity = True
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
entities.add(clean.lower()); is_entity = True
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
entities.add(clean.lower()); is_entity = True
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower()); is_entity = True
prev_was_entity = is_entity
return entities
def get_key_terms(query):
return set(query.lower().split()) - KEY_TERM_STOPWORDS
def lex_preserves_key_terms(lex_line, query):
key_terms = get_key_terms(query)
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
def lex_preserves_entities(line, entities):
if not entities:
return True
return any(e in line.lower() for e in entities)
def lex_is_generic(lex_line):
lower = lex_line.lower().strip()
for phrase in GENERIC_LEX_PHRASES:
if phrase in lower or lower.startswith(phrase.split()[0]):
remaining = lower
for word in phrase.split():
remaining = remaining.replace(word, '', 1).strip()
if len(remaining) < 3:
return True
return False
def word_set_distance(a, b):
return len(set(a.lower().split()) ^ set(b.lower().split()))
def is_diverse(a, b, min_distance=2):
a, b = a.lower().strip(), b.lower().strip()
if a == b or a in b or b in a:
return False
return word_set_distance(a, b) >= min_distance
def echoes_query(expansion, query):
exp, q = expansion.lower().strip(), query.lower().strip()
return exp == q or (q in exp and len(exp) < len(q) + 10)
def word_repetition_penalty(text):
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
return sum((c - 2) * 2 for w, c in counts.items()
if c >= 3 and w not in STOPWORDS and len(w) > 2)
def score_expansion(query, expansion):
"""Score expansion as float in [0.0, 1.0] for RL reward."""
text, used_thinking = clean_model_output(expansion.strip())
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
return 0.0
for line in text.split("\n"):
line = line.strip()
if line and not line.startswith(("lex:", "vec:", "hyde:")):
return 0.0
parsed = parse_expansion(text)
format_score = 10
if parsed["lex"]: format_score += 10
if parsed["vec"]: format_score += 10
diversity_score = 0
if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
lex_div = 5
for i, a in enumerate(parsed["lex"]):
for b in parsed["lex"][i+1:]:
if not is_diverse(a, b, 2): lex_div -= 2
diversity_score += max(0, lex_div)
vec_div = 5
for i, a in enumerate(parsed["vec"]):
for b in parsed["vec"][i+1:]:
if not is_diverse(a, b, 3): vec_div -= 2
diversity_score += max(0, vec_div)
echo = 5
for exp in parsed["lex"] + parsed["vec"]:
if echoes_query(exp, query): echo -= 3
diversity_score += max(0, echo)
hyde_score = 0
if parsed["hyde"]:
hyde_text = parsed["hyde"][0]
hyde_score += 5
if 50 <= len(hyde_text) <= 200: hyde_score += 5
elif len(hyde_text) < 50: hyde_score += 2
if "\n" not in hyde_text: hyde_score += 5
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
quality_score = 5
if parsed["lex"] and parsed["vec"]:
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
if avg_lex <= avg_vec: quality_score += 5
if parsed["vec"]:
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
quality_score += 5 if natural == len(parsed["vec"]) else 2
if parsed["lex"]:
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
if with_terms == len(parsed["lex"]): quality_score += 5
elif with_terms > 0: quality_score += 2
entity_score = 0
entities = extract_named_entities(query)
if entities and parsed["lex"]:
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
if with_entities == len(parsed["lex"]): entity_score += 15
elif with_entities > 0: entity_score += 5
else: entity_score -= 30
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
if generic_count: entity_score -= generic_count * 15
if parsed["vec"]:
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
if vec_with > 0: entity_score += 5
elif not entities:
entity_score = 10
think_bonus = 0 if used_thinking else 20
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
max_possible = 140 if parsed["hyde"] else 120
return max(0.0, min(1.0, total / max_possible))
def extract_query_from_prompt(prompt):
"""Extract the search query from a formatted prompt string."""
if "Expand this search query:" in prompt:
query = prompt.split("Expand this search query:")[-1].strip()
if "<|im_end|>" in query:
query = query.split("<|im_end|>")[0].strip()
return query
return prompt.strip()
class QMDRewardFunction:
"""Reward function wrapper for TRL's GRPOTrainer."""
__name__ = "qmd_scoring_reward"
def __call__(self, completions, prompts=None, **kwargs):
rewards = []
for i, completion in enumerate(completions):
query = ""
if prompts and i < len(prompts):
query = extract_query_from_prompt(prompts[i])
rewards.append(score_expansion(query, completion))
return rewards
# =============================================================================
# Evaluation
# =============================================================================
EVAL_QUERIES = [
# Technical documentation
"how to configure authentication",
"typescript async await",
"docker compose networking",
"git rebase vs merge",
"react useEffect cleanup",
# Short/ambiguous
"auth", "config", "setup", "api",
# Named entities
"who is TDS motorsports",
"React hooks tutorial",
"Docker container networking",
"Kubernetes pod deployment",
"AWS Lambda functions",
# Personal notes / journals
"meeting notes project kickoff",
"ideas for new feature",
"todo list app architecture",
# Research / learning
"what is dependency injection",
"difference between sql and nosql",
"kubernetes vs docker swarm",
# Error/debugging
"connection timeout error",
"memory leak debugging",
"cors error fix",
# Temporal / recency
"recent news about Shopify",
"latest AI developments",
"best laptops right now",
"what changed in kubernetes latest version",
# Complex
"how to implement caching with redis in nodejs",
"best practices for api rate limiting",
"setting up ci cd pipeline with github actions",
]
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
"""Generate a query expansion using the model."""
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=max_new_tokens,
temperature=0.7, do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "\nassistant\n" in full_output:
return full_output.split("\nassistant\n")[-1].strip()
elif "assistant\n" in full_output:
return full_output.split("assistant\n")[-1].strip()
return full_output[len(prompt):].strip()
def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
"""Evaluate model on EVAL_QUERIES, print results, upload CSV."""
api = HfApi()
api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
print(f"\n{'='*70}")
print(f" EVALUATING: {label}")
print(f"{'='*70}")
results = []
for i, query in enumerate(EVAL_QUERIES, 1):
expansion = generate_expansion(model, tokenizer, query)
score = score_expansion(query, expansion)
pct = round(score * 100, 1)
rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
print(f" [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s} {query}")
results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})
avg = sum(r["score"] for r in results) / len(results)
ratings = Counter(r["rating"] for r in results)
print(f"\n {''*50}")
print(f" Average score: {avg:.1f}%")
for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
c = ratings.get(r, 0)
if c:
print(f" {r:10s}: {c:2d} {'' * c}")
worst = sorted(results, key=lambda r: r["score"])[:5]
print(f"\n Bottom 5:")
for r in worst:
print(f" {r['score']:5.1f}% {r['query']}")
buf = io.StringIO()
writer = csv.writer(buf)
writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
for r in results:
writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])
filename = f"eval_{label}.csv"
print(f"\n Uploading {filename} to {upload_repo}...")
api.upload_file(
path_or_fileobj=buf.getvalue().encode("utf-8"),
path_in_repo=filename,
repo_id=upload_repo,
repo_type="model",
)
print(f" Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")

View File

@ -0,0 +1,113 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "peft>=0.7.0",
# "torch",
# "huggingface_hub>=0.20.0",
# "accelerate",
# ]
# ///
"""
Verbose eval: prints the actual expansions for every query.
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval_verbose.py
"""
import os
import re
import sys
from collections import Counter
import torch
from huggingface_hub import login
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL = "Qwen/Qwen3-1.7B"
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
QUERIES = [
"how to configure authentication",
"typescript async await",
"docker compose networking",
"git rebase vs merge",
"react useEffect cleanup",
"auth",
"config",
"setup",
"api",
"who is TDS motorsports",
"React hooks tutorial",
"Docker container networking",
"Kubernetes pod deployment",
"AWS Lambda functions",
"meeting notes project kickoff",
"ideas for new feature",
"todo list app architecture",
"what is dependency injection",
"difference between sql and nosql",
"kubernetes vs docker swarm",
"connection timeout error",
"memory leak debugging",
"cors error fix",
"recent news about Shopify",
"latest AI developments",
"best laptops right now",
"what changed in kubernetes latest version",
"how to implement caching with redis in nodejs",
"best practices for api rate limiting",
"setting up ci cd pipeline with github actions",
]
def load_model(base, sft=None, grpo=None):
tokenizer = AutoTokenizer.from_pretrained(base)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
if sft:
model = PeftModel.from_pretrained(model, sft)
model = model.merge_and_unload()
if grpo:
model = PeftModel.from_pretrained(model, grpo)
model.eval()
return model, tokenizer
def generate(model, tokenizer, query):
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
text = tokenizer.decode(out[0], skip_special_tokens=True)
if "\nassistant\n" in text:
text = text.split("\nassistant\n")[-1].strip()
elif "assistant\n" in text:
text = text.split("assistant\n")[-1].strip()
if "<think>" in text:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
return text
def main():
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Loading GRPO model...", file=sys.stderr)
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
for i, query in enumerate(QUERIES, 1):
expansion = generate(model, tokenizer, query)
print(f"\n{'='*60}")
print(f"[{i}/{len(QUERIES)}] {query}")
print(f"{''*60}")
print(expansion)
if __name__ == "__main__":
main()

View File

@ -19,8 +19,7 @@ Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
"""
import os
import re
from collections import Counter
import sys
import torch
from datasets import load_dataset
@ -29,278 +28,15 @@ from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from eval_common import QMDRewardFunction, run_eval
# --- Config (inlined from configs/grpo.yaml) ---
BASE_MODEL = "Qwen/Qwen3-1.7B"
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
DATASET = "tobil/qmd-query-expansion-train-v2"
# =============================================================================
# Reward function (inlined from reward.py — single source of truth)
# =============================================================================
STOPWORDS = frozenset({
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
})
KEY_TERM_STOPWORDS = frozenset({
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
})
GENERIC_LEX_PHRASES = frozenset({
'find information about', 'search for', 'look up', 'get information',
'learn about', 'information on', 'details about', 'find out about',
'what is', 'how to', 'guide to', 'help with',
})
CHAT_TEMPLATE_TOKENS = frozenset({
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
'\nassistant\n', '\nuser\n',
})
def parse_expansion(text: str) -> dict:
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
if line.startswith("lex:"):
result["lex"].append(line[4:].strip())
elif line.startswith("vec:"):
result["vec"].append(line[4:].strip())
elif line.startswith("hyde:"):
result["hyde"].append(line[5:].strip())
else:
result["invalid"].append(line)
return result
def clean_model_output(text: str) -> tuple[str, bool]:
text = text.replace('<|im_end|>', '').strip()
used_thinking = '<think>' in text and '</think>' in text
if used_thinking:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
return text, used_thinking
def extract_named_entities(query: str) -> set:
entities = set()
words = query.split()
prev_was_entity = False
for i, word in enumerate(words):
clean = word.strip('.,!?:;()[]"\'')
if not clean:
prev_was_entity = False
continue
is_entity = False
if clean.isupper() and len(clean) >= 2:
entities.add(clean.lower())
is_entity = True
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower())
is_entity = True
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
entities.add(clean.lower())
is_entity = True
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
entities.add(clean.lower())
is_entity = True
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower())
is_entity = True
prev_was_entity = is_entity
return entities
def get_key_terms(query: str) -> set:
return set(query.lower().split()) - KEY_TERM_STOPWORDS
def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
key_terms = get_key_terms(query)
if not key_terms:
return True
return bool(key_terms & set(lex_line.lower().split()))
def lex_preserves_entities(line: str, entities: set) -> bool:
if not entities:
return True
lower = line.lower()
return any(e in lower for e in entities)
def lex_is_generic(lex_line: str) -> bool:
lower = lex_line.lower().strip()
for phrase in GENERIC_LEX_PHRASES:
if phrase in lower or lower.startswith(phrase.split()[0]):
remaining = lower
for word in phrase.split():
remaining = remaining.replace(word, '', 1).strip()
if len(remaining) < 3:
return True
return False
def word_set_distance(a: str, b: str) -> int:
return len(set(a.lower().split()) ^ set(b.lower().split()))
def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
a, b = a.lower().strip(), b.lower().strip()
if a == b or a in b or b in a:
return False
return word_set_distance(a, b) >= min_distance
def echoes_query(expansion: str, query: str) -> bool:
exp, q = expansion.lower().strip(), query.lower().strip()
return exp == q or (q in exp and len(exp) < len(q) + 10)
def word_repetition_penalty(text: str) -> int:
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
return sum((c - 2) * 2 for w, c in counts.items()
if c >= 3 and w not in STOPWORDS and len(w) > 2)
def score_expansion(query: str, expansion: str) -> float:
"""Score expansion as float in [0.0, 1.0] for RL reward."""
text, used_thinking = clean_model_output(expansion.strip())
# Hard fail: chat template leakage
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
return 0.0
# Hard fail: invalid lines
for line in text.split("\n"):
line = line.strip()
if line and not line.startswith(("lex:", "vec:", "hyde:")):
return 0.0
parsed = parse_expansion(text)
# Format (0-30)
format_score = 10 # no invalid lines
if parsed["lex"]:
format_score += 10
if parsed["vec"]:
format_score += 10
# Diversity (0-30)
diversity_score = 0
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
if types_present >= 2:
diversity_score += 10
if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
diversity_score += 5
lex_div = 5
for i, a in enumerate(parsed["lex"]):
for b in parsed["lex"][i+1:]:
if not is_diverse(a, b, 2):
lex_div -= 2
diversity_score += max(0, lex_div)
vec_div = 5
for i, a in enumerate(parsed["vec"]):
for b in parsed["vec"][i+1:]:
if not is_diverse(a, b, 3):
vec_div -= 2
diversity_score += max(0, vec_div)
echo = 5
for exp in parsed["lex"] + parsed["vec"]:
if echoes_query(exp, query):
echo -= 3
diversity_score += max(0, echo)
# HyDE (0-20)
hyde_score = 0
if parsed["hyde"]:
hyde_text = parsed["hyde"][0]
hyde_score += 5
hyde_len = len(hyde_text)
if 50 <= hyde_len <= 200:
hyde_score += 5
elif hyde_len < 50:
hyde_score += 2
if "\n" not in hyde_text:
hyde_score += 5
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
# Quality (0-20)
quality_score = 5
if parsed["lex"] and parsed["vec"]:
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
if avg_lex <= avg_vec:
quality_score += 5
if parsed["vec"]:
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
quality_score += 5 if natural == len(parsed["vec"]) else 2
if parsed["lex"]:
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
if with_terms == len(parsed["lex"]):
quality_score += 5
elif with_terms > 0:
quality_score += 2
# Entity (-45 to +20)
entity_score = 0
entities = extract_named_entities(query)
if entities and parsed["lex"]:
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
if with_entities == len(parsed["lex"]):
entity_score += 15
elif with_entities > 0:
entity_score += 5
else:
entity_score -= 30
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
if generic_count:
entity_score -= generic_count * 15
if parsed["vec"]:
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
if vec_with > 0:
entity_score += 5
elif not entities:
entity_score = 10
# Think bonus (0-20)
think_bonus = 0 if used_thinking else 20
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
max_possible = 140 if parsed["hyde"] else 120
return max(0.0, min(1.0, total / max_possible))
def extract_query_from_prompt(prompt: str) -> str:
if "Expand this search query:" in prompt:
query = prompt.split("Expand this search query:")[-1].strip()
if "<|im_end|>" in query:
query = query.split("<|im_end|>")[0].strip()
return query
return prompt.strip()
class QMDRewardFunction:
__name__ = "qmd_scoring_reward"
def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
rewards = []
for i, completion in enumerate(completions):
query = ""
if prompts and i < len(prompts):
query = extract_query_from_prompt(prompts[i])
rewards.append(score_expansion(query, completion))
return rewards
# =============================================================================
# Main training
# =============================================================================
def main():
hf_token = os.environ.get("HF_TOKEN")
@ -384,6 +120,11 @@ def main():
trainer.push_to_hub()
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
# --- Automatic evaluation ---
print("\nStarting automatic evaluation...")
trainer.model.eval()
run_eval(trainer.model, tokenizer, "grpo")
if __name__ == "__main__":
main()

244
finetune/jobs/quantize.py Normal file
View File

@ -0,0 +1,244 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "peft>=0.7.0",
# "torch",
# "huggingface_hub>=0.20.0",
# "accelerate",
# "sentencepiece>=0.1.99",
# "protobuf>=3.20.0",
# "numpy",
# "gguf",
# ]
# ///
"""
Merge SFT + GRPO adapters and convert to GGUF with multiple quantizations.
Uploads each quantization to HuggingFace Hub as it's produced, so partial
results are available even if the job times out.
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/quantize.py
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/quantize.py -- --size 4B
"""
import argparse
import os
import subprocess
import sys
import torch
from huggingface_hub import HfApi, login
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
PRESETS = {
"1.7B": {
"base": "Qwen/Qwen3-1.7B",
"sft": "tobil/qmd-query-expansion-1.7B-sft",
"grpo": "tobil/qmd-query-expansion-1.7B-grpo",
"output": "tobil/qmd-query-expansion-1.7B-gguf",
},
"4B": {
"base": "Qwen/Qwen3-4B",
"sft": "tobil/qmd-query-expansion-4B-sft",
"grpo": "tobil/qmd-query-expansion-4B-grpo",
"output": "tobil/qmd-query-expansion-4B-gguf",
},
}
QUANT_TYPES = [
("Q4_K_M", "4-bit (recommended for most use)"),
("Q5_K_M", "5-bit (balanced quality/size)"),
("Q8_0", "8-bit (highest quality)"),
]
def run_cmd(cmd, description):
print(f" {description}...")
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
return True
except subprocess.CalledProcessError as e:
print(f" FAILED: {' '.join(cmd)}")
if e.stderr:
print(f" {e.stderr[:500]}")
return False
except FileNotFoundError:
print(f" Command not found: {cmd[0]}")
return False
def main():
parser = argparse.ArgumentParser(description="Convert QMD model to GGUF")
parser.add_argument("--size", default="1.7B", choices=PRESETS.keys(), help="Model size preset")
args = parser.parse_args()
preset = PRESETS[args.size]
base_model = preset["base"]
sft_model = preset["sft"]
grpo_model = preset["grpo"]
output_repo = preset["output"]
model_name = output_repo.split("/")[-1].replace("-gguf", "")
print(f"QMD GGUF Conversion: {model_name}")
print("=" * 60)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
api = HfApi()
api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
# Step 1: Install build tools
print("\nStep 1: Installing build dependencies...")
subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git"], capture_output=True)
# Step 2: Load and merge
print(f"\nStep 2: Loading base model {base_model}...")
model = AutoModelForCausalLM.from_pretrained(
base_model, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
)
print(f"Step 3: Merging SFT adapter {sft_model}...")
model = PeftModel.from_pretrained(model, sft_model)
model = model.merge_and_unload()
print(f"Step 4: Merging GRPO adapter {grpo_model}...")
model = PeftModel.from_pretrained(model, grpo_model)
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# Step 3: Save merged model
merged_dir = "/tmp/merged_model"
print(f"\nStep 5: Saving merged model to {merged_dir}...")
model.save_pretrained(merged_dir, safe_serialization=True)
tokenizer.save_pretrained(merged_dir)
del model
torch.cuda.empty_cache()
# Step 4: Setup llama.cpp
print("\nStep 6: Setting up llama.cpp...")
if not os.path.exists("/tmp/llama.cpp"):
run_cmd(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
"Cloning llama.cpp")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "/tmp/llama.cpp/requirements.txt"],
capture_output=True)
# Step 5: Convert to FP16 GGUF
gguf_dir = "/tmp/gguf_output"
os.makedirs(gguf_dir, exist_ok=True)
fp16_file = f"{gguf_dir}/{model_name}-f16.gguf"
print(f"\nStep 7: Converting to FP16 GGUF...")
if not run_cmd([sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py",
merged_dir, "--outfile", fp16_file, "--outtype", "f16"],
"Converting to FP16"):
sys.exit(1)
size_mb = os.path.getsize(fp16_file) / (1024 * 1024)
print(f" FP16: {size_mb:.1f} MB")
# Upload FP16 immediately
print(f" Uploading FP16 to {output_repo}...")
api.upload_file(path_or_fileobj=fp16_file,
path_in_repo=f"{model_name}-f16.gguf", repo_id=output_repo)
print(f" Uploaded: {model_name}-f16.gguf")
# Step 6: Build quantize tool
print("\nStep 8: Building quantize tool...")
os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
run_cmd(["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp", "-DGGML_CUDA=OFF"],
"CMake configure")
run_cmd(["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
"Building llama-quantize")
quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
# Step 7: Quantize and upload each one immediately
print("\nStep 9: Quantizing and uploading...")
for quant_type, desc in QUANT_TYPES:
qfile = f"{gguf_dir}/{model_name}-{quant_type.lower()}.gguf"
if run_cmd([quantize_bin, fp16_file, qfile, quant_type], f"{quant_type} ({desc})"):
qsize = os.path.getsize(qfile) / (1024 * 1024)
print(f" {quant_type}: {qsize:.1f} MB")
print(f" Uploading {quant_type} to {output_repo}...")
api.upload_file(path_or_fileobj=qfile,
path_in_repo=f"{model_name}-{quant_type.lower()}.gguf", repo_id=output_repo)
print(f" Uploaded: {model_name}-{quant_type.lower()}.gguf")
# Remove to save disk
os.remove(qfile)
# Step 8: Upload README
ollama_name = "qmd-expand" if args.size == "1.7B" else f"qmd-expand-{args.size.lower()}"
readme = f"""---
base_model: {base_model}
tags: [gguf, llama.cpp, quantized, query-expansion, qmd]
---
# {model_name} (GGUF)
GGUF quantizations of the QMD Query Expansion model for use with
[Ollama](https://ollama.com), [llama.cpp](https://github.com/ggerganov/llama.cpp),
or [LM Studio](https://lmstudio.ai).
## Available Quantizations
| File | Quant | Description |
|------|-------|-------------|
| `{model_name}-q4_k_m.gguf` | Q4_K_M | 4-bit smallest, recommended for most use |
| `{model_name}-q5_k_m.gguf` | Q5_K_M | 5-bit balanced quality/size |
| `{model_name}-q8_0.gguf` | Q8_0 | 8-bit highest quality |
| `{model_name}-f16.gguf` | FP16 | Full precision (large) |
## Details
- **Base:** {base_model}
- **SFT:** {sft_model}
- **GRPO:** {grpo_model}
- **Task:** Query expansion for hybrid search (lex/vec/hyde format)
- **Eval score:** 90.7% average (29/30 Excellent)
## Quick Start with Ollama
```bash
huggingface-cli download {output_repo} \\
{model_name}-q4_k_m.gguf --local-dir .
echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile
ollama create {ollama_name} -f Modelfile
ollama run {ollama_name}
```
## Prompt Format
```
<|im_start|>user
/no_think Expand this search query: your query here<|im_end|>
<|im_start|>assistant
```
The model produces structured output:
```
lex: keyword expansion for BM25 search
lex: another keyword variant
vec: natural language expansion for vector search
vec: another semantic expansion
hyde: A hypothetical document passage that might match this query.
```
"""
api.upload_file(path_or_fileobj=readme.encode(),
path_in_repo="README.md", repo_id=output_repo)
print(f"\nDone! Repository: https://huggingface.co/{output_repo}")
print(f"\nTo use with Ollama:")
print(f" huggingface-cli download {output_repo} {model_name}-q4_k_m.gguf --local-dir .")
print(f" echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile")
print(f" ollama create {ollama_name} -f Modelfile")
if __name__ == "__main__":
main()

View File

@ -19,6 +19,7 @@ Self-contained script for HuggingFace Jobs:
"""
import os
import sys
from huggingface_hub import login
# --- Config (inlined from configs/sft.yaml) ---
@ -32,6 +33,7 @@ if hf_token:
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig
# Load and split dataset
@ -51,7 +53,7 @@ config = SFTConfig(
hub_model_id=OUTPUT_MODEL,
hub_strategy="every_save",
num_train_epochs=3,
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
@ -96,3 +98,14 @@ trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub()
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
# --- Automatic evaluation ---
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from eval_common import run_eval
print("\nStarting automatic evaluation...")
eval_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if eval_tokenizer.pad_token is None:
eval_tokenizer.pad_token = eval_tokenizer.eos_token
trainer.model.eval()
run_eval(trainer.model, eval_tokenizer, "sft")

View File

@ -150,7 +150,7 @@ export type RerankDocument = {
const DEFAULT_EMBED_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
// const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf";
const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-1.7B-GGUF/Qwen3-1.7B-Q8_0.gguf";
const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
// Local model cache directory
const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models");