From 809aa361724941db1b32bb271b75788b4d661af0 Mon Sep 17 00:00:00 2001 From: programcaicai Date: Fri, 13 Mar 2026 17:39:17 +0800 Subject: [PATCH] fix: bound memory usage during embed --- src/cli/qmd.ts | 45 ++++++- src/index.ts | 4 + src/store.ts | 285 ++++++++++++++++++++++++++++++++------------- test/cli.test.ts | 14 +++ test/sdk.test.ts | 74 ++++++++++++ test/store.test.ts | 110 ++++++++++++++++- 6 files changed, 446 insertions(+), 86 deletions(-) diff --git a/src/cli/qmd.ts b/src/cli/qmd.ts index 52a076d..2570408 100755 --- a/src/cli/qmd.ts +++ b/src/cli/qmd.ts @@ -25,7 +25,6 @@ import { isDocid, matchFilesByGlob, getHashesNeedingEmbedding, - getHashesForEmbedding, clearAllEmbeddings, insertEmbedding, getStatus, @@ -65,6 +64,8 @@ import { type ExpandedQuery, type HybridQueryExplain, DEFAULT_EMBED_MODEL, + DEFAULT_EMBED_MAX_BATCH_BYTES, + DEFAULT_EMBED_MAX_DOCS_PER_BATCH, DEFAULT_RERANK_MODEL, DEFAULT_GLOB, DEFAULT_MULTI_GET_MAX_BYTES, @@ -1607,7 +1608,20 @@ function renderProgressBar(percent: number, width: number = 30): string { return bar; } -async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean = false): Promise { +function parseEmbedBatchOption(name: string, value: unknown): number | undefined { + if (value === undefined) return undefined; + const parsed = Number(value); + if (!Number.isInteger(parsed) || parsed < 1) { + throw new Error(`${name} must be a positive integer`); + } + return parsed; +} + +async function vectorIndex( + model: string = DEFAULT_EMBED_MODEL, + force: boolean = false, + batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number }, +): Promise { const storeInstance = getStore(); const db = storeInstance.db; @@ -1616,14 +1630,19 @@ async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean = } // Check if there's work to do before starting - const hashesToEmbed = getHashesForEmbedding(db); - if (hashesToEmbed.length === 0 && !force) { + const hashesToEmbed = getHashesNeedingEmbedding(db); + if (hashesToEmbed === 0 && !force) { console.log(`${c.green}✓ All content hashes already have embeddings.${c.reset}`); closeDb(); return; } console.log(`${c.dim}Model: ${model}${c.reset}\n`); + if (batchOptions?.maxDocsPerBatch !== undefined || batchOptions?.maxBatchBytes !== undefined) { + const maxDocsPerBatch = batchOptions.maxDocsPerBatch ?? DEFAULT_EMBED_MAX_DOCS_PER_BATCH; + const maxBatchBytes = batchOptions.maxBatchBytes ?? DEFAULT_EMBED_MAX_BATCH_BYTES; + console.log(`${c.dim}Batch: ${maxDocsPerBatch} docs / ${formatBytes(maxBatchBytes)}${c.reset}\n`); + } cursor.hide(); progress.indeterminate(); @@ -1632,6 +1651,8 @@ async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean = const result = await generateEmbeddings(storeInstance, { force, model, + maxDocsPerBatch: batchOptions?.maxDocsPerBatch, + maxBatchBytes: batchOptions?.maxBatchBytes, onProgress: (info) => { if (info.totalBytes === 0) return; const percent = (info.bytesProcessed / info.totalBytes) * 100; @@ -2334,6 +2355,8 @@ function parseCLI() { mask: { type: "string" }, // glob pattern // Embed options force: { type: "boolean", short: "f" }, + "max-docs-per-batch": { type: "string" }, + "max-batch-mb": { type: "string" }, // Update options pull: { type: "boolean" }, // git pull before update refresh: { type: "boolean" }, @@ -2547,6 +2570,8 @@ function showHelp(): void { console.log(" qmd status - View index + collection health"); console.log(" qmd update [--pull] - Re-index collections (optionally git pull first)"); console.log(" qmd embed [-f] - Generate/refresh vector embeddings"); + console.log(" --max-docs-per-batch - Cap docs loaded into memory per embedding batch"); + console.log(" --max-batch-mb - Cap UTF-8 MB loaded into memory per embedding batch"); console.log(" qmd cleanup - Clear caches, vacuum DB"); console.log(""); console.log("Query syntax (qmd query):"); @@ -2923,7 +2948,17 @@ if (isMain) { break; case "embed": - await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force); + try { + const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]); + const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]); + await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, { + maxDocsPerBatch, + maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024, + }); + } catch (error) { + console.error(error instanceof Error ? error.message : String(error)); + process.exit(1); + } break; case "pull": { diff --git a/src/index.ts b/src/index.ts index b921b51..22f3fa3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -286,6 +286,8 @@ export interface QMDStore { embed(options?: { force?: boolean; model?: string; + maxDocsPerBatch?: number; + maxBatchBytes?: number; onProgress?: (info: EmbedProgress) => void; }): Promise; @@ -502,6 +504,8 @@ export async function createStore(options: StoreOptions): Promise { return generateEmbeddings(internal, { force: embedOpts?.force, model: embedOpts?.model, + maxDocsPerBatch: embedOpts?.maxDocsPerBatch, + maxBatchBytes: embedOpts?.maxBatchBytes, onProgress: embedOpts?.onProgress, }); }, diff --git a/src/store.ts b/src/store.ts index aa5fae4..10cbec6 100644 --- a/src/store.ts +++ b/src/store.ts @@ -24,7 +24,6 @@ import { formatQueryForEmbedding, formatDocForEmbedding, withLLMSessionForLlm, - type LLMSessionOptions, type RerankDocument, type ILLMSession, } from "./llm.js"; @@ -45,6 +44,8 @@ export const DEFAULT_RERANK_MODEL = "ExpedientFalcon/qwen3-reranker:0.6b-q8_0"; export const DEFAULT_QUERY_MODEL = "Qwen/Qwen3-1.7B"; export const DEFAULT_GLOB = "**/*.md"; export const DEFAULT_MULTI_GET_MAX_BYTES = 10 * 1024; // 10KB +export const DEFAULT_EMBED_MAX_DOCS_PER_BATCH = 64; +export const DEFAULT_EMBED_MAX_BATCH_BYTES = 64 * 1024 * 1024; // 64MB // Chunking: 900 tokens per chunk with 15% overlap // Increased from 800 to accommodate smart chunking finding natural break points @@ -1179,6 +1180,109 @@ export type EmbedResult = { durationMs: number; }; +export type EmbedOptions = { + force?: boolean; + model?: string; + maxDocsPerBatch?: number; + maxBatchBytes?: number; + onProgress?: (info: EmbedProgress) => void; +}; + +type PendingEmbeddingDoc = { + hash: string; + path: string; + bytes: number; +}; + +type EmbeddingDoc = PendingEmbeddingDoc & { + body: string; +}; + +type ChunkItem = { + hash: string; + title: string; + text: string; + seq: number; + pos: number; + tokens: number; + bytes: number; +}; + +function validatePositiveIntegerOption(name: string, value: number | undefined, fallback: number): number { + if (value === undefined) return fallback; + if (!Number.isInteger(value) || value < 1) { + throw new Error(`${name} must be a positive integer`); + } + return value; +} + +function resolveEmbedOptions(options?: EmbedOptions): Required> { + return { + maxDocsPerBatch: validatePositiveIntegerOption("maxDocsPerBatch", options?.maxDocsPerBatch, DEFAULT_EMBED_MAX_DOCS_PER_BATCH), + maxBatchBytes: validatePositiveIntegerOption("maxBatchBytes", options?.maxBatchBytes, DEFAULT_EMBED_MAX_BATCH_BYTES), + }; +} + +function getPendingEmbeddingDocs(db: Database): PendingEmbeddingDoc[] { + return db.prepare(` + SELECT d.hash, MIN(d.path) as path, length(CAST(c.doc AS BLOB)) as bytes + FROM documents d + JOIN content c ON d.hash = c.hash + LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0 + WHERE d.active = 1 AND v.hash IS NULL + GROUP BY d.hash + ORDER BY MIN(d.path) + `).all() as PendingEmbeddingDoc[]; +} + +function buildEmbeddingBatches( + docs: PendingEmbeddingDoc[], + maxDocsPerBatch: number, + maxBatchBytes: number, +): PendingEmbeddingDoc[][] { + const batches: PendingEmbeddingDoc[][] = []; + let currentBatch: PendingEmbeddingDoc[] = []; + let currentBytes = 0; + + for (const doc of docs) { + const docBytes = Math.max(0, doc.bytes); + const wouldExceedDocs = currentBatch.length >= maxDocsPerBatch; + const wouldExceedBytes = currentBatch.length > 0 && (currentBytes + docBytes) > maxBatchBytes; + + if (wouldExceedDocs || wouldExceedBytes) { + batches.push(currentBatch); + currentBatch = []; + currentBytes = 0; + } + + currentBatch.push(doc); + currentBytes += docBytes; + } + + if (currentBatch.length > 0) { + batches.push(currentBatch); + } + + return batches; +} + +function getEmbeddingDocsForBatch(db: Database, batch: PendingEmbeddingDoc[]): EmbeddingDoc[] { + if (batch.length === 0) return []; + + const placeholders = batch.map(() => "?").join(","); + const rows = db.prepare(` + SELECT hash, doc as body + FROM content + WHERE hash IN (${placeholders}) + `).all(...batch.map(doc => doc.hash)) as { hash: string; body: string }[]; + const bodyByHash = new Map(rows.map(row => [row.hash, row.body])); + + return batch.map((doc) => ({ + ...doc, + body: bodyByHash.get(doc.hash) ?? "", + })); +} + /** * Generate vector embeddings for documents that need them. * Pure function — no console output, no db lifecycle management. @@ -1186,120 +1290,141 @@ export type EmbedResult = { */ export async function generateEmbeddings( store: Store, - options?: { - force?: boolean; - model?: string; - onProgress?: (info: EmbedProgress) => void; - } + options?: EmbedOptions ): Promise { const db = store.db; const model = options?.model ?? DEFAULT_EMBED_MODEL; const now = new Date().toISOString(); + const { maxDocsPerBatch, maxBatchBytes } = resolveEmbedOptions(options); + const encoder = new TextEncoder(); if (options?.force) { clearAllEmbeddings(db); } - const hashesToEmbed = getHashesForEmbedding(db); + const docsToEmbed = getPendingEmbeddingDocs(db); - if (hashesToEmbed.length === 0) { + if (docsToEmbed.length === 0) { return { docsProcessed: 0, chunksEmbedded: 0, errors: 0, durationMs: 0 }; } - - // Chunk all documents - type ChunkItem = { hash: string; title: string; text: string; seq: number; pos: number; tokens: number; bytes: number }; - const allChunks: ChunkItem[] = []; - - for (const item of hashesToEmbed) { - const encoder = new TextEncoder(); - const bodyBytes = encoder.encode(item.body).length; - if (bodyBytes === 0) continue; - - const title = extractTitle(item.body, item.path); - const chunks = await chunkDocumentByTokens(item.body); - - for (let seq = 0; seq < chunks.length; seq++) { - allChunks.push({ - hash: item.hash, - title, - text: chunks[seq]!.text, - seq, - pos: chunks[seq]!.pos, - tokens: chunks[seq]!.tokens, - bytes: encoder.encode(chunks[seq]!.text).length, - }); - } - } - - if (allChunks.length === 0) { - return { docsProcessed: 0, chunksEmbedded: 0, errors: 0, durationMs: 0 }; - } - - const totalBytes = allChunks.reduce((sum, chk) => sum + chk.bytes, 0); - const totalChunks = allChunks.length; - const totalDocs = hashesToEmbed.length; + const totalBytes = docsToEmbed.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0); + const totalDocs = docsToEmbed.length; const startTime = Date.now(); // Use store's LlamaCpp or global singleton, wrapped in a session const llm = getLlm(store); - const sessionOptions: LLMSessionOptions = { maxDuration: 30 * 60 * 1000, name: 'generateEmbeddings' }; // Create a session manager for this llm instance const result = await withLLMSessionForLlm(llm, async (session) => { - // Get embedding dimensions from first chunk - const firstChunk = allChunks[0]!; - const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title); - const firstResult = await session.embed(firstText); - if (!firstResult) { - throw new Error("Failed to get embedding dimensions from first chunk"); - } - store.ensureVecTable(firstResult.embedding.length); - - let chunksEmbedded = 0, errors = 0, bytesProcessed = 0; + let chunksEmbedded = 0; + let errors = 0; + let bytesProcessed = 0; + let totalChunks = 0; + let vectorTableInitialized = false; const BATCH_SIZE = 32; + const batches = buildEmbeddingBatches(docsToEmbed, maxDocsPerBatch, maxBatchBytes); - for (let batchStart = 0; batchStart < allChunks.length; batchStart += BATCH_SIZE) { - const batchEnd = Math.min(batchStart + BATCH_SIZE, allChunks.length); - const batch = allChunks.slice(batchStart, batchEnd); - const texts = batch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title)); + for (const batchMeta of batches) { + const batchDocs = getEmbeddingDocsForBatch(db, batchMeta); + const batchChunks: ChunkItem[] = []; + const batchBytes = batchMeta.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0); - try { - const embeddings = await session.embedBatch(texts); - for (let i = 0; i < batch.length; i++) { - const chunk = batch[i]!; - const embedding = embeddings[i]; - if (embedding) { - insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(embedding.embedding), model, now); - chunksEmbedded++; - } else { - errors++; - } - bytesProcessed += chunk.bytes; + for (const doc of batchDocs) { + if (!doc.body.trim()) continue; + + const title = extractTitle(doc.body, doc.path); + const chunks = await chunkDocumentByTokens(doc.body); + + for (let seq = 0; seq < chunks.length; seq++) { + batchChunks.push({ + hash: doc.hash, + title, + text: chunks[seq]!.text, + seq, + pos: chunks[seq]!.pos, + tokens: chunks[seq]!.tokens, + bytes: encoder.encode(chunks[seq]!.text).length, + }); } - } catch { - // Batch failed — try individual embeddings as fallback - for (const chunk of batch) { - try { - const text = formatDocForEmbedding(chunk.text, chunk.title); - const result = await session.embed(text); - if (result) { - insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now); + } + + totalChunks += batchChunks.length; + + if (batchChunks.length === 0) { + bytesProcessed += batchBytes; + options?.onProgress?.({ chunksEmbedded, totalChunks, bytesProcessed, totalBytes, errors }); + continue; + } + + if (!vectorTableInitialized) { + const firstChunk = batchChunks[0]!; + const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title); + const firstResult = await session.embed(firstText); + if (!firstResult) { + throw new Error("Failed to get embedding dimensions from first chunk"); + } + store.ensureVecTable(firstResult.embedding.length); + vectorTableInitialized = true; + } + + const totalBatchChunkBytes = batchChunks.reduce((sum, chunk) => sum + chunk.bytes, 0); + let batchChunkBytesProcessed = 0; + + for (let batchStart = 0; batchStart < batchChunks.length; batchStart += BATCH_SIZE) { + const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length); + const chunkBatch = batchChunks.slice(batchStart, batchEnd); + const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title)); + + try { + const embeddings = await session.embedBatch(texts); + for (let i = 0; i < chunkBatch.length; i++) { + const chunk = chunkBatch[i]!; + const embedding = embeddings[i]; + if (embedding) { + insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(embedding.embedding), model, now); chunksEmbedded++; } else { errors++; } - } catch { - errors++; + batchChunkBytesProcessed += chunk.bytes; + } + } catch { + // Batch failed — try individual embeddings as fallback + for (const chunk of chunkBatch) { + try { + const text = formatDocForEmbedding(chunk.text, chunk.title); + const result = await session.embed(text); + if (result) { + insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now); + chunksEmbedded++; + } else { + errors++; + } + } catch { + errors++; + } + batchChunkBytesProcessed += chunk.bytes; } - bytesProcessed += chunk.bytes; } + + const proportionalBytes = totalBatchChunkBytes === 0 + ? batchBytes + : Math.min(batchBytes, Math.round((batchChunkBytesProcessed / totalBatchChunkBytes) * batchBytes)); + options?.onProgress?.({ + chunksEmbedded, + totalChunks, + bytesProcessed: bytesProcessed + proportionalBytes, + totalBytes, + errors, + }); } + bytesProcessed += batchBytes; options?.onProgress?.({ chunksEmbedded, totalChunks, bytesProcessed, totalBytes, errors }); } return { chunksEmbedded, errors }; - }, sessionOptions); + }, { maxDuration: 30 * 60 * 1000, name: 'generateEmbeddings' }); return { docsProcessed: totalDocs, diff --git a/test/cli.test.ts b/test/cli.test.ts index 834ac18..7d6f526 100644 --- a/test/cli.test.ts +++ b/test/cli.test.ts @@ -241,6 +241,20 @@ describe("CLI Help", () => { }); }); +describe("CLI Embed", () => { + test("rejects invalid --max-docs-per-batch", async () => { + const { stderr, exitCode } = await runQmd(["embed", "--max-docs-per-batch", "0"]); + expect(exitCode).toBe(1); + expect(stderr).toContain("maxDocsPerBatch"); + }); + + test("rejects invalid --max-batch-mb", async () => { + const { stderr, exitCode } = await runQmd(["embed", "--max-batch-mb", "0"]); + expect(exitCode).toBe(1); + expect(stderr).toContain("maxBatchBytes"); + }); +}); + describe("CLI Skill Commands", () => { test("shows embedded skill with --skill alias", async () => { const { stdout, exitCode } = await runQmd(["--skill"]); diff --git a/test/sdk.test.ts b/test/sdk.test.ts index d246bc4..6672316 100644 --- a/test/sdk.test.ts +++ b/test/sdk.test.ts @@ -22,6 +22,7 @@ import { type VectorSearchOptions, type ExpandQueryOptions, } from "../src/index.js"; +import { setDefaultLlamaCpp } from "../src/llm.js"; // ============================================================================= // Test Helpers @@ -924,6 +925,79 @@ describe("update", () => { }); }); +describe("embed", () => { + function createFakeTokenizer() { + return { + async tokenize(text: string) { + return new Array(Math.max(1, Math.ceil(text.length / 16))).fill(1); + }, + }; + } + + function createFakeEmbedLlm() { + const embedBatchCalls: string[][] = []; + return { + embedBatchCalls, + async embed(_text: string) { + return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" }; + }, + async embedBatch(texts: string[]) { + embedBatchCalls.push([...texts]); + return texts.map((_text, index) => ({ + embedding: [index + 1, index + 2, index + 3], + model: "fake-embed", + })); + }, + }; + } + + test("store.embed forwards batch limit options", async () => { + const store = await createStore({ + dbPath: freshDbPath(), + config: { + collections: { + docs: { path: docsDir, pattern: "**/*.md" }, + }, + }, + }); + + const fakeLlm = createFakeEmbedLlm(); + setDefaultLlamaCpp(createFakeTokenizer() as any); + store.internal.llm = fakeLlm as any; + + try { + await store.update(); + const result = await store.embed({ + maxDocsPerBatch: 1, + maxBatchBytes: 1024 * 1024, + }); + + expect(fakeLlm.embedBatchCalls).toHaveLength(3); + expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([1, 1, 1]); + expect(result.docsProcessed).toBe(3); + expect(result.chunksEmbedded).toBe(3); + } finally { + setDefaultLlamaCpp(null); + await store.close(); + } + }); + + test("store.embed rejects invalid batch limits", async () => { + const store = await createStore({ + dbPath: freshDbPath(), + config: { collections: {} }, + }); + + try { + await expect(store.embed({ maxDocsPerBatch: 0 })).rejects.toThrow("maxDocsPerBatch"); + await expect(store.embed({ maxBatchBytes: 0 })).rejects.toThrow("maxBatchBytes"); + } finally { + setDefaultLlamaCpp(null); + await store.close(); + } + }); +}); + // ============================================================================= // Lifecycle Tests // ============================================================================= diff --git a/test/store.test.ts b/test/store.test.ts index d64bc0d..c5755f8 100644 --- a/test/store.test.ts +++ b/test/store.test.ts @@ -14,7 +14,7 @@ import { tmpdir } from "node:os"; import { join } from "node:path"; import YAML from "yaml"; import * as llmModule from "../src/llm.js"; -import { disposeDefaultLlamaCpp } from "../src/llm.js"; +import { disposeDefaultLlamaCpp, setDefaultLlamaCpp } from "../src/llm.js"; import { createStore, verifySqliteVecLoaded, @@ -47,6 +47,7 @@ import { syncConfigToDb, STRONG_SIGNAL_MIN_SCORE, STRONG_SIGNAL_MIN_GAP, + generateEmbeddings, type Store, type DocumentResult, type SearchResult, @@ -2589,6 +2590,113 @@ describe("Edge Cases", () => { }); }); +describe("Embedding batching", () => { + function createFakeTokenizer() { + return { + async tokenize(text: string) { + return new Array(Math.max(1, Math.ceil(text.length / 16))).fill(1); + }, + }; + } + + function createFakeEmbedLlm() { + const embedBatchCalls: string[][] = []; + return { + embedBatchCalls, + async embed(_text: string) { + return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" }; + }, + async embedBatch(texts: string[]) { + embedBatchCalls.push([...texts]); + return texts.map((_text, index) => ({ + embedding: [index + 1, index + 2, index + 3], + model: "fake-embed", + })); + }, + }; + } + + test("generateEmbeddings flushes batches when maxDocsPerBatch is reached", async () => { + const store = await createTestStore(); + const db = store.db; + const fakeLlm = createFakeEmbedLlm(); + + setDefaultLlamaCpp(createFakeTokenizer() as any); + store.llm = fakeLlm as any; + + try { + await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" }); + await insertTestDocument(db, "docs", { name: "two", body: "# Two\n\nBeta" }); + await insertTestDocument(db, "docs", { name: "three", body: "# Three\n\nGamma" }); + + const result = await generateEmbeddings(store, { + maxDocsPerBatch: 1, + maxBatchBytes: 1024 * 1024, + }); + + expect(fakeLlm.embedBatchCalls).toHaveLength(3); + expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([1, 1, 1]); + expect(result.docsProcessed).toBe(3); + expect(result.chunksEmbedded).toBe(3); + expect(db.prepare(`SELECT COUNT(*) as count FROM content_vectors`).get()).toEqual({ count: 3 }); + } finally { + setDefaultLlamaCpp(null); + await cleanupTestDb(store); + } + }); + + test("generateEmbeddings flushes batches when maxBatchBytes is reached", async () => { + const store = await createTestStore(); + const db = store.db; + const fakeLlm = createFakeEmbedLlm(); + + setDefaultLlamaCpp(createFakeTokenizer() as any); + store.llm = fakeLlm as any; + + const docOne = "# One\n\n" + "A".repeat(36); + const docTwo = "# Two\n\n" + "B".repeat(36); + const docThree = "# Three\n\n" + "C".repeat(36); + const batchLimit = new TextEncoder().encode(docOne).length + + new TextEncoder().encode(docTwo).length + + 1; + + try { + await insertTestDocument(db, "docs", { name: "a-one", body: docOne }); + await insertTestDocument(db, "docs", { name: "b-two", body: docTwo }); + await insertTestDocument(db, "docs", { name: "c-three", body: docThree }); + + const result = await generateEmbeddings(store, { + maxDocsPerBatch: 64, + maxBatchBytes: batchLimit, + }); + + expect(fakeLlm.embedBatchCalls).toHaveLength(2); + expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([2, 1]); + expect(result.docsProcessed).toBe(3); + expect(result.chunksEmbedded).toBe(3); + } finally { + setDefaultLlamaCpp(null); + await cleanupTestDb(store); + } + }); + + test("generateEmbeddings rejects invalid batch limits", async () => { + const store = await createTestStore(); + + try { + await expect(generateEmbeddings(store, { maxDocsPerBatch: 0 })).rejects.toThrow( + "maxDocsPerBatch" + ); + await expect(generateEmbeddings(store, { maxBatchBytes: 0 })).rejects.toThrow( + "maxBatchBytes" + ); + } finally { + setDefaultLlamaCpp(null); + await cleanupTestDb(store); + } + }); +}); + // ============================================================================= // Content-Addressable Storage Tests // =============================================================================