Switch from generic Qwen3-1.7B-Q8_0 (~2.2GB) to fine-tuned qmd-query-expansion-1.7B-q4_k_m (~1.1GB). The fine-tuned Q4 scores 91.7% avg with 30/30 Excellent, outperforming the base Q8. - Update default generate model in src/llm.ts - Update README model table, architecture diagram, config block - Add v2 training data, eval scripts, and quantize job - Remove superseded v1 training data (5,742 → 1,000 examples) - Update finetune README with v2 results and file structure Co-Authored-By: Claude (claude-fudge-eap-cc) <noreply@anthropic.com>
491 lines
17 KiB
Python
491 lines
17 KiB
Python
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "transformers>=4.45.0",
|
|
# "peft>=0.7.0",
|
|
# "torch",
|
|
# "huggingface_hub>=0.20.0",
|
|
# "accelerate",
|
|
# ]
|
|
# ///
|
|
"""
|
|
Evaluate QMD query expansion models on HuggingFace Jobs.
|
|
|
|
Self-contained script — inlines the reward function and test queries.
|
|
|
|
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
|
|
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import io
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
from collections import Counter
|
|
|
|
import torch
|
|
from huggingface_hub import HfApi, login
|
|
from peft import PeftModel
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# --- Config ---
|
|
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
|
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
|
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
|
|
|
# --- Test queries (inlined from evals/queries.txt) ---
|
|
QUERIES = [
|
|
# Technical documentation
|
|
"how to configure authentication",
|
|
"typescript async await",
|
|
"docker compose networking",
|
|
"git rebase vs merge",
|
|
"react useEffect cleanup",
|
|
# Short/ambiguous
|
|
"auth",
|
|
"config",
|
|
"setup",
|
|
"api",
|
|
# Named entities
|
|
"who is TDS motorsports",
|
|
"React hooks tutorial",
|
|
"Docker container networking",
|
|
"Kubernetes pod deployment",
|
|
"AWS Lambda functions",
|
|
# Personal notes / journals
|
|
"meeting notes project kickoff",
|
|
"ideas for new feature",
|
|
"todo list app architecture",
|
|
# Research / learning
|
|
"what is dependency injection",
|
|
"difference between sql and nosql",
|
|
"kubernetes vs docker swarm",
|
|
# Error/debugging
|
|
"connection timeout error",
|
|
"memory leak debugging",
|
|
"cors error fix",
|
|
# Temporal / recency
|
|
"recent news about Shopify",
|
|
"latest AI developments",
|
|
"best laptops right now",
|
|
"what changed in kubernetes latest version",
|
|
# Complex
|
|
"how to implement caching with redis in nodejs",
|
|
"best practices for api rate limiting",
|
|
"setting up ci cd pipeline with github actions",
|
|
]
|
|
|
|
# =============================================================================
|
|
# Reward function (inlined from reward.py)
|
|
# =============================================================================
|
|
|
|
STOPWORDS = frozenset({
|
|
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
|
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
|
})
|
|
|
|
KEY_TERM_STOPWORDS = frozenset({
|
|
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
|
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
|
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
|
})
|
|
|
|
GENERIC_LEX_PHRASES = frozenset({
|
|
'find information about', 'search for', 'look up', 'get information',
|
|
'learn about', 'information on', 'details about', 'find out about',
|
|
'what is', 'how to', 'guide to', 'help with',
|
|
})
|
|
|
|
CHAT_TEMPLATE_TOKENS = frozenset({
|
|
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
|
'\nassistant\n', '\nuser\n',
|
|
})
|
|
|
|
|
|
def parse_expansion(text):
|
|
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
for line in text.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
if line.startswith("lex:"):
|
|
result["lex"].append(line[4:].strip())
|
|
elif line.startswith("vec:"):
|
|
result["vec"].append(line[4:].strip())
|
|
elif line.startswith("hyde:"):
|
|
result["hyde"].append(line[5:].strip())
|
|
else:
|
|
result["invalid"].append(line)
|
|
return result
|
|
|
|
|
|
def clean_model_output(text):
|
|
text = text.replace('<|im_end|>', '').strip()
|
|
used_thinking = '<think>' in text and '</think>' in text
|
|
if used_thinking:
|
|
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
|
return text, used_thinking
|
|
|
|
|
|
def extract_named_entities(query):
|
|
entities = set()
|
|
words = query.split()
|
|
prev_was_entity = False
|
|
for i, word in enumerate(words):
|
|
clean = word.strip('.,!?:;()[]"\'')
|
|
if not clean:
|
|
prev_was_entity = False
|
|
continue
|
|
is_entity = False
|
|
if clean.isupper() and len(clean) >= 2:
|
|
entities.add(clean.lower()); is_entity = True
|
|
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
entities.add(clean.lower()); is_entity = True
|
|
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
|
entities.add(clean.lower()); is_entity = True
|
|
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
|
entities.add(clean.lower()); is_entity = True
|
|
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
entities.add(clean.lower()); is_entity = True
|
|
prev_was_entity = is_entity
|
|
return entities
|
|
|
|
|
|
def get_key_terms(query):
|
|
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
|
|
|
|
|
def lex_preserves_key_terms(lex_line, query):
|
|
key_terms = get_key_terms(query)
|
|
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
|
|
|
|
|
|
def lex_preserves_entities(line, entities):
|
|
if not entities: return True
|
|
return any(e in line.lower() for e in entities)
|
|
|
|
|
|
def lex_is_generic(lex_line):
|
|
lower = lex_line.lower().strip()
|
|
for phrase in GENERIC_LEX_PHRASES:
|
|
if phrase in lower or lower.startswith(phrase.split()[0]):
|
|
remaining = lower
|
|
for word in phrase.split():
|
|
remaining = remaining.replace(word, '', 1).strip()
|
|
if len(remaining) < 3:
|
|
return True
|
|
return False
|
|
|
|
|
|
def word_set_distance(a, b):
|
|
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
|
|
|
|
|
def is_diverse(a, b, min_distance=2):
|
|
a, b = a.lower().strip(), b.lower().strip()
|
|
if a == b or a in b or b in a: return False
|
|
return word_set_distance(a, b) >= min_distance
|
|
|
|
|
|
def echoes_query(expansion, query):
|
|
exp, q = expansion.lower().strip(), query.lower().strip()
|
|
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
|
|
|
|
|
def word_repetition_penalty(text):
|
|
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
|
return sum((c - 2) * 2 for w, c in counts.items()
|
|
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
|
|
|
|
|
def score_expansion_detailed(query, expansion):
|
|
text, used_thinking = clean_model_output(expansion.strip())
|
|
deductions = []
|
|
|
|
def _fail(reason):
|
|
return {
|
|
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
|
|
"think_bonus": 0, "total": 0, "max_possible": 100,
|
|
"percentage": 0.0, "rating": "Failed", "deductions": [reason],
|
|
}
|
|
|
|
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
|
return _fail("CHAT TEMPLATE LEAKAGE")
|
|
for line in text.split("\n"):
|
|
line = line.strip()
|
|
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
|
return _fail(f"INVALID LINE: {line[:50]}")
|
|
|
|
parsed = parse_expansion(text)
|
|
|
|
format_score = 10
|
|
if parsed["lex"]: format_score += 10
|
|
else: deductions.append("missing lex:")
|
|
if parsed["vec"]: format_score += 10
|
|
else: deductions.append("missing vec:")
|
|
|
|
diversity_score = 0
|
|
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
|
|
if types_present >= 2: diversity_score += 10
|
|
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
|
|
lex_div = 5
|
|
for i, a in enumerate(parsed["lex"]):
|
|
for b in parsed["lex"][i+1:]:
|
|
if not is_diverse(a, b, 2): lex_div -= 2
|
|
diversity_score += max(0, lex_div)
|
|
vec_div = 5
|
|
for i, a in enumerate(parsed["vec"]):
|
|
for b in parsed["vec"][i+1:]:
|
|
if not is_diverse(a, b, 3): vec_div -= 2
|
|
diversity_score += max(0, vec_div)
|
|
echo = 5
|
|
for exp in parsed["lex"] + parsed["vec"]:
|
|
if echoes_query(exp, query): echo -= 3
|
|
diversity_score += max(0, echo)
|
|
|
|
hyde_score = 0
|
|
if parsed["hyde"]:
|
|
hyde_text = parsed["hyde"][0]
|
|
hyde_score += 5
|
|
hyde_len = len(hyde_text)
|
|
if 50 <= hyde_len <= 200: hyde_score += 5
|
|
elif hyde_len < 50: hyde_score += 2
|
|
if "\n" not in hyde_text: hyde_score += 5
|
|
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
|
|
|
quality_score = 5
|
|
if parsed["lex"] and parsed["vec"]:
|
|
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
if avg_lex <= avg_vec: quality_score += 5
|
|
if parsed["vec"]:
|
|
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
|
if parsed["lex"]:
|
|
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
|
if with_terms == len(parsed["lex"]): quality_score += 5
|
|
elif with_terms > 0: quality_score += 2
|
|
|
|
entity_score = 0
|
|
entities = extract_named_entities(query)
|
|
if entities and parsed["lex"]:
|
|
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
|
if with_entities == len(parsed["lex"]): entity_score += 15
|
|
elif with_entities > 0: entity_score += 5
|
|
else: entity_score -= 30
|
|
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
|
if generic_count: entity_score -= generic_count * 15
|
|
if parsed["vec"]:
|
|
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
|
if vec_with > 0: entity_score += 5
|
|
elif not entities:
|
|
entity_score = 10
|
|
|
|
think_bonus = 0 if used_thinking else 20
|
|
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
|
max_possible = 140 if parsed["hyde"] else 120
|
|
percentage = max(0.0, min(100.0, total / max_possible * 100))
|
|
|
|
if percentage >= 80: rating = "Excellent"
|
|
elif percentage >= 60: rating = "Good"
|
|
elif percentage >= 40: rating = "Acceptable"
|
|
elif percentage >= 20: rating = "Poor"
|
|
else: rating = "Failed"
|
|
|
|
return {
|
|
"format": format_score, "diversity": diversity_score, "hyde": hyde_score,
|
|
"quality": quality_score, "entity": max(0, entity_score),
|
|
"think_bonus": think_bonus, "total": max(0, total),
|
|
"max_possible": max_possible, "percentage": round(percentage, 1),
|
|
"rating": rating, "deductions": deductions,
|
|
"entities_detected": list(entities) if entities else [],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Model loading and generation
|
|
# =============================================================================
|
|
|
|
def load_model(base, sft=None, grpo=None):
|
|
print(f"Loading tokenizer from {base}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(base)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
print(f"Loading base model {base}...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base, torch_dtype=torch.bfloat16, device_map="auto",
|
|
)
|
|
|
|
if sft:
|
|
print(f"Loading and merging SFT adapter {sft}...")
|
|
model = PeftModel.from_pretrained(model, sft)
|
|
model = model.merge_and_unload()
|
|
|
|
if grpo:
|
|
print(f"Loading GRPO adapter {grpo}...")
|
|
model = PeftModel.from_pretrained(model, grpo)
|
|
|
|
model.eval()
|
|
return model, tokenizer
|
|
|
|
|
|
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
|
|
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs, max_new_tokens=max_new_tokens,
|
|
temperature=0.7, do_sample=True,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
if "\nassistant\n" in full_output:
|
|
expansion = full_output.split("\nassistant\n")[-1].strip()
|
|
elif "assistant\n" in full_output:
|
|
expansion = full_output.split("assistant\n")[-1].strip()
|
|
else:
|
|
expansion = full_output[len(prompt):].strip()
|
|
|
|
if "<think>" in expansion:
|
|
expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
|
|
return expansion
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
def results_to_csv(results, label):
|
|
"""Convert eval results to CSV string."""
|
|
buf = io.StringIO()
|
|
writer = csv.writer(buf)
|
|
writer.writerow([
|
|
"model", "query", "expansion", "score_pct", "rating",
|
|
"format", "diversity", "hyde", "quality", "entity", "think_bonus",
|
|
"total", "max_possible", "deductions",
|
|
])
|
|
for r in results:
|
|
s = r["scores"]
|
|
writer.writerow([
|
|
label, r["query"], r["expansion"], s["percentage"], s["rating"],
|
|
s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
|
|
s["think_bonus"], s["total"], s["max_possible"],
|
|
"; ".join(s.get("deductions", [])),
|
|
])
|
|
return buf.getvalue()
|
|
|
|
|
|
def upload_csv(results, label, repo_id, api):
|
|
"""Upload eval results CSV to HuggingFace Hub."""
|
|
csv_data = results_to_csv(results, label)
|
|
tag = label.split("/")[-1].replace(" ", "_").lower()
|
|
filename = f"eval_{tag}.csv"
|
|
print(f" Uploading {filename} to {repo_id}...")
|
|
api.upload_file(
|
|
path_or_fileobj=csv_data.encode("utf-8"),
|
|
path_in_repo=filename,
|
|
repo_id=repo_id,
|
|
repo_type="model",
|
|
)
|
|
print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
|
|
|
|
|
|
def evaluate_model(model, tokenizer, label):
|
|
print(f"\n{'='*70}")
|
|
print(f" EVALUATING: {label}")
|
|
print(f"{'='*70}")
|
|
|
|
results = []
|
|
for i, query in enumerate(QUERIES, 1):
|
|
expansion = generate_expansion(model, tokenizer, query)
|
|
scores = score_expansion_detailed(query, expansion)
|
|
results.append({"query": query, "expansion": expansion, "scores": scores})
|
|
|
|
marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
|
|
print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
|
|
|
|
avg = sum(r["scores"]["percentage"] for r in results) / len(results)
|
|
ratings = Counter(r["scores"]["rating"] for r in results)
|
|
|
|
print(f"\n {'─'*50}")
|
|
print(f" Average score: {avg:.1f}%")
|
|
print(f" Ratings:")
|
|
for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
|
|
count = ratings.get(rating, 0)
|
|
if count > 0:
|
|
print(f" {rating:10s}: {count:2d} {'█' * count}")
|
|
|
|
# Show worst queries
|
|
worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
|
|
print(f"\n Bottom 5:")
|
|
for r in worst:
|
|
print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
|
|
if r["scores"]["deductions"]:
|
|
print(f" {', '.join(r['scores']['deductions'][:3])}")
|
|
|
|
return results, avg
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
|
|
parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
|
|
help="HF repo to upload CSV results")
|
|
args = parser.parse_args()
|
|
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
if hf_token:
|
|
login(token=hf_token)
|
|
|
|
api = HfApi()
|
|
api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
|
|
|
|
# Evaluate SFT
|
|
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
|
|
sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
|
|
upload_csv(sft_results, "sft", args.upload_repo, api)
|
|
|
|
if not args.sft_only:
|
|
# For GRPO: reload base, merge SFT, then load GRPO adapter
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
|
|
grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
|
|
upload_csv(grpo_results, "grpo", args.upload_repo, api)
|
|
|
|
# Upload combined comparison CSV
|
|
combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
|
|
api.upload_file(
|
|
path_or_fileobj=combined.encode("utf-8"),
|
|
path_in_repo="eval_comparison.csv",
|
|
repo_id=args.upload_repo,
|
|
repo_type="model",
|
|
)
|
|
print(f" Uploaded: eval_comparison.csv")
|
|
|
|
# Comparison
|
|
print(f"\n{'='*70}")
|
|
print(f" COMPARISON")
|
|
print(f"{'='*70}")
|
|
print(f" SFT average: {sft_avg:.1f}%")
|
|
print(f" GRPO average: {grpo_avg:.1f}%")
|
|
print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
|
|
|
|
improved = sum(1 for s, g in zip(sft_results, grpo_results)
|
|
if g["scores"]["percentage"] > s["scores"]["percentage"])
|
|
regressed = sum(1 for s, g in zip(sft_results, grpo_results)
|
|
if g["scores"]["percentage"] < s["scores"]["percentage"])
|
|
print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|