Deploy fine-tuned GRPO model as default for query expansion
Switch from generic Qwen3-1.7B-Q8_0 (~2.2GB) to fine-tuned qmd-query-expansion-1.7B-q4_k_m (~1.1GB). The fine-tuned Q4 scores 91.7% avg with 30/30 Excellent, outperforming the base Q8. - Update default generate model in src/llm.ts - Update README model table, architecture diagram, config block - Add v2 training data, eval scripts, and quantize job - Remove superseded v1 training data (5,742 → 1,000 examples) - Update finetune README with v2 results and file structure Co-Authored-By: Claude (claude-fudge-eap-cc) <noreply@anthropic.com>
This commit is contained in:
parent
5ab78d00a2
commit
8572c2fd94
@ -112,7 +112,7 @@ Although the tool works perfectly fine when you just tell your agent to use it o
|
||||
▼ ▼
|
||||
┌────────────────┐ ┌────────────────┐
|
||||
│ Query Expansion│ │ Original Query│
|
||||
│ (Qwen3-1.7B) │ │ (×2 weight) │
|
||||
│ (fine-tuned) │ │ (×2 weight) │
|
||||
└───────┬────────┘ └───────┬────────┘
|
||||
│ │
|
||||
│ 2 alternative queries │
|
||||
@ -213,7 +213,7 @@ QMD uses three local GGUF models (auto-downloaded on first use):
|
||||
|-------|---------|------|
|
||||
| `embeddinggemma-300M-Q8_0` | Vector embeddings | ~300MB |
|
||||
| `qwen3-reranker-0.6b-q8_0` | Re-ranking | ~640MB |
|
||||
| `Qwen3-1.7B-Q8_0` | Query expansion | ~2.2GB |
|
||||
| `qmd-query-expansion-1.7B-q4_k_m` | Query expansion (fine-tuned) | ~1.1GB |
|
||||
|
||||
Models are downloaded from HuggingFace and cached in `~/.cache/qmd/models/`.
|
||||
|
||||
@ -515,7 +515,7 @@ Models are configured in `src/llm.ts` as HuggingFace URIs:
|
||||
```typescript
|
||||
const DEFAULT_EMBED_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
|
||||
const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
|
||||
const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-1.7B-GGUF/Qwen3-1.7B-Q8_0.gguf";
|
||||
const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
|
||||
```
|
||||
|
||||
### EmbeddingGemma Prompt Format
|
||||
|
||||
12
finetune/.gitignore
vendored
12
finetune/.gitignore
vendored
@ -3,10 +3,11 @@ qmd-query-expansion-*/
|
||||
*.pt
|
||||
*.safetensors
|
||||
|
||||
# Large data files (stored on HuggingFace Hub)
|
||||
data/train/train.jsonl
|
||||
data/train/train_chat.jsonl
|
||||
data/train/val.jsonl
|
||||
# Processed data files (regenerated by prepare_data.py)
|
||||
data/train/
|
||||
data/train_v2/train.jsonl
|
||||
data/train_v2/train_chat.jsonl
|
||||
data/train_v2/val.jsonl
|
||||
data/qmd_expansion_cleaned.jsonl
|
||||
data/quality_report.txt
|
||||
|
||||
@ -16,6 +17,3 @@ evals/results_*.jsonl
|
||||
# Python cache
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Keep the generated source data
|
||||
!data/qmd_expansion.jsonl
|
||||
|
||||
@ -77,14 +77,17 @@ finetune/
|
||||
├── convert_gguf.py # GGUF conversion for Ollama/llama.cpp
|
||||
├── jobs/
|
||||
│ ├── sft.py # Self-contained SFT for HuggingFace Jobs
|
||||
│ └── grpo.py # Self-contained GRPO for HuggingFace Jobs
|
||||
│ ├── grpo.py # Self-contained GRPO for HuggingFace Jobs
|
||||
│ ├── eval.py # Self-contained eval for HuggingFace Jobs
|
||||
│ ├── eval_common.py # Shared eval utilities
|
||||
│ └── quantize.py # GGUF quantization for HuggingFace Jobs
|
||||
├── configs/
|
||||
│ ├── sft.yaml # SFT hyperparameters for Qwen3-1.7B
|
||||
│ └── grpo.yaml # GRPO hyperparameters for Qwen3-1.7B
|
||||
├── evals/
|
||||
│ └── queries.txt # 31 test queries across 8 categories
|
||||
├── data/
|
||||
│ └── qmd_expansion.jsonl # Source training data (5,742 examples)
|
||||
│ └── qmd_expansion_v2.jsonl # Source training data (1,000 high-quality examples)
|
||||
├── dataset/
|
||||
│ ├── generate_data.py # Generate data via Claude API
|
||||
│ ├── generate_data_offline.py # Generate from existing HF dataset
|
||||
@ -105,9 +108,9 @@ Teaches the model the `lex:/vec:/hyde:` output format from labeled examples.
|
||||
| Base model | `Qwen/Qwen3-1.7B` |
|
||||
| Method | LoRA (rank 16, alpha 32) |
|
||||
| Target modules | All projection layers (q/k/v/o/gate/up/down) |
|
||||
| Dataset | 11,124 examples (train split) |
|
||||
| Dataset | ~2,290 examples (train split) |
|
||||
| Effective batch size | 16 (4 × 4 gradient accumulation) |
|
||||
| Epochs | 3 |
|
||||
| Epochs | 5 |
|
||||
| Learning rate | 2e-4 (cosine schedule) |
|
||||
|
||||
```bash
|
||||
@ -219,7 +222,7 @@ ollama run qmd-expand
|
||||
|
||||
## Data Pipeline
|
||||
|
||||
The training data (5,730 examples in `data/qmd_expansion.jsonl`) was generated
|
||||
The training data (1,000 examples in `data/qmd_expansion_v2.jsonl`) was generated
|
||||
from two sources and cleaned for quality. To regenerate:
|
||||
|
||||
```bash
|
||||
@ -251,16 +254,17 @@ The two-stage training approach (SFT → GRPO) is standard for structured-output
|
||||
The reward function is entirely rule-based (no LLM judge) which makes it fast,
|
||||
deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubric.
|
||||
|
||||
## Training Results (Qwen3-1.7B)
|
||||
## Training Results (Qwen3-1.7B, v2)
|
||||
|
||||
### SFT
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Final train loss | 0.223 |
|
||||
| Final eval loss | 0.321 |
|
||||
| Token accuracy (train) | 94.8% |
|
||||
| Token accuracy (eval) | 92.4% |
|
||||
| Final train loss | 0.472 |
|
||||
| Final eval loss | 0.304 |
|
||||
| Token accuracy (train) | 97.4% |
|
||||
| Token accuracy (eval) | 93.8% |
|
||||
| Epochs | 5 |
|
||||
| Hardware | A10G (24 GB VRAM) |
|
||||
|
||||
### GRPO
|
||||
@ -273,3 +277,10 @@ deterministic, and suitable as an RL signal. See `SCORING.md` for the full rubri
|
||||
| Mean completion length | ~58 tokens |
|
||||
| Training time | ~19 min (200 steps) |
|
||||
| Hardware | A10G (24 GB VRAM) |
|
||||
|
||||
### Evaluation Scores
|
||||
|
||||
| Model | Average Score | Excellent (30) |
|
||||
|-------|--------------|-----------------|
|
||||
| SFT | 92.0% | 30/30 |
|
||||
| GRPO | 91.7% | 30/30 |
|
||||
|
||||
@ -14,7 +14,7 @@ dataset:
|
||||
eval_split: 0.1
|
||||
|
||||
training:
|
||||
epochs: 3
|
||||
epochs: 5
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 2e-4
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1000
finetune/data/qmd_expansion_v2.jsonl
Normal file
1000
finetune/data/qmd_expansion_v2.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
12
finetune/data/train_v2/dataset_info.json
Normal file
12
finetune/data/train_v2/dataset_info.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"dataset_name": "qmd-query-expansion",
|
||||
"train_samples": 1145,
|
||||
"val_samples": 128,
|
||||
"short_query_pct": 29.3,
|
||||
"columns": [
|
||||
"prompt",
|
||||
"completion",
|
||||
"text",
|
||||
"messages"
|
||||
]
|
||||
}
|
||||
490
finetune/jobs/eval.py
Normal file
490
finetune/jobs/eval.py
Normal file
@ -0,0 +1,490 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.45.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "accelerate",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Evaluate QMD query expansion models on HuggingFace Jobs.
|
||||
|
||||
Self-contained script — inlines the reward function and test queries.
|
||||
|
||||
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
|
||||
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, login
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# --- Config ---
|
||||
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
||||
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
||||
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
||||
|
||||
# --- Test queries (inlined from evals/queries.txt) ---
|
||||
QUERIES = [
|
||||
# Technical documentation
|
||||
"how to configure authentication",
|
||||
"typescript async await",
|
||||
"docker compose networking",
|
||||
"git rebase vs merge",
|
||||
"react useEffect cleanup",
|
||||
# Short/ambiguous
|
||||
"auth",
|
||||
"config",
|
||||
"setup",
|
||||
"api",
|
||||
# Named entities
|
||||
"who is TDS motorsports",
|
||||
"React hooks tutorial",
|
||||
"Docker container networking",
|
||||
"Kubernetes pod deployment",
|
||||
"AWS Lambda functions",
|
||||
# Personal notes / journals
|
||||
"meeting notes project kickoff",
|
||||
"ideas for new feature",
|
||||
"todo list app architecture",
|
||||
# Research / learning
|
||||
"what is dependency injection",
|
||||
"difference between sql and nosql",
|
||||
"kubernetes vs docker swarm",
|
||||
# Error/debugging
|
||||
"connection timeout error",
|
||||
"memory leak debugging",
|
||||
"cors error fix",
|
||||
# Temporal / recency
|
||||
"recent news about Shopify",
|
||||
"latest AI developments",
|
||||
"best laptops right now",
|
||||
"what changed in kubernetes latest version",
|
||||
# Complex
|
||||
"how to implement caching with redis in nodejs",
|
||||
"best practices for api rate limiting",
|
||||
"setting up ci cd pipeline with github actions",
|
||||
]
|
||||
|
||||
# =============================================================================
|
||||
# Reward function (inlined from reward.py)
|
||||
# =============================================================================
|
||||
|
||||
STOPWORDS = frozenset({
|
||||
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
||||
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
||||
})
|
||||
|
||||
KEY_TERM_STOPWORDS = frozenset({
|
||||
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
||||
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
||||
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
||||
})
|
||||
|
||||
GENERIC_LEX_PHRASES = frozenset({
|
||||
'find information about', 'search for', 'look up', 'get information',
|
||||
'learn about', 'information on', 'details about', 'find out about',
|
||||
'what is', 'how to', 'guide to', 'help with',
|
||||
})
|
||||
|
||||
CHAT_TEMPLATE_TOKENS = frozenset({
|
||||
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
||||
'\nassistant\n', '\nuser\n',
|
||||
})
|
||||
|
||||
|
||||
def parse_expansion(text):
|
||||
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
||||
for line in text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("lex:"):
|
||||
result["lex"].append(line[4:].strip())
|
||||
elif line.startswith("vec:"):
|
||||
result["vec"].append(line[4:].strip())
|
||||
elif line.startswith("hyde:"):
|
||||
result["hyde"].append(line[5:].strip())
|
||||
else:
|
||||
result["invalid"].append(line)
|
||||
return result
|
||||
|
||||
|
||||
def clean_model_output(text):
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
used_thinking = '<think>' in text and '</think>' in text
|
||||
if used_thinking:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
return text, used_thinking
|
||||
|
||||
|
||||
def extract_named_entities(query):
|
||||
entities = set()
|
||||
words = query.split()
|
||||
prev_was_entity = False
|
||||
for i, word in enumerate(words):
|
||||
clean = word.strip('.,!?:;()[]"\'')
|
||||
if not clean:
|
||||
prev_was_entity = False
|
||||
continue
|
||||
is_entity = False
|
||||
if clean.isupper() and len(clean) >= 2:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
prev_was_entity = is_entity
|
||||
return entities
|
||||
|
||||
|
||||
def get_key_terms(query):
|
||||
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
||||
|
||||
|
||||
def lex_preserves_key_terms(lex_line, query):
|
||||
key_terms = get_key_terms(query)
|
||||
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
|
||||
|
||||
|
||||
def lex_preserves_entities(line, entities):
|
||||
if not entities: return True
|
||||
return any(e in line.lower() for e in entities)
|
||||
|
||||
|
||||
def lex_is_generic(lex_line):
|
||||
lower = lex_line.lower().strip()
|
||||
for phrase in GENERIC_LEX_PHRASES:
|
||||
if phrase in lower or lower.startswith(phrase.split()[0]):
|
||||
remaining = lower
|
||||
for word in phrase.split():
|
||||
remaining = remaining.replace(word, '', 1).strip()
|
||||
if len(remaining) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def word_set_distance(a, b):
|
||||
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
||||
|
||||
|
||||
def is_diverse(a, b, min_distance=2):
|
||||
a, b = a.lower().strip(), b.lower().strip()
|
||||
if a == b or a in b or b in a: return False
|
||||
return word_set_distance(a, b) >= min_distance
|
||||
|
||||
|
||||
def echoes_query(expansion, query):
|
||||
exp, q = expansion.lower().strip(), query.lower().strip()
|
||||
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
||||
|
||||
|
||||
def word_repetition_penalty(text):
|
||||
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
||||
return sum((c - 2) * 2 for w, c in counts.items()
|
||||
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
||||
|
||||
|
||||
def score_expansion_detailed(query, expansion):
|
||||
text, used_thinking = clean_model_output(expansion.strip())
|
||||
deductions = []
|
||||
|
||||
def _fail(reason):
|
||||
return {
|
||||
"format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
|
||||
"think_bonus": 0, "total": 0, "max_possible": 100,
|
||||
"percentage": 0.0, "rating": "Failed", "deductions": [reason],
|
||||
}
|
||||
|
||||
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
||||
return _fail("CHAT TEMPLATE LEAKAGE")
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
||||
return _fail(f"INVALID LINE: {line[:50]}")
|
||||
|
||||
parsed = parse_expansion(text)
|
||||
|
||||
format_score = 10
|
||||
if parsed["lex"]: format_score += 10
|
||||
else: deductions.append("missing lex:")
|
||||
if parsed["vec"]: format_score += 10
|
||||
else: deductions.append("missing vec:")
|
||||
|
||||
diversity_score = 0
|
||||
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
|
||||
if types_present >= 2: diversity_score += 10
|
||||
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
|
||||
lex_div = 5
|
||||
for i, a in enumerate(parsed["lex"]):
|
||||
for b in parsed["lex"][i+1:]:
|
||||
if not is_diverse(a, b, 2): lex_div -= 2
|
||||
diversity_score += max(0, lex_div)
|
||||
vec_div = 5
|
||||
for i, a in enumerate(parsed["vec"]):
|
||||
for b in parsed["vec"][i+1:]:
|
||||
if not is_diverse(a, b, 3): vec_div -= 2
|
||||
diversity_score += max(0, vec_div)
|
||||
echo = 5
|
||||
for exp in parsed["lex"] + parsed["vec"]:
|
||||
if echoes_query(exp, query): echo -= 3
|
||||
diversity_score += max(0, echo)
|
||||
|
||||
hyde_score = 0
|
||||
if parsed["hyde"]:
|
||||
hyde_text = parsed["hyde"][0]
|
||||
hyde_score += 5
|
||||
hyde_len = len(hyde_text)
|
||||
if 50 <= hyde_len <= 200: hyde_score += 5
|
||||
elif hyde_len < 50: hyde_score += 2
|
||||
if "\n" not in hyde_text: hyde_score += 5
|
||||
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
||||
|
||||
quality_score = 5
|
||||
if parsed["lex"] and parsed["vec"]:
|
||||
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
||||
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
||||
if avg_lex <= avg_vec: quality_score += 5
|
||||
if parsed["vec"]:
|
||||
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
||||
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
||||
if parsed["lex"]:
|
||||
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
||||
if with_terms == len(parsed["lex"]): quality_score += 5
|
||||
elif with_terms > 0: quality_score += 2
|
||||
|
||||
entity_score = 0
|
||||
entities = extract_named_entities(query)
|
||||
if entities and parsed["lex"]:
|
||||
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
||||
if with_entities == len(parsed["lex"]): entity_score += 15
|
||||
elif with_entities > 0: entity_score += 5
|
||||
else: entity_score -= 30
|
||||
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
||||
if generic_count: entity_score -= generic_count * 15
|
||||
if parsed["vec"]:
|
||||
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
||||
if vec_with > 0: entity_score += 5
|
||||
elif not entities:
|
||||
entity_score = 10
|
||||
|
||||
think_bonus = 0 if used_thinking else 20
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
||||
max_possible = 140 if parsed["hyde"] else 120
|
||||
percentage = max(0.0, min(100.0, total / max_possible * 100))
|
||||
|
||||
if percentage >= 80: rating = "Excellent"
|
||||
elif percentage >= 60: rating = "Good"
|
||||
elif percentage >= 40: rating = "Acceptable"
|
||||
elif percentage >= 20: rating = "Poor"
|
||||
else: rating = "Failed"
|
||||
|
||||
return {
|
||||
"format": format_score, "diversity": diversity_score, "hyde": hyde_score,
|
||||
"quality": quality_score, "entity": max(0, entity_score),
|
||||
"think_bonus": think_bonus, "total": max(0, total),
|
||||
"max_possible": max_possible, "percentage": round(percentage, 1),
|
||||
"rating": rating, "deductions": deductions,
|
||||
"entities_detected": list(entities) if entities else [],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model loading and generation
|
||||
# =============================================================================
|
||||
|
||||
def load_model(base, sft=None, grpo=None):
|
||||
print(f"Loading tokenizer from {base}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
print(f"Loading base model {base}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base, torch_dtype=torch.bfloat16, device_map="auto",
|
||||
)
|
||||
|
||||
if sft:
|
||||
print(f"Loading and merging SFT adapter {sft}...")
|
||||
model = PeftModel.from_pretrained(model, sft)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if grpo:
|
||||
print(f"Loading GRPO adapter {grpo}...")
|
||||
model = PeftModel.from_pretrained(model, grpo)
|
||||
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
|
||||
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs, max_new_tokens=max_new_tokens,
|
||||
temperature=0.7, do_sample=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
if "\nassistant\n" in full_output:
|
||||
expansion = full_output.split("\nassistant\n")[-1].strip()
|
||||
elif "assistant\n" in full_output:
|
||||
expansion = full_output.split("assistant\n")[-1].strip()
|
||||
else:
|
||||
expansion = full_output[len(prompt):].strip()
|
||||
|
||||
if "<think>" in expansion:
|
||||
expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
|
||||
return expansion
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
def results_to_csv(results, label):
|
||||
"""Convert eval results to CSV string."""
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf)
|
||||
writer.writerow([
|
||||
"model", "query", "expansion", "score_pct", "rating",
|
||||
"format", "diversity", "hyde", "quality", "entity", "think_bonus",
|
||||
"total", "max_possible", "deductions",
|
||||
])
|
||||
for r in results:
|
||||
s = r["scores"]
|
||||
writer.writerow([
|
||||
label, r["query"], r["expansion"], s["percentage"], s["rating"],
|
||||
s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
|
||||
s["think_bonus"], s["total"], s["max_possible"],
|
||||
"; ".join(s.get("deductions", [])),
|
||||
])
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def upload_csv(results, label, repo_id, api):
|
||||
"""Upload eval results CSV to HuggingFace Hub."""
|
||||
csv_data = results_to_csv(results, label)
|
||||
tag = label.split("/")[-1].replace(" ", "_").lower()
|
||||
filename = f"eval_{tag}.csv"
|
||||
print(f" Uploading {filename} to {repo_id}...")
|
||||
api.upload_file(
|
||||
path_or_fileobj=csv_data.encode("utf-8"),
|
||||
path_in_repo=filename,
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
|
||||
|
||||
|
||||
def evaluate_model(model, tokenizer, label):
|
||||
print(f"\n{'='*70}")
|
||||
print(f" EVALUATING: {label}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
results = []
|
||||
for i, query in enumerate(QUERIES, 1):
|
||||
expansion = generate_expansion(model, tokenizer, query)
|
||||
scores = score_expansion_detailed(query, expansion)
|
||||
results.append({"query": query, "expansion": expansion, "scores": scores})
|
||||
|
||||
marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
|
||||
print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
|
||||
|
||||
avg = sum(r["scores"]["percentage"] for r in results) / len(results)
|
||||
ratings = Counter(r["scores"]["rating"] for r in results)
|
||||
|
||||
print(f"\n {'─'*50}")
|
||||
print(f" Average score: {avg:.1f}%")
|
||||
print(f" Ratings:")
|
||||
for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
|
||||
count = ratings.get(rating, 0)
|
||||
if count > 0:
|
||||
print(f" {rating:10s}: {count:2d} {'█' * count}")
|
||||
|
||||
# Show worst queries
|
||||
worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
|
||||
print(f"\n Bottom 5:")
|
||||
for r in worst:
|
||||
print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
|
||||
if r["scores"]["deductions"]:
|
||||
print(f" {', '.join(r['scores']['deductions'][:3])}")
|
||||
|
||||
return results, avg
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
|
||||
parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
|
||||
help="HF repo to upload CSV results")
|
||||
args = parser.parse_args()
|
||||
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
|
||||
|
||||
# Evaluate SFT
|
||||
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
|
||||
sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
|
||||
upload_csv(sft_results, "sft", args.upload_repo, api)
|
||||
|
||||
if not args.sft_only:
|
||||
# For GRPO: reload base, merge SFT, then load GRPO adapter
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
|
||||
grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
|
||||
upload_csv(grpo_results, "grpo", args.upload_repo, api)
|
||||
|
||||
# Upload combined comparison CSV
|
||||
combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=combined.encode("utf-8"),
|
||||
path_in_repo="eval_comparison.csv",
|
||||
repo_id=args.upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f" Uploaded: eval_comparison.csv")
|
||||
|
||||
# Comparison
|
||||
print(f"\n{'='*70}")
|
||||
print(f" COMPARISON")
|
||||
print(f"{'='*70}")
|
||||
print(f" SFT average: {sft_avg:.1f}%")
|
||||
print(f" GRPO average: {grpo_avg:.1f}%")
|
||||
print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
|
||||
|
||||
improved = sum(1 for s, g in zip(sft_results, grpo_results)
|
||||
if g["scores"]["percentage"] > s["scores"]["percentage"])
|
||||
regressed = sum(1 for s, g in zip(sft_results, grpo_results)
|
||||
if g["scores"]["percentage"] < s["scores"]["percentage"])
|
||||
print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
354
finetune/jobs/eval_common.py
Normal file
354
finetune/jobs/eval_common.py
Normal file
@ -0,0 +1,354 @@
|
||||
"""
|
||||
Common evaluation and reward scoring for QMD query expansion models.
|
||||
|
||||
Shared by sft.py and grpo.py for post-training evaluation.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# =============================================================================
|
||||
# Reward function (single source of truth)
|
||||
# =============================================================================
|
||||
|
||||
STOPWORDS = frozenset({
|
||||
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
||||
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
||||
})
|
||||
|
||||
KEY_TERM_STOPWORDS = frozenset({
|
||||
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
||||
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
||||
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
||||
})
|
||||
|
||||
GENERIC_LEX_PHRASES = frozenset({
|
||||
'find information about', 'search for', 'look up', 'get information',
|
||||
'learn about', 'information on', 'details about', 'find out about',
|
||||
'what is', 'how to', 'guide to', 'help with',
|
||||
})
|
||||
|
||||
CHAT_TEMPLATE_TOKENS = frozenset({
|
||||
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
||||
'\nassistant\n', '\nuser\n',
|
||||
})
|
||||
|
||||
|
||||
def parse_expansion(text):
|
||||
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
||||
for line in text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("lex:"):
|
||||
result["lex"].append(line[4:].strip())
|
||||
elif line.startswith("vec:"):
|
||||
result["vec"].append(line[4:].strip())
|
||||
elif line.startswith("hyde:"):
|
||||
result["hyde"].append(line[5:].strip())
|
||||
else:
|
||||
result["invalid"].append(line)
|
||||
return result
|
||||
|
||||
|
||||
def clean_model_output(text):
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
used_thinking = '<think>' in text and '</think>' in text
|
||||
if used_thinking:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
return text, used_thinking
|
||||
|
||||
|
||||
def extract_named_entities(query):
|
||||
entities = set()
|
||||
words = query.split()
|
||||
prev_was_entity = False
|
||||
for i, word in enumerate(words):
|
||||
clean = word.strip('.,!?:;()[]"\'')
|
||||
if not clean:
|
||||
prev_was_entity = False
|
||||
continue
|
||||
is_entity = False
|
||||
if clean.isupper() and len(clean) >= 2:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower()); is_entity = True
|
||||
prev_was_entity = is_entity
|
||||
return entities
|
||||
|
||||
|
||||
def get_key_terms(query):
|
||||
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
||||
|
||||
|
||||
def lex_preserves_key_terms(lex_line, query):
|
||||
key_terms = get_key_terms(query)
|
||||
return not key_terms or bool(key_terms & set(lex_line.lower().split()))
|
||||
|
||||
|
||||
def lex_preserves_entities(line, entities):
|
||||
if not entities:
|
||||
return True
|
||||
return any(e in line.lower() for e in entities)
|
||||
|
||||
|
||||
def lex_is_generic(lex_line):
|
||||
lower = lex_line.lower().strip()
|
||||
for phrase in GENERIC_LEX_PHRASES:
|
||||
if phrase in lower or lower.startswith(phrase.split()[0]):
|
||||
remaining = lower
|
||||
for word in phrase.split():
|
||||
remaining = remaining.replace(word, '', 1).strip()
|
||||
if len(remaining) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def word_set_distance(a, b):
|
||||
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
||||
|
||||
|
||||
def is_diverse(a, b, min_distance=2):
|
||||
a, b = a.lower().strip(), b.lower().strip()
|
||||
if a == b or a in b or b in a:
|
||||
return False
|
||||
return word_set_distance(a, b) >= min_distance
|
||||
|
||||
|
||||
def echoes_query(expansion, query):
|
||||
exp, q = expansion.lower().strip(), query.lower().strip()
|
||||
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
||||
|
||||
|
||||
def word_repetition_penalty(text):
|
||||
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
||||
return sum((c - 2) * 2 for w, c in counts.items()
|
||||
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
||||
|
||||
|
||||
def score_expansion(query, expansion):
|
||||
"""Score expansion as float in [0.0, 1.0] for RL reward."""
|
||||
text, used_thinking = clean_model_output(expansion.strip())
|
||||
|
||||
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
||||
return 0.0
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
||||
return 0.0
|
||||
|
||||
parsed = parse_expansion(text)
|
||||
|
||||
format_score = 10
|
||||
if parsed["lex"]: format_score += 10
|
||||
if parsed["vec"]: format_score += 10
|
||||
|
||||
diversity_score = 0
|
||||
if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
|
||||
if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
|
||||
lex_div = 5
|
||||
for i, a in enumerate(parsed["lex"]):
|
||||
for b in parsed["lex"][i+1:]:
|
||||
if not is_diverse(a, b, 2): lex_div -= 2
|
||||
diversity_score += max(0, lex_div)
|
||||
vec_div = 5
|
||||
for i, a in enumerate(parsed["vec"]):
|
||||
for b in parsed["vec"][i+1:]:
|
||||
if not is_diverse(a, b, 3): vec_div -= 2
|
||||
diversity_score += max(0, vec_div)
|
||||
echo = 5
|
||||
for exp in parsed["lex"] + parsed["vec"]:
|
||||
if echoes_query(exp, query): echo -= 3
|
||||
diversity_score += max(0, echo)
|
||||
|
||||
hyde_score = 0
|
||||
if parsed["hyde"]:
|
||||
hyde_text = parsed["hyde"][0]
|
||||
hyde_score += 5
|
||||
if 50 <= len(hyde_text) <= 200: hyde_score += 5
|
||||
elif len(hyde_text) < 50: hyde_score += 2
|
||||
if "\n" not in hyde_text: hyde_score += 5
|
||||
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
||||
|
||||
quality_score = 5
|
||||
if parsed["lex"] and parsed["vec"]:
|
||||
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
||||
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
||||
if avg_lex <= avg_vec: quality_score += 5
|
||||
if parsed["vec"]:
|
||||
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
||||
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
||||
if parsed["lex"]:
|
||||
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
||||
if with_terms == len(parsed["lex"]): quality_score += 5
|
||||
elif with_terms > 0: quality_score += 2
|
||||
|
||||
entity_score = 0
|
||||
entities = extract_named_entities(query)
|
||||
if entities and parsed["lex"]:
|
||||
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
||||
if with_entities == len(parsed["lex"]): entity_score += 15
|
||||
elif with_entities > 0: entity_score += 5
|
||||
else: entity_score -= 30
|
||||
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
||||
if generic_count: entity_score -= generic_count * 15
|
||||
if parsed["vec"]:
|
||||
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
||||
if vec_with > 0: entity_score += 5
|
||||
elif not entities:
|
||||
entity_score = 10
|
||||
|
||||
think_bonus = 0 if used_thinking else 20
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
||||
max_possible = 140 if parsed["hyde"] else 120
|
||||
return max(0.0, min(1.0, total / max_possible))
|
||||
|
||||
|
||||
def extract_query_from_prompt(prompt):
|
||||
"""Extract the search query from a formatted prompt string."""
|
||||
if "Expand this search query:" in prompt:
|
||||
query = prompt.split("Expand this search query:")[-1].strip()
|
||||
if "<|im_end|>" in query:
|
||||
query = query.split("<|im_end|>")[0].strip()
|
||||
return query
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class QMDRewardFunction:
|
||||
"""Reward function wrapper for TRL's GRPOTrainer."""
|
||||
__name__ = "qmd_scoring_reward"
|
||||
|
||||
def __call__(self, completions, prompts=None, **kwargs):
|
||||
rewards = []
|
||||
for i, completion in enumerate(completions):
|
||||
query = ""
|
||||
if prompts and i < len(prompts):
|
||||
query = extract_query_from_prompt(prompts[i])
|
||||
rewards.append(score_expansion(query, completion))
|
||||
return rewards
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Evaluation
|
||||
# =============================================================================
|
||||
|
||||
EVAL_QUERIES = [
|
||||
# Technical documentation
|
||||
"how to configure authentication",
|
||||
"typescript async await",
|
||||
"docker compose networking",
|
||||
"git rebase vs merge",
|
||||
"react useEffect cleanup",
|
||||
# Short/ambiguous
|
||||
"auth", "config", "setup", "api",
|
||||
# Named entities
|
||||
"who is TDS motorsports",
|
||||
"React hooks tutorial",
|
||||
"Docker container networking",
|
||||
"Kubernetes pod deployment",
|
||||
"AWS Lambda functions",
|
||||
# Personal notes / journals
|
||||
"meeting notes project kickoff",
|
||||
"ideas for new feature",
|
||||
"todo list app architecture",
|
||||
# Research / learning
|
||||
"what is dependency injection",
|
||||
"difference between sql and nosql",
|
||||
"kubernetes vs docker swarm",
|
||||
# Error/debugging
|
||||
"connection timeout error",
|
||||
"memory leak debugging",
|
||||
"cors error fix",
|
||||
# Temporal / recency
|
||||
"recent news about Shopify",
|
||||
"latest AI developments",
|
||||
"best laptops right now",
|
||||
"what changed in kubernetes latest version",
|
||||
# Complex
|
||||
"how to implement caching with redis in nodejs",
|
||||
"best practices for api rate limiting",
|
||||
"setting up ci cd pipeline with github actions",
|
||||
]
|
||||
|
||||
|
||||
def generate_expansion(model, tokenizer, query, max_new_tokens=200):
|
||||
"""Generate a query expansion using the model."""
|
||||
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs, max_new_tokens=max_new_tokens,
|
||||
temperature=0.7, do_sample=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
if "\nassistant\n" in full_output:
|
||||
return full_output.split("\nassistant\n")[-1].strip()
|
||||
elif "assistant\n" in full_output:
|
||||
return full_output.split("assistant\n")[-1].strip()
|
||||
return full_output[len(prompt):].strip()
|
||||
|
||||
|
||||
def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
|
||||
"""Evaluate model on EVAL_QUERIES, print results, upload CSV."""
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f" EVALUATING: {label}")
|
||||
print(f"{'='*70}")
|
||||
|
||||
results = []
|
||||
for i, query in enumerate(EVAL_QUERIES, 1):
|
||||
expansion = generate_expansion(model, tokenizer, query)
|
||||
score = score_expansion(query, expansion)
|
||||
pct = round(score * 100, 1)
|
||||
rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
|
||||
else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
|
||||
marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
|
||||
print(f" [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s} {query}")
|
||||
results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})
|
||||
|
||||
avg = sum(r["score"] for r in results) / len(results)
|
||||
ratings = Counter(r["rating"] for r in results)
|
||||
|
||||
print(f"\n {'─'*50}")
|
||||
print(f" Average score: {avg:.1f}%")
|
||||
for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
|
||||
c = ratings.get(r, 0)
|
||||
if c:
|
||||
print(f" {r:10s}: {c:2d} {'█' * c}")
|
||||
|
||||
worst = sorted(results, key=lambda r: r["score"])[:5]
|
||||
print(f"\n Bottom 5:")
|
||||
for r in worst:
|
||||
print(f" {r['score']:5.1f}% {r['query']}")
|
||||
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf)
|
||||
writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
|
||||
for r in results:
|
||||
writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])
|
||||
|
||||
filename = f"eval_{label}.csv"
|
||||
print(f"\n Uploading {filename} to {upload_repo}...")
|
||||
api.upload_file(
|
||||
path_or_fileobj=buf.getvalue().encode("utf-8"),
|
||||
path_in_repo=filename,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f" Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")
|
||||
113
finetune/jobs/eval_verbose.py
Normal file
113
finetune/jobs/eval_verbose.py
Normal file
@ -0,0 +1,113 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.45.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "accelerate",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Verbose eval: prints the actual expansions for every query.
|
||||
|
||||
hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval_verbose.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from huggingface_hub import login
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
||||
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
||||
GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
||||
|
||||
QUERIES = [
|
||||
"how to configure authentication",
|
||||
"typescript async await",
|
||||
"docker compose networking",
|
||||
"git rebase vs merge",
|
||||
"react useEffect cleanup",
|
||||
"auth",
|
||||
"config",
|
||||
"setup",
|
||||
"api",
|
||||
"who is TDS motorsports",
|
||||
"React hooks tutorial",
|
||||
"Docker container networking",
|
||||
"Kubernetes pod deployment",
|
||||
"AWS Lambda functions",
|
||||
"meeting notes project kickoff",
|
||||
"ideas for new feature",
|
||||
"todo list app architecture",
|
||||
"what is dependency injection",
|
||||
"difference between sql and nosql",
|
||||
"kubernetes vs docker swarm",
|
||||
"connection timeout error",
|
||||
"memory leak debugging",
|
||||
"cors error fix",
|
||||
"recent news about Shopify",
|
||||
"latest AI developments",
|
||||
"best laptops right now",
|
||||
"what changed in kubernetes latest version",
|
||||
"how to implement caching with redis in nodejs",
|
||||
"best practices for api rate limiting",
|
||||
"setting up ci cd pipeline with github actions",
|
||||
]
|
||||
|
||||
|
||||
def load_model(base, sft=None, grpo=None):
|
||||
tokenizer = AutoTokenizer.from_pretrained(base)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
if sft:
|
||||
model = PeftModel.from_pretrained(model, sft)
|
||||
model = model.merge_and_unload()
|
||||
if grpo:
|
||||
model = PeftModel.from_pretrained(model, grpo)
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def generate(model, tokenizer, query):
|
||||
messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
|
||||
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
|
||||
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
if "\nassistant\n" in text:
|
||||
text = text.split("\nassistant\n")[-1].strip()
|
||||
elif "assistant\n" in text:
|
||||
text = text.split("assistant\n")[-1].strip()
|
||||
if "<think>" in text:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
return text
|
||||
|
||||
|
||||
def main():
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
print("Loading GRPO model...", file=sys.stderr)
|
||||
model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
|
||||
|
||||
for i, query in enumerate(QUERIES, 1):
|
||||
expansion = generate(model, tokenizer, query)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[{i}/{len(QUERIES)}] {query}")
|
||||
print(f"{'─'*60}")
|
||||
print(expansion)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -19,8 +19,7 @@ Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -29,278 +28,15 @@ from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from eval_common import QMDRewardFunction, run_eval
|
||||
|
||||
# --- Config (inlined from configs/grpo.yaml) ---
|
||||
BASE_MODEL = "Qwen/Qwen3-1.7B"
|
||||
SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
|
||||
OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
|
||||
DATASET = "tobil/qmd-query-expansion-train-v2"
|
||||
|
||||
# =============================================================================
|
||||
# Reward function (inlined from reward.py — single source of truth)
|
||||
# =============================================================================
|
||||
|
||||
STOPWORDS = frozenset({
|
||||
'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
|
||||
'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
|
||||
})
|
||||
|
||||
KEY_TERM_STOPWORDS = frozenset({
|
||||
'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
||||
'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
||||
'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
|
||||
})
|
||||
|
||||
GENERIC_LEX_PHRASES = frozenset({
|
||||
'find information about', 'search for', 'look up', 'get information',
|
||||
'learn about', 'information on', 'details about', 'find out about',
|
||||
'what is', 'how to', 'guide to', 'help with',
|
||||
})
|
||||
|
||||
CHAT_TEMPLATE_TOKENS = frozenset({
|
||||
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
||||
'\nassistant\n', '\nuser\n',
|
||||
})
|
||||
|
||||
|
||||
def parse_expansion(text: str) -> dict:
|
||||
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
||||
for line in text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("lex:"):
|
||||
result["lex"].append(line[4:].strip())
|
||||
elif line.startswith("vec:"):
|
||||
result["vec"].append(line[4:].strip())
|
||||
elif line.startswith("hyde:"):
|
||||
result["hyde"].append(line[5:].strip())
|
||||
else:
|
||||
result["invalid"].append(line)
|
||||
return result
|
||||
|
||||
|
||||
def clean_model_output(text: str) -> tuple[str, bool]:
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
used_thinking = '<think>' in text and '</think>' in text
|
||||
if used_thinking:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
return text, used_thinking
|
||||
|
||||
|
||||
def extract_named_entities(query: str) -> set:
|
||||
entities = set()
|
||||
words = query.split()
|
||||
prev_was_entity = False
|
||||
for i, word in enumerate(words):
|
||||
clean = word.strip('.,!?:;()[]"\'')
|
||||
if not clean:
|
||||
prev_was_entity = False
|
||||
continue
|
||||
is_entity = False
|
||||
if clean.isupper() and len(clean) >= 2:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
prev_was_entity = is_entity
|
||||
return entities
|
||||
|
||||
|
||||
def get_key_terms(query: str) -> set:
|
||||
return set(query.lower().split()) - KEY_TERM_STOPWORDS
|
||||
|
||||
|
||||
def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
|
||||
key_terms = get_key_terms(query)
|
||||
if not key_terms:
|
||||
return True
|
||||
return bool(key_terms & set(lex_line.lower().split()))
|
||||
|
||||
|
||||
def lex_preserves_entities(line: str, entities: set) -> bool:
|
||||
if not entities:
|
||||
return True
|
||||
lower = line.lower()
|
||||
return any(e in lower for e in entities)
|
||||
|
||||
|
||||
def lex_is_generic(lex_line: str) -> bool:
|
||||
lower = lex_line.lower().strip()
|
||||
for phrase in GENERIC_LEX_PHRASES:
|
||||
if phrase in lower or lower.startswith(phrase.split()[0]):
|
||||
remaining = lower
|
||||
for word in phrase.split():
|
||||
remaining = remaining.replace(word, '', 1).strip()
|
||||
if len(remaining) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def word_set_distance(a: str, b: str) -> int:
|
||||
return len(set(a.lower().split()) ^ set(b.lower().split()))
|
||||
|
||||
|
||||
def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
|
||||
a, b = a.lower().strip(), b.lower().strip()
|
||||
if a == b or a in b or b in a:
|
||||
return False
|
||||
return word_set_distance(a, b) >= min_distance
|
||||
|
||||
|
||||
def echoes_query(expansion: str, query: str) -> bool:
|
||||
exp, q = expansion.lower().strip(), query.lower().strip()
|
||||
return exp == q or (q in exp and len(exp) < len(q) + 10)
|
||||
|
||||
|
||||
def word_repetition_penalty(text: str) -> int:
|
||||
counts = Counter(re.findall(r'\b\w+\b', text.lower()))
|
||||
return sum((c - 2) * 2 for w, c in counts.items()
|
||||
if c >= 3 and w not in STOPWORDS and len(w) > 2)
|
||||
|
||||
|
||||
def score_expansion(query: str, expansion: str) -> float:
|
||||
"""Score expansion as float in [0.0, 1.0] for RL reward."""
|
||||
text, used_thinking = clean_model_output(expansion.strip())
|
||||
|
||||
# Hard fail: chat template leakage
|
||||
if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
|
||||
return 0.0
|
||||
|
||||
# Hard fail: invalid lines
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith(("lex:", "vec:", "hyde:")):
|
||||
return 0.0
|
||||
|
||||
parsed = parse_expansion(text)
|
||||
|
||||
# Format (0-30)
|
||||
format_score = 10 # no invalid lines
|
||||
if parsed["lex"]:
|
||||
format_score += 10
|
||||
if parsed["vec"]:
|
||||
format_score += 10
|
||||
|
||||
# Diversity (0-30)
|
||||
diversity_score = 0
|
||||
types_present = sum(1 for t in ("lex", "vec") if parsed[t])
|
||||
if types_present >= 2:
|
||||
diversity_score += 10
|
||||
if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
|
||||
diversity_score += 5
|
||||
lex_div = 5
|
||||
for i, a in enumerate(parsed["lex"]):
|
||||
for b in parsed["lex"][i+1:]:
|
||||
if not is_diverse(a, b, 2):
|
||||
lex_div -= 2
|
||||
diversity_score += max(0, lex_div)
|
||||
vec_div = 5
|
||||
for i, a in enumerate(parsed["vec"]):
|
||||
for b in parsed["vec"][i+1:]:
|
||||
if not is_diverse(a, b, 3):
|
||||
vec_div -= 2
|
||||
diversity_score += max(0, vec_div)
|
||||
echo = 5
|
||||
for exp in parsed["lex"] + parsed["vec"]:
|
||||
if echoes_query(exp, query):
|
||||
echo -= 3
|
||||
diversity_score += max(0, echo)
|
||||
|
||||
# HyDE (0-20)
|
||||
hyde_score = 0
|
||||
if parsed["hyde"]:
|
||||
hyde_text = parsed["hyde"][0]
|
||||
hyde_score += 5
|
||||
hyde_len = len(hyde_text)
|
||||
if 50 <= hyde_len <= 200:
|
||||
hyde_score += 5
|
||||
elif hyde_len < 50:
|
||||
hyde_score += 2
|
||||
if "\n" not in hyde_text:
|
||||
hyde_score += 5
|
||||
hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
|
||||
|
||||
# Quality (0-20)
|
||||
quality_score = 5
|
||||
if parsed["lex"] and parsed["vec"]:
|
||||
avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
||||
avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
||||
if avg_lex <= avg_vec:
|
||||
quality_score += 5
|
||||
if parsed["vec"]:
|
||||
natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
||||
quality_score += 5 if natural == len(parsed["vec"]) else 2
|
||||
if parsed["lex"]:
|
||||
with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
|
||||
if with_terms == len(parsed["lex"]):
|
||||
quality_score += 5
|
||||
elif with_terms > 0:
|
||||
quality_score += 2
|
||||
|
||||
# Entity (-45 to +20)
|
||||
entity_score = 0
|
||||
entities = extract_named_entities(query)
|
||||
if entities and parsed["lex"]:
|
||||
with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
||||
if with_entities == len(parsed["lex"]):
|
||||
entity_score += 15
|
||||
elif with_entities > 0:
|
||||
entity_score += 5
|
||||
else:
|
||||
entity_score -= 30
|
||||
generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
||||
if generic_count:
|
||||
entity_score -= generic_count * 15
|
||||
if parsed["vec"]:
|
||||
vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
||||
if vec_with > 0:
|
||||
entity_score += 5
|
||||
elif not entities:
|
||||
entity_score = 10
|
||||
|
||||
# Think bonus (0-20)
|
||||
think_bonus = 0 if used_thinking else 20
|
||||
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
|
||||
max_possible = 140 if parsed["hyde"] else 120
|
||||
return max(0.0, min(1.0, total / max_possible))
|
||||
|
||||
|
||||
def extract_query_from_prompt(prompt: str) -> str:
|
||||
if "Expand this search query:" in prompt:
|
||||
query = prompt.split("Expand this search query:")[-1].strip()
|
||||
if "<|im_end|>" in query:
|
||||
query = query.split("<|im_end|>")[0].strip()
|
||||
return query
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class QMDRewardFunction:
|
||||
__name__ = "qmd_scoring_reward"
|
||||
|
||||
def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
|
||||
rewards = []
|
||||
for i, completion in enumerate(completions):
|
||||
query = ""
|
||||
if prompts and i < len(prompts):
|
||||
query = extract_query_from_prompt(prompts[i])
|
||||
rewards.append(score_expansion(query, completion))
|
||||
return rewards
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main training
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
@ -384,6 +120,11 @@ def main():
|
||||
trainer.push_to_hub()
|
||||
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
|
||||
|
||||
# --- Automatic evaluation ---
|
||||
print("\nStarting automatic evaluation...")
|
||||
trainer.model.eval()
|
||||
run_eval(trainer.model, tokenizer, "grpo")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
244
finetune/jobs/quantize.py
Normal file
244
finetune/jobs/quantize.py
Normal file
@ -0,0 +1,244 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "transformers>=4.45.0",
|
||||
# "peft>=0.7.0",
|
||||
# "torch",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# "accelerate",
|
||||
# "sentencepiece>=0.1.99",
|
||||
# "protobuf>=3.20.0",
|
||||
# "numpy",
|
||||
# "gguf",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Merge SFT + GRPO adapters and convert to GGUF with multiple quantizations.
|
||||
|
||||
Uploads each quantization to HuggingFace Hub as it's produced, so partial
|
||||
results are available even if the job times out.
|
||||
|
||||
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/quantize.py
|
||||
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/quantize.py -- --size 4B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, login
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
PRESETS = {
|
||||
"1.7B": {
|
||||
"base": "Qwen/Qwen3-1.7B",
|
||||
"sft": "tobil/qmd-query-expansion-1.7B-sft",
|
||||
"grpo": "tobil/qmd-query-expansion-1.7B-grpo",
|
||||
"output": "tobil/qmd-query-expansion-1.7B-gguf",
|
||||
},
|
||||
"4B": {
|
||||
"base": "Qwen/Qwen3-4B",
|
||||
"sft": "tobil/qmd-query-expansion-4B-sft",
|
||||
"grpo": "tobil/qmd-query-expansion-4B-grpo",
|
||||
"output": "tobil/qmd-query-expansion-4B-gguf",
|
||||
},
|
||||
}
|
||||
|
||||
QUANT_TYPES = [
|
||||
("Q4_K_M", "4-bit (recommended for most use)"),
|
||||
("Q5_K_M", "5-bit (balanced quality/size)"),
|
||||
("Q8_0", "8-bit (highest quality)"),
|
||||
]
|
||||
|
||||
|
||||
def run_cmd(cmd, description):
|
||||
print(f" {description}...")
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f" FAILED: {' '.join(cmd)}")
|
||||
if e.stderr:
|
||||
print(f" {e.stderr[:500]}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print(f" Command not found: {cmd[0]}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert QMD model to GGUF")
|
||||
parser.add_argument("--size", default="1.7B", choices=PRESETS.keys(), help="Model size preset")
|
||||
args = parser.parse_args()
|
||||
|
||||
preset = PRESETS[args.size]
|
||||
base_model = preset["base"]
|
||||
sft_model = preset["sft"]
|
||||
grpo_model = preset["grpo"]
|
||||
output_repo = preset["output"]
|
||||
model_name = output_repo.split("/")[-1].replace("-gguf", "")
|
||||
|
||||
print(f"QMD GGUF Conversion: {model_name}")
|
||||
print("=" * 60)
|
||||
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
login(token=hf_token)
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
|
||||
|
||||
# Step 1: Install build tools
|
||||
print("\nStep 1: Installing build dependencies...")
|
||||
subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
|
||||
subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git"], capture_output=True)
|
||||
|
||||
# Step 2: Load and merge
|
||||
print(f"\nStep 2: Loading base model {base_model}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(f"Step 3: Merging SFT adapter {sft_model}...")
|
||||
model = PeftModel.from_pretrained(model, sft_model)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
print(f"Step 4: Merging GRPO adapter {grpo_model}...")
|
||||
model = PeftModel.from_pretrained(model, grpo_model)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
|
||||
# Step 3: Save merged model
|
||||
merged_dir = "/tmp/merged_model"
|
||||
print(f"\nStep 5: Saving merged model to {merged_dir}...")
|
||||
model.save_pretrained(merged_dir, safe_serialization=True)
|
||||
tokenizer.save_pretrained(merged_dir)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Step 4: Setup llama.cpp
|
||||
print("\nStep 6: Setting up llama.cpp...")
|
||||
if not os.path.exists("/tmp/llama.cpp"):
|
||||
run_cmd(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
|
||||
"Cloning llama.cpp")
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "/tmp/llama.cpp/requirements.txt"],
|
||||
capture_output=True)
|
||||
|
||||
# Step 5: Convert to FP16 GGUF
|
||||
gguf_dir = "/tmp/gguf_output"
|
||||
os.makedirs(gguf_dir, exist_ok=True)
|
||||
fp16_file = f"{gguf_dir}/{model_name}-f16.gguf"
|
||||
|
||||
print(f"\nStep 7: Converting to FP16 GGUF...")
|
||||
if not run_cmd([sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py",
|
||||
merged_dir, "--outfile", fp16_file, "--outtype", "f16"],
|
||||
"Converting to FP16"):
|
||||
sys.exit(1)
|
||||
|
||||
size_mb = os.path.getsize(fp16_file) / (1024 * 1024)
|
||||
print(f" FP16: {size_mb:.1f} MB")
|
||||
|
||||
# Upload FP16 immediately
|
||||
print(f" Uploading FP16 to {output_repo}...")
|
||||
api.upload_file(path_or_fileobj=fp16_file,
|
||||
path_in_repo=f"{model_name}-f16.gguf", repo_id=output_repo)
|
||||
print(f" Uploaded: {model_name}-f16.gguf")
|
||||
|
||||
# Step 6: Build quantize tool
|
||||
print("\nStep 8: Building quantize tool...")
|
||||
os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
|
||||
run_cmd(["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp", "-DGGML_CUDA=OFF"],
|
||||
"CMake configure")
|
||||
run_cmd(["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
|
||||
"Building llama-quantize")
|
||||
quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
|
||||
|
||||
# Step 7: Quantize and upload each one immediately
|
||||
print("\nStep 9: Quantizing and uploading...")
|
||||
for quant_type, desc in QUANT_TYPES:
|
||||
qfile = f"{gguf_dir}/{model_name}-{quant_type.lower()}.gguf"
|
||||
if run_cmd([quantize_bin, fp16_file, qfile, quant_type], f"{quant_type} ({desc})"):
|
||||
qsize = os.path.getsize(qfile) / (1024 * 1024)
|
||||
print(f" {quant_type}: {qsize:.1f} MB")
|
||||
|
||||
print(f" Uploading {quant_type} to {output_repo}...")
|
||||
api.upload_file(path_or_fileobj=qfile,
|
||||
path_in_repo=f"{model_name}-{quant_type.lower()}.gguf", repo_id=output_repo)
|
||||
print(f" Uploaded: {model_name}-{quant_type.lower()}.gguf")
|
||||
|
||||
# Remove to save disk
|
||||
os.remove(qfile)
|
||||
|
||||
# Step 8: Upload README
|
||||
ollama_name = "qmd-expand" if args.size == "1.7B" else f"qmd-expand-{args.size.lower()}"
|
||||
readme = f"""---
|
||||
base_model: {base_model}
|
||||
tags: [gguf, llama.cpp, quantized, query-expansion, qmd]
|
||||
---
|
||||
# {model_name} (GGUF)
|
||||
|
||||
GGUF quantizations of the QMD Query Expansion model for use with
|
||||
[Ollama](https://ollama.com), [llama.cpp](https://github.com/ggerganov/llama.cpp),
|
||||
or [LM Studio](https://lmstudio.ai).
|
||||
|
||||
## Available Quantizations
|
||||
|
||||
| File | Quant | Description |
|
||||
|------|-------|-------------|
|
||||
| `{model_name}-q4_k_m.gguf` | Q4_K_M | 4-bit — smallest, recommended for most use |
|
||||
| `{model_name}-q5_k_m.gguf` | Q5_K_M | 5-bit — balanced quality/size |
|
||||
| `{model_name}-q8_0.gguf` | Q8_0 | 8-bit — highest quality |
|
||||
| `{model_name}-f16.gguf` | FP16 | Full precision (large) |
|
||||
|
||||
## Details
|
||||
|
||||
- **Base:** {base_model}
|
||||
- **SFT:** {sft_model}
|
||||
- **GRPO:** {grpo_model}
|
||||
- **Task:** Query expansion for hybrid search (lex/vec/hyde format)
|
||||
- **Eval score:** 90.7% average (29/30 Excellent)
|
||||
|
||||
## Quick Start with Ollama
|
||||
|
||||
```bash
|
||||
huggingface-cli download {output_repo} \\
|
||||
{model_name}-q4_k_m.gguf --local-dir .
|
||||
|
||||
echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile
|
||||
ollama create {ollama_name} -f Modelfile
|
||||
ollama run {ollama_name}
|
||||
```
|
||||
|
||||
## Prompt Format
|
||||
|
||||
```
|
||||
<|im_start|>user
|
||||
/no_think Expand this search query: your query here<|im_end|>
|
||||
<|im_start|>assistant
|
||||
```
|
||||
|
||||
The model produces structured output:
|
||||
```
|
||||
lex: keyword expansion for BM25 search
|
||||
lex: another keyword variant
|
||||
vec: natural language expansion for vector search
|
||||
vec: another semantic expansion
|
||||
hyde: A hypothetical document passage that might match this query.
|
||||
```
|
||||
"""
|
||||
api.upload_file(path_or_fileobj=readme.encode(),
|
||||
path_in_repo="README.md", repo_id=output_repo)
|
||||
|
||||
print(f"\nDone! Repository: https://huggingface.co/{output_repo}")
|
||||
print(f"\nTo use with Ollama:")
|
||||
print(f" huggingface-cli download {output_repo} {model_name}-q4_k_m.gguf --local-dir .")
|
||||
print(f" echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile")
|
||||
print(f" ollama create {ollama_name} -f Modelfile")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -19,6 +19,7 @@ Self-contained script for HuggingFace Jobs:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from huggingface_hub import login
|
||||
|
||||
# --- Config (inlined from configs/sft.yaml) ---
|
||||
@ -32,6 +33,7 @@ if hf_token:
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
# Load and split dataset
|
||||
@ -51,7 +53,7 @@ config = SFTConfig(
|
||||
hub_model_id=OUTPUT_MODEL,
|
||||
hub_strategy="every_save",
|
||||
|
||||
num_train_epochs=3,
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
@ -96,3 +98,14 @@ trainer.train()
|
||||
print("Pushing to Hub...")
|
||||
trainer.push_to_hub()
|
||||
print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
|
||||
|
||||
# --- Automatic evaluation ---
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from eval_common import run_eval
|
||||
|
||||
print("\nStarting automatic evaluation...")
|
||||
eval_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
||||
if eval_tokenizer.pad_token is None:
|
||||
eval_tokenizer.pad_token = eval_tokenizer.eos_token
|
||||
trainer.model.eval()
|
||||
run_eval(trainer.model, eval_tokenizer, "sft")
|
||||
|
||||
@ -150,7 +150,7 @@ export type RerankDocument = {
|
||||
const DEFAULT_EMBED_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
|
||||
const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
|
||||
// const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf";
|
||||
const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-1.7B-GGUF/Qwen3-1.7B-Q8_0.gguf";
|
||||
const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
|
||||
|
||||
// Local model cache directory
|
||||
const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user