From 2648512b7c2ba5aa484ed2a26b371d797dcfc8a1 Mon Sep 17 00:00:00 2001 From: Tobi Lutke Date: Sun, 25 Jan 2026 00:47:59 -0500 Subject: [PATCH] Fix TUI to load GRPO models with SFT base first GRPO adapters were trained on merged SFT weights, so they need SFT loaded and merged first before applying the GRPO adapter. Updated MODELS config to include sft_base path for GRPO models, and load_model() now handles the SFT -> merge -> GRPO flow. Co-Authored-By: Claude Opus 4.5 --- finetune/tui.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/finetune/tui.py b/finetune/tui.py index 5cd75d3..66e5996 100755 --- a/finetune/tui.py +++ b/finetune/tui.py @@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer # CONFIGURATION # ═══════════════════════════════════════════════════════════════════════════════ +# Model configs: (name, path, version, sft_base) +# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO MODELS = { - "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"), - "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"), - "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"), - "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"), + "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None), + "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"), + "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None), + "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"), } BASE_MODEL = "Qwen/Qwen3-0.6B" @@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str: console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]") console.print(f"[{DIM}]{'─' * 50}[/]") - for key, (name, path, version) in MODELS.items(): + for key, (name, path, version, sft_base) in MODELS.items(): marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]" - console.print(f" {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]") + sft_note = f" [{DIM}](+SFT)[/]" if sft_base else "" + console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]") console.print(f" [{DIM}]{path}[/]") console.print(f"[{DIM}]{'─' * 50}[/]") - return prompt(" Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip() + return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip() def render_expansion(expansion: str, scores: dict) -> Panel: @@ -328,8 +331,12 @@ class LoadedModel: version: str # "v1" or "v3" - determines prompt template -def load_model(model_path: str, model_name: str, version: str) -> LoadedModel: - """Load model with progress indicator.""" +def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel: + """Load model with progress indicator. + + For GRPO models, sft_base must be provided - the SFT adapter is loaded first, + merged into the base model, then the GRPO adapter is applied on top. + """ with Progress( SpinnerColumn(spinner_name="dots", style=CYAN), TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"), @@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel: tokenizer.pad_token = tokenizer.eos_token progress.update(task, description="base model") - base = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) + # For GRPO models: load SFT first, merge, then apply GRPO + if sft_base: + progress.update(task, description="SFT adapter") + model = PeftModel.from_pretrained(model, sft_base) + progress.update(task, description="merging SFT") + model = model.merge_and_unload() + progress.update(task, description="adapter") - model = PeftModel.from_pretrained(base, model_path) + model = PeftModel.from_pretrained(model, model_path) model.eval() return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version) @@ -428,11 +442,11 @@ def main(): console.print(f"[{RED}]Invalid choice. Exiting.[/]") return - model_name, model_path, model_version = MODELS[choice] + model_name, model_path, model_version, sft_base = MODELS[choice] console.print() try: - loaded = load_model(model_path, model_name, model_version) + loaded = load_model(model_path, model_name, model_version, sft_base) except Exception as e: console.print(f"[{RED}]Failed to load model: {e}[/]") return @@ -469,10 +483,10 @@ def main(): show_banner() choice = show_model_menu(loaded.path) if choice in MODELS: - new_name, new_path, new_version = MODELS[choice] + new_name, new_path, new_version, new_sft_base = MODELS[choice] if new_path != loaded.path: console.print() - loaded = load_model(new_path, new_name, new_version) + loaded = load_model(new_path, new_name, new_version, new_sft_base) console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]") console.print() continue