Make TUI model list dynamic from HuggingFace Hub

- Fetch available qmd-query-expansion models from tobil/ on Hub
- Auto-detect model size (0.6B, 1.7B, 4B) and use correct base model
- Group models by type (SFT vs GRPO) in menu
- Skip GGUF repos in model listing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Tobi Lutke 2026-01-25 17:17:40 -05:00
parent 891f3262cf
commit 3ea85eff50
No known key found for this signature in database

View File

@ -7,6 +7,7 @@
# "peft>=0.7.0",
# "torch",
# "prompt_toolkit>=3.0.0",
# "huggingface_hub>=0.20.0",
# ]
# ///
"""
@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models.
from collections import deque
from dataclasses import dataclass
from typing import Optional
import re
import torch
from huggingface_hub import HfApi
from peft import PeftModel
from prompt_toolkit import prompt
from prompt_toolkit.history import InMemoryHistory
@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════
# Model configs: (name, path, version, sft_base)
# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
MODELS = {
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
# Base models by size
BASE_MODELS = {
"0.6B": "Qwen/Qwen3-0.6B",
"1.7B": "Qwen/Qwen3-1.7B",
"4B": "Qwen/Qwen3-4B",
}
BASE_MODEL = "Qwen/Qwen3-0.6B"
def get_model_size(model_id: str) -> str:
"""Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
match = re.search(r'(\d+\.?\d*B)', model_id)
return match.group(1) if match else "0.6B"
def fetch_available_models() -> dict:
"""Dynamically fetch available qmd-query-expansion models from Hub."""
api = HfApi()
models = {}
idx = 1
try:
# Search for all qmd-query-expansion models
hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
# Group by size and type (SFT vs GRPO)
sft_models = []
grpo_models = []
for m in hub_models:
model_id = m.id
# Skip GGUF repos
if "gguf" in model_id.lower():
continue
if "grpo" in model_id.lower():
grpo_models.append(model_id)
elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
sft_models.append(model_id)
# Sort by size (0.6B, 1.7B, 4B)
def size_sort_key(m):
size = get_model_size(m)
return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
sft_models.sort(key=size_sort_key)
grpo_models.sort(key=size_sort_key)
# Add SFT models
for model_id in sft_models:
size = get_model_size(model_id)
models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
idx += 1
# Add GRPO models (need to find matching SFT base)
for model_id in grpo_models:
size = get_model_size(model_id)
# Find matching SFT model
sft_base = None
for sft in sft_models:
if get_model_size(sft) == size:
sft_base = sft
break
models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
idx += 1
except Exception as e:
# Fallback to default models if Hub fetch fails
models = {
"1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
"2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
}
return models
# Will be populated on startup
MODELS = {}
# v1 used simple format (before proper chat template)
PROMPT_TEMPLATE_V1 = """Expand this search query:
@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str:
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
console.print(f"[{DIM}]{'' * 50}[/]")
for key, (name, path, version, sft_base) in MODELS.items():
for key, model_info in MODELS.items():
name, path, version, sft_base = model_info[:4]
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
@ -331,12 +401,14 @@ class LoadedModel:
version: str # "v1" or "v3" - determines prompt template
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
"""Load model with progress indicator.
For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
merged into the base model, then the GRPO adapter is applied on top.
"""
base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
with Progress(
SpinnerColumn(spinner_name="dots", style=CYAN),
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona
transient=True,
) as progress:
task = progress.add_task("tokenizer", total=None)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
progress.update(task, description="base model")
progress.update(task, description=f"base model ({size})")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
base_model,
torch_dtype=torch.bfloat16,
device_map="auto",
)
@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str:
# ═══════════════════════════════════════════════════════════════════════════════
def main():
global MODELS
console.clear()
show_banner()
# Fetch available models from Hub
console.print(f"[{DIM}]Fetching available models...[/]")
MODELS = fetch_available_models()
if not MODELS:
console.print(f"[{RED}]No models found. Exiting.[/]")
return
# Model selection
choice = show_model_menu()
if choice not in MODELS:
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
return
model_name, model_path, model_version, sft_base = MODELS[choice]
model_info = MODELS[choice]
model_name, model_path, model_version, sft_base = model_info[:4]
model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
console.print()
try:
loaded = load_model(model_path, model_name, model_version, sft_base)
loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
except Exception as e:
console.print(f"[{RED}]Failed to load model: {e}[/]")
return
@ -483,10 +566,12 @@ def main():
show_banner()
choice = show_model_menu(loaded.path)
if choice in MODELS:
new_name, new_path, new_version, new_sft_base = MODELS[choice]
new_info = MODELS[choice]
new_name, new_path, new_version, new_sft_base = new_info[:4]
new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
if new_path != loaded.path:
console.print()
loaded = load_model(new_path, new_name, new_version, new_sft_base)
loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
console.print()
continue