From 698b44fe872e577c9d392904ea151beaedd6526d Mon Sep 17 00:00:00 2001 From: LJY <7yuny1@gmail.com> Date: Mon, 6 Apr 2026 04:45:04 +0800 Subject: [PATCH] Fix qmd embed model selection (#494) --- src/cli/qmd.ts | 4 ++-- src/llm.ts | 14 +++++++------- src/store.ts | 12 ++++++------ test/store.test.ts | 34 ++++++++++++++++++++++++++++++++-- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/cli/qmd.ts b/src/cli/qmd.ts index 7216965..2ac928a 100755 --- a/src/cli/qmd.ts +++ b/src/cli/qmd.ts @@ -1652,7 +1652,7 @@ function parseChunkStrategy(value: unknown): ChunkStrategy | undefined { } async function vectorIndex( - model: string = DEFAULT_EMBED_MODEL, + model: string = DEFAULT_EMBED_MODEL_URI, force: boolean = false, batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number; chunkStrategy?: ChunkStrategy }, ): Promise { @@ -3002,7 +3002,7 @@ if (isMain) { const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]); const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]); const embedChunkStrategy = parseChunkStrategy(cli.values["chunk-strategy"]); - await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, { + await vectorIndex(DEFAULT_EMBED_MODEL_URI, !!cli.values.force, { maxDocsPerBatch, maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024, chunkStrategy: embedChunkStrategy, diff --git a/src/llm.ts b/src/llm.ts index bd276bb..f67f18c 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -155,7 +155,7 @@ export type LLMSessionOptions = { */ export interface ILLMSession { embed(text: string, options?: EmbedOptions): Promise; - embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]>; + embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>; expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise; rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise; /** Whether this session is still valid (not released or aborted) */ @@ -880,7 +880,7 @@ export class LlamaCpp implements LLM { return { embedding: Array.from(embedding.vector), - model: this.embedModelUri, + model: options.model ?? this.embedModelUri, }; } catch (error) { console.error("Embedding error:", error); @@ -892,7 +892,7 @@ export class LlamaCpp implements LLM { * Batch embed multiple texts efficiently * Uses Promise.all for parallel embedding - node-llama-cpp handles batching internally */ - async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> { + async embedBatch(texts: string[], options: EmbedOptions = {}): Promise<(EmbeddingResult | null)[]> { if (this._ciMode) throw new Error("LLM operations are disabled in CI (set CI=true)"); // Ping activity at start to keep models alive during this operation this.touchActivity(); @@ -915,7 +915,7 @@ export class LlamaCpp implements LLM { } const embedding = await context.getEmbeddingFor(safeText); this.touchActivity(); - embeddings.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri }); + embeddings.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri }); } catch (err) { console.error("Embedding error for text:", err); embeddings.push(null); @@ -942,7 +942,7 @@ export class LlamaCpp implements LLM { } const embedding = await ctx.getEmbeddingFor(safeText); this.touchActivity(); - results.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri }); + results.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri }); } catch (err) { console.error("Embedding error for text:", err); results.push(null); @@ -1431,8 +1431,8 @@ class LLMSession implements ILLMSession { return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options)); } - async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> { - return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts)); + async embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> { + return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts, options)); } async expandQuery( diff --git a/src/store.ts b/src/store.ts index 34d1fd2..88075a9 100644 --- a/src/store.ts +++ b/src/store.ts @@ -1466,8 +1466,8 @@ export async function generateEmbeddings( if (!vectorTableInitialized) { const firstChunk = batchChunks[0]!; - const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title); - const firstResult = await session.embed(firstText); + const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, model); + const firstResult = await session.embed(firstText, { model }); if (!firstResult) { throw new Error("Failed to get embedding dimensions from first chunk"); } @@ -1498,10 +1498,10 @@ export async function generateEmbeddings( const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length); const chunkBatch = batchChunks.slice(batchStart, batchEnd); - const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title)); + const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, model)); try { - const embeddings = await session.embedBatch(texts); + const embeddings = await session.embedBatch(texts, { model }); for (let i = 0; i < chunkBatch.length; i++) { const chunk = chunkBatch[i]!; const embedding = embeddings[i]; @@ -1522,8 +1522,8 @@ export async function generateEmbeddings( } else { for (const chunk of chunkBatch) { try { - const text = formatDocForEmbedding(chunk.text, chunk.title); - const result = await session.embed(text); + const text = formatDocForEmbedding(chunk.text, chunk.title, model); + const result = await session.embed(text, { model }); if (result) { insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now); chunksEmbedded++; diff --git a/test/store.test.ts b/test/store.test.ts index 073297d..a0a47e6 100644 --- a/test/store.test.ts +++ b/test/store.test.ts @@ -2753,13 +2753,19 @@ describe("Embedding batching", () => { function createFakeEmbedLlm() { const embedBatchCalls: string[][] = []; + const embedCalls: { text: string; options?: { model?: string } }[] = []; + const embedBatchModelCalls: ({ model?: string } | undefined)[] = []; return { embedBatchCalls, - async embed(_text: string) { + embedCalls, + embedBatchModelCalls, + async embed(text: string, options?: { model?: string }) { + embedCalls.push({ text, options }); return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" }; }, - async embedBatch(texts: string[]) { + async embedBatch(texts: string[], options?: { model?: string }) { embedBatchCalls.push([...texts]); + embedBatchModelCalls.push(options); return texts.map((_text, index) => ({ embedding: [index + 1, index + 2, index + 3], model: "fake-embed", @@ -2832,6 +2838,30 @@ describe("Embedding batching", () => { } }); + test("generateEmbeddings passes the selected model through to embed calls and metadata", async () => { + const store = await createTestStore(); + const db = store.db; + const fakeLlm = createFakeEmbedLlm(); + const model = "hf:Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf"; + + setDefaultLlamaCpp(createFakeTokenizer() as any); + store.llm = fakeLlm as any; + + try { + await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" }); + + const result = await generateEmbeddings(store, { model }); + + expect(result.chunksEmbedded).toBe(1); + expect(fakeLlm.embedCalls[0]?.options?.model).toBe(model); + expect(fakeLlm.embedBatchModelCalls).toEqual([{ model }]); + expect(db.prepare(`SELECT DISTINCT model FROM content_vectors`).all()).toEqual([{ model }]); + } finally { + setDefaultLlamaCpp(null); + await cleanupTestDb(store); + } + }); + test("generateEmbeddings rejects invalid batch limits", async () => { const store = await createTestStore();