Fix qmd embed model selection (#494)
This commit is contained in:
parent
021236378b
commit
698b44fe87
@ -1652,7 +1652,7 @@ function parseChunkStrategy(value: unknown): ChunkStrategy | undefined {
|
||||
}
|
||||
|
||||
async function vectorIndex(
|
||||
model: string = DEFAULT_EMBED_MODEL,
|
||||
model: string = DEFAULT_EMBED_MODEL_URI,
|
||||
force: boolean = false,
|
||||
batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number; chunkStrategy?: ChunkStrategy },
|
||||
): Promise<void> {
|
||||
@ -3002,7 +3002,7 @@ if (isMain) {
|
||||
const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]);
|
||||
const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]);
|
||||
const embedChunkStrategy = parseChunkStrategy(cli.values["chunk-strategy"]);
|
||||
await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, {
|
||||
await vectorIndex(DEFAULT_EMBED_MODEL_URI, !!cli.values.force, {
|
||||
maxDocsPerBatch,
|
||||
maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024,
|
||||
chunkStrategy: embedChunkStrategy,
|
||||
|
||||
14
src/llm.ts
14
src/llm.ts
@ -155,7 +155,7 @@ export type LLMSessionOptions = {
|
||||
*/
|
||||
export interface ILLMSession {
|
||||
embed(text: string, options?: EmbedOptions): Promise<EmbeddingResult | null>;
|
||||
embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]>;
|
||||
embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>;
|
||||
expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise<Queryable[]>;
|
||||
rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise<RerankResult>;
|
||||
/** Whether this session is still valid (not released or aborted) */
|
||||
@ -880,7 +880,7 @@ export class LlamaCpp implements LLM {
|
||||
|
||||
return {
|
||||
embedding: Array.from(embedding.vector),
|
||||
model: this.embedModelUri,
|
||||
model: options.model ?? this.embedModelUri,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Embedding error:", error);
|
||||
@ -892,7 +892,7 @@ export class LlamaCpp implements LLM {
|
||||
* Batch embed multiple texts efficiently
|
||||
* Uses Promise.all for parallel embedding - node-llama-cpp handles batching internally
|
||||
*/
|
||||
async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
|
||||
async embedBatch(texts: string[], options: EmbedOptions = {}): Promise<(EmbeddingResult | null)[]> {
|
||||
if (this._ciMode) throw new Error("LLM operations are disabled in CI (set CI=true)");
|
||||
// Ping activity at start to keep models alive during this operation
|
||||
this.touchActivity();
|
||||
@ -915,7 +915,7 @@ export class LlamaCpp implements LLM {
|
||||
}
|
||||
const embedding = await context.getEmbeddingFor(safeText);
|
||||
this.touchActivity();
|
||||
embeddings.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
|
||||
embeddings.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
|
||||
} catch (err) {
|
||||
console.error("Embedding error for text:", err);
|
||||
embeddings.push(null);
|
||||
@ -942,7 +942,7 @@ export class LlamaCpp implements LLM {
|
||||
}
|
||||
const embedding = await ctx.getEmbeddingFor(safeText);
|
||||
this.touchActivity();
|
||||
results.push({ embedding: Array.from(embedding.vector), model: this.embedModelUri });
|
||||
results.push({ embedding: Array.from(embedding.vector), model: options.model ?? this.embedModelUri });
|
||||
} catch (err) {
|
||||
console.error("Embedding error for text:", err);
|
||||
results.push(null);
|
||||
@ -1431,8 +1431,8 @@ class LLMSession implements ILLMSession {
|
||||
return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options));
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> {
|
||||
return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts));
|
||||
async embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]> {
|
||||
return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts, options));
|
||||
}
|
||||
|
||||
async expandQuery(
|
||||
|
||||
12
src/store.ts
12
src/store.ts
@ -1466,8 +1466,8 @@ export async function generateEmbeddings(
|
||||
|
||||
if (!vectorTableInitialized) {
|
||||
const firstChunk = batchChunks[0]!;
|
||||
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title);
|
||||
const firstResult = await session.embed(firstText);
|
||||
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title, model);
|
||||
const firstResult = await session.embed(firstText, { model });
|
||||
if (!firstResult) {
|
||||
throw new Error("Failed to get embedding dimensions from first chunk");
|
||||
}
|
||||
@ -1498,10 +1498,10 @@ export async function generateEmbeddings(
|
||||
|
||||
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
|
||||
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
|
||||
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
||||
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title, model));
|
||||
|
||||
try {
|
||||
const embeddings = await session.embedBatch(texts);
|
||||
const embeddings = await session.embedBatch(texts, { model });
|
||||
for (let i = 0; i < chunkBatch.length; i++) {
|
||||
const chunk = chunkBatch[i]!;
|
||||
const embedding = embeddings[i];
|
||||
@ -1522,8 +1522,8 @@ export async function generateEmbeddings(
|
||||
} else {
|
||||
for (const chunk of chunkBatch) {
|
||||
try {
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||
const result = await session.embed(text);
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title, model);
|
||||
const result = await session.embed(text, { model });
|
||||
if (result) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
|
||||
@ -2753,13 +2753,19 @@ describe("Embedding batching", () => {
|
||||
|
||||
function createFakeEmbedLlm() {
|
||||
const embedBatchCalls: string[][] = [];
|
||||
const embedCalls: { text: string; options?: { model?: string } }[] = [];
|
||||
const embedBatchModelCalls: ({ model?: string } | undefined)[] = [];
|
||||
return {
|
||||
embedBatchCalls,
|
||||
async embed(_text: string) {
|
||||
embedCalls,
|
||||
embedBatchModelCalls,
|
||||
async embed(text: string, options?: { model?: string }) {
|
||||
embedCalls.push({ text, options });
|
||||
return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" };
|
||||
},
|
||||
async embedBatch(texts: string[]) {
|
||||
async embedBatch(texts: string[], options?: { model?: string }) {
|
||||
embedBatchCalls.push([...texts]);
|
||||
embedBatchModelCalls.push(options);
|
||||
return texts.map((_text, index) => ({
|
||||
embedding: [index + 1, index + 2, index + 3],
|
||||
model: "fake-embed",
|
||||
@ -2832,6 +2838,30 @@ describe("Embedding batching", () => {
|
||||
}
|
||||
});
|
||||
|
||||
test("generateEmbeddings passes the selected model through to embed calls and metadata", async () => {
|
||||
const store = await createTestStore();
|
||||
const db = store.db;
|
||||
const fakeLlm = createFakeEmbedLlm();
|
||||
const model = "hf:Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf";
|
||||
|
||||
setDefaultLlamaCpp(createFakeTokenizer() as any);
|
||||
store.llm = fakeLlm as any;
|
||||
|
||||
try {
|
||||
await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" });
|
||||
|
||||
const result = await generateEmbeddings(store, { model });
|
||||
|
||||
expect(result.chunksEmbedded).toBe(1);
|
||||
expect(fakeLlm.embedCalls[0]?.options?.model).toBe(model);
|
||||
expect(fakeLlm.embedBatchModelCalls).toEqual([{ model }]);
|
||||
expect(db.prepare(`SELECT DISTINCT model FROM content_vectors`).all()).toEqual([{ model }]);
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await cleanupTestDb(store);
|
||||
}
|
||||
});
|
||||
|
||||
test("generateEmbeddings rejects invalid batch limits", async () => {
|
||||
const store = await createTestStore();
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user