Merge origin/main into feat/ast-aware-chunking
Resolve conflicts: combine AST chunking args (filepath, chunkStrategy) with abort signal parameter from #458. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
1fb2e2819e
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
|
|
||||||
|
- Fix paths in nix flake
|
||||||
- Sync stale `bun.lock` (`better-sqlite3` 11.x → 12.x). CI and release
|
- Sync stale `bun.lock` (`better-sqlite3` 11.x → 12.x). CI and release
|
||||||
script now use `--frozen-lockfile` to prevent recurrence. #386
|
script now use `--frozen-lockfile` to prevent recurrence. #386
|
||||||
(thanks @Mic92)
|
(thanks @Mic92)
|
||||||
|
|||||||
2
bun.lock
2
bun.lock
@ -12,7 +12,7 @@
|
|||||||
"picomatch": "^4.0.0",
|
"picomatch": "^4.0.0",
|
||||||
"sqlite-vec": "^0.1.7-alpha.2",
|
"sqlite-vec": "^0.1.7-alpha.2",
|
||||||
"yaml": "^2.8.2",
|
"yaml": "^2.8.2",
|
||||||
"zod": "^4.2.1",
|
"zod": "4.2.1",
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/better-sqlite3": "^7.6.0",
|
"@types/better-sqlite3": "^7.6.0",
|
||||||
|
|||||||
182
finetune/benchmark.py
Normal file
182
finetune/benchmark.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Benchmark QMD query expansion: LFM2.5 vs Qwen3 finetuned models."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
QUERIES = [
|
||||||
|
"kubernetes pod networking",
|
||||||
|
"best practices for React server components",
|
||||||
|
"how to optimize PostgreSQL queries for large tables",
|
||||||
|
"what is retrieval augmented generation",
|
||||||
|
"python async await concurrency patterns",
|
||||||
|
"nginx reverse proxy load balancing",
|
||||||
|
"git rebase vs merge workflow",
|
||||||
|
"rust ownership and borrowing explained",
|
||||||
|
"docker compose multi-stage builds",
|
||||||
|
"elasticsearch full text search performance",
|
||||||
|
"shopify liquid template customization",
|
||||||
|
"machine learning feature engineering techniques",
|
||||||
|
"aws lambda cold start optimization",
|
||||||
|
"typescript generics and utility types",
|
||||||
|
"redis caching strategies for web apps",
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_model(base_name, adapter_dir, device, trust_remote=False):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_name, trust_remote_code=trust_remote)
|
||||||
|
base = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_name, dtype=torch.bfloat16, device_map=device, trust_remote_code=trust_remote
|
||||||
|
)
|
||||||
|
model = PeftModel.from_pretrained(base, adapter_dir, local_files_only=True)
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
gen_config_path = Path(adapter_dir) / "generation_config.json"
|
||||||
|
if gen_config_path.exists():
|
||||||
|
gen_config = GenerationConfig.from_pretrained(adapter_dir)
|
||||||
|
else:
|
||||||
|
gen_config = GenerationConfig(
|
||||||
|
temperature=0.1, top_k=50, top_p=0.1,
|
||||||
|
repetition_penalty=1.05, do_sample=True, max_new_tokens=300,
|
||||||
|
)
|
||||||
|
return model, tokenizer, gen_config
|
||||||
|
|
||||||
|
def run_inference(model, tokenizer, gen_config, query, device):
|
||||||
|
messages = [{"role": "user", "content": f"Expand this search query: {query}"}]
|
||||||
|
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
inputs = tokenizer(text, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model.generate(**inputs, generation_config=gen_config, max_new_tokens=300)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
|
new_tokens = out.shape[-1] - inputs["input_ids"].shape[-1]
|
||||||
|
result = tokenizer.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
|
||||||
|
return result, elapsed, new_tokens
|
||||||
|
|
||||||
|
def score_output(output):
|
||||||
|
"""Simple quality scoring: check for lex/vec/hyde presence and specificity."""
|
||||||
|
score = 0
|
||||||
|
lines = output.strip().split("\n")
|
||||||
|
has_lex = has_vec = has_hyde = False
|
||||||
|
hyde_text = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
l = line.strip()
|
||||||
|
if l.startswith("lex:"):
|
||||||
|
has_lex = True
|
||||||
|
score += 1
|
||||||
|
elif l.startswith("vec:"):
|
||||||
|
has_vec = True
|
||||||
|
score += 1
|
||||||
|
elif l.startswith("hyde:"):
|
||||||
|
has_hyde = True
|
||||||
|
hyde_text = l[5:].strip()
|
||||||
|
score += 2 # hyde is worth more
|
||||||
|
|
||||||
|
# Bonus for hyde length in sweet spot (80-200 chars)
|
||||||
|
if hyde_text:
|
||||||
|
hlen = len(hyde_text)
|
||||||
|
if 80 <= hlen <= 200:
|
||||||
|
score += 2
|
||||||
|
elif 50 <= hlen <= 250:
|
||||||
|
score += 1
|
||||||
|
|
||||||
|
# Penalty for generic/template hyde
|
||||||
|
generic_phrases = ["comprehensive guide", "everything you need to know", "beginners and advanced users"]
|
||||||
|
for phrase in generic_phrases:
|
||||||
|
if phrase in hyde_text.lower():
|
||||||
|
score -= 1
|
||||||
|
|
||||||
|
return score, {"has_lex": has_lex, "has_vec": has_vec, "has_hyde": has_hyde, "hyde_len": len(hyde_text)}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
models = {
|
||||||
|
"LFM2.5-1.2B (finetuned)": {
|
||||||
|
"base": "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||||
|
"adapter": "outputs/sft-lfm2",
|
||||||
|
"trust_remote": True,
|
||||||
|
},
|
||||||
|
"Qwen3-1.7B (finetuned)": {
|
||||||
|
"base": "Qwen/Qwen3-1.7B",
|
||||||
|
"adapter": "outputs/sft",
|
||||||
|
"trust_remote": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for name, cfg in models.items():
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Loading {name}...")
|
||||||
|
model, tokenizer, gen_config = load_model(
|
||||||
|
cfg["base"], cfg["adapter"], device, cfg["trust_remote"]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_results = []
|
||||||
|
total_time = 0
|
||||||
|
total_tokens = 0
|
||||||
|
total_score = 0
|
||||||
|
|
||||||
|
for query in QUERIES:
|
||||||
|
output, elapsed, n_tokens = run_inference(model, tokenizer, gen_config, query, device)
|
||||||
|
score, details = score_output(output)
|
||||||
|
|
||||||
|
model_results.append({
|
||||||
|
"query": query,
|
||||||
|
"output": output,
|
||||||
|
"time_s": round(elapsed, 3),
|
||||||
|
"tokens": n_tokens,
|
||||||
|
"score": score,
|
||||||
|
"details": details,
|
||||||
|
})
|
||||||
|
total_time += elapsed
|
||||||
|
total_tokens += n_tokens
|
||||||
|
total_score += score
|
||||||
|
|
||||||
|
tok_s = n_tokens / elapsed if elapsed > 0 else 0
|
||||||
|
print(f" [{score:2d}] {query[:40]:<40} {elapsed:.2f}s {n_tokens:3d}tok {tok_s:.0f}tok/s")
|
||||||
|
|
||||||
|
avg_time = total_time / len(QUERIES)
|
||||||
|
avg_score = total_score / len(QUERIES)
|
||||||
|
avg_toks = total_tokens / total_time if total_time > 0 else 0
|
||||||
|
|
||||||
|
results[name] = {
|
||||||
|
"queries": model_results,
|
||||||
|
"avg_time_s": round(avg_time, 3),
|
||||||
|
"avg_score": round(avg_score, 2),
|
||||||
|
"avg_tok_s": round(avg_toks, 1),
|
||||||
|
"total_score": total_score,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n Summary: avg_score={avg_score:.2f} avg_time={avg_time:.2f}s avg_tok/s={avg_toks:.0f}")
|
||||||
|
|
||||||
|
# Free GPU memory
|
||||||
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Print comparison
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("COMPARISON")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
for name, r in results.items():
|
||||||
|
print(f"\n{name}:")
|
||||||
|
print(f" Total Score: {r['total_score']} / {len(QUERIES) * 8}") # max ~8 per query
|
||||||
|
print(f" Avg Score: {r['avg_score']}")
|
||||||
|
print(f" Avg Time: {r['avg_time_s']}s")
|
||||||
|
print(f" Throughput: {r['avg_tok_s']} tok/s")
|
||||||
|
|
||||||
|
# Save full results
|
||||||
|
with open("outputs/benchmark_results.json", "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print("\nFull results saved to outputs/benchmark_results.json")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
58
finetune/configs/sft-lfm2.yaml
Normal file
58
finetune/configs/sft-lfm2.yaml
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# SFT Training Config for QMD Query Expansion
|
||||||
|
# Target: LiquidAI LFM2.5-1.2B-Instruct with LoRA
|
||||||
|
#
|
||||||
|
# LFM2.5 is a hybrid model: 10 conv blocks + 6 GQA attention blocks
|
||||||
|
# Uses ChatML template: <|im_start|>user\n...<|im_end|>\n<|im_start|>assistant\n
|
||||||
|
# No /no_think needed (not Qwen3)
|
||||||
|
#
|
||||||
|
# Usage: uv run train.py sft --config configs/sft-lfm2.yaml
|
||||||
|
|
||||||
|
model:
|
||||||
|
base: "LiquidAI/LFM2.5-1.2B-Instruct"
|
||||||
|
output: "outputs/sft-lfm2"
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
name: "data/train-lfm2/"
|
||||||
|
text_field: "text"
|
||||||
|
split: "train"
|
||||||
|
eval_split: 0.1
|
||||||
|
|
||||||
|
training:
|
||||||
|
epochs: 5
|
||||||
|
batch_size: 4
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
learning_rate: 2e-4
|
||||||
|
max_length: 512
|
||||||
|
warmup_ratio: 0.03
|
||||||
|
lr_scheduler: "cosine"
|
||||||
|
|
||||||
|
lora:
|
||||||
|
rank: 16
|
||||||
|
alpha: 32
|
||||||
|
dropout: 0.0
|
||||||
|
target_modules:
|
||||||
|
# Convolution blocks (layers 0,1,3,4,6,7,9,11,13,15)
|
||||||
|
- "conv.in_proj"
|
||||||
|
- "conv.out_proj"
|
||||||
|
# Attention blocks (layers 2,5,8,10,12,14)
|
||||||
|
- "q_proj"
|
||||||
|
- "k_proj"
|
||||||
|
- "v_proj"
|
||||||
|
- "out_proj"
|
||||||
|
# FFN (all 16 layers)
|
||||||
|
- "feed_forward.w1"
|
||||||
|
- "feed_forward.w2"
|
||||||
|
- "feed_forward.w3"
|
||||||
|
|
||||||
|
generation:
|
||||||
|
temperature: 0.1
|
||||||
|
top_k: 50
|
||||||
|
top_p: 0.1
|
||||||
|
repetition_penalty: 1.05
|
||||||
|
|
||||||
|
gguf: false # LFM2.5 hybrid arch not supported by llama.cpp
|
||||||
|
|
||||||
|
tracking:
|
||||||
|
project: "qmd-query-expansion"
|
||||||
|
run_name: "sft-lfm2-1.2B"
|
||||||
1
finetune/data/fix_hyde_checkpoint.json
Normal file
1
finetune/data/fix_hyde_checkpoint.json
Normal file
File diff suppressed because one or more lines are too long
@ -73,6 +73,8 @@ def format_for_training(ex: TrainingExample) -> dict:
|
|||||||
text = text.replace("<think>\n\n</think>\n\n", "")
|
text = text.replace("<think>\n\n</think>\n\n", "")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"query": ex.query,
|
||||||
|
"output": ex.output_as_lists(),
|
||||||
"text": text,
|
"text": text,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
}
|
}
|
||||||
|
|||||||
85
finetune/dataset/prepare_data_lfm2.py
Normal file
85
finetune/dataset/prepare_data_lfm2.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Prepare QMD query expansion data for LFM2.5-1.2B-Instruct training.
|
||||||
|
|
||||||
|
LFM2.5 uses ChatML format:
|
||||||
|
<|startoftext|><|im_start|>user
|
||||||
|
Expand this search query: {query}<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
{output}<|im_end|>
|
||||||
|
|
||||||
|
No /no_think needed (that's Qwen3-specific).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from dataset.schema import normalize_output_items, output_items_to_text
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def format_for_training(query_text: str, output_items: list[list[str]], tokenizer) -> dict:
|
||||||
|
"""Format a single example for SFT training using LFM2.5 chat format."""
|
||||||
|
output_text = output_items_to_text(output_items)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"Expand this search query: {query_text}"},
|
||||||
|
{"role": "assistant", "content": output_text},
|
||||||
|
]
|
||||||
|
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"text": text}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
input_path = Path("data/qmd_expansion_v2.jsonl")
|
||||||
|
output_dir = Path("data/train-lfm2")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print("Loading LFM2.5 tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"LiquidAI/LFM2.5-1.2B-Instruct", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
with open(input_path) as f:
|
||||||
|
for line in f:
|
||||||
|
row = json.loads(line)
|
||||||
|
items = normalize_output_items(row["output"])
|
||||||
|
example = format_for_training(row["query"], items, tokenizer)
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
# Shuffle and split
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(examples)
|
||||||
|
|
||||||
|
split_idx = int(len(examples) * 0.9)
|
||||||
|
train = examples[:split_idx]
|
||||||
|
val = examples[split_idx:]
|
||||||
|
|
||||||
|
# Write as JSONL
|
||||||
|
train_path = output_dir / "train.jsonl"
|
||||||
|
val_path = output_dir / "val.jsonl"
|
||||||
|
|
||||||
|
with open(train_path, "w") as f:
|
||||||
|
for ex in train:
|
||||||
|
f.write(json.dumps(ex) + "\n")
|
||||||
|
|
||||||
|
with open(val_path, "w") as f:
|
||||||
|
for ex in val:
|
||||||
|
f.write(json.dumps(ex) + "\n")
|
||||||
|
|
||||||
|
print(f"Written {len(train)} train, {len(val)} val examples to {output_dir}")
|
||||||
|
print(f"\nSample formatted text:")
|
||||||
|
print(train[0]["text"][:500])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
488
finetune/eval_retrieval.py
Normal file
488
finetune/eval_retrieval.py
Normal file
@ -0,0 +1,488 @@
|
|||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.10"
|
||||||
|
# dependencies = [
|
||||||
|
# "transformers>=4.45.0",
|
||||||
|
# "peft>=0.7.0",
|
||||||
|
# "torch",
|
||||||
|
# "accelerate",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
"""
|
||||||
|
QMD Retrieval-Based Evaluation with Precision & Recall
|
||||||
|
|
||||||
|
Evaluates model outputs against golden data (training set).
|
||||||
|
Measures how well the model reproduces the expected expansions.
|
||||||
|
|
||||||
|
Metrics:
|
||||||
|
- Precision: Of model-generated expansions, how many match golden?
|
||||||
|
- Recall: Of golden expansions, how many did the model generate?
|
||||||
|
- F1: Harmonic mean of precision and recall
|
||||||
|
|
||||||
|
Matching is done via token overlap (Jaccard similarity) with a threshold.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run eval_retrieval.py ./outputs/sft
|
||||||
|
uv run eval_retrieval.py tobil/qmd-query-expansion-1.7B --golden data/qmd_expansion_v3_structured.jsonl
|
||||||
|
uv run eval_retrieval.py ./outputs/sft --threshold 0.5 --sample 100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Matching Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def tokenize(text: str) -> set[str]:
|
||||||
|
"""Tokenize text into lowercase word set, removing stopwords."""
|
||||||
|
stopwords = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and',
|
||||||
|
'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
||||||
|
'how', 'what', 'do', 'does', 'can', 'you', 'your', 'i'}
|
||||||
|
words = re.findall(r'\b\w+\b', text.lower())
|
||||||
|
return {w for w in words if w not in stopwords and len(w) > 1}
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_similarity(a: str, b: str) -> float:
|
||||||
|
"""Jaccard similarity between two strings based on token overlap."""
|
||||||
|
tokens_a = tokenize(a)
|
||||||
|
tokens_b = tokenize(b)
|
||||||
|
if not tokens_a or not tokens_b:
|
||||||
|
return 0.0
|
||||||
|
intersection = len(tokens_a & tokens_b)
|
||||||
|
union = len(tokens_a | tokens_b)
|
||||||
|
return intersection / union if union > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_match(pred: str, golden_list: list[str], threshold: float) -> tuple[str | None, float]:
|
||||||
|
"""Find best matching golden expansion for a prediction."""
|
||||||
|
best_match = None
|
||||||
|
best_score = 0.0
|
||||||
|
for golden in golden_list:
|
||||||
|
score = jaccard_similarity(pred, golden)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_match = golden
|
||||||
|
if best_score >= threshold:
|
||||||
|
return best_match, best_score
|
||||||
|
return None, best_score
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Parsing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def parse_model_output(text: str) -> dict[str, list[str]]:
|
||||||
|
"""Parse model output into {lex: [...], vec: [...], hyde: [...]}."""
|
||||||
|
# Clean thinking tags
|
||||||
|
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
||||||
|
text = text.replace('<|im_end|>', '').strip()
|
||||||
|
|
||||||
|
result = {"lex": [], "vec": [], "hyde": []}
|
||||||
|
for line in text.strip().split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("lex:"):
|
||||||
|
result["lex"].append(line[4:].strip())
|
||||||
|
elif line.startswith("vec:"):
|
||||||
|
result["vec"].append(line[4:].strip())
|
||||||
|
elif line.startswith("hyde:"):
|
||||||
|
result["hyde"].append(line[5:].strip())
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_golden_data(searches: list[dict] | str) -> dict[str, list[str]]:
|
||||||
|
"""Parse golden data format into {lex: [...], vec: [...], hyde: [...]}."""
|
||||||
|
# If it's a string (from messages format), parse it
|
||||||
|
if isinstance(searches, str):
|
||||||
|
return parse_model_output(searches)
|
||||||
|
|
||||||
|
# Otherwise it's the structured format [{type, query}, ...]
|
||||||
|
result = {"lex": [], "vec": [], "hyde": []}
|
||||||
|
for item in searches:
|
||||||
|
exp_type = item.get("type", "")
|
||||||
|
value = item.get("query", "") or item.get("value", "")
|
||||||
|
if exp_type in result:
|
||||||
|
result[exp_type].append(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_golden_data(filepath: Path) -> list[dict]:
|
||||||
|
"""Load golden data from JSONL, supporting both structured and messages formats."""
|
||||||
|
data = []
|
||||||
|
with open(filepath) as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
item = json.loads(line)
|
||||||
|
|
||||||
|
# Structured format: {query, searches}
|
||||||
|
if "query" in item and "searches" in item:
|
||||||
|
data.append({
|
||||||
|
"query": item["query"],
|
||||||
|
"searches": item["searches"]
|
||||||
|
})
|
||||||
|
# Messages format: {messages: [{role, content}, ...]}
|
||||||
|
elif "messages" in item:
|
||||||
|
messages = item["messages"]
|
||||||
|
query = None
|
||||||
|
searches = None
|
||||||
|
for msg in messages:
|
||||||
|
if msg["role"] == "user":
|
||||||
|
# Extract query from "/no_think Expand this search query: ..."
|
||||||
|
content = msg["content"]
|
||||||
|
if "Expand this search query:" in content:
|
||||||
|
query = content.split("Expand this search query:")[-1].strip()
|
||||||
|
else:
|
||||||
|
query = content.strip()
|
||||||
|
elif msg["role"] == "assistant":
|
||||||
|
# The assistant content IS the expected output
|
||||||
|
searches = msg["content"]
|
||||||
|
if query and searches:
|
||||||
|
data.append({
|
||||||
|
"query": query,
|
||||||
|
"searches": searches # Will be parsed as string
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Metrics Calculation
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Different thresholds by type - lex needs strict matching, hyde is more flexible
|
||||||
|
DEFAULT_THRESHOLDS = {
|
||||||
|
"lex": 0.5, # Keywords should overlap well
|
||||||
|
"vec": 0.35, # Semantic sentences have more variation
|
||||||
|
"hyde": 0.25, # Passages have the most variation
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(
|
||||||
|
predictions: dict[str, list[str]],
|
||||||
|
golden: dict[str, list[str]],
|
||||||
|
threshold: float | dict[str, float] = 0.4,
|
||||||
|
return_mismatches: bool = False
|
||||||
|
) -> dict:
|
||||||
|
"""Calculate precision, recall, F1 per type and overall.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: Either a single float, or dict mapping type -> threshold
|
||||||
|
return_mismatches: If True, include lists of unmatched predictions/golden
|
||||||
|
"""
|
||||||
|
if isinstance(threshold, (int, float)):
|
||||||
|
thresholds = {"lex": threshold, "vec": threshold, "hyde": threshold}
|
||||||
|
else:
|
||||||
|
thresholds = threshold
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
mismatches = {}
|
||||||
|
total_tp = 0
|
||||||
|
total_pred = 0
|
||||||
|
total_golden = 0
|
||||||
|
|
||||||
|
for exp_type in ["lex", "vec", "hyde"]:
|
||||||
|
preds = predictions.get(exp_type, [])
|
||||||
|
golds = golden.get(exp_type, [])
|
||||||
|
type_threshold = thresholds.get(exp_type, 0.4)
|
||||||
|
|
||||||
|
if not preds and not golds:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Track which golden items were matched
|
||||||
|
matched_golden = set()
|
||||||
|
unmatched_preds = []
|
||||||
|
tp = 0
|
||||||
|
|
||||||
|
for pred in preds:
|
||||||
|
match, score = find_best_match(pred, golds, type_threshold)
|
||||||
|
if match is not None:
|
||||||
|
tp += 1
|
||||||
|
matched_golden.add(match)
|
||||||
|
else:
|
||||||
|
unmatched_preds.append((pred, score))
|
||||||
|
|
||||||
|
unmatched_golden = [g for g in golds if g not in matched_golden]
|
||||||
|
|
||||||
|
precision = tp / len(preds) if preds else 0.0
|
||||||
|
recall = len(matched_golden) / len(golds) if golds else 0.0
|
||||||
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||||
|
|
||||||
|
metrics[exp_type] = {
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1": f1,
|
||||||
|
"pred_count": len(preds),
|
||||||
|
"golden_count": len(golds),
|
||||||
|
"matched": tp,
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_mismatches:
|
||||||
|
mismatches[exp_type] = {
|
||||||
|
"unmatched_preds": unmatched_preds,
|
||||||
|
"unmatched_golden": unmatched_golden,
|
||||||
|
}
|
||||||
|
|
||||||
|
total_tp += tp
|
||||||
|
total_pred += len(preds)
|
||||||
|
total_golden += len(golds)
|
||||||
|
|
||||||
|
# Overall metrics (micro-averaged)
|
||||||
|
overall_precision = total_tp / total_pred if total_pred > 0 else 0.0
|
||||||
|
overall_recall = total_tp / total_golden if total_golden > 0 else 0.0
|
||||||
|
overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
|
||||||
|
|
||||||
|
metrics["overall"] = {
|
||||||
|
"precision": overall_precision,
|
||||||
|
"recall": overall_recall,
|
||||||
|
"f1": overall_f1,
|
||||||
|
"pred_count": total_pred,
|
||||||
|
"golden_count": total_golden,
|
||||||
|
"matched": total_tp,
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_mismatches:
|
||||||
|
metrics["_mismatches"] = mismatches
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Model Loading and Generation
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_model(model_path: str):
|
||||||
|
"""Load model (adapter or merged)."""
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
model_path = Path(model_path)
|
||||||
|
adapter_config = model_path / "adapter_config.json"
|
||||||
|
|
||||||
|
# Get base model from adapter config or default
|
||||||
|
base_model = "Qwen/Qwen3-1.7B"
|
||||||
|
if adapter_config.exists():
|
||||||
|
with open(adapter_config) as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
base_model = cfg.get("base_model_name_or_path", base_model)
|
||||||
|
|
||||||
|
print(f"Loading base: {base_model}", file=sys.stderr)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(base_model)
|
||||||
|
config.tie_word_embeddings = False
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model, dtype=torch.bfloat16, device_map={"": 0}, config=config
|
||||||
|
)
|
||||||
|
if model.generation_config is not None:
|
||||||
|
model.generation_config.do_sample = False
|
||||||
|
model.generation_config.temperature = None
|
||||||
|
model.generation_config.top_p = None
|
||||||
|
model.generation_config.top_k = None
|
||||||
|
|
||||||
|
# Load adapter if present
|
||||||
|
if adapter_config.exists():
|
||||||
|
print(f"Loading adapter: {model_path}", file=sys.stderr)
|
||||||
|
model = PeftModel.from_pretrained(model, str(model_path))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 400) -> str:
|
||||||
|
"""Generate expansion for a single query."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": f"/no_think Expand this search query: {query}"}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
input_len = inputs["input_ids"].shape[1]
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
out = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_tokens = out[0][input_len:]
|
||||||
|
return tokenizer.decode(gen_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Main Evaluation
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="QMD Retrieval-Based Evaluation")
|
||||||
|
parser.add_argument("model", help="Model path (local or HF)")
|
||||||
|
parser.add_argument("--golden", default="data/qmd_expansion_v3_structured.jsonl",
|
||||||
|
help="Golden data JSONL file")
|
||||||
|
parser.add_argument("--threshold", type=float, default=None,
|
||||||
|
help="Jaccard similarity threshold for all types (overrides --type-thresholds)")
|
||||||
|
parser.add_argument("--type-thresholds", action="store_true",
|
||||||
|
help="Use type-specific thresholds (lex=0.5, vec=0.35, hyde=0.25)")
|
||||||
|
parser.add_argument("--sample", type=int, default=0,
|
||||||
|
help="Sample N queries (0 = all)")
|
||||||
|
parser.add_argument("--seed", type=int, default=42,
|
||||||
|
help="Random seed for sampling")
|
||||||
|
parser.add_argument("--max-new-tokens", type=int, default=400,
|
||||||
|
help="Max new tokens to generate")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true",
|
||||||
|
help="Show per-query details")
|
||||||
|
parser.add_argument("--show-mismatches", action="store_true",
|
||||||
|
help="Show examples of mismatched predictions")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Determine thresholds
|
||||||
|
if args.threshold is not None:
|
||||||
|
thresholds = args.threshold
|
||||||
|
elif args.type_thresholds:
|
||||||
|
thresholds = DEFAULT_THRESHOLDS.copy()
|
||||||
|
else:
|
||||||
|
thresholds = 0.4 # Default single threshold
|
||||||
|
|
||||||
|
# Load golden data
|
||||||
|
golden_path = Path(args.golden)
|
||||||
|
if not golden_path.exists():
|
||||||
|
# Try relative to script directory
|
||||||
|
golden_path = Path(__file__).parent / args.golden
|
||||||
|
|
||||||
|
if not golden_path.exists():
|
||||||
|
print(f"Error: Golden data file not found: {args.golden}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loading golden data from {golden_path}...", file=sys.stderr)
|
||||||
|
golden_data = load_golden_data(golden_path)
|
||||||
|
print(f"Loaded {len(golden_data)} golden examples", file=sys.stderr)
|
||||||
|
|
||||||
|
# Sample if requested
|
||||||
|
if args.sample > 0 and args.sample < len(golden_data):
|
||||||
|
random.seed(args.seed)
|
||||||
|
golden_data = random.sample(golden_data, args.sample)
|
||||||
|
print(f"Sampled {len(golden_data)} examples", file=sys.stderr)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model, tokenizer = load_model(args.model)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
all_metrics = []
|
||||||
|
all_mismatches = []
|
||||||
|
type_aggregates = defaultdict(lambda: {"precision": [], "recall": [], "f1": []})
|
||||||
|
|
||||||
|
threshold_desc = thresholds if isinstance(thresholds, (int, float)) else f"lex={thresholds['lex']}, vec={thresholds['vec']}, hyde={thresholds['hyde']}"
|
||||||
|
print(f"\nEvaluating {len(golden_data)} queries (thresholds: {threshold_desc})...\n")
|
||||||
|
|
||||||
|
for i, item in enumerate(golden_data, 1):
|
||||||
|
query = item["query"]
|
||||||
|
golden_parsed = parse_golden_data(item["searches"])
|
||||||
|
|
||||||
|
# Generate model output
|
||||||
|
output = generate_expansion(model, tokenizer, query, args.max_new_tokens)
|
||||||
|
pred_parsed = parse_model_output(output)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
metrics = calculate_metrics(pred_parsed, golden_parsed, thresholds, return_mismatches=args.show_mismatches)
|
||||||
|
all_metrics.append({"query": query, "metrics": metrics, "pred": pred_parsed, "golden": golden_parsed})
|
||||||
|
|
||||||
|
if args.show_mismatches and "_mismatches" in metrics:
|
||||||
|
all_mismatches.append({"query": query, "mismatches": metrics.pop("_mismatches")})
|
||||||
|
|
||||||
|
# Aggregate by type
|
||||||
|
for exp_type in ["lex", "vec", "hyde", "overall"]:
|
||||||
|
if exp_type in metrics:
|
||||||
|
type_aggregates[exp_type]["precision"].append(metrics[exp_type]["precision"])
|
||||||
|
type_aggregates[exp_type]["recall"].append(metrics[exp_type]["recall"])
|
||||||
|
type_aggregates[exp_type]["f1"].append(metrics[exp_type]["f1"])
|
||||||
|
|
||||||
|
# Progress
|
||||||
|
overall = metrics.get("overall", {})
|
||||||
|
p = overall.get("precision", 0) * 100
|
||||||
|
r = overall.get("recall", 0) * 100
|
||||||
|
f = overall.get("f1", 0) * 100
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
print(f"[{i:3d}/{len(golden_data)}] P={p:5.1f}% R={r:5.1f}% F1={f:5.1f}% {query[:50]}")
|
||||||
|
elif i % 50 == 0 or i == len(golden_data):
|
||||||
|
print(f" Processed {i}/{len(golden_data)}...", file=sys.stderr)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"RESULTS: {args.model}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Threshold: {args.threshold} | Samples: {len(golden_data)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"{'Type':<10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
||||||
|
print("-" * 42)
|
||||||
|
|
||||||
|
for exp_type in ["lex", "vec", "hyde", "overall"]:
|
||||||
|
if exp_type in type_aggregates:
|
||||||
|
agg = type_aggregates[exp_type]
|
||||||
|
avg_p = sum(agg["precision"]) / len(agg["precision"]) * 100 if agg["precision"] else 0
|
||||||
|
avg_r = sum(agg["recall"]) / len(agg["recall"]) * 100 if agg["recall"] else 0
|
||||||
|
avg_f = sum(agg["f1"]) / len(agg["f1"]) * 100 if agg["f1"] else 0
|
||||||
|
label = exp_type.upper() if exp_type != "overall" else "OVERALL"
|
||||||
|
print(f"{label:<10} {avg_p:>9.1f}% {avg_r:>9.1f}% {avg_f:>9.1f}%")
|
||||||
|
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Show worst examples
|
||||||
|
print("\nBottom 5 by F1:")
|
||||||
|
sorted_by_f1 = sorted(all_metrics, key=lambda x: x["metrics"].get("overall", {}).get("f1", 0))
|
||||||
|
for item in sorted_by_f1[:5]:
|
||||||
|
f1 = item["metrics"].get("overall", {}).get("f1", 0) * 100
|
||||||
|
print(f" {f1:5.1f}% {item['query'][:60]}")
|
||||||
|
|
||||||
|
# Show mismatches if requested
|
||||||
|
if args.show_mismatches and all_mismatches:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("MISMATCH EXAMPLES")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Group by type and show up to 3 examples per type
|
||||||
|
for exp_type in ["lex", "vec", "hyde"]:
|
||||||
|
type_mismatches = []
|
||||||
|
for item in all_mismatches:
|
||||||
|
if exp_type in item["mismatches"]:
|
||||||
|
mm = item["mismatches"][exp_type]
|
||||||
|
if mm["unmatched_preds"] or mm["unmatched_golden"]:
|
||||||
|
type_mismatches.append({
|
||||||
|
"query": item["query"],
|
||||||
|
**mm
|
||||||
|
})
|
||||||
|
|
||||||
|
if type_mismatches:
|
||||||
|
print(f"\n--- {exp_type.upper()} mismatches ({len(type_mismatches)} queries) ---")
|
||||||
|
for example in type_mismatches[:3]:
|
||||||
|
print(f"\nQuery: {example['query'][:60]}")
|
||||||
|
if example["unmatched_preds"]:
|
||||||
|
print(f" Unmatched predictions:")
|
||||||
|
for pred, score in example["unmatched_preds"][:2]:
|
||||||
|
print(f" - [{score:.2f}] {pred[:80]}{'...' if len(pred) > 80 else ''}")
|
||||||
|
if example["unmatched_golden"]:
|
||||||
|
print(f" Missing golden:")
|
||||||
|
for g in example["unmatched_golden"][:2]:
|
||||||
|
print(f" - {g[:80]}{'...' if len(g) > 80 else ''}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
20
finetune/train-0.8B.log
Normal file
20
finetune/train-0.8B.log
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
============================================================
|
||||||
|
QMD Query Expansion — Unsloth SFT
|
||||||
|
Base model: unsloth/Qwen3.5-0.8B
|
||||||
|
Output: outputs/qwen3.5-0.8B
|
||||||
|
Data: data/train/train.jsonl
|
||||||
|
Epochs: 5
|
||||||
|
Batch: 4 x 4 accum
|
||||||
|
LR: 0.0002
|
||||||
|
LoRA rank: 16
|
||||||
|
Max seq len: 512
|
||||||
|
============================================================
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/home/tobi/src/github.com/tobi/qmd/finetune/train_unsloth.py", line 198, in <module>
|
||||||
|
main()
|
||||||
|
~~~~^^
|
||||||
|
File "/home/tobi/src/github.com/tobi/qmd/finetune/train_unsloth.py", line 68, in main
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
File "/home/tobi/src/github.com/tobi/qmd/finetune/.venv-unsloth/lib/python3.14/site-packages/unsloth/__init__.py", line 93, in <module>
|
||||||
|
raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!")
|
||||||
|
NotImplementedError: Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!
|
||||||
198
finetune/train_unsloth.py
Normal file
198
finetune/train_unsloth.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python train_unsloth.py --model 0.8B
|
||||||
|
python train_unsloth.py --model 2B
|
||||||
|
python train_unsloth.py --model 4B --epochs 3
|
||||||
|
|
||||||
|
Requires: pip install unsloth unsloth_zoo
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
MODEL_MAP = {
|
||||||
|
"0.8B": "unsloth/Qwen3.5-0.8B",
|
||||||
|
"2B": "unsloth/Qwen3.5-2B",
|
||||||
|
"4B": "unsloth/Qwen3.5-4B",
|
||||||
|
"9B": "unsloth/Qwen3.5-9B",
|
||||||
|
"27B": "unsloth/Qwen3.5-27B",
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth")
|
||||||
|
parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()),
|
||||||
|
help="Model size to train")
|
||||||
|
parser.add_argument("--epochs", type=int, default=5)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=4)
|
||||||
|
parser.add_argument("--grad-accum", type=int, default=4)
|
||||||
|
parser.add_argument("--lr", type=float, default=2e-4)
|
||||||
|
parser.add_argument("--max-seq-len", type=int, default=512)
|
||||||
|
parser.add_argument("--lora-rank", type=int, default=16)
|
||||||
|
parser.add_argument("--data", type=str, default="data/train/train.jsonl")
|
||||||
|
parser.add_argument("--output", type=str, default=None,
|
||||||
|
help="Output directory (default: outputs/qwen3.5-{size})")
|
||||||
|
parser.add_argument("--push-hub", type=str, default=None,
|
||||||
|
help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)")
|
||||||
|
parser.add_argument("--no-gguf", action="store_true")
|
||||||
|
parser.add_argument("--no-eval", action="store_true")
|
||||||
|
parser.add_argument("--dry-run", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_name = MODEL_MAP[args.model]
|
||||||
|
output_dir = args.output or f"outputs/qwen3.5-{args.model}"
|
||||||
|
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"QMD Query Expansion — Unsloth SFT")
|
||||||
|
print(f" Base model: {model_name}")
|
||||||
|
print(f" Output: {output_dir}")
|
||||||
|
print(f" Data: {args.data}")
|
||||||
|
print(f" Epochs: {args.epochs}")
|
||||||
|
print(f" Batch: {args.batch_size} x {args.grad_accum} accum")
|
||||||
|
print(f" LR: {args.lr}")
|
||||||
|
print(f" LoRA rank: {args.lora_rank}")
|
||||||
|
print(f" Max seq len: {args.max_seq_len}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print("Dry run — exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Imports (heavy) ---
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from datasets import load_dataset
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
|
||||||
|
# --- Load model ---
|
||||||
|
print(f"\nLoading {model_name}...")
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=model_name,
|
||||||
|
max_seq_length=args.max_seq_len,
|
||||||
|
load_in_4bit=False,
|
||||||
|
load_in_16bit=True,
|
||||||
|
full_finetuning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- LoRA ---
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=args.lora_rank,
|
||||||
|
target_modules=[
|
||||||
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj",
|
||||||
|
],
|
||||||
|
lora_alpha=args.lora_rank,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias="none",
|
||||||
|
use_gradient_checkpointing="unsloth",
|
||||||
|
random_state=3407,
|
||||||
|
max_seq_length=args.max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Dataset ---
|
||||||
|
print(f"Loading dataset from {args.data}...")
|
||||||
|
dataset = load_dataset("json", data_files=args.data, split="train")
|
||||||
|
dataset = dataset.shuffle(seed=42)
|
||||||
|
split = dataset.train_test_split(test_size=0.1, seed=42)
|
||||||
|
train_ds = split["train"]
|
||||||
|
eval_ds = split["test"]
|
||||||
|
print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
|
||||||
|
|
||||||
|
# --- Tracking ---
|
||||||
|
report_to = "none"
|
||||||
|
if os.environ.get("HF_TOKEN"):
|
||||||
|
try:
|
||||||
|
import trackio
|
||||||
|
report_to = "trackio"
|
||||||
|
os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- Trainer ---
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_dataset=train_ds,
|
||||||
|
eval_dataset=eval_ds,
|
||||||
|
args=SFTConfig(
|
||||||
|
output_dir=output_dir,
|
||||||
|
max_seq_length=args.max_seq_len,
|
||||||
|
num_train_epochs=args.epochs,
|
||||||
|
per_device_train_batch_size=args.batch_size,
|
||||||
|
gradient_accumulation_steps=args.grad_accum,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
warmup_ratio=0.03,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=10,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=200,
|
||||||
|
save_total_limit=3,
|
||||||
|
eval_strategy="steps",
|
||||||
|
eval_steps=200,
|
||||||
|
bf16=True,
|
||||||
|
optim="adamw_8bit",
|
||||||
|
seed=3407,
|
||||||
|
dataset_num_proc=4,
|
||||||
|
report_to=report_to,
|
||||||
|
run_name=f"sft-qwen3.5-{args.model}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nStarting training...")
|
||||||
|
stats = trainer.train()
|
||||||
|
print(f"\nTraining complete!")
|
||||||
|
print(f" Total steps: {stats.global_step}")
|
||||||
|
print(f" Final loss: {stats.training_loss:.4f}")
|
||||||
|
|
||||||
|
# --- Save ---
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
print(f"Adapter saved to {output_dir}")
|
||||||
|
|
||||||
|
# --- GGUF export ---
|
||||||
|
if not args.no_gguf:
|
||||||
|
print("\nExporting GGUF quantizations...")
|
||||||
|
gguf_dir = f"{output_dir}/gguf"
|
||||||
|
for quant in ["q4_k_m", "q8_0"]:
|
||||||
|
print(f" {quant}...")
|
||||||
|
try:
|
||||||
|
model.save_pretrained_gguf(
|
||||||
|
gguf_dir, tokenizer, quantization_method=quant
|
||||||
|
)
|
||||||
|
print(f" ✓ {quant} saved")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ {quant} failed: {e}")
|
||||||
|
|
||||||
|
# --- Push to Hub ---
|
||||||
|
if args.push_hub:
|
||||||
|
print(f"\nPushing to {args.push_hub}...")
|
||||||
|
model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora")
|
||||||
|
if not args.no_gguf:
|
||||||
|
for quant in ["q4_k_m", "q8_0"]:
|
||||||
|
try:
|
||||||
|
model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" GGUF push {quant} failed: {e}")
|
||||||
|
|
||||||
|
# --- Eval ---
|
||||||
|
if not args.no_eval:
|
||||||
|
print("\nRunning evaluation...")
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(
|
||||||
|
[sys.executable, "eval.py", output_dir],
|
||||||
|
cwd=str(Path(__file__).parent),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Done! Model at: {output_dir}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -48,7 +48,7 @@
|
|||||||
cp package.json $out/lib/qmd/
|
cp package.json $out/lib/qmd/
|
||||||
|
|
||||||
makeWrapper ${pkgs.bun}/bin/bun $out/bin/qmd \
|
makeWrapper ${pkgs.bun}/bin/bun $out/bin/qmd \
|
||||||
--add-flags "$out/lib/qmd/src/qmd.ts" \
|
--add-flags "$out/lib/qmd/src/cli/qmd.ts" \
|
||||||
--set DYLD_LIBRARY_PATH "${pkgs.sqlite.out}/lib" \
|
--set DYLD_LIBRARY_PATH "${pkgs.sqlite.out}/lib" \
|
||||||
--set LD_LIBRARY_PATH "${pkgs.sqlite.out}/lib"
|
--set LD_LIBRARY_PATH "${pkgs.sqlite.out}/lib"
|
||||||
'';
|
'';
|
||||||
@ -81,7 +81,7 @@
|
|||||||
shellHook = ''
|
shellHook = ''
|
||||||
export BREW_PREFIX="''${BREW_PREFIX:-${sqliteWithExtensions.out}}"
|
export BREW_PREFIX="''${BREW_PREFIX:-${sqliteWithExtensions.out}}"
|
||||||
echo "QMD development shell"
|
echo "QMD development shell"
|
||||||
echo "Run: bun src/qmd.ts <command>"
|
echo "Run: bun src/cli/qmd.ts <command>"
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
23
src/llm.ts
23
src/llm.ts
@ -209,7 +209,9 @@ export const DEFAULT_RERANK_MODEL_URI = DEFAULT_RERANK_MODEL;
|
|||||||
export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL;
|
export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL;
|
||||||
|
|
||||||
// Local model cache directory
|
// Local model cache directory
|
||||||
const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models");
|
const MODEL_CACHE_DIR = process.env.XDG_CACHE_HOME
|
||||||
|
? join(process.env.XDG_CACHE_HOME, "qmd", "models")
|
||||||
|
: join(homedir(), ".cache", "qmd", "models");
|
||||||
export const DEFAULT_MODEL_CACHE_DIR = MODEL_CACHE_DIR;
|
export const DEFAULT_MODEL_CACHE_DIR = MODEL_CACHE_DIR;
|
||||||
|
|
||||||
export type PullResult = {
|
export type PullResult = {
|
||||||
@ -757,9 +759,16 @@ export class LlamaCpp implements LLM {
|
|||||||
* - Combined: drops from 11.6 GB (auto, no flash) to 568 MB per context (20×)
|
* - Combined: drops from 11.6 GB (auto, no flash) to 568 MB per context (20×)
|
||||||
*/
|
*/
|
||||||
// Qwen3 reranker template adds ~200 tokens overhead (system prompt, tags, etc.)
|
// Qwen3 reranker template adds ~200 tokens overhead (system prompt, tags, etc.)
|
||||||
// Chunks are max 800 tokens, so 800 + 200 + query ≈ 1100 tokens typical.
|
// Default 2048 was too small for longer documents (e.g. session transcripts,
|
||||||
// Use 2048 for safety margin. Still 17× less than auto (40960).
|
// CJK text, or large markdown files) — callers hit "input lengths exceed
|
||||||
private static readonly RERANK_CONTEXT_SIZE = 2048;
|
// context size" errors even after truncation because the overhead estimate
|
||||||
|
// was insufficient. 4096 comfortably fits the largest real-world chunks
|
||||||
|
// while staying well below the 40 960-token auto size.
|
||||||
|
// Override with QMD_RERANK_CONTEXT_SIZE env var if you need more headroom.
|
||||||
|
private static readonly RERANK_CONTEXT_SIZE: number = (() => {
|
||||||
|
const v = parseInt(process.env.QMD_RERANK_CONTEXT_SIZE ?? "", 10);
|
||||||
|
return Number.isFinite(v) && v > 0 ? v : 4096;
|
||||||
|
})();
|
||||||
private async ensureRerankContexts(): Promise<Awaited<ReturnType<LlamaModel["createRankingContext"]>>[]> {
|
private async ensureRerankContexts(): Promise<Awaited<ReturnType<LlamaModel["createRankingContext"]>>[]> {
|
||||||
if (this.rerankContexts.length === 0) {
|
if (this.rerankContexts.length === 0) {
|
||||||
const model = await this.ensureRerankModel();
|
const model = await this.ensureRerankModel();
|
||||||
@ -1099,8 +1108,10 @@ export class LlamaCpp implements LLM {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Qwen3 reranker chat template overhead (system prompt, tags, separators)
|
// Qwen3 reranker chat template overhead (system prompt, tags, separators).
|
||||||
private static readonly RERANK_TEMPLATE_OVERHEAD = 200;
|
// Measured at ~350 tokens on real queries; use 512 as a safe upper bound so
|
||||||
|
// the truncation budget never lets a document slip past the context limit.
|
||||||
|
private static readonly RERANK_TEMPLATE_OVERHEAD = 512;
|
||||||
private static readonly RERANK_TARGET_DOCS_PER_CONTEXT = 10;
|
private static readonly RERANK_TARGET_DOCS_PER_CONTEXT = 10;
|
||||||
|
|
||||||
async rerank(
|
async rerank(
|
||||||
|
|||||||
@ -296,9 +296,12 @@ Intent-aware lex (C++ performance, not sports):
|
|||||||
intent: z.string().optional().describe(
|
intent: z.string().optional().describe(
|
||||||
"Background context to disambiguate the query. Example: query='performance', intent='web page load times and Core Web Vitals'. Does not search on its own."
|
"Background context to disambiguate the query. Example: query='performance', intent='web page load times and Core Web Vitals'. Does not search on its own."
|
||||||
),
|
),
|
||||||
|
rerank: z.boolean().optional().default(true).describe(
|
||||||
|
"Rerank results using LLM (default: true). Set to false for faster results on CPU-only machines."
|
||||||
|
),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
async ({ searches, limit, minScore, candidateLimit, collections, intent }) => {
|
async ({ searches, limit, minScore, candidateLimit, collections, intent, rerank }) => {
|
||||||
// Map to internal format
|
// Map to internal format
|
||||||
const queries: ExpandedQuery[] = searches.map(s => ({
|
const queries: ExpandedQuery[] = searches.map(s => ({
|
||||||
type: s.type,
|
type: s.type,
|
||||||
@ -313,6 +316,7 @@ Intent-aware lex (C++ performance, not sports):
|
|||||||
collections: effectiveCollections.length > 0 ? effectiveCollections : undefined,
|
collections: effectiveCollections.length > 0 ? effectiveCollections : undefined,
|
||||||
limit,
|
limit,
|
||||||
minScore,
|
minScore,
|
||||||
|
rerank,
|
||||||
intent,
|
intent,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
161
src/store.ts
161
src/store.ts
@ -1421,6 +1421,12 @@ export async function generateEmbeddings(
|
|||||||
const batches = buildEmbeddingBatches(docsToEmbed, maxDocsPerBatch, maxBatchBytes);
|
const batches = buildEmbeddingBatches(docsToEmbed, maxDocsPerBatch, maxBatchBytes);
|
||||||
|
|
||||||
for (const batchMeta of batches) {
|
for (const batchMeta of batches) {
|
||||||
|
// Abort early if session has been invalidated
|
||||||
|
if (!session.isValid) {
|
||||||
|
console.warn(`⚠ Session expired — skipping remaining document batches`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
const batchDocs = getEmbeddingDocsForBatch(db, batchMeta);
|
const batchDocs = getEmbeddingDocsForBatch(db, batchMeta);
|
||||||
const batchChunks: ChunkItem[] = [];
|
const batchChunks: ChunkItem[] = [];
|
||||||
const batchBytes = batchMeta.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0);
|
const batchBytes = batchMeta.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0);
|
||||||
@ -1434,6 +1440,7 @@ export async function generateEmbeddings(
|
|||||||
undefined, undefined, undefined,
|
undefined, undefined, undefined,
|
||||||
doc.path,
|
doc.path,
|
||||||
options?.chunkStrategy,
|
options?.chunkStrategy,
|
||||||
|
session.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
for (let seq = 0; seq < chunks.length; seq++) {
|
for (let seq = 0; seq < chunks.length; seq++) {
|
||||||
@ -1472,6 +1479,23 @@ export async function generateEmbeddings(
|
|||||||
let batchChunkBytesProcessed = 0;
|
let batchChunkBytesProcessed = 0;
|
||||||
|
|
||||||
for (let batchStart = 0; batchStart < batchChunks.length; batchStart += BATCH_SIZE) {
|
for (let batchStart = 0; batchStart < batchChunks.length; batchStart += BATCH_SIZE) {
|
||||||
|
// Abort early if session has been invalidated (e.g. max duration exceeded)
|
||||||
|
if (!session.isValid) {
|
||||||
|
const remaining = batchChunks.length - batchStart;
|
||||||
|
errors += remaining;
|
||||||
|
console.warn(`⚠ Session expired — skipping ${remaining} remaining chunks`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort early if error rate is too high (>80% of processed chunks failed)
|
||||||
|
const processed = chunksEmbedded + errors;
|
||||||
|
if (processed >= BATCH_SIZE && errors > processed * 0.8) {
|
||||||
|
const remaining = batchChunks.length - batchStart;
|
||||||
|
errors += remaining;
|
||||||
|
console.warn(`⚠ Error rate too high (${errors}/${processed}) — aborting embedding`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
|
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
|
||||||
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
|
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
|
||||||
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
||||||
@ -1491,20 +1515,26 @@ export async function generateEmbeddings(
|
|||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// Batch failed — try individual embeddings as fallback
|
// Batch failed — try individual embeddings as fallback
|
||||||
for (const chunk of chunkBatch) {
|
// But skip if session is already invalid (avoids N doomed retries)
|
||||||
try {
|
if (!session.isValid) {
|
||||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
errors += chunkBatch.length;
|
||||||
const result = await session.embed(text);
|
batchChunkBytesProcessed += chunkBatch.reduce((sum, c) => sum + c.bytes, 0);
|
||||||
if (result) {
|
} else {
|
||||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
for (const chunk of chunkBatch) {
|
||||||
chunksEmbedded++;
|
try {
|
||||||
} else {
|
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||||
|
const result = await session.embed(text);
|
||||||
|
if (result) {
|
||||||
|
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||||
|
chunksEmbedded++;
|
||||||
|
} else {
|
||||||
|
errors++;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
errors++;
|
errors++;
|
||||||
}
|
}
|
||||||
} catch {
|
batchChunkBytesProcessed += chunk.bytes;
|
||||||
errors++;
|
|
||||||
}
|
}
|
||||||
batchChunkBytesProcessed += chunk.bytes;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1684,7 +1714,6 @@ export function handelize(path: string): string {
|
|||||||
|
|
||||||
const result = path
|
const result = path
|
||||||
.replace(/___/g, '/') // Triple underscore becomes folder separator
|
.replace(/___/g, '/') // Triple underscore becomes folder separator
|
||||||
.toLowerCase()
|
|
||||||
.split('/')
|
.split('/')
|
||||||
.map((segment, idx, arr) => {
|
.map((segment, idx, arr) => {
|
||||||
const isLastSegment = idx === arr.length - 1;
|
const isLastSegment = idx === arr.length - 1;
|
||||||
@ -1699,7 +1728,7 @@ export function handelize(path: string): string {
|
|||||||
const nameWithoutExt = ext ? segment.slice(0, -ext.length) : segment;
|
const nameWithoutExt = ext ? segment.slice(0, -ext.length) : segment;
|
||||||
|
|
||||||
const cleanedName = nameWithoutExt
|
const cleanedName = nameWithoutExt
|
||||||
.replace(/[^\p{L}\p{N}$]+/gu, '-') // Keep route marker "$", dash-separate other chars
|
.replace(/[^\p{L}\p{N}.$]+/gu, '-') // Keep letters, numbers, dots, "$"; dash-separate rest
|
||||||
.replace(/^-+|-+$/g, ''); // Remove leading/trailing dashes
|
.replace(/^-+|-+$/g, ''); // Remove leading/trailing dashes
|
||||||
|
|
||||||
return cleanedName + ext;
|
return cleanedName + ext;
|
||||||
@ -2170,6 +2199,7 @@ export async function chunkDocumentByTokens(
|
|||||||
windowTokens: number = CHUNK_WINDOW_TOKENS,
|
windowTokens: number = CHUNK_WINDOW_TOKENS,
|
||||||
filepath?: string,
|
filepath?: string,
|
||||||
chunkStrategy: ChunkStrategy = "regex",
|
chunkStrategy: ChunkStrategy = "regex",
|
||||||
|
signal?: AbortSignal
|
||||||
): Promise<{ text: string; pos: number; tokens: number }[]> {
|
): Promise<{ text: string; pos: number; tokens: number }[]> {
|
||||||
const llm = getDefaultLlamaCpp();
|
const llm = getDefaultLlamaCpp();
|
||||||
|
|
||||||
@ -2188,6 +2218,9 @@ export async function chunkDocumentByTokens(
|
|||||||
const results: { text: string; pos: number; tokens: number }[] = [];
|
const results: { text: string; pos: number; tokens: number }[] = [];
|
||||||
|
|
||||||
for (const chunk of charChunks) {
|
for (const chunk of charChunks) {
|
||||||
|
// Respect abort signal to avoid runaway tokenization
|
||||||
|
if (signal?.aborted) break;
|
||||||
|
|
||||||
const tokens = await llm.tokenize(chunk.text);
|
const tokens = await llm.tokenize(chunk.text);
|
||||||
|
|
||||||
if (tokens.length <= maxTokens) {
|
if (tokens.length <= maxTokens) {
|
||||||
@ -2201,6 +2234,7 @@ export async function chunkDocumentByTokens(
|
|||||||
const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
|
const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
|
||||||
|
|
||||||
for (const subChunk of subChunks) {
|
for (const subChunk of subChunks) {
|
||||||
|
if (signal?.aborted) break;
|
||||||
const subTokens = await llm.tokenize(subChunk.text);
|
const subTokens = await llm.tokenize(subChunk.text);
|
||||||
results.push({
|
results.push({
|
||||||
text: subChunk.text,
|
text: subChunk.text,
|
||||||
@ -2732,20 +2766,46 @@ function sanitizeFTS5Term(term: string): string {
|
|||||||
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a token is a hyphenated compound word (e.g., multi-agent, DEC-0054, gpt-4).
|
||||||
|
* Returns true if the token contains internal hyphens between word/digit characters.
|
||||||
|
*/
|
||||||
|
function isHyphenatedToken(token: string): boolean {
|
||||||
|
return /^[\p{L}\p{N}][\p{L}\p{N}'-]*-[\p{L}\p{N}][\p{L}\p{N}'-]*$/u.test(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sanitize a hyphenated term into an FTS5 phrase by splitting on hyphens
|
||||||
|
* and sanitizing each part. Returns the parts joined by spaces for use
|
||||||
|
* inside FTS5 quotes: "multi agent" matches "multi-agent" in porter tokenizer.
|
||||||
|
*/
|
||||||
|
function sanitizeHyphenatedTerm(term: string): string {
|
||||||
|
return term.split('-').map(t => sanitizeFTS5Term(t)).filter(t => t).join(' ');
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parse lex query syntax into FTS5 query.
|
* Parse lex query syntax into FTS5 query.
|
||||||
*
|
*
|
||||||
* Supports:
|
* Supports:
|
||||||
* - Quoted phrases: "exact phrase" → "exact phrase" (exact match)
|
* - Quoted phrases: "exact phrase" → "exact phrase" (exact match)
|
||||||
* - Negation: -term or -"phrase" → uses FTS5 NOT operator
|
* - Negation: -term or -"phrase" → uses FTS5 NOT operator
|
||||||
|
* - Hyphenated tokens: multi-agent, DEC-0054, gpt-4 → treated as phrases
|
||||||
* - Plain terms: term → "term"* (prefix match)
|
* - Plain terms: term → "term"* (prefix match)
|
||||||
*
|
*
|
||||||
* FTS5 NOT is a binary operator: `term1 NOT term2` means "match term1 but not term2".
|
* FTS5 NOT is a binary operator: `term1 NOT term2` means "match term1 but not term2".
|
||||||
* So `-term` only works when there are also positive terms.
|
* So `-term` only works when there are also positive terms.
|
||||||
*
|
*
|
||||||
|
* Hyphen disambiguation: `-sports` at a word boundary is negation, but `multi-agent`
|
||||||
|
* (where `-` is between word characters) is treated as a hyphenated phrase.
|
||||||
|
* When a leading `-` is followed by what looks like a hyphenated compound word
|
||||||
|
* (e.g., `-multi-agent`), the entire token is treated as a negated phrase.
|
||||||
|
*
|
||||||
* Examples:
|
* Examples:
|
||||||
* performance -sports → "performance"* NOT "sports"*
|
* performance -sports → "performance"* NOT "sports"*
|
||||||
* "machine learning" → "machine learning"
|
* "machine learning" → "machine learning"
|
||||||
|
* multi-agent memory → "multi agent" AND "memory"*
|
||||||
|
* DEC-0054 → "dec 0054"
|
||||||
|
* -multi-agent → NOT "multi agent"
|
||||||
*/
|
*/
|
||||||
function buildFTS5Query(query: string): string | null {
|
function buildFTS5Query(query: string): string | null {
|
||||||
const positive: string[] = [];
|
const positive: string[] = [];
|
||||||
@ -2787,13 +2847,27 @@ function buildFTS5Query(query: string): string | null {
|
|||||||
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
||||||
const term = s.slice(start, i);
|
const term = s.slice(start, i);
|
||||||
|
|
||||||
const sanitized = sanitizeFTS5Term(term);
|
// Handle hyphenated tokens: multi-agent, DEC-0054, gpt-4
|
||||||
if (sanitized) {
|
// These get split into phrase queries so FTS5 porter tokenizer matches them.
|
||||||
const ftsTerm = `"${sanitized}"*`; // Prefix match
|
if (isHyphenatedToken(term)) {
|
||||||
if (negated) {
|
const sanitized = sanitizeHyphenatedTerm(term);
|
||||||
negative.push(ftsTerm);
|
if (sanitized) {
|
||||||
} else {
|
const ftsPhrase = `"${sanitized}"`; // Phrase match (no prefix)
|
||||||
positive.push(ftsTerm);
|
if (negated) {
|
||||||
|
negative.push(ftsPhrase);
|
||||||
|
} else {
|
||||||
|
positive.push(ftsPhrase);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const sanitized = sanitizeFTS5Term(term);
|
||||||
|
if (sanitized) {
|
||||||
|
const ftsTerm = `"${sanitized}"*`; // Prefix match
|
||||||
|
if (negated) {
|
||||||
|
negative.push(ftsTerm);
|
||||||
|
} else {
|
||||||
|
positive.push(ftsTerm);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2842,20 +2916,38 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
|
|||||||
const ftsQuery = buildFTS5Query(query);
|
const ftsQuery = buildFTS5Query(query);
|
||||||
if (!ftsQuery) return [];
|
if (!ftsQuery) return [];
|
||||||
|
|
||||||
|
// Use a CTE to force FTS5 to run first, then filter by collection.
|
||||||
|
// Without the CTE, SQLite's query planner combines FTS5 MATCH with the
|
||||||
|
// collection filter in a single WHERE clause, which can cause it to
|
||||||
|
// abandon the FTS5 index and fall back to a full scan — turning an 8ms
|
||||||
|
// query into a 17-second query on large collections.
|
||||||
|
const params: (string | number)[] = [ftsQuery];
|
||||||
|
|
||||||
|
// When filtering by collection, fetch extra candidates from the FTS index
|
||||||
|
// since some will be filtered out. Without a collection filter we can
|
||||||
|
// fetch exactly the requested limit.
|
||||||
|
const ftsLimit = collectionName ? limit * 10 : limit;
|
||||||
|
|
||||||
let sql = `
|
let sql = `
|
||||||
|
WITH fts_matches AS (
|
||||||
|
SELECT rowid, bm25(documents_fts, 1.5, 4.0, 1.0) as bm25_score
|
||||||
|
FROM documents_fts
|
||||||
|
WHERE documents_fts MATCH ?
|
||||||
|
ORDER BY bm25_score ASC
|
||||||
|
LIMIT ${ftsLimit}
|
||||||
|
)
|
||||||
SELECT
|
SELECT
|
||||||
'qmd://' || d.collection || '/' || d.path as filepath,
|
'qmd://' || d.collection || '/' || d.path as filepath,
|
||||||
d.collection || '/' || d.path as display_path,
|
d.collection || '/' || d.path as display_path,
|
||||||
d.title,
|
d.title,
|
||||||
content.doc as body,
|
content.doc as body,
|
||||||
d.hash,
|
d.hash,
|
||||||
bm25(documents_fts, 10.0, 1.0) as bm25_score
|
fm.bm25_score
|
||||||
FROM documents_fts f
|
FROM fts_matches fm
|
||||||
JOIN documents d ON d.id = f.rowid
|
JOIN documents d ON d.id = fm.rowid
|
||||||
JOIN content ON content.hash = d.hash
|
JOIN content ON content.hash = d.hash
|
||||||
WHERE documents_fts MATCH ? AND d.active = 1
|
WHERE d.active = 1
|
||||||
`;
|
`;
|
||||||
const params: (string | number)[] = [ftsQuery];
|
|
||||||
|
|
||||||
if (collectionName) {
|
if (collectionName) {
|
||||||
sql += ` AND d.collection = ?`;
|
sql += ` AND d.collection = ?`;
|
||||||
@ -2863,7 +2955,7 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
|
|||||||
}
|
}
|
||||||
|
|
||||||
// bm25 lower is better; sort ascending.
|
// bm25 lower is better; sort ascending.
|
||||||
sql += ` ORDER BY bm25_score ASC LIMIT ?`;
|
sql += ` ORDER BY fm.bm25_score ASC LIMIT ?`;
|
||||||
params.push(limit);
|
params.push(limit);
|
||||||
|
|
||||||
const rows = db.prepare(sql).all(...params) as { filepath: string; display_path: string; title: string; body: string; hash: string; bm25_score: number }[];
|
const rows = db.prepare(sql).all(...params) as { filepath: string; display_path: string; title: string; body: string; hash: string; bm25_score: number }[];
|
||||||
@ -3021,6 +3113,12 @@ export function clearAllEmbeddings(db: Database): void {
|
|||||||
/**
|
/**
|
||||||
* Insert a single embedding into both content_vectors and vectors_vec tables.
|
* Insert a single embedding into both content_vectors and vectors_vec tables.
|
||||||
* The hash_seq key is formatted as "hash_seq" for the vectors_vec table.
|
* The hash_seq key is formatted as "hash_seq" for the vectors_vec table.
|
||||||
|
*
|
||||||
|
* content_vectors is inserted first so that getHashesForEmbedding (which checks
|
||||||
|
* only content_vectors) won't re-select the hash on a crash between the two inserts.
|
||||||
|
*
|
||||||
|
* vectors_vec uses DELETE + INSERT instead of INSERT OR REPLACE because sqlite-vec's
|
||||||
|
* vec0 virtual tables silently ignore the OR REPLACE conflict clause.
|
||||||
*/
|
*/
|
||||||
export function insertEmbedding(
|
export function insertEmbedding(
|
||||||
db: Database,
|
db: Database,
|
||||||
@ -3032,11 +3130,16 @@ export function insertEmbedding(
|
|||||||
embeddedAt: string
|
embeddedAt: string
|
||||||
): void {
|
): void {
|
||||||
const hashSeq = `${hash}_${seq}`;
|
const hashSeq = `${hash}_${seq}`;
|
||||||
const insertVecStmt = db.prepare(`INSERT OR REPLACE INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`);
|
|
||||||
const insertContentVectorStmt = db.prepare(`INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, ?, ?, ?, ?)`);
|
|
||||||
|
|
||||||
insertVecStmt.run(hashSeq, embedding);
|
// Insert content_vectors first — crash-safe ordering (see getHashesForEmbedding)
|
||||||
|
const insertContentVectorStmt = db.prepare(`INSERT OR REPLACE INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, ?, ?, ?, ?)`);
|
||||||
insertContentVectorStmt.run(hash, seq, pos, model, embeddedAt);
|
insertContentVectorStmt.run(hash, seq, pos, model, embeddedAt);
|
||||||
|
|
||||||
|
// vec0 virtual tables don't support OR REPLACE — use DELETE + INSERT
|
||||||
|
const deleteVecStmt = db.prepare(`DELETE FROM vectors_vec WHERE hash_seq = ?`);
|
||||||
|
const insertVecStmt = db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`);
|
||||||
|
deleteVecStmt.run(hashSeq);
|
||||||
|
insertVecStmt.run(hashSeq, embedding);
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|||||||
25
test/eval-deep-research.jsonl
Normal file
25
test/eval-deep-research.jsonl
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{"query": "that tradeoff between data correctness and always being up", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems architecture", "notes": "CAP theorem - no keywords match"}
|
||||||
|
{"query": "what we learned from the dashboard thing", "expected_doc": "product-launch", "difficulty": "hard", "intent": "project retrospectives", "notes": "Project Phoenix retrospective - vague reference"}
|
||||||
|
{"query": "how much we're burning through each month", "expected_doc": "fundraising", "difficulty": "hard", "intent": "startup finances", "notes": "burn rate - colloquial phrasing"}
|
||||||
|
{"query": "when do I need to be online", "expected_doc": "remote-work", "difficulty": "hard", "intent": "work schedule policies", "notes": "core hours policy - no exact terms"}
|
||||||
|
{"query": "that algorithm for getting nodes to agree", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed consensus", "notes": "consensus/Raft/Paxos - conceptual reference"}
|
||||||
|
{"query": "why we pushed back the release date", "expected_doc": "product-launch", "difficulty": "hard", "intent": "project timeline decisions", "notes": "timeline pressure - implied from retrospective"}
|
||||||
|
{"query": "how to structure URLs for our service", "expected_doc": "api-design", "difficulty": "hard", "intent": "API design patterns", "notes": "REST endpoints - no exact match"}
|
||||||
|
{"query": "preventing the model from just memorizing", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "ML model training", "notes": "overfitting - conceptual synonym"}
|
||||||
|
{"query": "who we're pitching to first", "expected_doc": "fundraising", "difficulty": "hard", "intent": "investor outreach strategy", "notes": "tier 1 investors - colloquial"}
|
||||||
|
{"query": "can I work from another country", "expected_doc": "remote-work", "difficulty": "hard", "intent": "remote work eligibility", "notes": "remote eligibility - implied question"}
|
||||||
|
{"query": "how the beta users found problems", "expected_doc": "product-launch", "difficulty": "hard", "intent": "product testing feedback", "notes": "beta program bugs - indirect reference"}
|
||||||
|
{"query": "that thing Leslie Lamport invented", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems history", "notes": "Paxos - person reference only"}
|
||||||
|
{"query": "what happens when the network splits", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "network failure handling", "notes": "partition tolerance - rephrased concept"}
|
||||||
|
{"query": "teaching computers to find patterns", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "machine learning fundamentals", "notes": "ML definition - abstract description"}
|
||||||
|
{"query": "how much runway before we're out of cash", "expected_doc": "fundraising", "difficulty": "hard", "intent": "startup financial planning", "notes": "runway months - colloquial finance term"}
|
||||||
|
{"query": "the 47 issues we found before shipping", "expected_doc": "product-launch", "difficulty": "hard", "intent": "pre-launch QA", "notes": "beta bugs - specific number, no keywords"}
|
||||||
|
{"query": "grouping customers by behavior", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "customer analytics", "notes": "clustering/segmentation - conceptual"}
|
||||||
|
{"query": "why URLs should be things not actions", "expected_doc": "api-design", "difficulty": "hard", "intent": "RESTful design principles", "notes": "nouns not verbs - conceptual inversion"}
|
||||||
|
{"query": "what Eric Brewer proved you can't have", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed systems theory", "notes": "CAP theorem - person + concept"}
|
||||||
|
{"query": "how fast the new feature loaded", "expected_doc": "product-launch", "difficulty": "hard", "intent": "performance metrics", "notes": "performance 4.2s - indirect reference"}
|
||||||
|
{"query": "days everyone needs to be in the office", "expected_doc": "remote-work", "difficulty": "hard", "intent": "hybrid work schedule", "notes": "collaboration days - rephrased"}
|
||||||
|
{"query": "the number that shows customers are expanding", "expected_doc": "fundraising", "difficulty": "hard", "intent": "SaaS growth metrics", "notes": "NRR 124% - metric description"}
|
||||||
|
{"query": "telling spam from real email", "expected_doc": "machine-learning", "difficulty": "hard", "intent": "classification use cases", "notes": "classification example - specific use case"}
|
||||||
|
{"query": "how to get user 123's purchases", "expected_doc": "api-design", "difficulty": "hard", "intent": "API endpoint design", "notes": "hierarchical URLs - example-based query"}
|
||||||
|
{"query": "zookeeper etcd consul what they have in common", "expected_doc": "distributed-systems", "difficulty": "hard", "intent": "distributed coordination tools", "notes": "CP systems - asking about category"}
|
||||||
209
test/eval-deep-research.ts
Normal file
209
test/eval-deep-research.ts
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
/**
|
||||||
|
* Deep Research Evaluation for QMD
|
||||||
|
*
|
||||||
|
* Tests end-to-end retrieval quality: query → expansion → reranking → results
|
||||||
|
*
|
||||||
|
* These are HARD queries with NO exact keyword matches - they require
|
||||||
|
* semantic understanding via query expansion and reranking to succeed.
|
||||||
|
*
|
||||||
|
* Run: bun test/eval-deep-research.ts
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { execSync } from "child_process";
|
||||||
|
import { readFileSync, existsSync } from "fs";
|
||||||
|
import { join, dirname } from "path";
|
||||||
|
import { fileURLToPath } from "url";
|
||||||
|
|
||||||
|
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||||
|
|
||||||
|
interface EvalQuery {
|
||||||
|
query: string;
|
||||||
|
expected_doc: string;
|
||||||
|
difficulty: string;
|
||||||
|
intent: string; // Domain context hint for future intent-aware retrieval
|
||||||
|
notes: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SearchResult {
|
||||||
|
file: string;
|
||||||
|
score: number;
|
||||||
|
title?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadQueries(): EvalQuery[] {
|
||||||
|
const path = join(__dirname, "eval-deep-research.jsonl");
|
||||||
|
const content = readFileSync(path, "utf-8");
|
||||||
|
return content
|
||||||
|
.split("\n")
|
||||||
|
.filter((line) => line.trim())
|
||||||
|
.map((line) => JSON.parse(line));
|
||||||
|
}
|
||||||
|
|
||||||
|
function runBM25Search(query: string): SearchResult[] {
|
||||||
|
try {
|
||||||
|
const output = execSync(
|
||||||
|
`bun src/qmd.ts search "${query.replace(/"/g, '\\"')}" -c eval-docs --json -n 5 2>/dev/null`,
|
||||||
|
{ encoding: "utf-8", timeout: 30000 }
|
||||||
|
);
|
||||||
|
return JSON.parse(output);
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function runDeepResearch(query: string): SearchResult[] {
|
||||||
|
try {
|
||||||
|
const output = execSync(
|
||||||
|
`bun src/qmd.ts query "${query.replace(/"/g, '\\"')}" -c eval-docs --json -n 5 2>/dev/null`,
|
||||||
|
{ encoding: "utf-8", timeout: 120000 }
|
||||||
|
);
|
||||||
|
return JSON.parse(output);
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function matchesExpected(filepath: string, expectedDoc: string): boolean {
|
||||||
|
return filepath.toLowerCase().includes(expectedDoc.toLowerCase());
|
||||||
|
}
|
||||||
|
|
||||||
|
function findRank(results: SearchResult[], expectedDoc: string): number {
|
||||||
|
for (let i = 0; i < results.length; i++) {
|
||||||
|
if (matchesExpected(results[i]!.file, expectedDoc)) {
|
||||||
|
return i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1; // Not found
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MethodResults {
|
||||||
|
hit1: number;
|
||||||
|
hit3: number;
|
||||||
|
hit5: number;
|
||||||
|
total: number;
|
||||||
|
details: { query: string; rank: number; expected: string; intent?: string }[];
|
||||||
|
}
|
||||||
|
|
||||||
|
function evaluate(
|
||||||
|
queries: EvalQuery[],
|
||||||
|
searchFn: (q: string) => SearchResult[],
|
||||||
|
label: string
|
||||||
|
): MethodResults {
|
||||||
|
const results: MethodResults = {
|
||||||
|
hit1: 0,
|
||||||
|
hit3: 0,
|
||||||
|
hit5: 0,
|
||||||
|
total: queries.length,
|
||||||
|
details: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
console.log(`\n${"=".repeat(60)}`);
|
||||||
|
console.log(` ${label}`);
|
||||||
|
console.log(`${"=".repeat(60)}\n`);
|
||||||
|
|
||||||
|
for (const { query, expected_doc, intent, notes } of queries) {
|
||||||
|
const searchResults = searchFn(query);
|
||||||
|
const rank = findRank(searchResults, expected_doc);
|
||||||
|
|
||||||
|
results.details.push({ query, rank, expected: expected_doc, intent });
|
||||||
|
|
||||||
|
if (rank === 1) results.hit1++;
|
||||||
|
if (rank >= 1 && rank <= 3) results.hit3++;
|
||||||
|
if (rank >= 1 && rank <= 5) results.hit5++;
|
||||||
|
|
||||||
|
const status =
|
||||||
|
rank === 1 ? "✓" : rank > 0 && rank <= 3 ? `@${rank}` : rank > 0 ? `@${rank}` : "✗";
|
||||||
|
const statusPad = status.padEnd(4);
|
||||||
|
console.log(` ${statusPad} "${query.slice(0, 45).padEnd(45)}" → ${expected_doc}`);
|
||||||
|
if (rank === -1) {
|
||||||
|
console.log(` intent: ${intent} | ${notes}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const hit1Pct = ((results.hit1 / results.total) * 100).toFixed(0);
|
||||||
|
const hit3Pct = ((results.hit3 / results.total) * 100).toFixed(0);
|
||||||
|
const hit5Pct = ((results.hit5 / results.total) * 100).toFixed(0);
|
||||||
|
|
||||||
|
console.log(`\n ${"─".repeat(50)}`);
|
||||||
|
console.log(` Hit@1: ${hit1Pct}% (${results.hit1}/${results.total})`);
|
||||||
|
console.log(` Hit@3: ${hit3Pct}% (${results.hit3}/${results.total})`);
|
||||||
|
console.log(` Hit@5: ${hit5Pct}% (${results.hit5}/${results.total})`);
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
console.log("QMD Deep Research Evaluation");
|
||||||
|
console.log("=".repeat(60));
|
||||||
|
console.log("Testing hard queries that require semantic understanding.");
|
||||||
|
console.log("These have NO exact keyword matches in documents.");
|
||||||
|
|
||||||
|
// Check if eval-docs collection exists
|
||||||
|
try {
|
||||||
|
const status = execSync("bun src/qmd.ts status --json 2>/dev/null", {
|
||||||
|
encoding: "utf-8",
|
||||||
|
});
|
||||||
|
if (!status.includes("eval-docs")) {
|
||||||
|
console.log("\n⚠️ eval-docs collection not found. Run:");
|
||||||
|
console.log(" qmd collection add test/eval-docs --name eval-docs");
|
||||||
|
console.log(" qmd embed");
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
console.log("\n⚠️ Could not check status. Make sure qmd is working.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const queries = loadQueries();
|
||||||
|
console.log(`\nLoaded ${queries.length} hard queries.`);
|
||||||
|
|
||||||
|
// Run BM25 baseline (expected to fail on most)
|
||||||
|
const bm25Results = evaluate(queries, runBM25Search, "BM25 BASELINE (keyword search)");
|
||||||
|
|
||||||
|
// Run deep research (expected to succeed via expansion + reranking)
|
||||||
|
const deepResults = evaluate(queries, runDeepResearch, "DEEP RESEARCH (expansion + reranking)");
|
||||||
|
|
||||||
|
// Comparison
|
||||||
|
console.log(`\n${"=".repeat(60)}`);
|
||||||
|
console.log(" COMPARISON");
|
||||||
|
console.log(`${"=".repeat(60)}`);
|
||||||
|
console.log(`\n Method Hit@1 Hit@3 Hit@5`);
|
||||||
|
console.log(` ${"─".repeat(45)}`);
|
||||||
|
console.log(
|
||||||
|
` BM25 (baseline) ${((bm25Results.hit1 / bm25Results.total) * 100).toFixed(0).padStart(3)}% ${((bm25Results.hit3 / bm25Results.total) * 100).toFixed(0).padStart(3)}% ${((bm25Results.hit5 / bm25Results.total) * 100).toFixed(0).padStart(3)}%`
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
` Deep Research ${((deepResults.hit1 / deepResults.total) * 100).toFixed(0).padStart(3)}% ${((deepResults.hit3 / deepResults.total) * 100).toFixed(0).padStart(3)}% ${((deepResults.hit5 / deepResults.total) * 100).toFixed(0).padStart(3)}%`
|
||||||
|
);
|
||||||
|
|
||||||
|
const improvement = deepResults.hit3 - bm25Results.hit3;
|
||||||
|
console.log(`\n Improvement (Hit@3): +${improvement} queries (${((improvement / bm25Results.total) * 100).toFixed(0)}%)`);
|
||||||
|
|
||||||
|
// Show queries where deep research recovered failures
|
||||||
|
const recovered = deepResults.details.filter(
|
||||||
|
(d) =>
|
||||||
|
d.rank >= 1 &&
|
||||||
|
d.rank <= 3 &&
|
||||||
|
bm25Results.details.find((b) => b.query === d.query)?.rank === -1
|
||||||
|
);
|
||||||
|
|
||||||
|
if (recovered.length > 0) {
|
||||||
|
console.log(`\n Recovered by expansion + reranking (${recovered.length}):`);
|
||||||
|
for (const { query, rank, expected } of recovered.slice(0, 5)) {
|
||||||
|
console.log(` @${rank} "${query.slice(0, 40)}..." → ${expected}`);
|
||||||
|
}
|
||||||
|
if (recovered.length > 5) {
|
||||||
|
console.log(` ... and ${recovered.length - 5} more`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit with error if deep research performs poorly
|
||||||
|
const deepHit3Pct = (deepResults.hit3 / deepResults.total) * 100;
|
||||||
|
if (deepHit3Pct < 60) {
|
||||||
|
console.log(`\n❌ Deep research Hit@3 < 60% (${deepHit3Pct.toFixed(0)}%)`);
|
||||||
|
process.exit(1);
|
||||||
|
} else {
|
||||||
|
console.log(`\n✓ Deep research Hit@3 >= 60% (${deepHit3Pct.toFixed(0)}%)`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
@ -114,14 +114,14 @@ describe("cleanupOrphanedVectors", () => {
|
|||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
describe("handelize", () => {
|
describe("handelize", () => {
|
||||||
test("converts to lowercase", () => {
|
test("preserves original case", () => {
|
||||||
expect(handelize("README.md")).toBe("readme.md");
|
expect(handelize("README.md")).toBe("README.md");
|
||||||
expect(handelize("MyFile.MD")).toBe("myfile.md");
|
expect(handelize("MyFile.MD")).toBe("MyFile.MD");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("preserves folder structure", () => {
|
test("preserves folder structure", () => {
|
||||||
expect(handelize("a/b/c/d.md")).toBe("a/b/c/d.md");
|
expect(handelize("a/b/c/d.md")).toBe("a/b/c/d.md");
|
||||||
expect(handelize("docs/api/README.md")).toBe("docs/api/readme.md");
|
expect(handelize("docs/api/README.md")).toBe("docs/api/README.md");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("replaces non-word characters with dash", () => {
|
test("replaces non-word characters with dash", () => {
|
||||||
@ -151,7 +151,7 @@ describe("handelize", () => {
|
|||||||
test("handles complex real-world meeting notes", () => {
|
test("handles complex real-world meeting notes", () => {
|
||||||
const complexName = "Money Movement Licensing Review - 2025/11/19 10:25 EST - Notes by Gemini.md";
|
const complexName = "Money Movement Licensing Review - 2025/11/19 10:25 EST - Notes by Gemini.md";
|
||||||
const result = handelize(complexName);
|
const result = handelize(complexName);
|
||||||
expect(result).toBe("money-movement-licensing-review-2025-11-19-10-25-est-notes-by-gemini.md");
|
expect(result).toBe("Money-Movement-Licensing-Review-2025-11-19-10-25-EST-Notes-by-Gemini.md");
|
||||||
expect(result).not.toContain(" ");
|
expect(result).not.toContain(" ");
|
||||||
expect(result).not.toContain("/");
|
expect(result).not.toContain("/");
|
||||||
expect(result).not.toContain(":");
|
expect(result).not.toContain(":");
|
||||||
@ -159,7 +159,7 @@ describe("handelize", () => {
|
|||||||
|
|
||||||
test("handles unicode characters", () => {
|
test("handles unicode characters", () => {
|
||||||
expect(handelize("日本語.md")).toBe("日本語.md");
|
expect(handelize("日本語.md")).toBe("日本語.md");
|
||||||
expect(handelize("Зоны и проекты.md")).toBe("зоны-и-проекты.md");
|
expect(handelize("Зоны и проекты.md")).toBe("Зоны-и-проекты.md");
|
||||||
expect(handelize("café-notes.md")).toBe("café-notes.md");
|
expect(handelize("café-notes.md")).toBe("café-notes.md");
|
||||||
expect(handelize("naïve.md")).toBe("naïve.md");
|
expect(handelize("naïve.md")).toBe("naïve.md");
|
||||||
expect(handelize("日本語-notes.md")).toBe("日本語-notes.md");
|
expect(handelize("日本語-notes.md")).toBe("日本語-notes.md");
|
||||||
@ -181,13 +181,13 @@ describe("handelize", () => {
|
|||||||
test("handles dates and times in filenames", () => {
|
test("handles dates and times in filenames", () => {
|
||||||
expect(handelize("meeting-2025-01-15.md")).toBe("meeting-2025-01-15.md");
|
expect(handelize("meeting-2025-01-15.md")).toBe("meeting-2025-01-15.md");
|
||||||
expect(handelize("notes 2025/01/15.md")).toBe("notes-2025/01/15.md");
|
expect(handelize("notes 2025/01/15.md")).toBe("notes-2025/01/15.md");
|
||||||
expect(handelize("call_10:30_AM.md")).toBe("call-10-30-am.md");
|
expect(handelize("call_10:30_AM.md")).toBe("call-10-30-AM.md");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("handles special project naming patterns", () => {
|
test("handles special project naming patterns", () => {
|
||||||
expect(handelize("PROJECT_ABC_v2.0.md")).toBe("project-abc-v2-0.md");
|
expect(handelize("PROJECT_ABC_v2.0.md")).toBe("PROJECT-ABC-v2.0.md");
|
||||||
expect(handelize("[WIP] Feature Request.md")).toBe("wip-feature-request.md");
|
expect(handelize("[WIP] Feature Request.md")).toBe("WIP-Feature-Request.md");
|
||||||
expect(handelize("(DRAFT) Proposal v1.md")).toBe("draft-proposal-v1.md");
|
expect(handelize("(DRAFT) Proposal v1.md")).toBe("DRAFT-Proposal-v1.md");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("handles symbol-only route filenames", () => {
|
test("handles symbol-only route filenames", () => {
|
||||||
|
|||||||
@ -1327,6 +1327,34 @@ describe("FTS Search", () => {
|
|||||||
await cleanupTestDb(store);
|
await cleanupTestDb(store);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("searchFTS title boost outweighs higher body frequency", async () => {
|
||||||
|
const store = await createTestStore();
|
||||||
|
const collectionName = await createTestCollection();
|
||||||
|
|
||||||
|
// Document with "quantum" mentioned in a longer body but NOT in the title
|
||||||
|
await insertTestDocument(store.db, collectionName, {
|
||||||
|
name: "body-only",
|
||||||
|
title: "General Science Notes",
|
||||||
|
body: "This research paper discusses quantum mechanics and the quantum model of computation. The quantum approach offers improvements over classical methods.",
|
||||||
|
displayPath: "test/body-only.md",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Document with "quantum" in the title but a shorter body mention
|
||||||
|
await insertTestDocument(store.db, collectionName, {
|
||||||
|
name: "title-match",
|
||||||
|
title: "Quantum Computing Overview",
|
||||||
|
body: "An introduction to the fundamentals of this emerging computing paradigm.",
|
||||||
|
displayPath: "test/title-match.md",
|
||||||
|
});
|
||||||
|
|
||||||
|
const results = store.searchFTS("quantum", 10);
|
||||||
|
expect(results.length).toBe(2);
|
||||||
|
// Title-match doc should rank higher due to BM25 column weights boosting title
|
||||||
|
expect(results[0]!.displayPath).toBe(`${collectionName}/test/title-match.md`);
|
||||||
|
|
||||||
|
await cleanupTestDb(store);
|
||||||
|
});
|
||||||
|
|
||||||
test("searchFTS respects limit parameter", async () => {
|
test("searchFTS respects limit parameter", async () => {
|
||||||
const store = await createTestStore();
|
const store = await createTestStore();
|
||||||
const collectionName = await createTestCollection();
|
const collectionName = await createTestCollection();
|
||||||
|
|||||||
@ -399,6 +399,14 @@ describe("buildFTS5Query (lex parser)", () => {
|
|||||||
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
return term.replace(/[^\p{L}\p{N}']/gu, '').toLowerCase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isHyphenatedToken(token: string): boolean {
|
||||||
|
return /^[\p{L}\p{N}][\p{L}\p{N}'-]*-[\p{L}\p{N}][\p{L}\p{N}'-]*$/u.test(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sanitizeHyphenatedTerm(term: string): string {
|
||||||
|
return term.split('-').map(t => sanitizeFTS5Term(t)).filter(t => t).join(' ');
|
||||||
|
}
|
||||||
|
|
||||||
function buildFTS5Query(query: string): string | null {
|
function buildFTS5Query(query: string): string | null {
|
||||||
const positive: string[] = [];
|
const positive: string[] = [];
|
||||||
const negative: string[] = [];
|
const negative: string[] = [];
|
||||||
@ -424,8 +432,14 @@ describe("buildFTS5Query (lex parser)", () => {
|
|||||||
const start = i;
|
const start = i;
|
||||||
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
while (i < s.length && !/[\s"]/.test(s[i]!)) i++;
|
||||||
const term = s.slice(start, i);
|
const term = s.slice(start, i);
|
||||||
const sanitized = sanitizeFTS5Term(term);
|
|
||||||
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"*`);
|
if (isHyphenatedToken(term)) {
|
||||||
|
const sanitized = sanitizeHyphenatedTerm(term);
|
||||||
|
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"`);
|
||||||
|
} else {
|
||||||
|
const sanitized = sanitizeFTS5Term(term);
|
||||||
|
if (sanitized) (negated ? negative : positive).push(`"${sanitized}"*`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,4 +502,37 @@ describe("buildFTS5Query (lex parser)", () => {
|
|||||||
test("special chars in terms stripped", () => {
|
test("special chars in terms stripped", () => {
|
||||||
expect(buildFTS5Query("hello!world")).toBe('"helloworld"*');
|
expect(buildFTS5Query("hello!world")).toBe('"helloworld"*');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Hyphenated token tests
|
||||||
|
test("hyphenated term → phrase match", () => {
|
||||||
|
expect(buildFTS5Query("multi-agent")).toBe('"multi agent"');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hyphenated identifier → phrase match", () => {
|
||||||
|
expect(buildFTS5Query("DEC-0054")).toBe('"dec 0054"');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hyphenated model name → phrase match", () => {
|
||||||
|
expect(buildFTS5Query("gpt-4")).toBe('"gpt 4"');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("multi-hyphen term → phrase match", () => {
|
||||||
|
expect(buildFTS5Query("foo-bar-baz")).toBe('"foo bar baz"');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hyphenated term mixed with plain terms", () => {
|
||||||
|
expect(buildFTS5Query("multi-agent memory")).toBe('"multi agent" AND "memory"*');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("negation still works alongside hyphenated terms", () => {
|
||||||
|
expect(buildFTS5Query("multi-agent -sports")).toBe('"multi agent" NOT "sports"*');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("negated hyphenated term", () => {
|
||||||
|
expect(buildFTS5Query("performance -multi-agent")).toBe('"performance"* NOT "multi agent"');
|
||||||
|
});
|
||||||
|
|
||||||
|
test("plain negation still works (not confused with hyphen)", () => {
|
||||||
|
expect(buildFTS5Query("performance -sports")).toBe('"performance"* NOT "sports"*');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user