Fix TUI to load GRPO models with SFT base first

GRPO adapters were trained on merged SFT weights, so they need SFT
loaded and merged first before applying the GRPO adapter.

Updated MODELS config to include sft_base path for GRPO models,
and load_model() now handles the SFT -> merge -> GRPO flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Tobi Lutke 2026-01-25 00:47:59 -05:00
parent f96766cce8
commit 2648512b7c
No known key found for this signature in database

View File

@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════
# Model configs: (name, path, version, sft_base)
# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
MODELS = {
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
}
BASE_MODEL = "Qwen/Qwen3-0.6B"
@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str:
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
console.print(f"[{DIM}]{'' * 50}[/]")
for key, (name, path, version) in MODELS.items():
for key, (name, path, version, sft_base) in MODELS.items():
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
console.print(f" {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
console.print(f" [{DIM}]{path}[/]")
console.print(f"[{DIM}]{'' * 50}[/]")
return prompt(" Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
def render_expansion(expansion: str, scores: dict) -> Panel:
@ -328,8 +331,12 @@ class LoadedModel:
version: str # "v1" or "v3" - determines prompt template
def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
"""Load model with progress indicator."""
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
"""Load model with progress indicator.
For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
merged into the base model, then the GRPO adapter is applied on top.
"""
with Progress(
SpinnerColumn(spinner_name="dots", style=CYAN),
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
tokenizer.pad_token = tokenizer.eos_token
progress.update(task, description="base model")
base = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# For GRPO models: load SFT first, merge, then apply GRPO
if sft_base:
progress.update(task, description="SFT adapter")
model = PeftModel.from_pretrained(model, sft_base)
progress.update(task, description="merging SFT")
model = model.merge_and_unload()
progress.update(task, description="adapter")
model = PeftModel.from_pretrained(base, model_path)
model = PeftModel.from_pretrained(model, model_path)
model.eval()
return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
@ -428,11 +442,11 @@ def main():
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
return
model_name, model_path, model_version = MODELS[choice]
model_name, model_path, model_version, sft_base = MODELS[choice]
console.print()
try:
loaded = load_model(model_path, model_name, model_version)
loaded = load_model(model_path, model_name, model_version, sft_base)
except Exception as e:
console.print(f"[{RED}]Failed to load model: {e}[/]")
return
@ -469,10 +483,10 @@ def main():
show_banner()
choice = show_model_menu(loaded.path)
if choice in MODELS:
new_name, new_path, new_version = MODELS[choice]
new_name, new_path, new_version, new_sft_base = MODELS[choice]
if new_path != loaded.path:
console.print()
loaded = load_model(new_path, new_name, new_version)
loaded = load_model(new_path, new_name, new_version, new_sft_base)
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
console.print()
continue