Harden embedding overflow handling
This commit is contained in:
parent
171e9e3e65
commit
e4990e470e
34
src/llm.ts
34
src/llm.ts
@ -889,20 +889,30 @@ export class LlamaCpp implements LLM {
|
|||||||
* detokenizes back to text if truncation is needed.
|
* detokenizes back to text if truncation is needed.
|
||||||
* Returns the (possibly truncated) text and whether truncation occurred.
|
* Returns the (possibly truncated) text and whether truncation occurred.
|
||||||
*/
|
*/
|
||||||
private async truncateToContextSize(text: string): Promise<{ text: string; truncated: boolean }> {
|
private resolveEmbedTokenLimit(): number {
|
||||||
if (!this.embedModel) return { text, truncated: false };
|
const trainedContextSize = this.embedModel?.trainContextSize;
|
||||||
|
if (typeof trainedContextSize === "number" && Number.isFinite(trainedContextSize) && trainedContextSize > 0) {
|
||||||
|
return Math.max(1, Math.min(LlamaCpp.EMBED_CONTEXT_SIZE, trainedContextSize));
|
||||||
|
}
|
||||||
|
return LlamaCpp.EMBED_CONTEXT_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
const maxTokens = this.embedModel.trainContextSize;
|
private async truncateToContextSize(
|
||||||
if (maxTokens <= 0) return { text, truncated: false };
|
text: string
|
||||||
|
): Promise<{ text: string; truncated: boolean; limit: number }> {
|
||||||
|
if (!this.embedModel) return { text, truncated: false, limit: LlamaCpp.EMBED_CONTEXT_SIZE };
|
||||||
|
|
||||||
|
const maxTokens = this.resolveEmbedTokenLimit();
|
||||||
|
if (maxTokens <= 0) return { text, truncated: false, limit: maxTokens };
|
||||||
|
|
||||||
const tokens = this.embedModel.tokenize(text);
|
const tokens = this.embedModel.tokenize(text);
|
||||||
if (tokens.length <= maxTokens) return { text, truncated: false };
|
if (tokens.length <= maxTokens) return { text, truncated: false, limit: maxTokens };
|
||||||
|
|
||||||
// Leave a small margin (4 tokens) for BOS/EOS overhead
|
// Leave a small margin (4 tokens) for BOS/EOS overhead
|
||||||
const safeLimit = Math.max(1, maxTokens - 4);
|
const safeLimit = Math.max(1, maxTokens - 4);
|
||||||
const truncatedTokens = tokens.slice(0, safeLimit);
|
const truncatedTokens = tokens.slice(0, safeLimit);
|
||||||
const truncatedText = this.embedModel.detokenize(truncatedTokens);
|
const truncatedText = this.embedModel.detokenize(truncatedTokens);
|
||||||
return { text: truncatedText, truncated: true };
|
return { text: truncatedText, truncated: true, limit: maxTokens };
|
||||||
}
|
}
|
||||||
|
|
||||||
async embed(text: string, options: EmbedOptions = {}): Promise<EmbeddingResult | null> {
|
async embed(text: string, options: EmbedOptions = {}): Promise<EmbeddingResult | null> {
|
||||||
@ -913,9 +923,9 @@ export class LlamaCpp implements LLM {
|
|||||||
const context = await this.ensureEmbedContext();
|
const context = await this.ensureEmbedContext();
|
||||||
|
|
||||||
// Guard: truncate text that exceeds model context window to prevent GGML crash
|
// Guard: truncate text that exceeds model context window to prevent GGML crash
|
||||||
const { text: safeText, truncated } = await this.truncateToContextSize(text);
|
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
|
||||||
if (truncated) {
|
if (truncated) {
|
||||||
console.warn(`⚠ Text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
|
console.warn(`⚠ Text truncated to fit embedding context (${limit} tokens)`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const embedding = await context.getEmbeddingFor(safeText);
|
const embedding = await context.getEmbeddingFor(safeText);
|
||||||
@ -951,9 +961,9 @@ export class LlamaCpp implements LLM {
|
|||||||
const embeddings: ({ embedding: number[]; model: string } | null)[] = [];
|
const embeddings: ({ embedding: number[]; model: string } | null)[] = [];
|
||||||
for (const text of texts) {
|
for (const text of texts) {
|
||||||
try {
|
try {
|
||||||
const { text: safeText, truncated } = await this.truncateToContextSize(text);
|
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
|
||||||
if (truncated) {
|
if (truncated) {
|
||||||
console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
|
console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
|
||||||
}
|
}
|
||||||
const embedding = await context.getEmbeddingFor(safeText);
|
const embedding = await context.getEmbeddingFor(safeText);
|
||||||
this.touchActivity();
|
this.touchActivity();
|
||||||
@ -978,9 +988,9 @@ export class LlamaCpp implements LLM {
|
|||||||
const results: (EmbeddingResult | null)[] = [];
|
const results: (EmbeddingResult | null)[] = [];
|
||||||
for (const text of chunk) {
|
for (const text of chunk) {
|
||||||
try {
|
try {
|
||||||
const { text: safeText, truncated } = await this.truncateToContextSize(text);
|
const { text: safeText, truncated, limit } = await this.truncateToContextSize(text);
|
||||||
if (truncated) {
|
if (truncated) {
|
||||||
console.warn(`⚠ Batch text truncated to fit embedding context (${this.embedModel?.trainContextSize} tokens)`);
|
console.warn(`⚠ Batch text truncated to fit embedding context (${limit} tokens)`);
|
||||||
}
|
}
|
||||||
const embedding = await ctx.getEmbeddingFor(safeText);
|
const embedding = await ctx.getEmbeddingFor(safeText);
|
||||||
this.touchActivity();
|
this.touchActivity();
|
||||||
|
|||||||
84
src/store.ts
84
src/store.ts
@ -2228,33 +2228,67 @@ export async function chunkDocumentByTokens(
|
|||||||
|
|
||||||
// Tokenize and split any chunks that still exceed limit
|
// Tokenize and split any chunks that still exceed limit
|
||||||
const results: { text: string; pos: number; tokens: number }[] = [];
|
const results: { text: string; pos: number; tokens: number }[] = [];
|
||||||
|
const clampOverlapChars = (value: number, maxChars: number): number => {
|
||||||
|
if (maxChars <= 1) return 0;
|
||||||
|
return Math.max(0, Math.min(maxChars - 1, Math.floor(value)));
|
||||||
|
};
|
||||||
|
|
||||||
|
const pushChunkWithinTokenLimit = async (text: string, pos: number): Promise<void> => {
|
||||||
|
if (signal?.aborted) return;
|
||||||
|
|
||||||
|
const tokens = await llm.tokenize(text);
|
||||||
|
if (tokens.length <= maxTokens || text.length <= 1) {
|
||||||
|
results.push({ text, pos, tokens: tokens.length });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const actualCharsPerToken = text.length / tokens.length;
|
||||||
|
let safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95);
|
||||||
|
if (!Number.isFinite(safeMaxChars) || safeMaxChars < 1) {
|
||||||
|
safeMaxChars = Math.floor(text.length / 2);
|
||||||
|
}
|
||||||
|
safeMaxChars = Math.max(1, Math.min(text.length - 1, safeMaxChars));
|
||||||
|
|
||||||
|
let nextOverlapChars = clampOverlapChars(
|
||||||
|
overlapChars * actualCharsPerToken / 2,
|
||||||
|
safeMaxChars,
|
||||||
|
);
|
||||||
|
let nextWindowChars = Math.max(0, Math.floor(windowChars * actualCharsPerToken / 2));
|
||||||
|
let subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
|
||||||
|
|
||||||
|
// Pathological single-line blobs can produce no meaningful breakpoint progress.
|
||||||
|
// Fall back to a simple half split so every recursion step strictly shrinks.
|
||||||
|
if (
|
||||||
|
subChunks.length <= 1
|
||||||
|
|| subChunks[0]?.text.length === text.length
|
||||||
|
) {
|
||||||
|
safeMaxChars = Math.max(1, Math.floor(text.length / 2));
|
||||||
|
nextOverlapChars = 0;
|
||||||
|
nextWindowChars = 0;
|
||||||
|
subChunks = chunkDocument(text, safeMaxChars, nextOverlapChars, nextWindowChars);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
subChunks.length <= 1
|
||||||
|
|| subChunks[0]?.text.length === text.length
|
||||||
|
) {
|
||||||
|
const fallbackTokens = tokens.slice(0, Math.max(1, maxTokens));
|
||||||
|
const truncatedText = await llm.detokenize(fallbackTokens);
|
||||||
|
results.push({
|
||||||
|
text: truncatedText,
|
||||||
|
pos,
|
||||||
|
tokens: fallbackTokens.length,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const subChunk of subChunks) {
|
||||||
|
await pushChunkWithinTokenLimit(text.slice(subChunk.pos, subChunk.pos + subChunk.text.length), pos + subChunk.pos);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for (const chunk of charChunks) {
|
for (const chunk of charChunks) {
|
||||||
// Respect abort signal to avoid runaway tokenization
|
await pushChunkWithinTokenLimit(chunk.text, chunk.pos);
|
||||||
if (signal?.aborted) break;
|
|
||||||
|
|
||||||
const tokens = await llm.tokenize(chunk.text);
|
|
||||||
|
|
||||||
if (tokens.length <= maxTokens) {
|
|
||||||
results.push({ text: chunk.text, pos: chunk.pos, tokens: tokens.length });
|
|
||||||
} else {
|
|
||||||
// Chunk is still too large - split it further
|
|
||||||
// Use actual token count to estimate better char limit
|
|
||||||
const actualCharsPerToken = chunk.text.length / tokens.length;
|
|
||||||
const safeMaxChars = Math.floor(maxTokens * actualCharsPerToken * 0.95); // 5% safety margin
|
|
||||||
|
|
||||||
const subChunks = chunkDocument(chunk.text, safeMaxChars, Math.floor(overlapChars * actualCharsPerToken / 2), Math.floor(windowChars * actualCharsPerToken / 2));
|
|
||||||
|
|
||||||
for (const subChunk of subChunks) {
|
|
||||||
if (signal?.aborted) break;
|
|
||||||
const subTokens = await llm.tokenize(subChunk.text);
|
|
||||||
results.push({
|
|
||||||
text: subChunk.text,
|
|
||||||
pos: chunk.pos + subChunk.pos,
|
|
||||||
tokens: subTokens.length,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
|
|||||||
@ -194,6 +194,32 @@ describe("LlamaCpp model resolution (config > env > default)", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("LlamaCpp embedding truncation", () => {
|
||||||
|
test("truncates against the active embedding context limit, not the model train context", async () => {
|
||||||
|
const llm = new LlamaCpp({}) as any;
|
||||||
|
const getEmbeddingFor = vi.fn(async (text: string) => ({
|
||||||
|
vector: new Float32Array([0.25, 0.5]),
|
||||||
|
text,
|
||||||
|
}));
|
||||||
|
|
||||||
|
llm.touchActivity = vi.fn();
|
||||||
|
llm.embedModel = {
|
||||||
|
trainContextSize: 8192,
|
||||||
|
tokenize: (text: string) => Array.from({ length: text.length }, () => 1),
|
||||||
|
detokenize: (tokens: readonly number[]) => "x".repeat(tokens.length),
|
||||||
|
};
|
||||||
|
llm.ensureEmbedContext = vi.fn().mockResolvedValue({ getEmbeddingFor });
|
||||||
|
|
||||||
|
const result = await llm.embed("x".repeat(3000));
|
||||||
|
|
||||||
|
expect(getEmbeddingFor).toHaveBeenCalledWith("x".repeat(2044));
|
||||||
|
expect(result).toEqual({
|
||||||
|
embedding: [0.25, 0.5],
|
||||||
|
model: llm.embedModelUri,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe("LlamaCpp rerank deduping", () => {
|
describe("LlamaCpp rerank deduping", () => {
|
||||||
test("deduplicates identical document texts before scoring", async () => {
|
test("deduplicates identical document texts before scoring", async () => {
|
||||||
const llm = new LlamaCpp({}) as any;
|
const llm = new LlamaCpp({}) as any;
|
||||||
|
|||||||
@ -2805,6 +2805,31 @@ describe("Embedding batching", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("Token chunking guardrails", () => {
|
||||||
|
test("chunkDocumentByTokens keeps pathological single-line blobs under the token limit", async () => {
|
||||||
|
setDefaultLlamaCpp({
|
||||||
|
async tokenize(text: string) {
|
||||||
|
return Array.from({ length: text.length }, () => 1);
|
||||||
|
},
|
||||||
|
async detokenize(tokens: readonly number[]) {
|
||||||
|
return "x".repeat(tokens.length);
|
||||||
|
},
|
||||||
|
} as any);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const chunks = await chunkDocumentByTokens("x".repeat(1200), 100, 15, 20);
|
||||||
|
|
||||||
|
expect(chunks.length).toBeGreaterThan(1);
|
||||||
|
expect(chunks.every((chunk) => chunk.tokens <= 100)).toBe(true);
|
||||||
|
for (let i = 1; i < chunks.length; i++) {
|
||||||
|
expect(chunks[i]!.pos).toBeGreaterThan(chunks[i - 1]!.pos);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setDefaultLlamaCpp(null);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Content-Addressable Storage Tests
|
// Content-Addressable Storage Tests
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user