686 lines
22 KiB
Python
686 lines
22 KiB
Python
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "torch",
|
|
# "trl>=0.12.0",
|
|
# "peft>=0.7.0",
|
|
# "transformers>=4.45.0",
|
|
# "accelerate>=0.24.0",
|
|
# "huggingface_hub>=0.20.0",
|
|
# "trackio",
|
|
# "nvidia-ml-py",
|
|
# "datasets",
|
|
# "bitsandbytes",
|
|
# "pyyaml",
|
|
# "gguf",
|
|
# ]
|
|
# ///
|
|
"""
|
|
Unified training script for QMD query expansion models.
|
|
|
|
Primary pipeline is SFT-only:
|
|
sft - Supervised fine-tuning on labeled examples
|
|
|
|
GRPO was moved to `experiments/grpo/` and is not part of the main training
|
|
pipeline by default.
|
|
|
|
Usage:
|
|
uv run train.py sft --config configs/sft.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from transformers import TrainerCallback
|
|
|
|
|
|
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
|
"""Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
|
|
import shutil
|
|
import tempfile
|
|
|
|
output_path = Path(output_dir)
|
|
gguf_dir = output_path / "gguf"
|
|
gguf_dir.mkdir(exist_ok=True)
|
|
|
|
# Save merged model to temp dir
|
|
print("Saving merged model for GGUF conversion...")
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
merged_path = Path(tmp) / "merged"
|
|
model.save_pretrained(merged_path, safe_serialization=True)
|
|
tokenizer.save_pretrained(merged_path)
|
|
|
|
# Setup llama.cpp
|
|
llama_cpp = Path("/tmp/llama.cpp")
|
|
if not llama_cpp.exists():
|
|
print("Cloning llama.cpp...")
|
|
subprocess.run(
|
|
[
|
|
"git",
|
|
"clone",
|
|
"--depth",
|
|
"1",
|
|
"https://github.com/ggerganov/llama.cpp.git",
|
|
str(llama_cpp),
|
|
],
|
|
capture_output=True,
|
|
)
|
|
subprocess.run(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"pip",
|
|
"install",
|
|
"-q",
|
|
"-r",
|
|
str(llama_cpp / "requirements.txt"),
|
|
],
|
|
capture_output=True,
|
|
)
|
|
|
|
# Build quantize tool if needed
|
|
quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
|
|
if not quantize_bin.exists():
|
|
print("Building llama-quantize...")
|
|
build_dir = llama_cpp / "build"
|
|
build_dir.mkdir(exist_ok=True)
|
|
subprocess.run(
|
|
[
|
|
"cmake",
|
|
"-B",
|
|
str(build_dir),
|
|
"-S",
|
|
str(llama_cpp),
|
|
"-DGGML_CUDA=OFF",
|
|
],
|
|
capture_output=True,
|
|
)
|
|
subprocess.run(
|
|
[
|
|
"cmake",
|
|
"--build",
|
|
str(build_dir),
|
|
"--target",
|
|
"llama-quantize",
|
|
"-j",
|
|
"4",
|
|
],
|
|
capture_output=True,
|
|
)
|
|
|
|
# Convert to FP16 first
|
|
fp16_file = gguf_dir / f"{model_name}-f16.gguf"
|
|
print(f"Converting to FP16: {fp16_file}")
|
|
log_out = Path("/tmp/qmd-gguf-convert.log")
|
|
log_err = Path("/tmp/qmd-gguf-convert.err")
|
|
with log_out.open("w") as out_f, log_err.open("w") as err_f:
|
|
result = subprocess.run(
|
|
[
|
|
sys.executable,
|
|
str(llama_cpp / "convert_hf_to_gguf.py"),
|
|
str(merged_path),
|
|
"--outfile",
|
|
str(fp16_file),
|
|
"--outtype",
|
|
"f16",
|
|
],
|
|
stdout=out_f,
|
|
stderr=err_f,
|
|
text=True,
|
|
)
|
|
if result.returncode != 0:
|
|
print("GGUF conversion failed.")
|
|
print(f"stdout: {log_out}")
|
|
print(f"stderr: {log_err}")
|
|
return
|
|
|
|
# Quantize to 4, 6, 8 bit
|
|
for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
|
|
out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
|
|
print(f"Quantizing {quant_type}: {out_file}")
|
|
subprocess.run(
|
|
[str(quantize_bin), str(fp16_file), str(out_file), quant_type],
|
|
capture_output=True,
|
|
)
|
|
if out_file.exists():
|
|
size_mb = out_file.stat().st_size / (1024 * 1024)
|
|
print(f" {quant_type}: {size_mb:.1f} MB")
|
|
|
|
# Remove FP16 to save space
|
|
if fp16_file.exists():
|
|
fp16_file.unlink()
|
|
|
|
print(f"GGUF files saved to: {gguf_dir}")
|
|
|
|
|
|
class TimedSaveCallback(TrainerCallback):
|
|
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
|
|
|
|
def __init__(self, interval_minutes: float):
|
|
self.interval_seconds = float(interval_minutes) * 60.0
|
|
self.last_save_time = time.time()
|
|
|
|
def on_step_end(self, args, state, control, **kwargs):
|
|
if not getattr(state, "is_world_process_zero", False):
|
|
return control
|
|
|
|
now = time.time()
|
|
if now - self.last_save_time >= self.interval_seconds:
|
|
control.should_save = True
|
|
self.last_save_time = now
|
|
return control
|
|
|
|
|
|
def run_eval(model_path: str) -> float | None:
|
|
"""Run eval.py on the trained model and return average score."""
|
|
print("\n" + "=" * 60)
|
|
print("Running evaluation...")
|
|
print("=" * 60)
|
|
|
|
eval_script = Path(__file__).parent / "eval.py"
|
|
result = subprocess.run(
|
|
[sys.executable, str(eval_script), model_path],
|
|
cwd=str(Path(__file__).parent),
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if result.stdout:
|
|
print(result.stdout, end="")
|
|
if result.stderr:
|
|
print(result.stderr, end="")
|
|
|
|
avg = None
|
|
for line in (result.stdout or "").splitlines():
|
|
if line.strip().startswith("Average:"):
|
|
try:
|
|
avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
|
|
except ValueError:
|
|
pass
|
|
break
|
|
return avg
|
|
|
|
|
|
def cmd_sft(args):
|
|
"""Run supervised fine-tuning."""
|
|
import torch
|
|
from datasets import load_dataset
|
|
import torch.distributed as dist
|
|
from peft import LoraConfig
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
from transformers.utils import logging as hf_logging
|
|
|
|
hf_logging.set_verbosity_error()
|
|
from trl import SFTTrainer, SFTConfig
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
|
|
|
|
if args.dry_run:
|
|
print("SFT Training Configuration:")
|
|
print(yaml.dump(cfg, default_flow_style=False))
|
|
return
|
|
|
|
dataset_name = cfg["dataset"]["name"]
|
|
print(f"Loading dataset: {dataset_name}...")
|
|
|
|
# Support local JSONL files and glob patterns
|
|
if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
|
|
from pathlib import Path
|
|
import glob
|
|
|
|
# Handle glob patterns like "data/*.jsonl"
|
|
if "*" in dataset_name:
|
|
jsonl_files = sorted(glob.glob(dataset_name))
|
|
if not jsonl_files:
|
|
raise ValueError(f"No files found matching: {dataset_name}")
|
|
print(
|
|
f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
|
|
)
|
|
dataset = load_dataset("json", data_files=jsonl_files, split="train")
|
|
else:
|
|
data_path = Path(dataset_name)
|
|
if data_path.is_dir():
|
|
train_file = data_path / "train.jsonl"
|
|
dataset = load_dataset(
|
|
"json", data_files=str(train_file), split="train"
|
|
)
|
|
else:
|
|
dataset = load_dataset("json", data_files=dataset_name, split="train")
|
|
else:
|
|
dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
|
|
print(f"Dataset loaded: {len(dataset)} examples")
|
|
|
|
dataset = dataset.shuffle(seed=42)
|
|
split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
|
|
train_dataset = split["train"]
|
|
eval_dataset = split["test"]
|
|
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
|
|
|
|
# Check if output looks like a HF Hub path (contains /)
|
|
output_name = cfg["model"]["output"]
|
|
push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
|
|
if "push_to_hub" in cfg["model"]:
|
|
push_to_hub = bool(cfg["model"]["push_to_hub"])
|
|
output_dir = output_name.split("/")[-1] if push_to_hub else output_name
|
|
|
|
report_to = "none"
|
|
if os.environ.get("HF_TOKEN"):
|
|
try:
|
|
import trackio # noqa: F401
|
|
|
|
report_to = "trackio"
|
|
except Exception:
|
|
print("Trackio not installed; disabling tracking.")
|
|
|
|
tracking = cfg.get("tracking", {})
|
|
if report_to == "trackio":
|
|
project = tracking.get("project")
|
|
if project:
|
|
os.environ.setdefault("TRACKIO_PROJECT", project)
|
|
|
|
run_name = tracking.get("run_name")
|
|
if run_name and "{" in run_name:
|
|
from datetime import datetime
|
|
|
|
now = datetime.now()
|
|
run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
|
|
"{time}", now.strftime("%H:%M")
|
|
)
|
|
|
|
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
|
save_steps = cfg["training"].get("save_steps", 200)
|
|
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
|
if save_interval_minutes:
|
|
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
|
save_steps = max(save_steps, 10_000_000)
|
|
|
|
callbacks = []
|
|
if save_interval_minutes:
|
|
try:
|
|
interval_value = float(save_interval_minutes)
|
|
except (TypeError, ValueError):
|
|
interval_value = None
|
|
if interval_value and interval_value > 0:
|
|
callbacks.append(TimedSaveCallback(interval_value))
|
|
|
|
config = SFTConfig(
|
|
output_dir=output_dir,
|
|
push_to_hub=push_to_hub,
|
|
hub_model_id=output_name if push_to_hub else None,
|
|
hub_strategy="every_save" if push_to_hub else "end",
|
|
num_train_epochs=cfg["training"]["epochs"],
|
|
per_device_train_batch_size=cfg["training"]["batch_size"],
|
|
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
|
|
learning_rate=cfg["training"]["learning_rate"],
|
|
max_length=cfg["training"]["max_length"],
|
|
logging_steps=10,
|
|
save_strategy="steps",
|
|
save_steps=save_steps,
|
|
save_total_limit=save_total_limit,
|
|
eval_strategy="steps",
|
|
eval_steps=cfg["training"].get("eval_steps", 200),
|
|
warmup_ratio=cfg["training"]["warmup_ratio"],
|
|
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
|
ddp_find_unused_parameters=cfg["training"].get(
|
|
"ddp_find_unused_parameters", False
|
|
),
|
|
bf16=True,
|
|
report_to=report_to,
|
|
run_name=run_name if report_to == "trackio" else None,
|
|
)
|
|
|
|
# LoRA config with modules_to_save for embedding layers
|
|
# This prevents token ID mismatches during inference
|
|
peft_config = LoraConfig(
|
|
r=cfg["lora"]["rank"],
|
|
lora_alpha=cfg["lora"]["alpha"],
|
|
lora_dropout=cfg["lora"]["dropout"],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
|
ensure_weight_tying=True,
|
|
)
|
|
|
|
print("Loading tokenizer...")
|
|
base_model = cfg["model"]["base"]
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
print("Initializing SFT trainer...")
|
|
trainer = SFTTrainer(
|
|
model=base_model,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
args=config,
|
|
peft_config=peft_config,
|
|
processing_class=tokenizer,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
print("Starting SFT training...")
|
|
trainer.train()
|
|
|
|
is_main = os.environ.get("RANK", "0") == "0"
|
|
if dist.is_available() and dist.is_initialized():
|
|
dist.barrier()
|
|
|
|
if not is_main:
|
|
return
|
|
|
|
if push_to_hub:
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub()
|
|
print(f"Done! Model: https://huggingface.co/{output_name}")
|
|
else:
|
|
trainer.save_model()
|
|
print(f"Done! Model saved to: {output_dir}")
|
|
|
|
# Export GGUF
|
|
print("\nExporting to GGUF...")
|
|
# Need to get the merged model for GGUF
|
|
print("Loading model for GGUF export...")
|
|
from peft import PeftModel
|
|
|
|
base = AutoModelForCausalLM.from_pretrained(
|
|
base_model, torch_dtype=torch.bfloat16, device_map="auto"
|
|
)
|
|
base.config.tie_word_embeddings = False
|
|
model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
|
|
model = model.merge_and_unload()
|
|
export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
|
|
|
|
# Run eval
|
|
eval_avg = run_eval(output_dir)
|
|
if report_to == "trackio":
|
|
try:
|
|
import trackio
|
|
|
|
if eval_avg is not None:
|
|
trackio.log({"eval.avg": eval_avg})
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def cmd_grpo(args):
|
|
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
|
print(
|
|
"GRPO is not part of the main training pipeline and has been moved to `experiments/grpo/`."
|
|
)
|
|
print("To run experimental GRPO, use:")
|
|
print(" cd finetune && uv run python experiments/grpo/grpo.py")
|
|
print("Or, if you have local config wiring ready:")
|
|
print(" uv run train.py grpo --config experiments/grpo/grpo.yaml")
|
|
return
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import os
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.utils import logging as hf_logging
|
|
|
|
hf_logging.set_verbosity_error()
|
|
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
# Import reward from the shared module
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from reward import QMDRewardFunction
|
|
|
|
with open(args.config) as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
|
|
|
|
if args.dry_run:
|
|
print("GRPO Training Configuration:")
|
|
print(yaml.dump(cfg, default_flow_style=False))
|
|
return
|
|
|
|
# Tracking
|
|
report_to = "none"
|
|
if os.environ.get("HF_TOKEN"):
|
|
try:
|
|
import trackio # noqa: F401
|
|
|
|
report_to = "trackio"
|
|
except Exception:
|
|
print("Trackio not installed; disabling tracking.")
|
|
|
|
tracking = cfg.get("tracking", {})
|
|
if report_to == "trackio":
|
|
project = tracking.get("project")
|
|
if project:
|
|
os.environ.setdefault("TRACKIO_PROJECT", project)
|
|
|
|
run_name = tracking.get("run_name")
|
|
if run_name and "{" in run_name:
|
|
from datetime import datetime
|
|
|
|
now = datetime.now()
|
|
run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
|
|
"{time}", now.strftime("%H:%M")
|
|
)
|
|
|
|
# Load tokenizer
|
|
base_model_name = cfg["model"]["base"]
|
|
print(f"Loading tokenizer from {base_model_name}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# Load and format dataset
|
|
print("Loading dataset...")
|
|
dataset = load_dataset(cfg["dataset"]["name"], split="train")
|
|
|
|
def extract_prompt(example):
|
|
content = example[cfg["dataset"]["prompt_field"]][0]["content"]
|
|
messages = [{"role": "user", "content": content}]
|
|
formatted = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
return {"prompt": formatted}
|
|
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
max_samples = cfg["dataset"].get("max_samples", len(dataset))
|
|
dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
|
|
# Load base model, merge SFT adapter
|
|
sft_model_name = cfg["model"]["sft"]
|
|
print(f"Loading SFT model from {sft_model_name}...")
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
if torch.cuda.is_available():
|
|
available = torch.cuda.device_count()
|
|
if available == 0:
|
|
raise RuntimeError("CUDA is available but no devices were detected.")
|
|
if local_rank >= available:
|
|
print(
|
|
f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
|
|
"Falling back to the last available device."
|
|
)
|
|
local_rank = available - 1
|
|
torch.cuda.set_device(local_rank)
|
|
dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
|
|
dtype_map = {
|
|
"bfloat16": torch.bfloat16,
|
|
"float16": torch.float16,
|
|
"float32": torch.float32,
|
|
}
|
|
torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
|
|
model_kwargs = {
|
|
"torch_dtype": torch_dtype,
|
|
"device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
|
|
}
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
base_model_name,
|
|
**model_kwargs,
|
|
)
|
|
model = PeftModel.from_pretrained(base_model, sft_model_name)
|
|
model = model.merge_and_unload()
|
|
print("SFT adapter merged.")
|
|
|
|
# Add fresh LoRA for GRPO with modules_to_save
|
|
grpo_lora_config = LoraConfig(
|
|
r=cfg["lora"]["rank"],
|
|
lora_alpha=cfg["lora"]["alpha"],
|
|
lora_dropout=cfg["lora"]["dropout"],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
|
ensure_weight_tying=True,
|
|
)
|
|
model = get_peft_model(model, grpo_lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
# Build GRPO config
|
|
output_name = cfg["model"]["output"]
|
|
push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
|
|
if "push_to_hub" in cfg["model"]:
|
|
push_to_hub = bool(cfg["model"]["push_to_hub"])
|
|
output_dir = output_name.split("/")[-1] if push_to_hub else output_name
|
|
|
|
grpo_cfg = cfg.get("grpo", {})
|
|
learning_rate = cfg["training"]["learning_rate"]
|
|
if isinstance(learning_rate, str):
|
|
learning_rate = float(learning_rate)
|
|
|
|
save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
|
save_steps = cfg["training"].get("save_steps", 200)
|
|
save_total_limit = cfg["training"].get("save_total_limit", 2)
|
|
save_strategy = cfg["training"].get("save_strategy", "epoch")
|
|
if save_interval_minutes:
|
|
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
|
save_steps = max(save_steps, 10_000_000)
|
|
save_strategy = "steps"
|
|
|
|
callbacks = []
|
|
if save_interval_minutes:
|
|
try:
|
|
interval_value = float(save_interval_minutes)
|
|
except (TypeError, ValueError):
|
|
interval_value = None
|
|
if interval_value and interval_value > 0:
|
|
callbacks.append(TimedSaveCallback(interval_value))
|
|
|
|
config = GRPOConfig(
|
|
output_dir=output_dir,
|
|
push_to_hub=push_to_hub,
|
|
hub_model_id=output_name if push_to_hub else None,
|
|
num_generations=grpo_cfg.get("num_generations", 4),
|
|
max_completion_length=grpo_cfg.get("max_completion_length", 200),
|
|
beta=grpo_cfg.get("beta", 0.04),
|
|
num_train_epochs=cfg["training"]["epochs"],
|
|
per_device_train_batch_size=cfg["training"]["batch_size"],
|
|
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
|
|
learning_rate=learning_rate,
|
|
max_grad_norm=cfg["training"]["max_grad_norm"],
|
|
max_steps=cfg["training"].get("max_steps", -1),
|
|
logging_steps=10,
|
|
save_strategy=save_strategy,
|
|
save_steps=save_steps,
|
|
save_total_limit=save_total_limit,
|
|
bf16=True,
|
|
skip_memory_metrics=True,
|
|
report_to=report_to,
|
|
run_name=run_name if report_to == "trackio" else None,
|
|
)
|
|
|
|
# Train
|
|
print("Initializing GRPO trainer...")
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
args=config,
|
|
train_dataset=dataset,
|
|
reward_funcs=[QMDRewardFunction()],
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
print("Starting GRPO training...")
|
|
trainer.train()
|
|
|
|
is_main = os.environ.get("RANK", "0") == "0"
|
|
if dist.is_available() and dist.is_initialized():
|
|
dist.barrier()
|
|
if not is_main:
|
|
return
|
|
|
|
if push_to_hub:
|
|
print("Pushing to Hub...")
|
|
trainer.push_to_hub()
|
|
|
|
trainer.save_model()
|
|
if report_to == "trackio":
|
|
try:
|
|
import trackio
|
|
|
|
trackio.finish()
|
|
except Exception:
|
|
pass
|
|
print(f"Done! Model saved to: {output_dir}")
|
|
|
|
# Export GGUF
|
|
print("\nExporting to GGUF...")
|
|
merged = model.merge_and_unload()
|
|
export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
|
|
|
|
# Run eval
|
|
eval_avg = run_eval(output_dir)
|
|
if report_to == "trackio" and eval_avg is not None:
|
|
try:
|
|
import trackio
|
|
|
|
trackio.log({"eval.avg": eval_avg})
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="QMD Query Expansion Training",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
uv run train.py sft --config configs/sft.yaml
|
|
""",
|
|
)
|
|
sub = parser.add_subparsers(dest="stage", required=True)
|
|
|
|
sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
|
|
sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
|
|
sft_parser.add_argument(
|
|
"--dry-run", action="store_true", help="Print config and exit"
|
|
)
|
|
|
|
grpo_parser = sub.add_parser(
|
|
"grpo",
|
|
help="Experimental: GRPO reinforcement learning (moved to experiments/grpo/)",
|
|
)
|
|
grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
|
|
grpo_parser.add_argument(
|
|
"--dry-run", action="store_true", help="Print config, test reward, and exit"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.stage == "sft":
|
|
cmd_sft(args)
|
|
elif args.stage == "grpo":
|
|
cmd_grpo(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|