Merge origin/main into feat/ast-aware-chunking
Resolve conflicts: combine AST chunking args (filepath, chunkStrategy) with abort signal parameter from #458. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
1fb2e2819e
@ -16,6 +16,7 @@
|
||||
|
||||
### Fixes
|
||||
|
||||
- Fix paths in nix flake
|
||||
- Sync stale `bun.lock` (`better-sqlite3` 11.x → 12.x). CI and release
|
||||
script now use `--frozen-lockfile` to prevent recurrence. #386
|
||||
(thanks @Mic92)
|
||||
|
||||
2
bun.lock
2
bun.lock
@ -12,7 +12,7 @@
|
||||
"picomatch": "^4.0.0",
|
||||
"sqlite-vec": "^0.1.7-alpha.2",
|
||||
"yaml": "^2.8.2",
|
||||
"zod": "^4.2.1",
|
||||
"zod": "4.2.1",
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/better-sqlite3": "^7.6.0",
|
||||
|
||||
182
finetune/benchmark.py
Normal file
182
finetune/benchmark.py
Normal file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark QMD query expansion: LFM2.5 vs Qwen3 finetuned models."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from peft import PeftModel
|
||||
|
||||
QUERIES = [
|
||||
"kubernetes pod networking",
|
||||
"best practices for React server components",
|
||||
"how to optimize PostgreSQL queries for large tables",
|
||||
"what is retrieval augmented generation",
|
||||
"python async await concurrency patterns",
|
||||
"nginx reverse proxy load balancing",
|
||||
"git rebase vs merge workflow",
|
||||
"rust ownership and borrowing explained",
|
||||
"docker compose multi-stage builds",
|
||||
"elasticsearch full text search performance",
|
||||
"shopify liquid template customization",
|
||||
"machine learning feature engineering techniques",
|
||||
"aws lambda cold start optimization",
|
||||
"typescript generics and utility types",
|
||||
"redis caching strategies for web apps",
|
||||
]
|
||||
|
||||
def load_model(base_name, adapter_dir, device, trust_remote=False):
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_name, trust_remote_code=trust_remote)
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
base_name, dtype=torch.bfloat16, device_map=device, trust_remote_code=trust_remote
|
||||
)
|
||||
model = PeftModel.from_pretrained(base, adapter_dir, local_files_only=True)
|
||||
model = model.merge_and_unload()
|
||||
model.eval()
|
||||
|
||||
gen_config_path = Path(adapter_dir) / "generation_config.json"
|
||||
if gen_config_path.exists():
|
||||
gen_config = GenerationConfig.from_pretrained(adapter_dir)
|
||||
else:
|
||||
gen_config = GenerationConfig(
|
||||
temperature=0.1, top_k=50, top_p=0.1,
|
||||
repetition_penalty=1.05, do_sample=True, max_new_tokens=300,
|
||||
)
|
||||
return model, tokenizer, gen_config
|
||||
|
||||
def run_inference(model, tokenizer, gen_config, query, device):
|
||||
messages = [{"role": "user", "content": f"Expand this search query: {query}"}]
|
||||
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = tokenizer(text, return_tensors="pt").to(device)
|
||||
|
||||
start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, generation_config=gen_config, max_new_tokens=300)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
new_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
|
||||
result = tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
||||
return result, elapsed, new_tokens
|
||||
|
||||
def score_output(output):
|
||||
"""Simple quality scoring: check for lex/vec/hyde presence and specificity."""
|
||||
score = 0
|
||||
lines = output.strip().split("\n")
|
||||
has_lex = has_vec = has_hyde = False
|
||||
hyde_text = ""
|
||||
|
||||
for line in lines:
|
||||
l = line.strip()
|
||||
if l.startswith("lex:"):
|
||||
has_lex = True
|
||||
score += 1
|
||||
elif l.startswith("vec:"):
|
||||
has_vec = True
|
||||
score += 1
|
||||
elif l.startswith("hyde:"):
|
||||
has_hyde = True
|
||||
hyde_text = l[5:].strip()
|
||||
score += 2 # hyde is worth more
|
||||
|
||||
# Bonus for hyde length in sweet spot (80-200 chars)
|
||||
if hyde_text:
|
||||
hlen = len(hyde_text)
|
||||
if 80 <= hlen <= 200:
|
||||
score += 2
|
||||
elif 50 <= hlen <= 250:
|
||||
score += 1
|
||||
|
||||
# Penalty for generic/template hyde
|
||||
generic_phrases = ["comprehensive guide", "everything you need to know", "beginners and advanced users"]
|
||||
for phrase in generic_phrases:
|
||||
if phrase in hyde_text.lower():
|
||||
score -= 1
|
||||
|
||||
return score, {"has_lex": has_lex, "has_vec": has_vec, "has_hyde": has_hyde, "hyde_len": len(hyde_text)}
|
||||
|
||||
def main():
|
||||
device = "cuda:0"
|
||||
|
||||
models = {
|
||||
"LFM2.5-1.2B (finetuned)": {
|
||||
"base": "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
"adapter": "outputs/sft-lfm2",
|
||||
"trust_remote": True,
|
||||
},
|
||||
"Qwen3-1.7B (finetuned)": {
|
||||
"base": "Qwen/Qwen3-1.7B",
|
||||
"adapter": "outputs/sft",
|
||||
"trust_remote": False,
|
||||
},
|
||||
}
|
||||
|
||||
results = {}
|
||||
|
||||
for name, cfg in models.items():
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Loading {name}...")
|
||||
model, tokenizer, gen_config = load_model(
|
||||
cfg["base"], cfg["adapter"], device, cfg["trust_remote"]
|
||||
)
|
||||
|
||||
model_results = []
|
||||
total_time = 0
|
||||
total_tokens = 0
|
||||
total_score = 0
|
||||
|
||||
for query in QUERIES:
|
||||
output, elapsed, n_tokens = run_inference(model, tokenizer, gen_config, query, device)
|
||||
score, details = score_output(output)
|
||||
|
||||
model_results.append({
|
||||
"query": query,
|
||||
"output": output,
|
||||
"time_s": round(elapsed, 3),
|
||||
"tokens": n_tokens,
|
||||
"score": score,
|
||||
"details": details,
|
||||
})
|
||||
total_time += elapsed
|
||||
total_tokens += n_tokens
|
||||
total_score += score
|
||||
|
||||
tok_s = n_tokens / elapsed if elapsed > 0 else 0
|
||||
print(f" [{score:2d}] {query[:40]:<40} {elapsed:.2f}s {n_tokens:3d}tok {tok_s:.0f}tok/s")
|
||||
|
||||
avg_time = total_time / len(QUERIES)
|
||||
avg_score = total_score / len(QUERIES)
|
||||
avg_toks = total_tokens / total_time if total_time > 0 else 0
|
||||
|
||||
results[name] = {
|
||||
"queries": model_results,
|
||||
"avg_time_s": round(avg_time, 3),
|
||||
"avg_score": round(avg_score, 2),
|
||||
"avg_tok_s": round(avg_toks, 1),
|
||||
"total_score": total_score,
|
||||
}
|
||||
|
||||
print(f"\n Summary: avg_score={avg_score:.2f} avg_time={avg_time:.2f}s avg_tok/s={avg_toks:.0f}")
|
||||
|
||||
# Free GPU memory
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Print comparison
|
||||
print(f"\n{'='*60}")
|
||||
print("COMPARISON")
|
||||
print(f"{'='*60}")
|
||||
for name, r in results.items():
|
||||
print(f"\n{name}:")
|
||||
print(f" Total Score: {r['total_score']} / {len(QUERIES) * 8}") # max ~8 per query
|
||||
print(f" Avg Score: {r['avg_score']}")
|
||||
print(f" Avg Time: {r['avg_time_s']}s")
|
||||
print(f" Throughput: {r['avg_tok_s']} tok/s")
|
||||
|
||||
# Save full results
|
||||
with open("outputs/benchmark_results.json", "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print("\nFull results saved to outputs/benchmark_results.json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
finetune/configs/sft-lfm2.yaml
Normal file
58
finetune/configs/sft-lfm2.yaml
Normal file
@ -0,0 +1,58 @@
|
||||
# SFT Training Config for QMD Query Expansion
|
||||
# Target: LiquidAI LFM2.5-1.2B-Instruct with LoRA
|
||||
#
|
||||
# LFM2.5 is a hybrid model: 10 conv blocks + 6 GQA attention blocks
|
||||
# Uses ChatML template: <|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n
|
||||
# No /no_think needed (not Qwen3)
|
||||
#
|
||||
# Usage: uv run train.py sft --config configs/sft-lfm2.yaml
|
||||
|
||||
model:
|
||||
base: "LiquidAI/LFM2.5-1.2B-Instruct"
|
||||
output: "outputs/sft-lfm2"
|
||||
trust_remote_code: true
|
||||
|
||||
dataset:
|
||||
name: "data/train-lfm2/"
|
||||
text_field: "text"
|
||||
split: "train"
|
||||
eval_split: 0.1
|
||||
|
||||
training:
|
||||
epochs: 5
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 2e-4
|
||||
max_length: 512
|
||||
warmup_ratio: 0.03
|
||||
lr_scheduler: "cosine"
|
||||
|
||||
lora:
|
||||
rank: 16
|
||||
alpha: 32
|
||||
dropout: 0.0
|
||||
target_modules:
|
||||
# Convolution blocks (layers 0,1,3,4,6,7,9,11,13,15)
|
||||
- "conv.in_proj"
|
||||
- "conv.out_proj"
|
||||
# Attention blocks (layers 2,5,8,10,12,14)
|
||||
- "q_proj"
|
||||
- "k_proj"
|
||||
- "v_proj"
|
||||
- "out_proj"
|
||||
# FFN (all 16 layers)
|
||||
- "feed_forward.w1"
|
||||
- "feed_forward.w2"
|
||||
- "feed_forward.w3"
|
||||
|
||||
generation:
|
||||
temperature: 0.1
|
||||
top_k: 50
|
||||
top_p: 0.1
|
||||
repetition_penalty: 1.05
|
||||
|
||||
gguf: false # LFM2.5 hybrid arch not supported by llama.cpp
|
||||
|
||||
tracking:
|
||||
project: "qmd-query-expansion"
|
||||
run_name: "sft-lfm2-1.2B"
|
||||
1
finetune/data/fix_hyde_checkpoint.json
Normal file
1
finetune/data/fix_hyde_checkpoint.json
Normal file
File diff suppressed because one or more lines are too long
@ -73,6 +73,8 @@ def format_for_training(ex: TrainingExample) -> dict:
|
||||
text = text.replace("<think>\n\n</think>\n\n", "")
|
||||
|
||||
return {
|
||||
"query": ex.query,
|
||||
"output": ex.output_as_lists(),
|
||||
"text": text,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
85
finetune/dataset/prepare_data_lfm2.py
Normal file
85
finetune/dataset/prepare_data_lfm2.py
Normal file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare QMD query expansion data for LFM2.5-1.2B-Instruct training.
|
||||
|
||||
LFM2.5 uses ChatML format:
|
||||
<|startoftext|><|im_start|>user
|
||||
Expand this search query: {query}<|im_end|>
|
||||
<|im_start|>assistant
|
||||
{output}<|im_end|>
|
||||
|
||||
No /no_think needed (that's Qwen3-specific).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dataset.schema import normalize_output_items, output_items_to_text
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def format_for_training(query_text: str, output_items: list[list[str]], tokenizer) -> dict:
|
||||
"""Format a single example for SFT training using LFM2.5 chat format."""
|
||||
output_text = output_items_to_text(output_items)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Expand this search query: {query_text}"},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
return {"text": text}
|
||||
|
||||
|
||||
def main():
|
||||
input_path = Path("data/qmd_expansion_v2.jsonl")
|
||||
output_dir = Path("data/train-lfm2")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("Loading LFM2.5 tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"LiquidAI/LFM2.5-1.2B-Instruct", trust_remote_code=True
|
||||
)
|
||||
|
||||
examples = []
|
||||
with open(input_path) as f:
|
||||
for line in f:
|
||||
row = json.loads(line)
|
||||
items = normalize_output_items(row["output"])
|
||||
example = format_for_training(row["query"], items, tokenizer)
|
||||
examples.append(example)
|
||||
|
||||
# Shuffle and split
|
||||
random.seed(42)
|
||||
random.shuffle(examples)
|
||||
|
||||
split_idx = int(len(examples) * 0.9)
|
||||
train = examples[:split_idx]
|
||||
val = examples[split_idx:]
|
||||
|
||||
# Write as JSONL
|
||||
train_path = output_dir / "train.jsonl"
|
||||
val_path = output_dir / "val.jsonl"
|
||||
|
||||
with open(train_path, "w") as f:
|
||||
for ex in train:
|
||||
f.write(json.dumps(ex) + "\n")
|
||||
|
||||
with open(val_path, "w") as f:
|
||||
for ex in val:
|
||||
f.write(json.dumps(ex) + "\n")
|
||||
|
||||
print(f"Written {len(train)} train, {len(val)} val examples to {output_dir}")
|
||||
print(f"\nSample formatted text:")
|
||||
print(train[0]["text"][:500])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
488
finetune/eval_retrieval.py
Normal file
488
finetune/eval_retrieval.py
Normal file
@ -0,0 +1,488 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.45.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch",
|
||||
# "accelerate",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
QMD Retrieval-Based Evaluation with Precision & Recall
|
||||
|
||||
Evaluates model outputs against golden data (training set).
|
||||
Measures how well the model reproduces the expected expansions.
|
||||
|
||||
Metrics:
|
||||
- Precision: Of model-generated expansions, how many match golden?
|
||||
- Recall: Of golden expansions, how many did the model generate?
|
||||
- F1: Harmonic mean of precision and recall
|
||||
|
||||
Matching is done via token overlap (Jaccard similarity) with a threshold.
|
||||
|
||||
Usage:
|
||||
uv run eval_retrieval.py ./outputs/sft
|
||||
uv run eval_retrieval.py tobil/qmd-query-expansion-1.7B --golden data/qmd_expansion_v3_structured.jsonl
|
||||
uv run eval_retrieval.py ./outputs/sft --threshold 0.5 --sample 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
# =============================================================================
|
||||
# Matching Functions
|
||||
# =============================================================================
|
||||
|
||||
def tokenize(text: str) -> set[str]:
|
||||
"""Tokenize text into lowercase word set, removing stopwords."""
|
||||
stopwords = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and',
|
||||
'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
||||
'how', 'what', 'do', 'does', 'can', 'you', 'your', 'i'}
|
||||
words = re.findall(r'\b\w+\b', text.lower())
|
||||
return {w for w in words if w not in stopwords and len(w) > 1}
|
||||
|
||||
|
||||
def jaccard_similarity(a: str, b: str) -> float:
|
||||
"""Jaccard similarity between two strings based on token overlap."""
|
||||
tokens_a = tokenize(a)
|
||||
tokens_b = tokenize(b)
|
||||
if not tokens_a or not tokens_b:
|
||||
return 0.0
|
||||
intersection = len(tokens_a & tokens_b)
|
||||
union = len(tokens_a | tokens_b)
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def find_best_match(pred: str, golden_list: list[str], threshold: float) -> tuple[str | None, float]:
|
||||
"""Find best matching golden expansion for a prediction."""
|
||||
best_match = None
|
||||
best_score = 0.0
|
||||
for golden in golden_list:
|
||||
score = jaccard_similarity(pred, golden)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = golden
|
||||
if best_score >= threshold:
|
||||
return best_match, best_score
|
||||
return None, best_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Parsing
|
||||
# =============================================================================
|
||||
|
||||
def parse_model_output(text: str) -> dict[str, list[str]]:
|
||||
"""Parse model output into {lex: [...], vec: [...], hyde: [...]}."""
|
||||
# Clean thinking tags
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
|
||||
result = {"lex": [], "vec": [], "hyde": []}
|
||||
for line in text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("lex:"):
|
||||
result["lex"].append(line[4:].strip())
|
||||
elif line.startswith("vec:"):
|
||||
result["vec"].append(line[4:].strip())
|
||||
elif line.startswith("hyde:"):
|
||||
result["hyde"].append(line[5:].strip())
|
||||
return result
|
||||
|
||||
|
||||
def parse_golden_data(searches: list[dict] | str) -> dict[str, list[str]]:
|
||||
"""Parse golden data format into {lex: [...], vec: [...], hyde: [...]}."""
|
||||
# If it's a string (from messages format), parse it
|
||||
if isinstance(searches, str):
|
||||
return parse_model_output(searches)
|
||||
|
||||
# Otherwise it's the structured format [{type, query}, ...]
|
||||
result = {"lex": [], "vec": [], "hyde": []}
|
||||
for item in searches:
|
||||
exp_type = item.get("type", "")
|
||||
value = item.get("query", "") or item.get("value", "")
|
||||
if exp_type in result:
|
||||
result[exp_type].append(value)
|
||||
return result
|
||||
|
||||
|
||||
def load_golden_data(filepath: Path) -> list[dict]:
|
||||
"""Load golden data from JSONL, supporting both structured and messages formats."""
|
||||
data = []
|
||||
with open(filepath) as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
item = json.loads(line)
|
||||
|
||||
# Structured format: {query, searches}
|
||||
if "query" in item and "searches" in item:
|
||||
data.append({
|
||||
"query": item["query"],
|
||||
"searches": item["searches"]
|
||||
})
|
||||
# Messages format: {messages: [{role, content}, ...]}
|
||||
elif "messages" in item:
|
||||
messages = item["messages"]
|
||||
query = None
|
||||
searches = None
|
||||
for msg in messages:
|
||||
if msg["role"] == "user":
|
||||
# Extract query from "/no_think Expand this search query: ..."
|
||||
content = msg["content"]
|
||||
if "Expand this search query:" in content:
|
||||
query = content.split("Expand this search query:")[-1].strip()
|
||||
else:
|
||||
query = content.strip()
|
||||
elif msg["role"] == "assistant":
|
||||
# The assistant content IS the expected output
|
||||
searches = msg["content"]
|
||||
if query and searches:
|
||||
data.append({
|
||||
"query": query,
|
||||
"searches": searches # Will be parsed as string
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics Calculation
|
||||
# =============================================================================
|
||||
|
||||
# Different thresholds by type - lex needs strict matching, hyde is more flexible
|
||||
DEFAULT_THRESHOLDS = {
|
||||
"lex": 0.5, # Keywords should overlap well
|
||||
"vec": 0.35, # Semantic sentences have more variation
|
||||
"hyde": 0.25, # Passages have the most variation
|
||||
}
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
predictions: dict[str, list[str]],
|
||||
golden: dict[str, list[str]],
|
||||
threshold: float | dict[str, float] = 0.4,
|
||||
return_mismatches: bool = False
|
||||
) -> dict:
|
||||
"""Calculate precision, recall, F1 per type and overall.
|
||||
|
||||
Args:
|
||||
threshold: Either a single float, or dict mapping type -> threshold
|
||||
return_mismatches: If True, include lists of unmatched predictions/golden
|
||||
"""
|
||||
if isinstance(threshold, (int, float)):
|
||||
thresholds = {"lex": threshold, "vec": threshold, "hyde": threshold}
|
||||
else:
|
||||
thresholds = threshold
|
||||
|
||||
metrics = {}
|
||||
mismatches = {}
|
||||
total_tp = 0
|
||||
total_pred = 0
|
||||
total_golden = 0
|
||||
|
||||
for exp_type in ["lex", "vec", "hyde"]:
|
||||
preds = predictions.get(exp_type, [])
|
||||
golds = golden.get(exp_type, [])
|
||||
type_threshold = thresholds.get(exp_type, 0.4)
|
||||
|
||||
if not preds and not golds:
|
||||
continue
|
||||
|
||||
# Track which golden items were matched
|
||||
matched_golden = set()
|
||||
unmatched_preds = []
|
||||
tp = 0
|
||||
|
||||
for pred in preds:
|
||||
match, score = find_best_match(pred, golds, type_threshold)
|
||||
if match is not None:
|
||||
tp += 1
|
||||
matched_golden.add(match)
|
||||
else:
|
||||
unmatched_preds.append((pred, score))
|
||||
|
||||
unmatched_golden = [g for g in golds if g not in matched_golden]
|
||||
|
||||
precision = tp / len(preds) if preds else 0.0
|
||||
recall = len(matched_golden) / len(golds) if golds else 0.0
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
metrics[exp_type] = {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
"pred_count": len(preds),
|
||||
"golden_count": len(golds),
|
||||
"matched": tp,
|
||||
}
|
||||
|
||||
if return_mismatches:
|
||||
mismatches[exp_type] = {
|
||||
"unmatched_preds": unmatched_preds,
|
||||
"unmatched_golden": unmatched_golden,
|
||||
}
|
||||
|
||||
total_tp += tp
|
||||
total_pred += len(preds)
|
||||
total_golden += len(golds)
|
||||
|
||||
# Overall metrics (micro-averaged)
|
||||
overall_precision = total_tp / total_pred if total_pred > 0 else 0.0
|
||||
overall_recall = total_tp / total_golden if total_golden > 0 else 0.0
|
||||
overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
|
||||
|
||||
metrics["overall"] = {
|
||||
"precision": overall_precision,
|
||||
"recall": overall_recall,
|
||||
"f1": overall_f1,
|
||||
"pred_count": total_pred,
|
||||
"golden_count": total_golden,
|
||||
"matched": total_tp,
|
||||
}
|
||||
|
||||
if return_mismatches:
|
||||
metrics["_mismatches"] = mismatches
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Loading and Generation
|
||||
# =============================================================================
|
||||
|
||||
def load_model(model_path: str):
|
||||
"""Load model (adapter or merged)."""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_path = Path(model_path)
|
||||
adapter_config = model_path / "adapter_config.json"
|
||||
|
||||
# Get base model from adapter config or default
|
||||
base_model = "Qwen/Qwen3-1.7B"
|
||||
if adapter_config.exists():
|
||||
with open(adapter_config) as f:
|
||||
cfg = json.load(f)
|
||||
base_model = cfg.get("base_model_name_or_path", base_model)
|
||||
|
||||
print(f"Loading base: {base_model}", file=sys.stderr)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
config = AutoConfig.from_pretrained(base_model)
|
||||
config.tie_word_embeddings = False
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model, dtype=torch.bfloat16, device_map={"": 0}, config=config
|
||||
)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = False
|
||||
model.generation_config.temperature = None
|
||||
model.generation_config.top_p = None
|
||||
model.generation_config.top_k = None
|
||||
|
||||
# Load adapter if present
|
||||
if adapter_config.exists():
|
||||
print(f"Loading adapter: {model_path}", file=sys.stderr)
|
||||
model = PeftModel.from_pretrained(model, str(model_path))
|
||||
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 400) -> str:
|
||||
"""Generate expansion for a single query."""
|
||||
import torch
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": f"/no_think Expand this search query: {query}"}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
|
||||
with torch.inference_mode():
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
gen_tokens = out[0][input_len:]
|
||||
return tokenizer.decode(gen_tokens, skip_special_tokens=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Evaluation
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="QMD Retrieval-Based Evaluation")
|
||||
parser.add_argument("model", help="Model path (local or HF)")
|
||||
parser.add_argument("--golden", default="data/qmd_expansion_v3_structured.jsonl",
|
||||
help="Golden data JSONL file")
|
||||
parser.add_argument("--threshold", type=float, default=None,
|
||||
help="Jaccard similarity threshold for all types (overrides --type-thresholds)")
|
||||
parser.add_argument("--type-thresholds", action="store_true",
|
||||
help="Use type-specific thresholds (lex=0.5, vec=0.35, hyde=0.25)")
|
||||
parser.add_argument("--sample", type=int, default=0,
|
||||
help="Sample N queries (0 = all)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed for sampling")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=400,
|
||||
help="Max new tokens to generate")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Show per-query details")
|
||||
parser.add_argument("--show-mismatches", action="store_true",
|
||||
help="Show examples of mismatched predictions")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine thresholds
|
||||
if args.threshold is not None:
|
||||
thresholds = args.threshold
|
||||
elif args.type_thresholds:
|
||||
thresholds = DEFAULT_THRESHOLDS.copy()
|
||||
else:
|
||||
thresholds = 0.4 # Default single threshold
|
||||
|
||||
# Load golden data
|
||||
golden_path = Path(args.golden)
|
||||
if not golden_path.exists():
|
||||
# Try relative to script directory
|
||||
golden_path = Path(__file__).parent / args.golden
|
||||
|
||||
if not golden_path.exists():
|
||||
print(f"Error: Golden data file not found: {args.golden}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading golden data from {golden_path}...", file=sys.stderr)
|
||||
golden_data = load_golden_data(golden_path)
|
||||
print(f"Loaded {len(golden_data)} golden examples", file=sys.stderr)
|
||||
|
||||
# Sample if requested
|
||||
if args.sample > 0 and args.sample < len(golden_data):
|
||||
random.seed(args.seed)
|
||||
golden_data = random.sample(golden_data, args.sample)
|
||||
print(f"Sampled {len(golden_data)} examples", file=sys.stderr)
|
||||
|
||||
# Load model
|
||||
model, tokenizer = load_model(args.model)
|
||||
|
||||
# Evaluate
|
||||
all_metrics = []
|
||||
all_mismatches = []
|
||||
type_aggregates = defaultdict(lambda: {"precision": [], "recall": [], "f1": []})
|
||||
|
||||
threshold_desc = thresholds if isinstance(thresholds, (int, float)) else f"lex={thresholds['lex']}, vec={thresholds['vec']}, hyde={thresholds['hyde']}"
|
||||
print(f"\nEvaluating {len(golden_data)} queries (thresholds: {threshold_desc})...\n")
|
||||
|
||||
for i, item in enumerate(golden_data, 1):
|
||||
query = item["query"]
|
||||
golden_parsed = parse_golden_data(item["searches"])
|
||||
|
||||
# Generate model output
|
||||
output = generate_expansion(model, tokenizer, query, args.max_new_tokens)
|
||||
pred_parsed = parse_model_output(output)
|
||||
|
||||
# Calculate metrics
|
||||
metrics = calculate_metrics(pred_parsed, golden_parsed, thresholds, return_mismatches=args.show_mismatches)
|
||||
all_metrics.append({"query": query, "metrics": metrics, "pred": pred_parsed, "golden": golden_parsed})
|
||||
|
||||
if args.show_mismatches and "_mismatches" in metrics:
|
||||
all_mismatches.append({"query": query, "mismatches": metrics.pop("_mismatches")})
|
||||
|
||||
# Aggregate by type
|
||||
for exp_type in ["lex", "vec", "hyde", "overall"]:
|
||||
if exp_type in metrics:
|
||||
type_aggregates[exp_type]["precision"].append(metrics[exp_type]["precision"])
|
||||
type_aggregates[exp_type]["recall"].append(metrics[exp_type]["recall"])
|
||||
type_aggregates[exp_type]["f1"].append(metrics[exp_type]["f1"])
|
||||
|
||||
# Progress
|
||||
overall = metrics.get("overall", {})
|
||||
p = overall.get("precision", 0) * 100
|
||||
r = overall.get("recall", 0) * 100
|
||||
f = overall.get("f1", 0) * 100
|
||||
|
||||
if args.verbose:
|
||||
print(f"[{i:3d}/{len(golden_data)}] P={p:5.1f}% R={r:5.1f}% F1={f:5.1f}% {query[:50]}")
|
||||
elif i % 50 == 0 or i == len(golden_data):
|
||||
print(f" Processed {i}/{len(golden_data)}...", file=sys.stderr)
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RESULTS: {args.model}")
|
||||
print(f"{'='*60}")
|
||||
print(f"Threshold: {args.threshold} | Samples: {len(golden_data)}")
|
||||
print()
|
||||
|
||||
print(f"{'Type':<10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
||||
print("-" * 42)
|
||||
|
||||
for exp_type in ["lex", "vec", "hyde", "overall"]:
|
||||
if exp_type in type_aggregates:
|
||||
agg = type_aggregates[exp_type]
|
||||
avg_p = sum(agg["precision"]) / len(agg["precision"]) * 100 if agg["precision"] else 0
|
||||
avg_r = sum(agg["recall"]) / len(agg["recall"]) * 100 if agg["recall"] else 0
|
||||
avg_f = sum(agg["f1"]) / len(agg["f1"]) * 100 if agg["f1"] else 0
|
||||
label = exp_type.upper() if exp_type != "overall" else "OVERALL"
|
||||
print(f"{label:<10} {avg_p:>9.1f}% {avg_r:>9.1f}% {avg_f:>9.1f}%")
|
||||
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Show worst examples
|
||||
print("\nBottom 5 by F1:")
|
||||
sorted_by_f1 = sorted(all_metrics, key=lambda x: x["metrics"].get("overall", {}).get("f1", 0))
|
||||
for item in sorted_by_f1[:5]:
|
||||
f1 = item["metrics"].get("overall", {}).get("f1", 0) * 100
|
||||
print(f" {f1:5.1f}% {item['query'][:60]}")
|
||||
|
||||
# Show mismatches if requested
|
||||
if args.show_mismatches and all_mismatches:
|
||||
print(f"\n{'='*60}")
|
||||
print("MISMATCH EXAMPLES")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Group by type and show up to 3 examples per type
|
||||
for exp_type in ["lex", "vec", "hyde"]:
|
||||
type_mismatches = []
|
||||
for item in all_mismatches:
|
||||
if exp_type in item["mismatches"]:
|
||||
mm = item["mismatches"][exp_type]
|
||||
if mm["unmatched_preds"] or mm["unmatched_golden"]:
|
||||
type_mismatches.append({
|
||||
"query": item["query"],
|
||||
**mm
|
||||
})
|
||||
|
||||
if type_mismatches:
|
||||
print(f"\n--- {exp_type.upper()} mismatches ({len(type_mismatches)} queries) ---")
|
||||
for example in type_mismatches[:3]:
|
||||
print(f"\nQuery: {example['query'][:60]}")
|
||||
if example["unmatched_preds"]:
|
||||
print(f" Unmatched predictions:")
|
||||
for pred, score in example["unmatched_preds"][:2]:
|
||||
print(f" - [{score:.2f}] {pred[:80]}{'...' if len(pred) > 80 else ''}")
|
||||
if example["unmatched_golden"]:
|
||||
print(f" Missing golden:")
|
||||
for g in example["unmatched_golden"][:2]:
|
||||
print(f" - {g[:80]}{'...' if len(g) > 80 else ''}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
20
finetune/train-0.8B.log
Normal file
20
finetune/train-0.8B.log
Normal file
@ -0,0 +1,20 @@
|
||||
============================================================
|
||||
QMD Query Expansion — Unsloth SFT
|
||||
Base model: unsloth/Qwen3.5-0.8B
|
||||
Output: outputs/qwen3.5-0.8B
|
||||
Data: data/train/train.jsonl
|
||||
Epochs: 5
|
||||
Batch: 4 x 4 accum
|
||||
LR: 0.0002
|
||||
LoRA rank: 16
|
||||
Max seq len: 512
|
||||
============================================================
|
||||
Traceback (most recent call last):
|
||||
File "/home/tobi/src/github.com/tobi/qmd/finetune/train_unsloth.py", line 198, in <module>
|
||||
main()
|
||||
~~~~^^
|
||||
File "/home/tobi/src/github.com/tobi/qmd/finetune/train_unsloth.py", line 68, in main
|
||||
from unsloth import FastLanguageModel
|
||||
File "/home/tobi/src/github.com/tobi/qmd/finetune/.venv-unsloth/lib/python3.14/site-packages/unsloth/__init__.py", line 93, in <module>
|
||||
raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!")
|
||||
NotImplementedError: Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!
|
||||
198
finetune/train_unsloth.py
Normal file
198
finetune/train_unsloth.py
Normal file
@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support).
|
||||
|
||||
Usage:
|
||||
python train_unsloth.py --model 0.8B
|
||||
python train_unsloth.py --model 2B
|
||||
python train_unsloth.py --model 4B --epochs 3
|
||||
|
||||
Requires: pip install unsloth unsloth_zoo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
MODEL_MAP = {
|
||||
"0.8B": "unsloth/Qwen3.5-0.8B",
|
||||
"2B": "unsloth/Qwen3.5-2B",
|
||||
"4B": "unsloth/Qwen3.5-4B",
|
||||
"9B": "unsloth/Qwen3.5-9B",
|
||||
"27B": "unsloth/Qwen3.5-27B",
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth")
|
||||
parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()),
|
||||
help="Model size to train")
|
||||
parser.add_argument("--epochs", type=int, default=5)
|
||||
parser.add_argument("--batch-size", type=int, default=4)
|
||||
parser.add_argument("--grad-accum", type=int, default=4)
|
||||
parser.add_argument("--lr", type=float, default=2e-4)
|
||||
parser.add_argument("--max-seq-len", type=int, default=512)
|
||||
parser.add_argument("--lora-rank", type=int, default=16)
|
||||
parser.add_argument("--data", type=str, default="data/train/train.jsonl")
|
||||
parser.add_argument("--output", type=str, default=None,
|
||||
help="Output directory (default: outputs/qwen3.5-{size})")
|
||||
parser.add_argument("--push-hub", type=str, default=None,
|
||||
help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)")
|
||||
parser.add_argument("--no-gguf", action="store_true")
|
||||
parser.add_argument("--no-eval", action="store_true")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_name = MODEL_MAP[args.model]
|
||||
output_dir = args.output or f"outputs/qwen3.5-{args.model}"
|
||||
|
||||
print(f"{'='*60}")
|
||||
print(f"QMD Query Expansion — Unsloth SFT")
|
||||
print(f" Base model: {model_name}")
|
||||
print(f" Output: {output_dir}")
|
||||
print(f" Data: {args.data}")
|
||||
print(f" Epochs: {args.epochs}")
|
||||
print(f" Batch: {args.batch_size} x {args.grad_accum} accum")
|
||||
print(f" LR: {args.lr}")
|
||||
print(f" LoRA rank: {args.lora_rank}")
|
||||
print(f" Max seq len: {args.max_seq_len}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if args.dry_run:
|
||||
print("Dry run — exiting.")
|
||||
return
|
||||
|
||||
# --- Imports (heavy) ---
|
||||
import os
|
||||
import torch
|
||||
from unsloth import FastLanguageModel
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
# --- Load model ---
|
||||
print(f"\nLoading {model_name}...")
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=model_name,
|
||||
max_seq_length=args.max_seq_len,
|
||||
load_in_4bit=False,
|
||||
load_in_16bit=True,
|
||||
full_finetuning=False,
|
||||
)
|
||||
|
||||
# --- LoRA ---
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=args.lora_rank,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_alpha=args.lora_rank,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
max_seq_length=args.max_seq_len,
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
print(f"Loading dataset from {args.data}...")
|
||||
dataset = load_dataset("json", data_files=args.data, split="train")
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||
train_ds = split["train"]
|
||||
eval_ds = split["test"]
|
||||
print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
|
||||
|
||||
# --- Tracking ---
|
||||
report_to = "none"
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
try:
|
||||
import trackio
|
||||
report_to = "trackio"
|
||||
os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# --- Trainer ---
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=eval_ds,
|
||||
args=SFTConfig(
|
||||
output_dir=output_dir,
|
||||
max_seq_length=args.max_seq_len,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
gradient_accumulation_steps=args.grad_accum,
|
||||
learning_rate=args.lr,
|
||||
warmup_ratio=0.03,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=200,
|
||||
save_total_limit=3,
|
||||
eval_strategy="steps",
|
||||
eval_steps=200,
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
seed=3407,
|
||||
dataset_num_proc=4,
|
||||
report_to=report_to,
|
||||
run_name=f"sft-qwen3.5-{args.model}",
|
||||
),
|
||||
)
|
||||
|
||||
print("\nStarting training...")
|
||||
stats = trainer.train()
|
||||
print(f"\nTraining complete!")
|
||||
print(f" Total steps: {stats.global_step}")
|
||||
print(f" Final loss: {stats.training_loss:.4f}")
|
||||
|
||||
# --- Save ---
|
||||
trainer.save_model(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print(f"Adapter saved to {output_dir}")
|
||||
|
||||
# --- GGUF export ---
|
||||
if not args.no_gguf:
|
||||
print("\nExporting GGUF quantizations...")
|
||||
gguf_dir = f"{output_dir}/gguf"
|
||||
for quant in ["q4_k_m", "q8_0"]:
|
||||
print(f" {quant}...")
|
||||
try:
|
||||
model.save_pretrained_gguf(
|
||||
gguf_dir, tokenizer, quantization_method=quant
|
||||
)
|
||||
print(f" ✓ {quant} saved")
|
||||
except Exception as e:
|
||||
print(f" ✗ {quant} failed: {e}")
|
||||
|
||||
# --- Push to Hub ---
|
||||
if args.push_hub:
|
||||
print(f"\nPushing to {args.push_hub}...")
|
||||
model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora")
|
||||
if not args.no_gguf:
|
||||
for quant in ["q4_k_m", "q8_0"]:
|
||||
try:
|
||||
model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant)
|
||||
except Exception as e:
|
||||
print(f" GGUF push {quant} failed: {e}")
|
||||
|
||||
# --- Eval ---
|
||||
if not args.no_eval:
|
||||
print("\nRunning evaluation...")
|
||||
import subprocess
|
||||
subprocess.run(
|
||||
[sys.executable, "eval.py", output_dir],
|
||||
cwd=str(Path(__file__).parent),
|
||||
)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Done! Model at: {output_dir}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -48,7 +48,7 @@
|
||||
cp package.json $out/lib/qmd/
|
||||
|
||||
makeWrapper ${pkgs.bun}/bin/bun $out/bin/qmd \
|
||||
--add-flags "$out/lib/qmd/src/qmd.ts" \
|
||||
--add-flags "$out/lib/qmd/src/cli/qmd.ts" \
|
||||
--set DYLD_LIBRARY_PATH "${pkgs.sqlite.out}/lib" \
|
||||
--set LD_LIBRARY_PATH "${pkgs.sqlite.out}/lib"
|
||||
'';
|
||||
@ -81,7 +81,7 @@
|
||||
shellHook = ''
|
||||
export BREW_PREFIX="''${BREW_PREFIX:-${sqliteWithExtensions.out}}"
|
||||
echo "QMD development shell"
|
||||
echo "Run: bun src/qmd.ts <command>"
|
||||
echo "Run: bun src/cli/qmd.ts <command>"
|
||||
'';
|
||||
};
|
||||
}
|
||||
|
||||
23
src/llm.ts
23
src/llm.ts
@ -209,7 +209,9 @@ export const DEFAULT_RERANK_MODEL_URI = DEFAULT_RERANK_MODEL;
|
||||
export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL;
|
||||
|
||||
// Local model cache directory
|
||||
const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models");
|
||||
const MODEL_CACHE_DIR = process.env.XDG_CACHE_HOME
|
||||
? join(process.env.XDG_CACHE_HOME, "qmd", "models")
|
||||
: join(homedir(), ".cache", "qmd", "models");
|
||||
export const DEFAULT_MODEL_CACHE_DIR = MODEL_CACHE_DIR;
|
||||
|
||||
export type PullResult = {
|
||||
@ -757,9 +759,16 @@ export class LlamaCpp implements LLM {
|
||||
* - Combined: drops from 11.6 GB (auto, no flash) to 568 MB per context (20×)
|
||||
*/
|
||||
// Qwen3 reranker template adds ~200 tokens overhead (system prompt, tags, etc.)
|
||||
// Chunks are max 800 tokens, so 800 + 200 + query ≈ 1100 tokens typical.
|
||||
// Use 2048 for safety margin. Still 17× less than auto (40960).
|
||||
private static readonly RERANK_CONTEXT_SIZE = 2048;
|
||||
// Default 2048 was too small for longer documents (e.g. session transcripts,
|
||||
// CJK text, or large markdown files) — callers hit "input lengths exceed
|
||||
// context size" errors even after truncation because the overhead estimate
|
||||
// was insufficient. 4096 comfortably fits the largest real-world chunks
|
||||
// while staying well below the 40 960-token auto size.
|
||||
// Override with QMD_RERANK_CONTEXT_SIZE env var if you need more headroom.
|
||||
private static readonly RERANK_CONTEXT_SIZE: number = (() => {
|
||||
const v = parseInt(process.env.QMD_RERANK_CONTEXT_SIZE ?? "", 10);
|
||||
return Number.isFinite(v) && v > 0 ? v : 4096;
|
||||
})();
|
||||
private async ensureRerankContexts(): Promise<Awaited<ReturnType<LlamaModel["createRankingContext"]>>[]> {
|
||||
if (this.rerankContexts.length === 0) {
|
||||
const model = await this.ensureRerankModel();
|
||||
@ -1099,8 +1108,10 @@ export class LlamaCpp implements LLM {
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen3 reranker chat template overhead (system prompt, tags, separators)
|
||||
private static readonly RERANK_TEMPLATE_OVERHEAD = 200;
|
||||
// Qwen3 reranker chat template overhead (system prompt, tags, separators).
|
||||
// Measured at ~350 tokens on real queries; use 512 as a safe upper bound so
|
||||
// the truncation budget never lets a document slip past the context limit.
|
||||
private static readonly RERANK_TEMPLATE_OVERHEAD = 512;
|
||||
private static readonly RERANK_TARGET_DOCS_PER_CONTEXT = 10;
|
||||
|
||||
async rerank(
|
||||
|
||||
@ -296,9 +296,12 @@ Intent-aware lex (C++ performance, not sports):
|
||||
intent: z.string().optional().describe(
|
||||
"Background context to disambiguate the query. Example: query='performance', intent='web page load times and Core Web Vitals'. Does not search on its own."
|
||||
),
|
||||
rerank: z.boolean().optional().default(true).describe(
|
||||
"Rerank results using LLM (default: true). Set to false for faster results on CPU-only machines."
|
||||
),
|
||||
},
|
||||
},
|
||||
async ({ searches, limit, minScore, candidateLimit, collections, intent }) => {
|
||||
async ({ searches, limit, minScore, candidateLimit, collections, intent, rerank }) => {
|
||||
// Map to internal format
|
||||
const queries: ExpandedQuery[] = searches.map(s => ({
|
||||
type: s.type,
|
||||
@ -313,6 +316,7 @@ Intent-aware lex (C++ performance, not sports):
|
||||
collections: effectiveCollections.length > 0 ? effectiveCollections : undefined,
|
||||
limit,
|
||||
minScore,
|
||||
rerank,
|
||||
intent,
|
||||
});
|
||||
|
||||
|
||||
161
src/store.ts
161
src/store.ts
@ -1421,6 +1421,12 @@ export async function generateEmbeddings(
|
||||
const batches = buildEmbeddingBatches(docsToEmbed, maxDocsPerBatch, maxBatchBytes);
|
||||
|
||||
for (const batchMeta of batches) {
|
||||
// Abort early if session has been invalidated
|
||||
if (!session.isValid) {
|
||||
console.warn(`⚠ Session expired — skipping remaining document batches`);
|
||||
break;
|
||||
}
|
||||
|
||||
const batchDocs = getEmbeddingDocsForBatch(db, batchMeta);
|
||||
const batchChunks: ChunkItem[] = [];
|
||||
const batchBytes = batchMeta.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0);
|
||||
@ -1434,6 +1440,7 @@ export async function generateEmbeddings(
|
||||
undefined, undefined, undefined,
|
||||
doc.path,
|
||||
options?.chunkStrategy,
|
||||
session.signal,
|
||||
);
|
||||
|
||||
for (let seq = 0; seq < chunks.length; seq++) {
|
||||
@ -1472,6 +1479,23 @@ export async function generateEmbeddings(
|
||||
let batchChunkBytesProcessed = 0;
|
||||
|
||||
for (let batchStart = 0; batchStart < batchChunks.length; batchStart += BATCH_SIZE) {
|
||||
// Abort early if session has been invalidated (e.g. max duration exceeded)
|
||||
if (!session.isValid) {
|
||||
const remaining = batchChunks.length - batchStart;
|
||||
errors += remaining;
|
||||
console.warn(`⚠ Session expired — skipping ${remaining} remaining chunks`);
|
||||
break;
|
||||
}
|
||||
|
||||
// Abort early if error rate is too high (>80% of processed chunks failed)
|
||||
const processed = chunksEmbedded + errors;
|
||||
if (processed >= BATCH_SIZE && errors > processed * 0.8) {
|
||||
const remaining = batchChunks.length - batchStart;
|
||||
errors += remaining;
|
||||
console.warn(`⚠ Error rate too high (${errors}/${processed}) — aborting embedding`);
|
||||
break;
|
||||
}
|
||||
|
||||
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
|
||||
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
|
||||
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
||||
@ -1491,20 +1515,26 @@ export async function generateEmbeddings(
|
||||
}
|
||||
} catch {
|
||||
// Batch failed — try individual embeddings as fallback
|
||||
for (const chunk of chunkBatch) {
|
||||
try {
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||
const result = await session.embed(text);
|
||||
if (result) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
} else {
|
||||
// But skip if session is already invalid (avoids N doomed retries)
|
||||
if (!session.isValid) {
|
||||
errors += chunkBatch.length;
|
||||
batchChunkBytesProcessed += chunkBatch.reduce((sum, c) => sum + c.bytes, 0);
|
||||
} else {
|
||||
for (const chunk of chunkBatch) {
|
||||
try {
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||
const result = await session.embed(text);
|
||||
if (result) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
} else {
|
||||
errors++;
|
||||
}
|
||||
} catch {
|
||||
errors++;
|
||||
}
|
||||
} catch {
|
||||
errors++;
|
||||
batchChunkBytesProcessed += chunk.bytes;
|
||||
}
|
||||
batchChunkBytesProcessed += chunk.bytes;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1684,7 +1714,6 @@ export function handelize(path: string): string {
|
||||
|
||||
const result = path
|
||||
.replace(/___/g, '/') // Triple underscore becomes folder separator
|
||||
.toLowerCase()
|
||||
.split('/')
|
||||
.map((segment, idx, arr) => {
|
||||
const isLastSegment = idx === arr.length - 1;
|
||||
@ -1699,7 +1728,7 @@ export function handelize(path: string): string {
|
||||
const nameWithoutExt = ext ? segment.slice(0, -ext.length) : segment;
|
||||
|
||||
const cleanedName = nameWithoutExt
|
||||
.replace(/[^\p{L}\p{N}$]+/gu, '-') // Keep route marker "$", dash-separate other chars
|
||||
.replace(/[^\p{L}\p{N}.$]+/gu, '-') // Keep letters, numbers, dots, "$"; dash-separate rest
|
||||
.replace(/^-+|-+$/g, ''); // Remove leading/trailing dashes
|
||||
|
||||
return cleanedName + ext;
|
||||
@ -2170,6 +2199,7 @@ export async function chunkDocumentByTokens(
|
||||
windowTokens: number = CHUNK_WINDOW_TOKENS,
|
||||
filepath?: string,
|
||||
chunkStrategy: ChunkStrategy = "regex",
|
||||
signal?: AbortSignal
|
||||
): Promise<{ text: string; pos: number; tokens: number }[]> {
|
||||
const llm = getDefaultLlamaCpp();
|
||||
|
||||
@ -2188,6 +2218,9 @@ export async function chunkDocumentByTokens(
|
||||
const results: { text: string; pos: number; tokens: number }[] = [];
|
||||
|
||||
for (const chunk of charChunks) {
|
||||
// Respect abort signal to avoid runaway tokenization
|
||||
if (signal?.aborted) break;
|
||||
|
||||
const tokens = await llm.tokenize(chunk.text);
|
||||
|
||||
if (tokens.length <= maxTokens) {
|
||||
@ -2201,6 +2234,7 @@ export async function chunkDocumentByTokens(
|
||||
const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
|
||||
|
||||
for (const subChunk of subChunks) {
|
||||
if (signal?.aborted) break;
|
||||
const subTokens = await llm.tokenize(subChunk.text);
|
||||
results.push({
|
||||
text: subChunk.text,
|
||||
@ -2732,20 +2766,46 @@ function sanitizeFTS5Term(term: string): string {
|
||||
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a token is a hyphenated compound word (e.g., multi-agent, DEC-0054, gpt-4).
|
||||
* Returns true if the token contains internal hyphens between word/digit characters.
|
||||
*/
|
||||
function isHyphenatedToken(token: string): boolean {
|
||||
return /^[\p{L}\p{N}][\p{L}\p{N}'-]*-[\p{L}\p{N}][\p{L}\p{N}'-]*$/u.test(token);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize a hyphenated term into an FTS5 phrase by splitting on hyphens
|
||||
* and sanitizing each part. Returns the parts joined by spaces for use
|
||||
* inside FTS5 quotes: "multi agent" matches "multi-agent" in porter tokenizer.
|
||||
*/
|
||||
function sanitizeHyphenatedTerm(term: string): string {
|
||||
return term.split('-').map(t => sanitizeFTS5Term(t)).filter(t => t).join(' ');
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse lex query syntax into FTS5 query.
|
||||
*
|
||||
* Supports:
|
||||
* - Quoted phrases: "exact phrase" → "exact phrase" (exact match)
|
||||
* - Negation: -term or -"phrase" → uses FTS5 NOT operator
|
||||
* - Hyphenated tokens: multi-agent, DEC-0054, gpt-4 → treated as phrases
|
||||
* - Plain terms: term → "term"* (prefix match)
|
||||
*
|
||||
* FTS5 NOT is a binary operator: `term1 NOT term2` means "match term1 but not term2".
|
||||
* So `-term` only works when there are also positive terms.
|
||||
*
|
||||
* Hyphen disambiguation: `-sports` at a word boundary is negation, but `multi-agent`
|
||||
* (where `-` is between word characters) is treated as a hyphenated phrase.
|
||||
* When a leading `-` is followed by what looks like a hyphenated compound word
|
||||
* (e.g., `-multi-agent`), the entire token is treated as a negated phrase.
|
||||
*
|
||||
* Examples:
|
||||
* performance -sports → "performance"* NOT "sports"*
|
||||
* "machine learning" → "machine learning"
|
||||
* multi-agent memory → "multi agent" AND "memory"*
|
||||
* DEC-0054 → "dec 0054"
|
||||
* -multi-agent → NOT "multi agent"
|
||||
*/
|
||||
function buildFTS5Query(query: string): string | null {
|
||||
const positive: string[] = [];
|
||||
@ -2787,13 +2847,27 @@ function buildFTS5Query(query: string): string | null {
|
||||
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
||||
const term = s.slice(start, i);
|
||||
|
||||
const sanitized = sanitizeFTS5Term(term);
|
||||
if (sanitized) {
|
||||
const ftsTerm = `"${sanitized}"*`; // Prefix match
|
||||
if (negated) {
|
||||
negative.push(ftsTerm);
|
||||
} else {
|
||||
positive.push(ftsTerm);
|
||||
// Handle hyphenated tokens: multi-agent, DEC-0054, gpt-4
|
||||
// These get split into phrase queries so FTS5 porter tokenizer matches them.
|
||||
if (isHyphenatedToken(term)) {
|
||||
const sanitized = sanitizeHyphenatedTerm(term);
|
||||
if (sanitized) {
|
||||
const ftsPhrase = `"${sanitized}"`; // Phrase match (no prefix)
|
||||
if (negated) {
|
||||
negative.push(ftsPhrase);
|
||||
} else {
|
||||
positive.push(ftsPhrase);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const sanitized = sanitizeFTS5Term(term);
|
||||
if (sanitized) {
|
||||
const ftsTerm = `"${sanitized}"*`; // Prefix match
|
||||
if (negated) {
|
||||
negative.push(ftsTerm);
|
||||
} else {
|
||||
positive.push(ftsTerm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2842,20 +2916,38 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
|
||||
const ftsQuery = buildFTS5Query(query);
|
||||
if (!ftsQuery) return [];
|
||||
|
||||
// Use a CTE to force FTS5 to run first, then filter by collection.
|
||||
// Without the CTE, SQLite's query planner combines FTS5 MATCH with the
|
||||
// collection filter in a single WHERE clause, which can cause it to
|
||||
// abandon the FTS5 index and fall back to a full scan — turning an 8ms
|
||||
// query into a 17-second query on large collections.
|
||||
const params: (string | number)[] = [ftsQuery];
|
||||
|
||||
// When filtering by collection, fetch extra candidates from the FTS index
|
||||
// since some will be filtered out. Without a collection filter we can
|
||||
// fetch exactly the requested limit.
|
||||
const ftsLimit = collectionName ? limit * 10 : limit;
|
||||
|
||||
let sql = `
|
||||
WITH fts_matches AS (
|
||||
SELECT rowid, bm25(documents_fts, 1.5, 4.0, 1.0) as bm25_score
|
||||
FROM documents_fts
|
||||
WHERE documents_fts MATCH ?
|
||||
ORDER BY bm25_score ASC
|
||||
LIMIT ${ftsLimit}
|
||||
)
|
||||
SELECT
|
||||
'qmd://' || d.collection || '/' || d.path as filepath,
|
||||
d.collection || '/' || d.path as display_path,
|
||||
d.title,
|
||||
content.doc as body,
|
||||
d.hash,
|
||||
bm25(documents_fts, 10.0, 1.0) as bm25_score
|
||||
FROM documents_fts f
|
||||
JOIN documents d ON d.id = f.rowid
|
||||
fm.bm25_score
|
||||
FROM fts_matches fm
|
||||
JOIN documents d ON d.id = fm.rowid
|
||||
JOIN content ON content.hash = d.hash
|
||||
WHERE documents_fts MATCH ? AND d.active = 1
|
||||
WHERE d.active = 1
|
||||
`;
|
||||
const params: (string | number)[] = [ftsQuery];
|
||||
|
||||
if (collectionName) {
|
||||
sql += ` AND d.collection = ?`;
|
||||
@ -2863,7 +2955,7 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
|
||||
}
|
||||
|
||||
// bm25 lower is better; sort ascending.
|
||||
sql += ` ORDER BY bm25_score ASC LIMIT ?`;
|
||||
sql += ` ORDER BY fm.bm25_score ASC LIMIT ?`;
|
||||
params.push(limit);
|
||||
|
||||
const rows = db.prepare(sql).all(...params) as { filepath: string; display_path: string; title: string; body: string; hash: string; bm25_score: number }[];
|
||||
@ -3021,6 +3113,12 @@ export function clearAllEmbeddings(db: Database): void {
|
||||
/**
|
||||
* Insert a single embedding into both content_vectors and vectors_vec tables.
|
||||
* The hash_seq key is formatted as "hash_seq" for the vectors_vec table.
|
||||
*
|
||||
* content_vectors is inserted first so that getHashesForEmbedding (which checks
|
||||
* only content_vectors) won't re-select the hash on a crash between the two inserts.
|
||||
*
|
||||
* vectors_vec uses DELETE + INSERT instead of INSERT OR REPLACE because sqlite-vec's
|
||||
* vec0 virtual tables silently ignore the OR REPLACE conflict clause.
|
||||
*/
|
||||
export function insertEmbedding(
|
||||
db: Database,
|
||||
@ -3032,11 +3130,16 @@ export function insertEmbedding(
|
||||
embeddedAt: string
|
||||
): void {
|
||||
const hashSeq = `${hash}_${seq}`;
|
||||
const insertVecStmt = db.prepare(`INSERT OR REPLACE INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`);
|
||||
const insertContentVectorStmt = db.prepare(`INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, ?, ?, ?, ?)`);
|
||||
|
||||
insertVecStmt.run(hashSeq, embedding);
|
||||
// Insert content_vectors first — crash-safe ordering (see getHashesForEmbedding)
|
||||
const insertContentVectorStmt = db.prepare(`INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, ?, ?, ?, ?)`);
|
||||
insertContentVectorStmt.run(hash, seq, pos, model, embeddedAt);
|
||||
|
||||
// vec0 virtual tables don't support OR REPLACE — use DELETE + INSERT
|
||||
const deleteVecStmt = db.prepare(`DELETE FROM vectors_vec WHERE hash_seq = ?`);
|
||||
const insertVecStmt = db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`);
|
||||
deleteVecStmt.run(hashSeq);
|
||||
insertVecStmt.run(hashSeq, embedding);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
25
test/eval-deep-research.jsonl
Normal file
25
test/eval-deep-research.jsonl
Normal file
@ -0,0 +1,25 @@
|
||||
{"query": "that tradeoff between data correctness and always being up", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems architecture", "notes": "CAP theorem - no keywords match"}
|
||||
{"query": "what we learned from the dashboard thing", "expected_doc": "product-launch", "difficulty": "hard", "intent": "project retrospectives", "notes": "Project Phoenix retrospective - vague reference"}
|
||||
{"query": "how much we're burning through each month", "expected_doc": "fundraising", "difficulty": "hard", "intent": "startup finances", "notes": "burn rate - colloquial phrasing"}
|
||||
{"query": "when do I need to be online", "expected_doc": "remote-work", "difficulty": "hard", "intent": "work schedule policies", "notes": "core hours policy - no exact terms"}
|
||||
{"query": "that algorithm for getting nodes to agree", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed consensus", "notes": "consensus/Raft/Paxos - conceptual reference"}
|
||||
{"query": "why we pushed back the release date", "expected_doc": "product-launch", "difficulty": "hard", "intent": "project timeline decisions", "notes": "timeline pressure - implied from retrospective"}
|
||||
{"query": "how to structure URLs for our service", "expected_doc": "api-design", "difficulty": "hard", "intent": "API design patterns", "notes": "REST endpoints - no exact match"}
|
||||
{"query": "preventing the model from just memorizing", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "ML model training", "notes": "overfitting - conceptual synonym"}
|
||||
{"query": "who we're pitching to first", "expected_doc": "fundraising", "difficulty": "hard", "intent": "investor outreach strategy", "notes": "tier 1 investors - colloquial"}
|
||||
{"query": "can I work from another country", "expected_doc": "remote-work", "difficulty": "hard", "intent": "remote work eligibility", "notes": "remote eligibility - implied question"}
|
||||
{"query": "how the beta users found problems", "expected_doc": "product-launch", "difficulty": "hard", "intent": "product testing feedback", "notes": "beta program bugs - indirect reference"}
|
||||
{"query": "that thing Leslie Lamport invented", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems history", "notes": "Paxos - person reference only"}
|
||||
{"query": "what happens when the network splits", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "network failure handling", "notes": "partition tolerance - rephrased concept"}
|
||||
{"query": "teaching computers to find patterns", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "machine learning fundamentals", "notes": "ML definition - abstract description"}
|
||||
{"query": "how much runway before we're out of cash", "expected_doc": "fundraising", "difficulty": "hard", "intent": "startup financial planning", "notes": "runway months - colloquial finance term"}
|
||||
{"query": "the 47 issues we found before shipping", "expected_doc": "product-launch", "difficulty": "hard", "intent": "pre-launch QA", "notes": "beta bugs - specific number, no keywords"}
|
||||
{"query": "grouping customers by behavior", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "customer analytics", "notes": "clustering/segmentation - conceptual"}
|
||||
{"query": "why URLs should be things not actions", "expected_doc": "api-design", "difficulty": "hard", "intent": "RESTful design principles", "notes": "nouns not verbs - conceptual inversion"}
|
||||
{"query": "what Eric Brewer proved you can't have", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems theory", "notes": "CAP theorem - person + concept"}
|
||||
{"query": "how fast the new feature loaded", "expected_doc": "product-launch", "difficulty": "hard", "intent": "performance metrics", "notes": "performance 4.2s - indirect reference"}
|
||||
{"query": "days everyone needs to be in the office", "expected_doc": "remote-work", "difficulty": "hard", "intent": "hybrid work schedule", "notes": "collaboration days - rephrased"}
|
||||
{"query": "the number that shows customers are expanding", "expected_doc": "fundraising", "difficulty": "hard", "intent": "SaaS growth metrics", "notes": "NRR 124% - metric description"}
|
||||
{"query": "telling spam from real email", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "classification use cases", "notes": "classification example - specific use case"}
|
||||
{"query": "how to get user 123's purchases", "expected_doc": "api-design", "difficulty": "hard", "intent": "API endpoint design", "notes": "hierarchical URLs - example-based query"}
|
||||
{"query": "zookeeper etcd consul what they have in common", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed coordination tools", "notes": "CP systems - asking about category"}
|
||||
209
test/eval-deep-research.ts
Normal file
209
test/eval-deep-research.ts
Normal file
@ -0,0 +1,209 @@
|
||||
/**
|
||||
* Deep Research Evaluation for QMD
|
||||
*
|
||||
* Tests end-to-end retrieval quality: query → expansion → reranking → results
|
||||
*
|
||||
* These are HARD queries with NO exact keyword matches - they require
|
||||
* semantic understanding via query expansion and reranking to succeed.
|
||||
*
|
||||
* Run: bun test/eval-deep-research.ts
|
||||
*/
|
||||
|
||||
import { execSync } from "child_process";
|
||||
import { readFileSync, existsSync } from "fs";
|
||||
import { join, dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
interface EvalQuery {
|
||||
query: string;
|
||||
expected_doc: string;
|
||||
difficulty: string;
|
||||
intent: string; // Domain context hint for future intent-aware retrieval
|
||||
notes: string;
|
||||
}
|
||||
|
||||
interface SearchResult {
|
||||
file: string;
|
||||
score: number;
|
||||
title?: string;
|
||||
}
|
||||
|
||||
function loadQueries(): EvalQuery[] {
|
||||
const path = join(__dirname, "eval-deep-research.jsonl");
|
||||
const content = readFileSync(path, "utf-8");
|
||||
return content
|
||||
.split("\n")
|
||||
.filter((line) => line.trim())
|
||||
.map((line) => JSON.parse(line));
|
||||
}
|
||||
|
||||
function runBM25Search(query: string): SearchResult[] {
|
||||
try {
|
||||
const output = execSync(
|
||||
`bun src/qmd.ts search "${query.replace(/"/g, '\\"')}" -c eval-docs --json -n 5 2>/dev/null`,
|
||||
{ encoding: "utf-8", timeout: 30000 }
|
||||
);
|
||||
return JSON.parse(output);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function runDeepResearch(query: string): SearchResult[] {
|
||||
try {
|
||||
const output = execSync(
|
||||
`bun src/qmd.ts query "${query.replace(/"/g, '\\"')}" -c eval-docs --json -n 5 2>/dev/null`,
|
||||
{ encoding: "utf-8", timeout: 120000 }
|
||||
);
|
||||
return JSON.parse(output);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function matchesExpected(filepath: string, expectedDoc: string): boolean {
|
||||
return filepath.toLowerCase().includes(expectedDoc.toLowerCase());
|
||||
}
|
||||
|
||||
function findRank(results: SearchResult[], expectedDoc: string): number {
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
if (matchesExpected(results[i]!.file, expectedDoc)) {
|
||||
return i + 1;
|
||||
}
|
||||
}
|
||||
return -1; // Not found
|
||||
}
|
||||
|
||||
interface MethodResults {
|
||||
hit1: number;
|
||||
hit3: number;
|
||||
hit5: number;
|
||||
total: number;
|
||||
details: { query: string; rank: number; expected: string; intent?: string }[];
|
||||
}
|
||||
|
||||
function evaluate(
|
||||
queries: EvalQuery[],
|
||||
searchFn: (q: string) => SearchResult[],
|
||||
label: string
|
||||
): MethodResults {
|
||||
const results: MethodResults = {
|
||||
hit1: 0,
|
||||
hit3: 0,
|
||||
hit5: 0,
|
||||
total: queries.length,
|
||||
details: [],
|
||||
};
|
||||
|
||||
console.log(`\n${"=".repeat(60)}`);
|
||||
console.log(` ${label}`);
|
||||
console.log(`${"=".repeat(60)}\n`);
|
||||
|
||||
for (const { query, expected_doc, intent, notes } of queries) {
|
||||
const searchResults = searchFn(query);
|
||||
const rank = findRank(searchResults, expected_doc);
|
||||
|
||||
results.details.push({ query, rank, expected: expected_doc, intent });
|
||||
|
||||
if (rank === 1) results.hit1++;
|
||||
if (rank >= 1 && rank <= 3) results.hit3++;
|
||||
if (rank >= 1 && rank <= 5) results.hit5++;
|
||||
|
||||
const status =
|
||||
rank === 1 ? "✓" : rank > 0 && rank <= 3 ? `@${rank}` : rank > 0 ? `@${rank}` : "✗";
|
||||
const statusPad = status.padEnd(4);
|
||||
console.log(` ${statusPad} "${query.slice(0, 45).padEnd(45)}" → ${expected_doc}`);
|
||||
if (rank === -1) {
|
||||
console.log(` intent: ${intent} | ${notes}`);
|
||||
}
|
||||
}
|
||||
|
||||
const hit1Pct = ((results.hit1 / results.total) * 100).toFixed(0);
|
||||
const hit3Pct = ((results.hit3 / results.total) * 100).toFixed(0);
|
||||
const hit5Pct = ((results.hit5 / results.total) * 100).toFixed(0);
|
||||
|
||||
console.log(`\n ${"─".repeat(50)}`);
|
||||
console.log(` Hit@1: ${hit1Pct}% (${results.hit1}/${results.total})`);
|
||||
console.log(` Hit@3: ${hit3Pct}% (${results.hit3}/${results.total})`);
|
||||
console.log(` Hit@5: ${hit5Pct}% (${results.hit5}/${results.total})`);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
console.log("QMD Deep Research Evaluation");
|
||||
console.log("=".repeat(60));
|
||||
console.log("Testing hard queries that require semantic understanding.");
|
||||
console.log("These have NO exact keyword matches in documents.");
|
||||
|
||||
// Check if eval-docs collection exists
|
||||
try {
|
||||
const status = execSync("bun src/qmd.ts status --json 2>/dev/null", {
|
||||
encoding: "utf-8",
|
||||
});
|
||||
if (!status.includes("eval-docs")) {
|
||||
console.log("\n⚠️ eval-docs collection not found. Run:");
|
||||
console.log(" qmd collection add test/eval-docs --name eval-docs");
|
||||
console.log(" qmd embed");
|
||||
process.exit(1);
|
||||
}
|
||||
} catch {
|
||||
console.log("\n⚠️ Could not check status. Make sure qmd is working.");
|
||||
}
|
||||
|
||||
const queries = loadQueries();
|
||||
console.log(`\nLoaded ${queries.length} hard queries.`);
|
||||
|
||||
// Run BM25 baseline (expected to fail on most)
|
||||
const bm25Results = evaluate(queries, runBM25Search, "BM25 BASELINE (keyword search)");
|
||||
|
||||
// Run deep research (expected to succeed via expansion + reranking)
|
||||
const deepResults = evaluate(queries, runDeepResearch, "DEEP RESEARCH (expansion + reranking)");
|
||||
|
||||
// Comparison
|
||||
console.log(`\n${"=".repeat(60)}`);
|
||||
console.log(" COMPARISON");
|
||||
console.log(`${"=".repeat(60)}`);
|
||||
console.log(`\n Method Hit@1 Hit@3 Hit@5`);
|
||||
console.log(` ${"─".repeat(45)}`);
|
||||
console.log(
|
||||
` BM25 (baseline) ${((bm25Results.hit1 / bm25Results.total) * 100).toFixed(0).padStart(3)}% ${((bm25Results.hit3 / bm25Results.total) * 100).toFixed(0).padStart(3)}% ${((bm25Results.hit5 / bm25Results.total) * 100).toFixed(0).padStart(3)}%`
|
||||
);
|
||||
console.log(
|
||||
` Deep Research ${((deepResults.hit1 / deepResults.total) * 100).toFixed(0).padStart(3)}% ${((deepResults.hit3 / deepResults.total) * 100).toFixed(0).padStart(3)}% ${((deepResults.hit5 / deepResults.total) * 100).toFixed(0).padStart(3)}%`
|
||||
);
|
||||
|
||||
const improvement = deepResults.hit3 - bm25Results.hit3;
|
||||
console.log(`\n Improvement (Hit@3): +${improvement} queries (${((improvement / bm25Results.total) * 100).toFixed(0)}%)`);
|
||||
|
||||
// Show queries where deep research recovered failures
|
||||
const recovered = deepResults.details.filter(
|
||||
(d) =>
|
||||
d.rank >= 1 &&
|
||||
d.rank <= 3 &&
|
||||
bm25Results.details.find((b) => b.query === d.query)?.rank === -1
|
||||
);
|
||||
|
||||
if (recovered.length > 0) {
|
||||
console.log(`\n Recovered by expansion + reranking (${recovered.length}):`);
|
||||
for (const { query, rank, expected } of recovered.slice(0, 5)) {
|
||||
console.log(` @${rank} "${query.slice(0, 40)}..." → ${expected}`);
|
||||
}
|
||||
if (recovered.length > 5) {
|
||||
console.log(` ... and ${recovered.length - 5} more`);
|
||||
}
|
||||
}
|
||||
|
||||
// Exit with error if deep research performs poorly
|
||||
const deepHit3Pct = (deepResults.hit3 / deepResults.total) * 100;
|
||||
if (deepHit3Pct < 60) {
|
||||
console.log(`\n❌ Deep research Hit@3 < 60% (${deepHit3Pct.toFixed(0)}%)`);
|
||||
process.exit(1);
|
||||
} else {
|
||||
console.log(`\n✓ Deep research Hit@3 >= 60% (${deepHit3Pct.toFixed(0)}%)`);
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
@ -114,14 +114,14 @@ describe("cleanupOrphanedVectors", () => {
|
||||
// =============================================================================
|
||||
|
||||
describe("handelize", () => {
|
||||
test("converts to lowercase", () => {
|
||||
expect(handelize("README.md")).toBe("readme.md");
|
||||
expect(handelize("MyFile.MD")).toBe("myfile.md");
|
||||
test("preserves original case", () => {
|
||||
expect(handelize("README.md")).toBe("README.md");
|
||||
expect(handelize("MyFile.MD")).toBe("MyFile.MD");
|
||||
});
|
||||
|
||||
test("preserves folder structure", () => {
|
||||
expect(handelize("a/b/c/d.md")).toBe("a/b/c/d.md");
|
||||
expect(handelize("docs/api/README.md")).toBe("docs/api/readme.md");
|
||||
expect(handelize("docs/api/README.md")).toBe("docs/api/README.md");
|
||||
});
|
||||
|
||||
test("replaces non-word characters with dash", () => {
|
||||
@ -151,7 +151,7 @@ describe("handelize", () => {
|
||||
test("handles complex real-world meeting notes", () => {
|
||||
const complexName = "Money Movement Licensing Review - 2025/11/19 10:25 EST - Notes by Gemini.md";
|
||||
const result = handelize(complexName);
|
||||
expect(result).toBe("money-movement-licensing-review-2025-11-19-10-25-est-notes-by-gemini.md");
|
||||
expect(result).toBe("Money-Movement-Licensing-Review-2025-11-19-10-25-EST-Notes-by-Gemini.md");
|
||||
expect(result).not.toContain(" ");
|
||||
expect(result).not.toContain("/");
|
||||
expect(result).not.toContain(":");
|
||||
@ -159,7 +159,7 @@ describe("handelize", () => {
|
||||
|
||||
test("handles unicode characters", () => {
|
||||
expect(handelize("日本語.md")).toBe("日本語.md");
|
||||
expect(handelize("Зоны и проекты.md")).toBe("зоны-и-проекты.md");
|
||||
expect(handelize("Зоны и проекты.md")).toBe("Зоны-и-проекты.md");
|
||||
expect(handelize("café-notes.md")).toBe("café-notes.md");
|
||||
expect(handelize("naïve.md")).toBe("naïve.md");
|
||||
expect(handelize("日本語-notes.md")).toBe("日本語-notes.md");
|
||||
@ -181,13 +181,13 @@ describe("handelize", () => {
|
||||
test("handles dates and times in filenames", () => {
|
||||
expect(handelize("meeting-2025-01-15.md")).toBe("meeting-2025-01-15.md");
|
||||
expect(handelize("notes 2025/01/15.md")).toBe("notes-2025/01/15.md");
|
||||
expect(handelize("call_10:30_AM.md")).toBe("call-10-30-am.md");
|
||||
expect(handelize("call_10:30_AM.md")).toBe("call-10-30-AM.md");
|
||||
});
|
||||
|
||||
test("handles special project naming patterns", () => {
|
||||
expect(handelize("PROJECT_ABC_v2.0.md")).toBe("project-abc-v2-0.md");
|
||||
expect(handelize("[WIP] Feature Request.md")).toBe("wip-feature-request.md");
|
||||
expect(handelize("(DRAFT) Proposal v1.md")).toBe("draft-proposal-v1.md");
|
||||
expect(handelize("PROJECT_ABC_v2.0.md")).toBe("PROJECT-ABC-v2.0.md");
|
||||
expect(handelize("[WIP] Feature Request.md")).toBe("WIP-Feature-Request.md");
|
||||
expect(handelize("(DRAFT) Proposal v1.md")).toBe("DRAFT-Proposal-v1.md");
|
||||
});
|
||||
|
||||
test("handles symbol-only route filenames", () => {
|
||||
|
||||
@ -1327,6 +1327,34 @@ describe("FTS Search", () => {
|
||||
await cleanupTestDb(store);
|
||||
});
|
||||
|
||||
test("searchFTS title boost outweighs higher body frequency", async () => {
|
||||
const store = await createTestStore();
|
||||
const collectionName = await createTestCollection();
|
||||
|
||||
// Document with "quantum" mentioned in a longer body but NOT in the title
|
||||
await insertTestDocument(store.db, collectionName, {
|
||||
name: "body-only",
|
||||
title: "General Science Notes",
|
||||
body: "This research paper discusses quantum mechanics and the quantum model of computation. The quantum approach offers improvements over classical methods.",
|
||||
displayPath: "test/body-only.md",
|
||||
});
|
||||
|
||||
// Document with "quantum" in the title but a shorter body mention
|
||||
await insertTestDocument(store.db, collectionName, {
|
||||
name: "title-match",
|
||||
title: "Quantum Computing Overview",
|
||||
body: "An introduction to the fundamentals of this emerging computing paradigm.",
|
||||
displayPath: "test/title-match.md",
|
||||
});
|
||||
|
||||
const results = store.searchFTS("quantum", 10);
|
||||
expect(results.length).toBe(2);
|
||||
// Title-match doc should rank higher due to BM25 column weights boosting title
|
||||
expect(results[0]!.displayPath).toBe(`${collectionName}/test/title-match.md`);
|
||||
|
||||
await cleanupTestDb(store);
|
||||
});
|
||||
|
||||
test("searchFTS respects limit parameter", async () => {
|
||||
const store = await createTestStore();
|
||||
const collectionName = await createTestCollection();
|
||||
|
||||
@ -399,6 +399,14 @@ describe("buildFTS5Query (lex parser)", () => {
|
||||
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
||||
}
|
||||
|
||||
function isHyphenatedToken(token: string): boolean {
|
||||
return /^[\p{L}\p{N}][\p{L}\p{N}'-]*-[\p{L}\p{N}][\p{L}\p{N}'-]*$/u.test(token);
|
||||
}
|
||||
|
||||
function sanitizeHyphenatedTerm(term: string): string {
|
||||
return term.split('-').map(t => sanitizeFTS5Term(t)).filter(t => t).join(' ');
|
||||
}
|
||||
|
||||
function buildFTS5Query(query: string): string | null {
|
||||
const positive: string[] = [];
|
||||
const negative: string[] = [];
|
||||
@ -424,8 +432,14 @@ describe("buildFTS5Query (lex parser)", () => {
|
||||
const start = i;
|
||||
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
||||
const term = s.slice(start, i);
|
||||
const sanitized = sanitizeFTS5Term(term);
|
||||
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"*`);
|
||||
|
||||
if (isHyphenatedToken(term)) {
|
||||
const sanitized = sanitizeHyphenatedTerm(term);
|
||||
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"`);
|
||||
} else {
|
||||
const sanitized = sanitizeFTS5Term(term);
|
||||
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"*`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -488,4 +502,37 @@ describe("buildFTS5Query (lex parser)", () => {
|
||||
test("special chars in terms stripped", () => {
|
||||
expect(buildFTS5Query("hello!world")).toBe('"helloworld"*');
|
||||
});
|
||||
|
||||
// Hyphenated token tests
|
||||
test("hyphenated term → phrase match", () => {
|
||||
expect(buildFTS5Query("multi-agent")).toBe('"multi agent"');
|
||||
});
|
||||
|
||||
test("hyphenated identifier → phrase match", () => {
|
||||
expect(buildFTS5Query("DEC-0054")).toBe('"dec 0054"');
|
||||
});
|
||||
|
||||
test("hyphenated model name → phrase match", () => {
|
||||
expect(buildFTS5Query("gpt-4")).toBe('"gpt 4"');
|
||||
});
|
||||
|
||||
test("multi-hyphen term → phrase match", () => {
|
||||
expect(buildFTS5Query("foo-bar-baz")).toBe('"foo bar baz"');
|
||||
});
|
||||
|
||||
test("hyphenated term mixed with plain terms", () => {
|
||||
expect(buildFTS5Query("multi-agent memory")).toBe('"multi agent" AND "memory"*');
|
||||
});
|
||||
|
||||
test("negation still works alongside hyphenated terms", () => {
|
||||
expect(buildFTS5Query("multi-agent -sports")).toBe('"multi agent" NOT "sports"*');
|
||||
});
|
||||
|
||||
test("negated hyphenated term", () => {
|
||||
expect(buildFTS5Query("performance -multi-agent")).toBe('"performance"* NOT "multi agent"');
|
||||
});
|
||||
|
||||
test("plain negation still works (not confused with hyphen)", () => {
|
||||
expect(buildFTS5Query("performance -sports")).toBe('"performance"* NOT "sports"*');
|
||||
});
|
||||
});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user