Fix qmd embed model selection (#494)

This commit is contained in:
LJY 2026-04-06 04:45:04 +08:00 committed by GitHub
parent 021236378b
commit 698b44fe87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 17 deletions

View File

@ -1652,7 +1652,7 @@ function parseChunkStrategy(value: unknown): ChunkStrategy | undefined {
}
async function vectorIndex(
model: string = DEFAULT_EMBED_MODEL,
model: string = DEFAULT_EMBED_MODEL_URI,
force: boolean = false,
batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number; chunkStrategy?: ChunkStrategy },
): Promise<void> {
@ -3002,7 +3002,7 @@ if (isMain) {
const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]);
const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]);
const embedChunkStrategy = parseChunkStrategy(cli.values["chunk-strategy"]);
await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, {
await vectorIndex(DEFAULT_EMBED_MODEL_URI, !!cli.values.force, {
maxDocsPerBatch,
maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024,
chunkStrategy: embedChunkStrategy,

View File

@ -155,7 +155,7 @@ export type LLMSessionOptions = {
*/
export interface ILLMSession {
embed(text: string, options?: EmbedOptions): Promise<EmbeddingResult | null>;
embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]>;
embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>;
expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise<Queryable[]>;
rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise<RerankResult>;
/** Whether this session is still valid (not released or aborted) */
@ -880,7 +880,7 @@ export class LlamaCpp implements LLM {
return {
embedding: Array.from(embedding.vector),
model: this.embedModelUri,
model: options.model ?? this.embedModelUri,
};
} catch (error) {
console.error("Embedding error:", error);
@ -892,7 +892,7 @@ export class LlamaCpp implements LLM {
* Batch embed multiple texts efficiently
* Uses Promise.all for parallel embedding - node-llama-cpp handles batching internally
*/
async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
async embedBatch(texts: string[], options: EmbedOptions = {}): Promise<(EmbeddingResult | null)[]> {
if (this._ciMode) throw new Error("LLM operations are disabled in CI (set CI=true)");
// Ping activity at start to keep models alive during this operation
this.touchActivity();
@ -915,7 +915,7 @@ export class LlamaCpp implements LLM {
}
const embedding = await context.getEmbeddingFor(safeText);
this.touchActivity();
embeddings.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
embeddings.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
} catch (err) {
console.error("Embedding error for text:", err);
embeddings.push(null);
@ -942,7 +942,7 @@ export class LlamaCpp implements LLM {
}
const embedding = await ctx.getEmbeddingFor(safeText);
this.touchActivity();
results.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
results.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
} catch (err) {
console.error("Embedding error for text:", err);
results.push(null);
@ -1431,8 +1431,8 @@ class LLMSession implements ILLMSession {
return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options));
}
async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts));
async embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> {
return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts, options));
}
async expandQuery(

View File

@ -1466,8 +1466,8 @@ export async function generateEmbeddings(
if (!vectorTableInitialized) {
const firstChunk = batchChunks[0]!;
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title);
const firstResult = await session.embed(firstText);
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, model);
const firstResult = await session.embed(firstText, { model });
if (!firstResult) {
throw new Error("Failed to get embedding dimensions from first chunk");
}
@ -1498,10 +1498,10 @@ export async function generateEmbeddings(
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, model));
try {
const embeddings = await session.embedBatch(texts);
const embeddings = await session.embedBatch(texts, { model });
for (let i = 0; i < chunkBatch.length; i++) {
const chunk = chunkBatch[i]!;
const embedding = embeddings[i];
@ -1522,8 +1522,8 @@ export async function generateEmbeddings(
} else {
for (const chunk of chunkBatch) {
try {
const text = formatDocForEmbedding(chunk.text, chunk.title);
const result = await session.embed(text);
const text = formatDocForEmbedding(chunk.text, chunk.title, model);
const result = await session.embed(text, { model });
if (result) {
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
chunksEmbedded++;

View File

@ -2753,13 +2753,19 @@ describe("Embedding batching", () => {
function createFakeEmbedLlm() {
const embedBatchCalls: string[][] = [];
const embedCalls: { text: string; options?: { model?: string } }[] = [];
const embedBatchModelCalls: ({ model?: string } | undefined)[] = [];
return {
embedBatchCalls,
async embed(_text: string) {
embedCalls,
embedBatchModelCalls,
async embed(text: string, options?: { model?: string }) {
embedCalls.push({ text, options });
return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" };
},
async embedBatch(texts: string[]) {
async embedBatch(texts: string[], options?: { model?: string }) {
embedBatchCalls.push([...texts]);
embedBatchModelCalls.push(options);
return texts.map((_text, index) => ({
embedding: [index + 1, index + 2, index + 3],
model: "fake-embed",
@ -2832,6 +2838,30 @@ describe("Embedding batching", () => {
}
});
test("generateEmbeddings passes the selected model through to embed calls and metadata", async () => {
const store = await createTestStore();
const db = store.db;
const fakeLlm = createFakeEmbedLlm();
const model = "hf:Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf";
setDefaultLlamaCpp(createFakeTokenizer() as any);
store.llm = fakeLlm as any;
try {
await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" });
const result = await generateEmbeddings(store, { model });
expect(result.chunksEmbedded).toBe(1);
expect(fakeLlm.embedCalls[0]?.options?.model).toBe(model);
expect(fakeLlm.embedBatchModelCalls).toEqual([{ model }]);
expect(db.prepare(`SELECT DISTINCT model FROM content_vectors`).all()).toEqual([{ model }]);
} finally {
setDefaultLlamaCpp(null);
await cleanupTestDb(store);
}
});
test("generateEmbeddings rejects invalid batch limits", async () => {
const store = await createTestStore();