Fix TUI to load GRPO models with SFT base first
GRPO adapters were trained on merged SFT weights, so they need SFT loaded and merged first before applying the GRPO adapter. Updated MODELS config to include sft_base path for GRPO models, and load_model() now handles the SFT -> merge -> GRPO flow. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f96766cce8
commit
2648512b7c
@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
# CONFIGURATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Model configs: (name, path, version, sft_base)
|
||||
# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
|
||||
MODELS = {
|
||||
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
|
||||
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
|
||||
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
|
||||
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
|
||||
"1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
|
||||
"2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
||||
"3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
|
||||
"4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
||||
}
|
||||
|
||||
BASE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str:
|
||||
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
|
||||
console.print(f"[{DIM}]{'─' * 50}[/]")
|
||||
|
||||
for key, (name, path, version) in MODELS.items():
|
||||
for key, (name, path, version, sft_base) in MODELS.items():
|
||||
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
|
||||
console.print(f" {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
|
||||
sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
|
||||
console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
|
||||
console.print(f" [{DIM}]{path}[/]")
|
||||
|
||||
console.print(f"[{DIM}]{'─' * 50}[/]")
|
||||
return prompt(" Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
|
||||
return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
|
||||
|
||||
|
||||
def render_expansion(expansion: str, scores: dict) -> Panel:
|
||||
@ -328,8 +331,12 @@ class LoadedModel:
|
||||
version: str # "v1" or "v3" - determines prompt template
|
||||
|
||||
|
||||
def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
|
||||
"""Load model with progress indicator."""
|
||||
def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
|
||||
"""Load model with progress indicator.
|
||||
|
||||
For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
|
||||
merged into the base model, then the GRPO adapter is applied on top.
|
||||
"""
|
||||
with Progress(
|
||||
SpinnerColumn(spinner_name="dots", style=CYAN),
|
||||
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
|
||||
@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
progress.update(task, description="base model")
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# For GRPO models: load SFT first, merge, then apply GRPO
|
||||
if sft_base:
|
||||
progress.update(task, description="SFT adapter")
|
||||
model = PeftModel.from_pretrained(model, sft_base)
|
||||
progress.update(task, description="merging SFT")
|
||||
model = model.merge_and_unload()
|
||||
|
||||
progress.update(task, description="adapter")
|
||||
model = PeftModel.from_pretrained(base, model_path)
|
||||
model = PeftModel.from_pretrained(model, model_path)
|
||||
model.eval()
|
||||
|
||||
return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
|
||||
@ -428,11 +442,11 @@ def main():
|
||||
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
|
||||
return
|
||||
|
||||
model_name, model_path, model_version = MODELS[choice]
|
||||
model_name, model_path, model_version, sft_base = MODELS[choice]
|
||||
console.print()
|
||||
|
||||
try:
|
||||
loaded = load_model(model_path, model_name, model_version)
|
||||
loaded = load_model(model_path, model_name, model_version, sft_base)
|
||||
except Exception as e:
|
||||
console.print(f"[{RED}]Failed to load model: {e}[/]")
|
||||
return
|
||||
@ -469,10 +483,10 @@ def main():
|
||||
show_banner()
|
||||
choice = show_model_menu(loaded.path)
|
||||
if choice in MODELS:
|
||||
new_name, new_path, new_version = MODELS[choice]
|
||||
new_name, new_path, new_version, new_sft_base = MODELS[choice]
|
||||
if new_path != loaded.path:
|
||||
console.print()
|
||||
loaded = load_model(new_path, new_name, new_version)
|
||||
loaded = load_model(new_path, new_name, new_version, new_sft_base)
|
||||
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
|
||||
console.print()
|
||||
continue
|
||||
|
||||
Loading…
Reference in New Issue
Block a user