Merge pull request #395 from ProgramCaiCai/fix/embed-batching-memory
Bound qmd embed memory usage with default batched processing
This commit is contained in:
commit
7ab1497ebb
@ -25,7 +25,6 @@ import {
|
||||
isDocid,
|
||||
matchFilesByGlob,
|
||||
getHashesNeedingEmbedding,
|
||||
getHashesForEmbedding,
|
||||
clearAllEmbeddings,
|
||||
insertEmbedding,
|
||||
getStatus,
|
||||
@ -65,6 +64,8 @@ import {
|
||||
type ExpandedQuery,
|
||||
type HybridQueryExplain,
|
||||
DEFAULT_EMBED_MODEL,
|
||||
DEFAULT_EMBED_MAX_BATCH_BYTES,
|
||||
DEFAULT_EMBED_MAX_DOCS_PER_BATCH,
|
||||
DEFAULT_RERANK_MODEL,
|
||||
DEFAULT_GLOB,
|
||||
DEFAULT_MULTI_GET_MAX_BYTES,
|
||||
@ -1607,7 +1608,20 @@ function renderProgressBar(percent: number, width: number = 30): string {
|
||||
return bar;
|
||||
}
|
||||
|
||||
async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean = false): Promise<void> {
|
||||
function parseEmbedBatchOption(name: string, value: unknown): number | undefined {
|
||||
if (value === undefined) return undefined;
|
||||
const parsed = Number(value);
|
||||
if (!Number.isInteger(parsed) || parsed < 1) {
|
||||
throw new Error(`${name} must be a positive integer`);
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
async function vectorIndex(
|
||||
model: string = DEFAULT_EMBED_MODEL,
|
||||
force: boolean = false,
|
||||
batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number },
|
||||
): Promise<void> {
|
||||
const storeInstance = getStore();
|
||||
const db = storeInstance.db;
|
||||
|
||||
@ -1616,14 +1630,19 @@ async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean =
|
||||
}
|
||||
|
||||
// Check if there's work to do before starting
|
||||
const hashesToEmbed = getHashesForEmbedding(db);
|
||||
if (hashesToEmbed.length === 0 && !force) {
|
||||
const hashesToEmbed = getHashesNeedingEmbedding(db);
|
||||
if (hashesToEmbed === 0 && !force) {
|
||||
console.log(`${c.green}✓ All content hashes already have embeddings.${c.reset}`);
|
||||
closeDb();
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`${c.dim}Model: ${model}${c.reset}\n`);
|
||||
if (batchOptions?.maxDocsPerBatch !== undefined || batchOptions?.maxBatchBytes !== undefined) {
|
||||
const maxDocsPerBatch = batchOptions.maxDocsPerBatch ?? DEFAULT_EMBED_MAX_DOCS_PER_BATCH;
|
||||
const maxBatchBytes = batchOptions.maxBatchBytes ?? DEFAULT_EMBED_MAX_BATCH_BYTES;
|
||||
console.log(`${c.dim}Batch: ${maxDocsPerBatch} docs / ${formatBytes(maxBatchBytes)}${c.reset}\n`);
|
||||
}
|
||||
cursor.hide();
|
||||
progress.indeterminate();
|
||||
|
||||
@ -1632,6 +1651,8 @@ async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean =
|
||||
const result = await generateEmbeddings(storeInstance, {
|
||||
force,
|
||||
model,
|
||||
maxDocsPerBatch: batchOptions?.maxDocsPerBatch,
|
||||
maxBatchBytes: batchOptions?.maxBatchBytes,
|
||||
onProgress: (info) => {
|
||||
if (info.totalBytes === 0) return;
|
||||
const percent = (info.bytesProcessed / info.totalBytes) * 100;
|
||||
@ -2334,6 +2355,8 @@ function parseCLI() {
|
||||
mask: { type: "string" }, // glob pattern
|
||||
// Embed options
|
||||
force: { type: "boolean", short: "f" },
|
||||
"max-docs-per-batch": { type: "string" },
|
||||
"max-batch-mb": { type: "string" },
|
||||
// Update options
|
||||
pull: { type: "boolean" }, // git pull before update
|
||||
refresh: { type: "boolean" },
|
||||
@ -2547,6 +2570,8 @@ function showHelp(): void {
|
||||
console.log(" qmd status - View index + collection health");
|
||||
console.log(" qmd update [--pull] - Re-index collections (optionally git pull first)");
|
||||
console.log(" qmd embed [-f] - Generate/refresh vector embeddings");
|
||||
console.log(" --max-docs-per-batch <n> - Cap docs loaded into memory per embedding batch");
|
||||
console.log(" --max-batch-mb <n> - Cap UTF-8 MB loaded into memory per embedding batch");
|
||||
console.log(" qmd cleanup - Clear caches, vacuum DB");
|
||||
console.log("");
|
||||
console.log("Query syntax (qmd query):");
|
||||
@ -2923,7 +2948,17 @@ if (isMain) {
|
||||
break;
|
||||
|
||||
case "embed":
|
||||
await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force);
|
||||
try {
|
||||
const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]);
|
||||
const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]);
|
||||
await vectorIndex(DEFAULT_EMBED_MODEL, !!cli.values.force, {
|
||||
maxDocsPerBatch,
|
||||
maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(error instanceof Error ? error.message : String(error));
|
||||
process.exit(1);
|
||||
}
|
||||
break;
|
||||
|
||||
case "pull": {
|
||||
|
||||
@ -286,6 +286,8 @@ export interface QMDStore {
|
||||
embed(options?: {
|
||||
force?: boolean;
|
||||
model?: string;
|
||||
maxDocsPerBatch?: number;
|
||||
maxBatchBytes?: number;
|
||||
onProgress?: (info: EmbedProgress) => void;
|
||||
}): Promise<EmbedResult>;
|
||||
|
||||
@ -502,6 +504,8 @@ export async function createStore(options: StoreOptions): Promise<QMDStore> {
|
||||
return generateEmbeddings(internal, {
|
||||
force: embedOpts?.force,
|
||||
model: embedOpts?.model,
|
||||
maxDocsPerBatch: embedOpts?.maxDocsPerBatch,
|
||||
maxBatchBytes: embedOpts?.maxBatchBytes,
|
||||
onProgress: embedOpts?.onProgress,
|
||||
});
|
||||
},
|
||||
|
||||
285
src/store.ts
285
src/store.ts
@ -24,7 +24,6 @@ import {
|
||||
formatQueryForEmbedding,
|
||||
formatDocForEmbedding,
|
||||
withLLMSessionForLlm,
|
||||
type LLMSessionOptions,
|
||||
type RerankDocument,
|
||||
type ILLMSession,
|
||||
} from "./llm.js";
|
||||
@ -45,6 +44,8 @@ export const DEFAULT_RERANK_MODEL = "ExpedientFalcon/qwen3-reranker:0.6b-q8_0";
|
||||
export const DEFAULT_QUERY_MODEL = "Qwen/Qwen3-1.7B";
|
||||
export const DEFAULT_GLOB = "**/*.md";
|
||||
export const DEFAULT_MULTI_GET_MAX_BYTES = 10 * 1024; // 10KB
|
||||
export const DEFAULT_EMBED_MAX_DOCS_PER_BATCH = 64;
|
||||
export const DEFAULT_EMBED_MAX_BATCH_BYTES = 64 * 1024 * 1024; // 64MB
|
||||
|
||||
// Chunking: 900 tokens per chunk with 15% overlap
|
||||
// Increased from 800 to accommodate smart chunking finding natural break points
|
||||
@ -1179,6 +1180,109 @@ export type EmbedResult = {
|
||||
durationMs: number;
|
||||
};
|
||||
|
||||
export type EmbedOptions = {
|
||||
force?: boolean;
|
||||
model?: string;
|
||||
maxDocsPerBatch?: number;
|
||||
maxBatchBytes?: number;
|
||||
onProgress?: (info: EmbedProgress) => void;
|
||||
};
|
||||
|
||||
type PendingEmbeddingDoc = {
|
||||
hash: string;
|
||||
path: string;
|
||||
bytes: number;
|
||||
};
|
||||
|
||||
type EmbeddingDoc = PendingEmbeddingDoc & {
|
||||
body: string;
|
||||
};
|
||||
|
||||
type ChunkItem = {
|
||||
hash: string;
|
||||
title: string;
|
||||
text: string;
|
||||
seq: number;
|
||||
pos: number;
|
||||
tokens: number;
|
||||
bytes: number;
|
||||
};
|
||||
|
||||
function validatePositiveIntegerOption(name: string, value: number | undefined, fallback: number): number {
|
||||
if (value === undefined) return fallback;
|
||||
if (!Number.isInteger(value) || value < 1) {
|
||||
throw new Error(`${name} must be a positive integer`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
function resolveEmbedOptions(options?: EmbedOptions): Required<Pick<EmbedOptions, "maxDocsPerBatch" | "maxBatchBytes">> {
|
||||
return {
|
||||
maxDocsPerBatch: validatePositiveIntegerOption("maxDocsPerBatch", options?.maxDocsPerBatch, DEFAULT_EMBED_MAX_DOCS_PER_BATCH),
|
||||
maxBatchBytes: validatePositiveIntegerOption("maxBatchBytes", options?.maxBatchBytes, DEFAULT_EMBED_MAX_BATCH_BYTES),
|
||||
};
|
||||
}
|
||||
|
||||
function getPendingEmbeddingDocs(db: Database): PendingEmbeddingDoc[] {
|
||||
return db.prepare(`
|
||||
SELECT d.hash, MIN(d.path) as path, length(CAST(c.doc AS BLOB)) as bytes
|
||||
FROM documents d
|
||||
JOIN content c ON d.hash = c.hash
|
||||
LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0
|
||||
WHERE d.active = 1 AND v.hash IS NULL
|
||||
GROUP BY d.hash
|
||||
ORDER BY MIN(d.path)
|
||||
`).all() as PendingEmbeddingDoc[];
|
||||
}
|
||||
|
||||
function buildEmbeddingBatches(
|
||||
docs: PendingEmbeddingDoc[],
|
||||
maxDocsPerBatch: number,
|
||||
maxBatchBytes: number,
|
||||
): PendingEmbeddingDoc[][] {
|
||||
const batches: PendingEmbeddingDoc[][] = [];
|
||||
let currentBatch: PendingEmbeddingDoc[] = [];
|
||||
let currentBytes = 0;
|
||||
|
||||
for (const doc of docs) {
|
||||
const docBytes = Math.max(0, doc.bytes);
|
||||
const wouldExceedDocs = currentBatch.length >= maxDocsPerBatch;
|
||||
const wouldExceedBytes = currentBatch.length > 0 && (currentBytes + docBytes) > maxBatchBytes;
|
||||
|
||||
if (wouldExceedDocs || wouldExceedBytes) {
|
||||
batches.push(currentBatch);
|
||||
currentBatch = [];
|
||||
currentBytes = 0;
|
||||
}
|
||||
|
||||
currentBatch.push(doc);
|
||||
currentBytes += docBytes;
|
||||
}
|
||||
|
||||
if (currentBatch.length > 0) {
|
||||
batches.push(currentBatch);
|
||||
}
|
||||
|
||||
return batches;
|
||||
}
|
||||
|
||||
function getEmbeddingDocsForBatch(db: Database, batch: PendingEmbeddingDoc[]): EmbeddingDoc[] {
|
||||
if (batch.length === 0) return [];
|
||||
|
||||
const placeholders = batch.map(() => "?").join(",");
|
||||
const rows = db.prepare(`
|
||||
SELECT hash, doc as body
|
||||
FROM content
|
||||
WHERE hash IN (${placeholders})
|
||||
`).all(...batch.map(doc => doc.hash)) as { hash: string; body: string }[];
|
||||
const bodyByHash = new Map(rows.map(row => [row.hash, row.body]));
|
||||
|
||||
return batch.map((doc) => ({
|
||||
...doc,
|
||||
body: bodyByHash.get(doc.hash) ?? "",
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate vector embeddings for documents that need them.
|
||||
* Pure function — no console output, no db lifecycle management.
|
||||
@ -1186,120 +1290,141 @@ export type EmbedResult = {
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
store: Store,
|
||||
options?: {
|
||||
force?: boolean;
|
||||
model?: string;
|
||||
onProgress?: (info: EmbedProgress) => void;
|
||||
}
|
||||
options?: EmbedOptions
|
||||
): Promise<EmbedResult> {
|
||||
const db = store.db;
|
||||
const model = options?.model ?? DEFAULT_EMBED_MODEL;
|
||||
const now = new Date().toISOString();
|
||||
const { maxDocsPerBatch, maxBatchBytes } = resolveEmbedOptions(options);
|
||||
const encoder = new TextEncoder();
|
||||
|
||||
if (options?.force) {
|
||||
clearAllEmbeddings(db);
|
||||
}
|
||||
|
||||
const hashesToEmbed = getHashesForEmbedding(db);
|
||||
const docsToEmbed = getPendingEmbeddingDocs(db);
|
||||
|
||||
if (hashesToEmbed.length === 0) {
|
||||
if (docsToEmbed.length === 0) {
|
||||
return { docsProcessed: 0, chunksEmbedded: 0, errors: 0, durationMs: 0 };
|
||||
}
|
||||
|
||||
// Chunk all documents
|
||||
type ChunkItem = { hash: string; title: string; text: string; seq: number; pos: number; tokens: number; bytes: number };
|
||||
const allChunks: ChunkItem[] = [];
|
||||
|
||||
for (const item of hashesToEmbed) {
|
||||
const encoder = new TextEncoder();
|
||||
const bodyBytes = encoder.encode(item.body).length;
|
||||
if (bodyBytes === 0) continue;
|
||||
|
||||
const title = extractTitle(item.body, item.path);
|
||||
const chunks = await chunkDocumentByTokens(item.body);
|
||||
|
||||
for (let seq = 0; seq < chunks.length; seq++) {
|
||||
allChunks.push({
|
||||
hash: item.hash,
|
||||
title,
|
||||
text: chunks[seq]!.text,
|
||||
seq,
|
||||
pos: chunks[seq]!.pos,
|
||||
tokens: chunks[seq]!.tokens,
|
||||
bytes: encoder.encode(chunks[seq]!.text).length,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (allChunks.length === 0) {
|
||||
return { docsProcessed: 0, chunksEmbedded: 0, errors: 0, durationMs: 0 };
|
||||
}
|
||||
|
||||
const totalBytes = allChunks.reduce((sum, chk) => sum + chk.bytes, 0);
|
||||
const totalChunks = allChunks.length;
|
||||
const totalDocs = hashesToEmbed.length;
|
||||
const totalBytes = docsToEmbed.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0);
|
||||
const totalDocs = docsToEmbed.length;
|
||||
const startTime = Date.now();
|
||||
|
||||
// Use store's LlamaCpp or global singleton, wrapped in a session
|
||||
const llm = getLlm(store);
|
||||
const sessionOptions: LLMSessionOptions = { maxDuration: 30 * 60 * 1000, name: 'generateEmbeddings' };
|
||||
|
||||
// Create a session manager for this llm instance
|
||||
const result = await withLLMSessionForLlm(llm, async (session) => {
|
||||
// Get embedding dimensions from first chunk
|
||||
const firstChunk = allChunks[0]!;
|
||||
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title);
|
||||
const firstResult = await session.embed(firstText);
|
||||
if (!firstResult) {
|
||||
throw new Error("Failed to get embedding dimensions from first chunk");
|
||||
}
|
||||
store.ensureVecTable(firstResult.embedding.length);
|
||||
|
||||
let chunksEmbedded = 0, errors = 0, bytesProcessed = 0;
|
||||
let chunksEmbedded = 0;
|
||||
let errors = 0;
|
||||
let bytesProcessed = 0;
|
||||
let totalChunks = 0;
|
||||
let vectorTableInitialized = false;
|
||||
const BATCH_SIZE = 32;
|
||||
const batches = buildEmbeddingBatches(docsToEmbed, maxDocsPerBatch, maxBatchBytes);
|
||||
|
||||
for (let batchStart = 0; batchStart < allChunks.length; batchStart += BATCH_SIZE) {
|
||||
const batchEnd = Math.min(batchStart + BATCH_SIZE, allChunks.length);
|
||||
const batch = allChunks.slice(batchStart, batchEnd);
|
||||
const texts = batch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
||||
for (const batchMeta of batches) {
|
||||
const batchDocs = getEmbeddingDocsForBatch(db, batchMeta);
|
||||
const batchChunks: ChunkItem[] = [];
|
||||
const batchBytes = batchMeta.reduce((sum, doc) => sum + Math.max(0, doc.bytes), 0);
|
||||
|
||||
try {
|
||||
const embeddings = await session.embedBatch(texts);
|
||||
for (let i = 0; i < batch.length; i++) {
|
||||
const chunk = batch[i]!;
|
||||
const embedding = embeddings[i];
|
||||
if (embedding) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(embedding.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
} else {
|
||||
errors++;
|
||||
}
|
||||
bytesProcessed += chunk.bytes;
|
||||
for (const doc of batchDocs) {
|
||||
if (!doc.body.trim()) continue;
|
||||
|
||||
const title = extractTitle(doc.body, doc.path);
|
||||
const chunks = await chunkDocumentByTokens(doc.body);
|
||||
|
||||
for (let seq = 0; seq < chunks.length; seq++) {
|
||||
batchChunks.push({
|
||||
hash: doc.hash,
|
||||
title,
|
||||
text: chunks[seq]!.text,
|
||||
seq,
|
||||
pos: chunks[seq]!.pos,
|
||||
tokens: chunks[seq]!.tokens,
|
||||
bytes: encoder.encode(chunks[seq]!.text).length,
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
// Batch failed — try individual embeddings as fallback
|
||||
for (const chunk of batch) {
|
||||
try {
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||
const result = await session.embed(text);
|
||||
if (result) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||
}
|
||||
|
||||
totalChunks += batchChunks.length;
|
||||
|
||||
if (batchChunks.length === 0) {
|
||||
bytesProcessed += batchBytes;
|
||||
options?.onProgress?.({ chunksEmbedded, totalChunks, bytesProcessed, totalBytes, errors });
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!vectorTableInitialized) {
|
||||
const firstChunk = batchChunks[0]!;
|
||||
const firstText = formatDocForEmbedding(firstChunk.text, firstChunk.title);
|
||||
const firstResult = await session.embed(firstText);
|
||||
if (!firstResult) {
|
||||
throw new Error("Failed to get embedding dimensions from first chunk");
|
||||
}
|
||||
store.ensureVecTable(firstResult.embedding.length);
|
||||
vectorTableInitialized = true;
|
||||
}
|
||||
|
||||
const totalBatchChunkBytes = batchChunks.reduce((sum, chunk) => sum + chunk.bytes, 0);
|
||||
let batchChunkBytesProcessed = 0;
|
||||
|
||||
for (let batchStart = 0; batchStart < batchChunks.length; batchStart += BATCH_SIZE) {
|
||||
const batchEnd = Math.min(batchStart + BATCH_SIZE, batchChunks.length);
|
||||
const chunkBatch = batchChunks.slice(batchStart, batchEnd);
|
||||
const texts = chunkBatch.map(chunk => formatDocForEmbedding(chunk.text, chunk.title));
|
||||
|
||||
try {
|
||||
const embeddings = await session.embedBatch(texts);
|
||||
for (let i = 0; i < chunkBatch.length; i++) {
|
||||
const chunk = chunkBatch[i]!;
|
||||
const embedding = embeddings[i];
|
||||
if (embedding) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(embedding.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
} else {
|
||||
errors++;
|
||||
}
|
||||
} catch {
|
||||
errors++;
|
||||
batchChunkBytesProcessed += chunk.bytes;
|
||||
}
|
||||
} catch {
|
||||
// Batch failed — try individual embeddings as fallback
|
||||
for (const chunk of chunkBatch) {
|
||||
try {
|
||||
const text = formatDocForEmbedding(chunk.text, chunk.title);
|
||||
const result = await session.embed(text);
|
||||
if (result) {
|
||||
insertEmbedding(db, chunk.hash, chunk.seq, chunk.pos, new Float32Array(result.embedding), model, now);
|
||||
chunksEmbedded++;
|
||||
} else {
|
||||
errors++;
|
||||
}
|
||||
} catch {
|
||||
errors++;
|
||||
}
|
||||
batchChunkBytesProcessed += chunk.bytes;
|
||||
}
|
||||
bytesProcessed += chunk.bytes;
|
||||
}
|
||||
|
||||
const proportionalBytes = totalBatchChunkBytes === 0
|
||||
? batchBytes
|
||||
: Math.min(batchBytes, Math.round((batchChunkBytesProcessed / totalBatchChunkBytes) * batchBytes));
|
||||
options?.onProgress?.({
|
||||
chunksEmbedded,
|
||||
totalChunks,
|
||||
bytesProcessed: bytesProcessed + proportionalBytes,
|
||||
totalBytes,
|
||||
errors,
|
||||
});
|
||||
}
|
||||
|
||||
bytesProcessed += batchBytes;
|
||||
options?.onProgress?.({ chunksEmbedded, totalChunks, bytesProcessed, totalBytes, errors });
|
||||
}
|
||||
|
||||
return { chunksEmbedded, errors };
|
||||
}, sessionOptions);
|
||||
}, { maxDuration: 30 * 60 * 1000, name: 'generateEmbeddings' });
|
||||
|
||||
return {
|
||||
docsProcessed: totalDocs,
|
||||
|
||||
@ -241,6 +241,20 @@ describe("CLI Help", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("CLI Embed", () => {
|
||||
test("rejects invalid --max-docs-per-batch", async () => {
|
||||
const { stderr, exitCode } = await runQmd(["embed", "--max-docs-per-batch", "0"]);
|
||||
expect(exitCode).toBe(1);
|
||||
expect(stderr).toContain("maxDocsPerBatch");
|
||||
});
|
||||
|
||||
test("rejects invalid --max-batch-mb", async () => {
|
||||
const { stderr, exitCode } = await runQmd(["embed", "--max-batch-mb", "0"]);
|
||||
expect(exitCode).toBe(1);
|
||||
expect(stderr).toContain("maxBatchBytes");
|
||||
});
|
||||
});
|
||||
|
||||
describe("CLI Skill Commands", () => {
|
||||
test("shows embedded skill with --skill alias", async () => {
|
||||
const { stdout, exitCode } = await runQmd(["--skill"]);
|
||||
|
||||
@ -22,6 +22,7 @@ import {
|
||||
type VectorSearchOptions,
|
||||
type ExpandQueryOptions,
|
||||
} from "../src/index.js";
|
||||
import { setDefaultLlamaCpp } from "../src/llm.js";
|
||||
|
||||
// =============================================================================
|
||||
// Test Helpers
|
||||
@ -924,6 +925,79 @@ describe("update", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("embed", () => {
|
||||
function createFakeTokenizer() {
|
||||
return {
|
||||
async tokenize(text: string) {
|
||||
return new Array(Math.max(1, Math.ceil(text.length / 16))).fill(1);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createFakeEmbedLlm() {
|
||||
const embedBatchCalls: string[][] = [];
|
||||
return {
|
||||
embedBatchCalls,
|
||||
async embed(_text: string) {
|
||||
return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" };
|
||||
},
|
||||
async embedBatch(texts: string[]) {
|
||||
embedBatchCalls.push([...texts]);
|
||||
return texts.map((_text, index) => ({
|
||||
embedding: [index + 1, index + 2, index + 3],
|
||||
model: "fake-embed",
|
||||
}));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
test("store.embed forwards batch limit options", async () => {
|
||||
const store = await createStore({
|
||||
dbPath: freshDbPath(),
|
||||
config: {
|
||||
collections: {
|
||||
docs: { path: docsDir, pattern: "**/*.md" },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const fakeLlm = createFakeEmbedLlm();
|
||||
setDefaultLlamaCpp(createFakeTokenizer() as any);
|
||||
store.internal.llm = fakeLlm as any;
|
||||
|
||||
try {
|
||||
await store.update();
|
||||
const result = await store.embed({
|
||||
maxDocsPerBatch: 1,
|
||||
maxBatchBytes: 1024 * 1024,
|
||||
});
|
||||
|
||||
expect(fakeLlm.embedBatchCalls).toHaveLength(3);
|
||||
expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([1, 1, 1]);
|
||||
expect(result.docsProcessed).toBe(3);
|
||||
expect(result.chunksEmbedded).toBe(3);
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await store.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("store.embed rejects invalid batch limits", async () => {
|
||||
const store = await createStore({
|
||||
dbPath: freshDbPath(),
|
||||
config: { collections: {} },
|
||||
});
|
||||
|
||||
try {
|
||||
await expect(store.embed({ maxDocsPerBatch: 0 })).rejects.toThrow("maxDocsPerBatch");
|
||||
await expect(store.embed({ maxBatchBytes: 0 })).rejects.toThrow("maxBatchBytes");
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await store.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// =============================================================================
|
||||
// Lifecycle Tests
|
||||
// =============================================================================
|
||||
|
||||
@ -14,7 +14,7 @@ import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import YAML from "yaml";
|
||||
import * as llmModule from "../src/llm.js";
|
||||
import { disposeDefaultLlamaCpp } from "../src/llm.js";
|
||||
import { disposeDefaultLlamaCpp, setDefaultLlamaCpp } from "../src/llm.js";
|
||||
import {
|
||||
createStore,
|
||||
verifySqliteVecLoaded,
|
||||
@ -47,6 +47,7 @@ import {
|
||||
syncConfigToDb,
|
||||
STRONG_SIGNAL_MIN_SCORE,
|
||||
STRONG_SIGNAL_MIN_GAP,
|
||||
generateEmbeddings,
|
||||
type Store,
|
||||
type DocumentResult,
|
||||
type SearchResult,
|
||||
@ -2589,6 +2590,113 @@ describe("Edge Cases", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Embedding batching", () => {
|
||||
function createFakeTokenizer() {
|
||||
return {
|
||||
async tokenize(text: string) {
|
||||
return new Array(Math.max(1, Math.ceil(text.length / 16))).fill(1);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function createFakeEmbedLlm() {
|
||||
const embedBatchCalls: string[][] = [];
|
||||
return {
|
||||
embedBatchCalls,
|
||||
async embed(_text: string) {
|
||||
return { embedding: [0.1, 0.2, 0.3], model: "fake-embed" };
|
||||
},
|
||||
async embedBatch(texts: string[]) {
|
||||
embedBatchCalls.push([...texts]);
|
||||
return texts.map((_text, index) => ({
|
||||
embedding: [index + 1, index + 2, index + 3],
|
||||
model: "fake-embed",
|
||||
}));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
test("generateEmbeddings flushes batches when maxDocsPerBatch is reached", async () => {
|
||||
const store = await createTestStore();
|
||||
const db = store.db;
|
||||
const fakeLlm = createFakeEmbedLlm();
|
||||
|
||||
setDefaultLlamaCpp(createFakeTokenizer() as any);
|
||||
store.llm = fakeLlm as any;
|
||||
|
||||
try {
|
||||
await insertTestDocument(db, "docs", { name: "one", body: "# One\n\nAlpha" });
|
||||
await insertTestDocument(db, "docs", { name: "two", body: "# Two\n\nBeta" });
|
||||
await insertTestDocument(db, "docs", { name: "three", body: "# Three\n\nGamma" });
|
||||
|
||||
const result = await generateEmbeddings(store, {
|
||||
maxDocsPerBatch: 1,
|
||||
maxBatchBytes: 1024 * 1024,
|
||||
});
|
||||
|
||||
expect(fakeLlm.embedBatchCalls).toHaveLength(3);
|
||||
expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([1, 1, 1]);
|
||||
expect(result.docsProcessed).toBe(3);
|
||||
expect(result.chunksEmbedded).toBe(3);
|
||||
expect(db.prepare(`SELECT COUNT(*) as count FROM content_vectors`).get()).toEqual({ count: 3 });
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await cleanupTestDb(store);
|
||||
}
|
||||
});
|
||||
|
||||
test("generateEmbeddings flushes batches when maxBatchBytes is reached", async () => {
|
||||
const store = await createTestStore();
|
||||
const db = store.db;
|
||||
const fakeLlm = createFakeEmbedLlm();
|
||||
|
||||
setDefaultLlamaCpp(createFakeTokenizer() as any);
|
||||
store.llm = fakeLlm as any;
|
||||
|
||||
const docOne = "# One\n\n" + "A".repeat(36);
|
||||
const docTwo = "# Two\n\n" + "B".repeat(36);
|
||||
const docThree = "# Three\n\n" + "C".repeat(36);
|
||||
const batchLimit = new TextEncoder().encode(docOne).length
|
||||
+ new TextEncoder().encode(docTwo).length
|
||||
+ 1;
|
||||
|
||||
try {
|
||||
await insertTestDocument(db, "docs", { name: "a-one", body: docOne });
|
||||
await insertTestDocument(db, "docs", { name: "b-two", body: docTwo });
|
||||
await insertTestDocument(db, "docs", { name: "c-three", body: docThree });
|
||||
|
||||
const result = await generateEmbeddings(store, {
|
||||
maxDocsPerBatch: 64,
|
||||
maxBatchBytes: batchLimit,
|
||||
});
|
||||
|
||||
expect(fakeLlm.embedBatchCalls).toHaveLength(2);
|
||||
expect(fakeLlm.embedBatchCalls.map(call => call.length)).toEqual([2, 1]);
|
||||
expect(result.docsProcessed).toBe(3);
|
||||
expect(result.chunksEmbedded).toBe(3);
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await cleanupTestDb(store);
|
||||
}
|
||||
});
|
||||
|
||||
test("generateEmbeddings rejects invalid batch limits", async () => {
|
||||
const store = await createTestStore();
|
||||
|
||||
try {
|
||||
await expect(generateEmbeddings(store, { maxDocsPerBatch: 0 })).rejects.toThrow(
|
||||
"maxDocsPerBatch"
|
||||
);
|
||||
await expect(generateEmbeddings(store, { maxBatchBytes: 0 })).rejects.toThrow(
|
||||
"maxBatchBytes"
|
||||
);
|
||||
} finally {
|
||||
setDefaultLlamaCpp(null);
|
||||
await cleanupTestDb(store);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// =============================================================================
|
||||
// Content-Addressable Storage Tests
|
||||
// =============================================================================
|
||||
|
||||
Loading…
Reference in New Issue
Block a user