diff --git a/finetune/convert_onnx.py b/finetune/convert_onnx.py index 144a683..9dedfc8 100644 --- a/finetune/convert_onnx.py +++ b/finetune/convert_onnx.py @@ -321,12 +321,18 @@ def upload_to_hub( commit_message="Upload ONNX model", ) + # Map quantize_type to Transformers.js dtype values + dtype_map = {"q4": "q4", "q8": "q8", "fp16": "fp16", "none": "fp32"} + tj_dtype = dtype_map.get(quantize_type, "fp32") + format_desc = "FP32 (no quantization)" if quantize_type == "none" else f"{quantize_type.upper()} quantization" + repo_name = output_repo.split("/")[-1] + readme = f"""--- base_model: {base_model} tags: [onnx, transformers.js, webgpu, query-expansion, qmd] library_name: transformers.js --- -# {output_repo.split("/")[-1]} +# {repo_name} ONNX conversion of the QMD Query Expansion model for use with [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU. @@ -336,7 +342,7 @@ ONNX conversion of the QMD Query Expansion model for use with - **SFT:** {sft_model} - **GRPO:** {grpo_model} - **Task:** Query expansion (lex/vec/hyde format) -- **Format:** ONNX with {quantize_type.upper()} quantization +- **Format:** ONNX with {format_desc} ## Usage with Transformers.js @@ -345,7 +351,7 @@ import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}"); const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{ - dtype: "{quantize_type}", + dtype: "{tj_dtype}", device: "webgpu", }}); ```