diff --git a/finetune/tui.py b/finetune/tui.py index 66e5996..62f579c 100755 --- a/finetune/tui.py +++ b/finetune/tui.py @@ -7,6 +7,7 @@ # "peft>=0.7.0", # "torch", # "prompt_toolkit>=3.0.0", +# "huggingface_hub>=0.20.0", # ] # /// """ @@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models. from collections import deque from dataclasses import dataclass from typing import Optional +import re import torch +from huggingface_hub import HfApi from peft import PeftModel from prompt_toolkit import prompt from prompt_toolkit.history import InMemoryHistory @@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer # CONFIGURATION # ═══════════════════════════════════════════════════════════════════════════════ -# Model configs: (name, path, version, sft_base) -# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO -MODELS = { - "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None), - "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"), - "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None), - "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"), +# Base models by size +BASE_MODELS = { + "0.6B": "Qwen/Qwen3-0.6B", + "1.7B": "Qwen/Qwen3-1.7B", + "4B": "Qwen/Qwen3-4B", } -BASE_MODEL = "Qwen/Qwen3-0.6B" + +def get_model_size(model_id: str) -> str: + """Extract model size from model ID (e.g., '0.6B', '1.7B', '4B').""" + match = re.search(r'(\d+\.?\d*B)', model_id) + return match.group(1) if match else "0.6B" + + +def fetch_available_models() -> dict: + """Dynamically fetch available qmd-query-expansion models from Hub.""" + api = HfApi() + models = {} + idx = 1 + + try: + # Search for all qmd-query-expansion models + hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion")) + + # Group by size and type (SFT vs GRPO) + sft_models = [] + grpo_models = [] + + for m in hub_models: + model_id = m.id + # Skip GGUF repos + if "gguf" in model_id.lower(): + continue + if "grpo" in model_id.lower(): + grpo_models.append(model_id) + elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]): + sft_models.append(model_id) + + # Sort by size (0.6B, 1.7B, 4B) + def size_sort_key(m): + size = get_model_size(m) + return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3) + + sft_models.sort(key=size_sort_key) + grpo_models.sort(key=size_sort_key) + + # Add SFT models + for model_id in sft_models: + size = get_model_size(model_id) + models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size) + idx += 1 + + # Add GRPO models (need to find matching SFT base) + for model_id in grpo_models: + size = get_model_size(model_id) + # Find matching SFT model + sft_base = None + for sft in sft_models: + if get_model_size(sft) == size: + sft_base = sft + break + models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size) + idx += 1 + + except Exception as e: + # Fallback to default models if Hub fetch fails + models = { + "1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"), + "2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"), + } + + return models + + +# Will be populated on startup +MODELS = {} # v1 used simple format (before proper chat template) PROMPT_TEMPLATE_V1 = """Expand this search query: @@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str: console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]") console.print(f"[{DIM}]{'─' * 50}[/]") - for key, (name, path, version, sft_base) in MODELS.items(): + for key, model_info in MODELS.items(): + name, path, version, sft_base = model_info[:4] marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]" sft_note = f" [{DIM}](+SFT)[/]" if sft_base else "" console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]") @@ -331,12 +401,14 @@ class LoadedModel: version: str # "v1" or "v3" - determines prompt template -def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel: +def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel: """Load model with progress indicator. For GRPO models, sft_base must be provided - the SFT adapter is loaded first, merged into the base model, then the GRPO adapter is applied on top. """ + base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"]) + with Progress( SpinnerColumn(spinner_name="dots", style=CYAN), TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"), @@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona transient=True, ) as progress: task = progress.add_task("tokenizer", total=None) - tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) + tokenizer = AutoTokenizer.from_pretrained(base_model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - progress.update(task, description="base model") + progress.update(task, description=f"base model ({size})") model = AutoModelForCausalLM.from_pretrained( - BASE_MODEL, + base_model, torch_dtype=torch.bfloat16, device_map="auto", ) @@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str: # ═══════════════════════════════════════════════════════════════════════════════ def main(): + global MODELS console.clear() show_banner() + # Fetch available models from Hub + console.print(f"[{DIM}]Fetching available models...[/]") + MODELS = fetch_available_models() + + if not MODELS: + console.print(f"[{RED}]No models found. Exiting.[/]") + return + # Model selection choice = show_model_menu() if choice not in MODELS: console.print(f"[{RED}]Invalid choice. Exiting.[/]") return - model_name, model_path, model_version, sft_base = MODELS[choice] + model_info = MODELS[choice] + model_name, model_path, model_version, sft_base = model_info[:4] + model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path) console.print() try: - loaded = load_model(model_path, model_name, model_version, sft_base) + loaded = load_model(model_path, model_name, model_version, sft_base, model_size) except Exception as e: console.print(f"[{RED}]Failed to load model: {e}[/]") return @@ -483,10 +566,12 @@ def main(): show_banner() choice = show_model_menu(loaded.path) if choice in MODELS: - new_name, new_path, new_version, new_sft_base = MODELS[choice] + new_info = MODELS[choice] + new_name, new_path, new_version, new_sft_base = new_info[:4] + new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path) if new_path != loaded.path: console.print() - loaded = load_model(new_path, new_name, new_version, new_sft_base) + loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size) console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]") console.print() continue