Make TUI model list dynamic from HuggingFace Hub
- Fetch available qmd-query-expansion models from tobil/ on Hub - Auto-detect model size (0.6B, 1.7B, 4B) and use correct base model - Group models by type (SFT vs GRPO) in menu - Skip GGUF repos in model listing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
891f3262cf
commit
3ea85eff50
119
finetune/tui.py
119
finetune/tui.py
@ -7,6 +7,7 @@
|
||||
# "peft>=0.7.0",
|
||||
# "torch",
|
||||
# "prompt_toolkit>=3.0.0",
|
||||
# "huggingface_hub>=0.20.0",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models.
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from peft import PeftModel
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
# CONFIGURATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Model configs: (name, path, version, sft_base)
|
||||
# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
|
||||
MODELS = {
|
||||
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
|
||||
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
||||
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
|
||||
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
||||
# Base models by size
|
||||
BASE_MODELS = {
|
||||
"0.6B": "Qwen/Qwen3-0.6B",
|
||||
"1.7B": "Qwen/Qwen3-1.7B",
|
||||
"4B": "Qwen/Qwen3-4B",
|
||||
}
|
||||
|
||||
BASE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
def get_model_size(model_id: str) -> str:
|
||||
"""Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
|
||||
match = re.search(r'(\d+\.?\d*B)', model_id)
|
||||
return match.group(1) if match else "0.6B"
|
||||
|
||||
|
||||
def fetch_available_models() -> dict:
|
||||
"""Dynamically fetch available qmd-query-expansion models from Hub."""
|
||||
api = HfApi()
|
||||
models = {}
|
||||
idx = 1
|
||||
|
||||
try:
|
||||
# Search for all qmd-query-expansion models
|
||||
hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
|
||||
|
||||
# Group by size and type (SFT vs GRPO)
|
||||
sft_models = []
|
||||
grpo_models = []
|
||||
|
||||
for m in hub_models:
|
||||
model_id = m.id
|
||||
# Skip GGUF repos
|
||||
if "gguf" in model_id.lower():
|
||||
continue
|
||||
if "grpo" in model_id.lower():
|
||||
grpo_models.append(model_id)
|
||||
elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
|
||||
sft_models.append(model_id)
|
||||
|
||||
# Sort by size (0.6B, 1.7B, 4B)
|
||||
def size_sort_key(m):
|
||||
size = get_model_size(m)
|
||||
return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
|
||||
|
||||
sft_models.sort(key=size_sort_key)
|
||||
grpo_models.sort(key=size_sort_key)
|
||||
|
||||
# Add SFT models
|
||||
for model_id in sft_models:
|
||||
size = get_model_size(model_id)
|
||||
models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
|
||||
idx += 1
|
||||
|
||||
# Add GRPO models (need to find matching SFT base)
|
||||
for model_id in grpo_models:
|
||||
size = get_model_size(model_id)
|
||||
# Find matching SFT model
|
||||
sft_base = None
|
||||
for sft in sft_models:
|
||||
if get_model_size(sft) == size:
|
||||
sft_base = sft
|
||||
break
|
||||
models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
|
||||
idx += 1
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to default models if Hub fetch fails
|
||||
models = {
|
||||
"1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
|
||||
"2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
|
||||
}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Will be populated on startup
|
||||
MODELS = {}
|
||||
|
||||
# v1 used simple format (before proper chat template)
|
||||
PROMPT_TEMPLATE_V1 = """Expand this search query:
|
||||
@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str:
|
||||
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
|
||||
console.print(f"[{DIM}]{'─' * 50}[/]")
|
||||
|
||||
for key, (name, path, version, sft_base) in MODELS.items():
|
||||
for key, model_info in MODELS.items():
|
||||
name, path, version, sft_base = model_info[:4]
|
||||
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
|
||||
sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
|
||||
console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
|
||||
@ -331,12 +401,14 @@ class LoadedModel:
|
||||
version: str # "v1" or "v3" - determines prompt template
|
||||
|
||||
|
||||
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
|
||||
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
|
||||
"""Load model with progress indicator.
|
||||
|
||||
For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
|
||||
merged into the base model, then the GRPO adapter is applied on top.
|
||||
"""
|
||||
base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(spinner_name="dots", style=CYAN),
|
||||
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
|
||||
@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona
|
||||
transient=True,
|
||||
) as progress:
|
||||
task = progress.add_task("tokenizer", total=None)
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
progress.update(task, description="base model")
|
||||
progress.update(task, description=f"base model ({size})")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
base_model,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str:
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def main():
|
||||
global MODELS
|
||||
console.clear()
|
||||
show_banner()
|
||||
|
||||
# Fetch available models from Hub
|
||||
console.print(f"[{DIM}]Fetching available models...[/]")
|
||||
MODELS = fetch_available_models()
|
||||
|
||||
if not MODELS:
|
||||
console.print(f"[{RED}]No models found. Exiting.[/]")
|
||||
return
|
||||
|
||||
# Model selection
|
||||
choice = show_model_menu()
|
||||
if choice not in MODELS:
|
||||
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
|
||||
return
|
||||
|
||||
model_name, model_path, model_version, sft_base = MODELS[choice]
|
||||
model_info = MODELS[choice]
|
||||
model_name, model_path, model_version, sft_base = model_info[:4]
|
||||
model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
|
||||
console.print()
|
||||
|
||||
try:
|
||||
loaded = load_model(model_path, model_name, model_version, sft_base)
|
||||
loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
|
||||
except Exception as e:
|
||||
console.print(f"[{RED}]Failed to load model: {e}[/]")
|
||||
return
|
||||
@ -483,10 +566,12 @@ def main():
|
||||
show_banner()
|
||||
choice = show_model_menu(loaded.path)
|
||||
if choice in MODELS:
|
||||
new_name, new_path, new_version, new_sft_base = MODELS[choice]
|
||||
new_info = MODELS[choice]
|
||||
new_name, new_path, new_version, new_sft_base = new_info[:4]
|
||||
new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
|
||||
if new_path != loaded.path:
|
||||
console.print()
|
||||
loaded = load_model(new_path, new_name, new_version, new_sft_base)
|
||||
loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
|
||||
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
Loading…
Reference in New Issue
Block a user