- Add missing subprocess import (NameError on any quantize path) - Replace broken optimum-cli quantize calls with direct onnxruntime: Q4 uses MatMulNBitsQuantizer, Q8 uses quantize_dynamic - Add onnxconverter-common to deps for FP16 (was silently swallowed) - Make FP16 fail loudly on missing dep instead of silently uploading FP32 - README and transformers_js_config now reflect actual quantize_type instead of always hardcoding Q4 - Remove dead _convert_fp16_external function
456 lines
15 KiB
Python
456 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# /// script
|
|
# requires-python = ">=3.10"
|
|
# dependencies = [
|
|
# "transformers>=4.36.0",
|
|
# "peft>=0.7.0",
|
|
# "torch>=2.0.0",
|
|
# "accelerate>=0.24.0",
|
|
# "huggingface_hub>=0.20.0",
|
|
# "sentencepiece>=0.1.99",
|
|
# "protobuf>=3.20.0",
|
|
# "numpy",
|
|
# "optimum[onnxruntime]",
|
|
# "onnx>=1.15.0",
|
|
# "onnxruntime>=1.17.0",
|
|
# "onnxconverter-common>=1.14.0",
|
|
# ]
|
|
# ///
|
|
"""
|
|
Convert QMD query expansion model to ONNX format for Transformers.js.
|
|
|
|
Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
|
|
with quantization for browser deployment via Transformers.js + WebGPU.
|
|
|
|
Usage:
|
|
uv run convert_onnx.py --size 1.7B
|
|
uv run convert_onnx.py --size 1.7B --no-upload
|
|
uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
|
|
--sft tobil/qmd-query-expansion-1.7B-sft \
|
|
--grpo tobil/qmd-query-expansion-1.7B-grpo \
|
|
--output tobil/qmd-query-expansion-1.7B-ONNX
|
|
|
|
Quantization options:
|
|
--quantize q4 MatMulNBits 4-bit (default, smallest)
|
|
--quantize q8 8-bit dynamic quantization
|
|
--quantize fp16 FP16 (requires GPU export)
|
|
--quantize none No quantization (FP32, ~7GB)
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from huggingface_hub import HfApi, login
|
|
from peft import PeftModel
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
PRESETS = {
|
|
"1.7B": {
|
|
"base": "Qwen/Qwen3-1.7B",
|
|
"sft": "tobil/qmd-query-expansion-1.7B-sft",
|
|
"grpo": "tobil/qmd-query-expansion-1.7B-grpo",
|
|
"output": "tobil/qmd-query-expansion-1.7B-ONNX",
|
|
},
|
|
"4B": {
|
|
"base": "Qwen/Qwen3-4B",
|
|
"sft": "tobil/qmd-query-expansion-4B-sft",
|
|
"grpo": "tobil/qmd-query-expansion-4B-grpo",
|
|
"output": "tobil/qmd-query-expansion-4B-ONNX",
|
|
},
|
|
}
|
|
|
|
|
|
def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
|
|
"""Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
|
|
print(f"\nStep 1: Loading base model {base_model}...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
base_model, dtype=torch.float32, trust_remote_code=True,
|
|
)
|
|
|
|
print(f"Step 2: Merging SFT adapter {sft_model}...")
|
|
model = PeftModel.from_pretrained(model, sft_model)
|
|
model = model.merge_and_unload()
|
|
|
|
print(f"Step 3: Merging GRPO adapter {grpo_model}...")
|
|
model = PeftModel.from_pretrained(model, grpo_model)
|
|
model = model.merge_and_unload()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
|
return model, tokenizer
|
|
|
|
|
|
def export_onnx(model, tokenizer, output_dir: str):
|
|
"""Export merged model to ONNX using Optimum."""
|
|
from optimum.exporters.onnx import main_export
|
|
|
|
# Save merged model to temp dir first (Optimum needs HF format on disk)
|
|
merged_dir = "/tmp/merged_model_onnx"
|
|
print(f"\nStep 4: Saving merged model to {merged_dir}...")
|
|
model.save_pretrained(merged_dir, safe_serialization=True)
|
|
tokenizer.save_pretrained(merged_dir)
|
|
|
|
print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
|
|
# no_post_process=True avoids the 2GB protobuf serialization limit
|
|
# that occurs during tied-weight deduplication on large FP32 models.
|
|
# The exported model still works correctly — the tied weights just
|
|
# aren't deduplicated in the graph, which is fine since we quantize next.
|
|
main_export(
|
|
model_name_or_path=merged_dir,
|
|
output=output_dir,
|
|
task="text-generation-with-past",
|
|
device="cpu",
|
|
fp16=False,
|
|
no_post_process=True,
|
|
)
|
|
|
|
# Clean up temp merged dir
|
|
shutil.rmtree(merged_dir, ignore_errors=True)
|
|
|
|
|
|
def _find_onnx_model(onnx_dir: str) -> Path:
|
|
"""Find the main ONNX model file in the output directory."""
|
|
model_path = Path(onnx_dir) / "model.onnx"
|
|
if model_path.exists():
|
|
return model_path
|
|
candidates = list(Path(onnx_dir).glob("*.onnx"))
|
|
if not candidates:
|
|
raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
|
|
return candidates[0]
|
|
|
|
|
|
def quantize_onnx(onnx_dir: str, quantize_type: str):
|
|
"""Quantize the exported ONNX model."""
|
|
if quantize_type == "none":
|
|
print("\nSkipping quantization (FP32).")
|
|
return
|
|
|
|
model_path = _find_onnx_model(onnx_dir)
|
|
print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
|
|
|
|
if quantize_type == "q4":
|
|
_quantize_q4(model_path)
|
|
elif quantize_type == "q8":
|
|
_quantize_q8(model_path)
|
|
elif quantize_type == "fp16":
|
|
_convert_fp16(model_path)
|
|
|
|
|
|
def _quantize_q4(model_path: Path):
|
|
"""4-bit MatMulNBits quantization via onnxruntime. Needs ~16GB RAM for 1.7B models."""
|
|
from onnxruntime.quantization import matmul_nbits_quantizer
|
|
|
|
q_path = model_path.with_name(model_path.stem + "_q4" + model_path.suffix)
|
|
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
|
|
model=str(model_path),
|
|
block_size=32,
|
|
is_symmetric=True,
|
|
bits=4,
|
|
)
|
|
quant.process()
|
|
quant.model.save(str(q_path))
|
|
|
|
# Remove original FP32 files, keep only quantized
|
|
if q_path.exists():
|
|
_report_size(q_path)
|
|
model_path.unlink(missing_ok=True)
|
|
data_path = model_path.with_name(model_path.name + "_data")
|
|
data_path.unlink(missing_ok=True)
|
|
# Rename quantized to model.onnx for Transformers.js compatibility
|
|
q_path.rename(model_path)
|
|
print(f" Renamed {q_path.name} -> {model_path.name}")
|
|
|
|
|
|
def _quantize_q8(model_path: Path):
|
|
"""8-bit dynamic quantization via onnxruntime."""
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
|
|
q_path = model_path.with_name(model_path.stem + "_q8" + model_path.suffix)
|
|
quantize_dynamic(
|
|
model_input=str(model_path),
|
|
model_output=str(q_path),
|
|
weight_type=QuantType.QUInt8,
|
|
)
|
|
|
|
if q_path.exists():
|
|
_report_size(q_path)
|
|
model_path.unlink(missing_ok=True)
|
|
data_path = model_path.with_name(model_path.name + "_data")
|
|
data_path.unlink(missing_ok=True)
|
|
q_path.rename(model_path)
|
|
print(f" Renamed {q_path.name} -> {model_path.name}")
|
|
|
|
|
|
def _convert_fp16(model_path: Path):
|
|
"""Convert ONNX model weights to FP16."""
|
|
from onnxconverter_common import float16
|
|
import onnx
|
|
|
|
print(" Converting to FP16...")
|
|
model = onnx.load(str(model_path), load_external_data=True)
|
|
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
|
|
|
fp16_path = model_path.with_name(model_path.stem + "_fp16" + model_path.suffix)
|
|
onnx.save(model_fp16, str(fp16_path))
|
|
|
|
if fp16_path.exists():
|
|
_report_size(fp16_path)
|
|
model_path.unlink(missing_ok=True)
|
|
data_path = model_path.with_name(model_path.name + "_data")
|
|
data_path.unlink(missing_ok=True)
|
|
fp16_path.rename(model_path)
|
|
print(f" Renamed {fp16_path.name} -> {model_path.name}")
|
|
|
|
|
|
def _report_size(path: Path):
|
|
"""Print file size in MB."""
|
|
size_mb = path.stat().st_size / (1024 * 1024)
|
|
print(f" {path.name}: {size_mb:.1f} MB")
|
|
|
|
|
|
|
|
def validate_onnx(onnx_dir: str, base_model: str):
|
|
"""Run a sample inference through the ONNX model to verify it works."""
|
|
import onnxruntime as ort
|
|
import numpy as np
|
|
|
|
model_path = _find_onnx_model(onnx_dir)
|
|
print(f"\nValidation: loading {model_path.name}...")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
|
|
session = ort.InferenceSession(
|
|
str(model_path),
|
|
providers=["CPUExecutionProvider"],
|
|
)
|
|
|
|
# Tokenize a test prompt
|
|
test_query = "/no_think Expand this search query: distributed consensus"
|
|
chat_prompt = tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": test_query}],
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
inputs = tokenizer(chat_prompt, return_tensors="np")
|
|
input_ids = inputs["input_ids"].astype(np.int64)
|
|
attention_mask = inputs["attention_mask"].astype(np.int64)
|
|
|
|
# Build feed dict with all required inputs
|
|
seq_len = input_ids.shape[1]
|
|
feed = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
|
# Add position_ids if needed
|
|
all_inputs = {inp.name: inp for inp in session.get_inputs()}
|
|
if "position_ids" in all_inputs:
|
|
feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
|
|
|
|
# Initialize past_key_values to zeros if the model expects them
|
|
for name, inp in sorted(all_inputs.items()):
|
|
if name.startswith("past_key_values"):
|
|
shape = []
|
|
for dim in inp.shape:
|
|
shape.append(dim if isinstance(dim, int) else 0)
|
|
# batch dim = 1
|
|
if shape and shape[0] == 0:
|
|
shape[0] = 1
|
|
feed[name] = np.zeros(shape, dtype=np.float32)
|
|
|
|
# Run inference
|
|
output_names = [o.name for o in session.get_outputs()]
|
|
results = session.run(output_names, feed)
|
|
|
|
# Check logits shape
|
|
logits = results[0]
|
|
print(f" Input tokens: {input_ids.shape[1]}")
|
|
print(f" Output logits shape: {logits.shape}")
|
|
print(f" Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
|
|
|
|
# Greedy decode next token
|
|
next_token_id = int(np.argmax(logits[0, -1, :]))
|
|
next_token = tokenizer.decode([next_token_id])
|
|
print(f" Next token: {repr(next_token)} (id={next_token_id})")
|
|
|
|
# Check KV cache outputs exist
|
|
kv_outputs = [n for n in output_names if n.startswith("present")]
|
|
if kv_outputs:
|
|
print(f" KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
|
|
else:
|
|
print(" WARNING: No KV cache outputs — model may not support efficient generation")
|
|
|
|
# Sanity checks
|
|
assert logits.shape[0] == 1, "Batch size mismatch"
|
|
assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
|
|
assert logits.max() > logits.min(), "Logits are constant (broken model)"
|
|
assert not np.isnan(logits).any(), "Logits contain NaN"
|
|
assert not np.isinf(logits).any(), "Logits contain Inf"
|
|
|
|
print(" Validation PASSED")
|
|
|
|
|
|
def write_transformers_js_config(onnx_dir: str, quantize_type: str = "q4"):
|
|
"""Write Transformers.js compatibility config."""
|
|
config_path = Path(onnx_dir) / "transformers_js_config.json"
|
|
config = {
|
|
"model_type": "text-generation",
|
|
"quantized": quantize_type != "none",
|
|
}
|
|
config_path.write_text(json.dumps(config, indent=2) + "\n")
|
|
print(f" Wrote {config_path.name}")
|
|
|
|
|
|
def upload_to_hub(
|
|
onnx_dir: str,
|
|
output_repo: str,
|
|
base_model: str,
|
|
sft_model: str,
|
|
grpo_model: str,
|
|
quantize_type: str = "q4",
|
|
):
|
|
"""Upload ONNX model to HuggingFace Hub."""
|
|
print(f"\nStep 7: Uploading to {output_repo}...")
|
|
api = HfApi()
|
|
api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
|
|
|
|
api.upload_folder(
|
|
folder_path=onnx_dir,
|
|
repo_id=output_repo,
|
|
commit_message="Upload ONNX model",
|
|
)
|
|
|
|
readme = f"""---
|
|
base_model: {base_model}
|
|
tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
|
|
library_name: transformers.js
|
|
---
|
|
# {output_repo.split("/")[-1]}
|
|
|
|
ONNX conversion of the QMD Query Expansion model for use with
|
|
[Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.
|
|
|
|
## Details
|
|
- **Base:** {base_model}
|
|
- **SFT:** {sft_model}
|
|
- **GRPO:** {grpo_model}
|
|
- **Task:** Query expansion (lex/vec/hyde format)
|
|
- **Format:** ONNX with {quantize_type.upper()} quantization
|
|
|
|
## Usage with Transformers.js
|
|
|
|
```javascript
|
|
import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";
|
|
|
|
const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
|
|
const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
|
|
dtype: "{quantize_type}",
|
|
device: "webgpu",
|
|
}});
|
|
```
|
|
|
|
## Prompt Format
|
|
```
|
|
<|im_start|>user
|
|
/no_think Expand this search query: your query here<|im_end|>
|
|
<|im_start|>assistant
|
|
```
|
|
"""
|
|
api.upload_file(
|
|
path_or_fileobj=readme.encode(),
|
|
path_in_repo="README.md",
|
|
repo_id=output_repo,
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
|
|
parser.add_argument(
|
|
"--size", choices=PRESETS.keys(), help="Use preset config for model size",
|
|
)
|
|
parser.add_argument("--base", help="Base model (overrides preset)")
|
|
parser.add_argument("--sft", help="SFT adapter (overrides preset)")
|
|
parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
|
|
parser.add_argument("--output", help="Output HF repo (overrides preset)")
|
|
parser.add_argument(
|
|
"--quantize",
|
|
choices=["q4", "q8", "fp16", "none"],
|
|
default="q4",
|
|
help="Quantization type (default: q4)",
|
|
)
|
|
parser.add_argument(
|
|
"--no-upload", action="store_true", help="Don't upload to HF Hub",
|
|
)
|
|
parser.add_argument(
|
|
"--validate", action="store_true",
|
|
help="Run inference validation on exported model",
|
|
)
|
|
parser.add_argument(
|
|
"--validate-only", metavar="DIR",
|
|
help="Skip export, only validate an existing ONNX dir",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Validate-only mode: skip export, just run validation
|
|
if args.validate_only:
|
|
validate_onnx(args.validate_only, "")
|
|
return
|
|
|
|
# Resolve config
|
|
if args.size:
|
|
preset = PRESETS[args.size]
|
|
base_model = args.base or preset["base"]
|
|
sft_model = args.sft or preset["sft"]
|
|
grpo_model = args.grpo or preset["grpo"]
|
|
output_repo = args.output or preset["output"]
|
|
elif args.base and args.sft and args.grpo and args.output:
|
|
base_model = args.base
|
|
sft_model = args.sft
|
|
grpo_model = args.grpo
|
|
output_repo = args.output
|
|
else:
|
|
parser.error(
|
|
"Either --size or all of --base/--sft/--grpo/--output are required",
|
|
)
|
|
|
|
model_name = output_repo.split("/")[-1]
|
|
print(f"QMD ONNX Conversion: {model_name}")
|
|
print("=" * 60)
|
|
|
|
# Login
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
if hf_token:
|
|
print("Logging in to HuggingFace...")
|
|
login(token=hf_token)
|
|
|
|
# Merge adapters
|
|
model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)
|
|
|
|
# Export to ONNX
|
|
onnx_dir = f"/tmp/onnx_output/{model_name}"
|
|
os.makedirs(onnx_dir, exist_ok=True)
|
|
export_onnx(model, tokenizer, onnx_dir)
|
|
|
|
# Quantize
|
|
quantize_onnx(onnx_dir, args.quantize)
|
|
|
|
# Write Transformers.js config
|
|
write_transformers_js_config(onnx_dir, args.quantize)
|
|
|
|
# Validate
|
|
if args.validate:
|
|
validate_onnx(onnx_dir, base_model)
|
|
|
|
# Upload
|
|
if not args.no_upload:
|
|
upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model, args.quantize)
|
|
|
|
print(f"\nDone! ONNX files at: {onnx_dir}")
|
|
if not args.no_upload:
|
|
print(f"Repository: https://huggingface.co/{output_repo}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|