Harden embedding overflow handling

This commit is contained in:
Bek 2026-04-10 16:02:46 -04:00
parent 171e9e3e65
commit e4990e470e
4 changed files with 132 additions and 37 deletions

View File

@ -889,20 +889,30 @@ export class LlamaCpp implements LLM {
* detokenizes back to text if truncation is needed.
* Returns the (possibly truncated) text and whether truncation occurred.
*/
private async truncateToContextSize(text: string): Promise<{ text: string; truncated: boolean }> {
if (!this.embedModel) return { text, truncated: false };
private resolveEmbedTokenLimit(): number {
const trainedContextSize = this.embedModel?.trainContextSize;
if (typeof trainedContextSize === "number" && Number.isFinite(trainedContextSize) && trainedContextSize > 0) {
return Math.max(1, Math.min(LlamaCpp.EMBED_CONTEXT_SIZE, trainedContextSize));
}
return LlamaCpp.EMBED_CONTEXT_SIZE;
}
const maxTokens = this.embedModel.trainContextSize;
if (maxTokens <= 0) return { text, truncated: false };
private async truncateToContextSize(
text: string
): Promise<{ text: string; truncated: boolean; limit: number }> {
if (!this.embedModel) return { text, truncated: false, limit: LlamaCpp.EMBED_CONTEXT_SIZE };
const maxTokens = this.resolveEmbedTokenLimit();
if (maxTokens <= 0) return { text, truncated: false, limit: maxTokens };
const tokens = this.embedModel.tokenize(text);
if (tokens.length <= maxTokens) return { text, truncated: false };
if (tokens.length <= maxTokens) return { text, truncated: false, limit: maxTokens };
// Leave a small margin (4 tokens) for BOS/EOS overhead
const safeLimit = Math.max(1, maxTokens - 4);
const truncatedTokens = tokens.slice(0, safeLimit);
const truncatedText = this.embedModel.detokenize(truncatedTokens);
return { text: truncatedText, truncated: true };
return { text: truncatedText, truncated: true, limit: maxTokens };
}
async embed(text: string, options: EmbedOptions = {}): Promise<EmbeddingResult | null> {
@ -913,9 +923,9 @@ export class LlamaCpp implements LLM {
const context = await this.ensureEmbedContext();
// Guard: truncate text that exceeds model context window to prevent GGML crash
const { text: safeText, truncated } = await this.truncateToContextSize(text);
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
if (truncated) {
console.warn(`⚠ Text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
console.warn(`⚠ Text truncated to fit embedding context (${limit} tokens)`);
}
const embedding = await context.getEmbeddingFor(safeText);
@ -951,9 +961,9 @@ export class LlamaCpp implements LLM {
const embeddings: ({ embedding: number[]; model: string } | null)[] = [];
for (const text of texts) {
try {
const { text: safeText, truncated } = await this.truncateToContextSize(text);
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
if (truncated) {
console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
}
const embedding = await context.getEmbeddingFor(safeText);
this.touchActivity();
@ -978,9 +988,9 @@ export class LlamaCpp implements LLM {
const results: (EmbeddingResult | null)[] = [];
for (const text of chunk) {
try {
const { text: safeText, truncated } = await this.truncateToContextSize(text);
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
if (truncated) {
console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
}
const embedding = await ctx.getEmbeddingFor(safeText);
this.touchActivity();

View File

@ -2228,33 +2228,67 @@ export async function chunkDocumentByTokens(
// Tokenize and split any chunks that still exceed limit
const results: { text: string; pos: number; tokens: number }[] = [];
const clampOverlapChars = (value: number, maxChars: number): number => {
if (maxChars <= 1) return 0;
return Math.max(0, Math.min(maxChars - 1, Math.floor(value)));
};
const pushChunkWithinTokenLimit = async (text: string, pos: number): Promise<void> => {
if (signal?.aborted) return;
const tokens = await llm.tokenize(text);
if (tokens.length <= maxTokens || text.length <= 1) {
results.push({ text, pos, tokens: tokens.length });
return;
}
const actualCharsPerToken = text.length / tokens.length;
let safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95);
if (!Number.isFinite(safeMaxChars) || safeMaxChars < 1) {
safeMaxChars = Math.floor(text.length / 2);
}
safeMaxChars = Math.max(1, Math.min(text.length - 1, safeMaxChars));
let nextOverlapChars = clampOverlapChars(
overlapChars * actualCharsPerToken / 2,
safeMaxChars,
);
let nextWindowChars = Math.max(0, Math.floor(windowChars * actualCharsPerToken / 2));
let subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
// Pathological single-line blobs can produce no meaningful breakpoint progress.
// Fall back to a simple half split so every recursion step strictly shrinks.
if (
subChunks.length <= 1
|| subChunks[0]?.text.length === text.length
) {
safeMaxChars = Math.max(1, Math.floor(text.length / 2));
nextOverlapChars = 0;
nextWindowChars = 0;
subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
}
if (
subChunks.length <= 1
|| subChunks[0]?.text.length === text.length
) {
const fallbackTokens = tokens.slice(0, Math.max(1, maxTokens));
const truncatedText = await llm.detokenize(fallbackTokens);
results.push({
text: truncatedText,
pos,
tokens: fallbackTokens.length,
});
return;
}
for (const subChunk of subChunks) {
await pushChunkWithinTokenLimit(text.slice(subChunk.pos, subChunk.pos + subChunk.text.length), pos + subChunk.pos);
}
};
for (const chunk of charChunks) {
// Respect abort signal to avoid runaway tokenization
if (signal?.aborted) break;
const tokens = await llm.tokenize(chunk.text);
if (tokens.length <= maxTokens) {
results.push({ text: chunk.text, pos: chunk.pos, tokens: tokens.length });
} else {
// Chunk is still too large - split it further
// Use actual token count to estimate better char limit
const actualCharsPerToken = chunk.text.length / tokens.length;
const safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95); // 5% safety margin
const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
for (const subChunk of subChunks) {
if (signal?.aborted) break;
const subTokens = await llm.tokenize(subChunk.text);
results.push({
text: subChunk.text,
pos: chunk.pos + subChunk.pos,
tokens: subTokens.length,
});
}
}
await pushChunkWithinTokenLimit(chunk.text, chunk.pos);
}
return results;

View File

@ -194,6 +194,32 @@ describe("LlamaCpp model resolution (config > env > default)", () => {
});
});
describe("LlamaCpp embedding truncation", () => {
test("truncates against the active embedding context limit, not the model train context", async () => {
const llm = new LlamaCpp({}) as any;
const getEmbeddingFor = vi.fn(async (text: string) => ({
vector: new Float32Array([0.25, 0.5]),
text,
}));
llm.touchActivity = vi.fn();
llm.embedModel = {
trainContextSize: 8192,
tokenize: (text: string) => Array.from({ length: text.length }, () => 1),
detokenize: (tokens: readonly number[]) => "x".repeat(tokens.length),
};
llm.ensureEmbedContext = vi.fn().mockResolvedValue({ getEmbeddingFor });
const result = await llm.embed("x".repeat(3000));
expect(getEmbeddingFor).toHaveBeenCalledWith("x".repeat(2044));
expect(result).toEqual({
embedding: [0.25, 0.5],
model: llm.embedModelUri,
});
});
});
describe("LlamaCpp rerank deduping", () => {
test("deduplicates identical document texts before scoring", async () => {
const llm = new LlamaCpp({}) as any;

View File

@ -2805,6 +2805,31 @@ describe("Embedding batching", () => {
});
});
describe("Token chunking guardrails", () => {
test("chunkDocumentByTokens keeps pathological single-line blobs under the token limit", async () => {
setDefaultLlamaCpp({
async tokenize(text: string) {
return Array.from({ length: text.length }, () => 1);
},
async detokenize(tokens: readonly number[]) {
return "x".repeat(tokens.length);
},
} as any);
try {
const chunks = await chunkDocumentByTokens("x".repeat(1200), 100, 15, 20);
expect(chunks.length).toBeGreaterThan(1);
expect(chunks.every((chunk) => chunk.tokens <= 100)).toBe(true);
for (let i = 1; i < chunks.length; i++) {
expect(chunks[i]!.pos).toBeGreaterThan(chunks[i - 1]!.pos);
}
} finally {
setDefaultLlamaCpp(null);
}
});
});
// =============================================================================
// Content-Addressable Storage Tests
// =============================================================================