qmd/test/llm.test.ts
2026-04-07 18:49:22 +02:00

802 lines
30 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* llm.test.ts - Unit tests for the LLM abstraction layer (node-llama-cpp)
*
* Run with: bun test src/llm.test.ts
*
* These tests require the actual models to be downloaded. Run the embed or
* rerank functions first to trigger model downloads.
*/
import { describe, test, expect, beforeAll, afterAll, vi } from "vitest";
import {
LlamaCpp,
getDefaultLlamaCpp,
disposeDefaultLlamaCpp,
resolveLlamaGpuMode,
withLLMSession,
canUnloadLLM,
SessionReleasedError,
type RerankDocument,
type ILLMSession,
} from "../src/llm.js";
// =============================================================================
// Singleton Tests (no model loading required)
// =============================================================================
describe("Default LlamaCpp Singleton", () => {
// Test singleton behavior without resetting to avoid orphan instances
test("getDefaultLlamaCpp returns same instance on subsequent calls", () => {
const llm1 = getDefaultLlamaCpp();
const llm2 = getDefaultLlamaCpp();
expect(llm1).toBe(llm2);
expect(llm1).toBeInstanceOf(LlamaCpp);
});
});
// =============================================================================
// Model Existence Tests
// =============================================================================
describe("LlamaCpp.modelExists", () => {
test("returns exists:true for HuggingFace model URIs", async () => {
const llm = getDefaultLlamaCpp();
const result = await llm.modelExists("hf:org/repo/model.gguf");
expect(result.exists).toBe(true);
expect(result.name).toBe("hf:org/repo/model.gguf");
});
test("returns exists:false for non-existent local paths", async () => {
const llm = getDefaultLlamaCpp();
const result = await llm.modelExists("/nonexistent/path/model.gguf");
expect(result.exists).toBe(false);
expect(result.name).toBe("/nonexistent/path/model.gguf");
});
});
describe("QMD_LLAMA_GPU resolution", () => {
test("uses auto when unset or blank", () => {
expect(resolveLlamaGpuMode(undefined)).toBe("auto");
expect(resolveLlamaGpuMode(" ")).toBe("auto");
});
test("maps CPU disable values to false", () => {
expect(resolveLlamaGpuMode("false")).toBe(false);
expect(resolveLlamaGpuMode("OFF")).toBe(false);
expect(resolveLlamaGpuMode(" none ")).toBe(false);
expect(resolveLlamaGpuMode("disabled")).toBe(false);
expect(resolveLlamaGpuMode("0")).toBe(false);
});
test("passes through supported GPU backends", () => {
expect(resolveLlamaGpuMode("metal")).toBe("metal");
expect(resolveLlamaGpuMode("VULKAN")).toBe("vulkan");
expect(resolveLlamaGpuMode(" cuda ")).toBe("cuda");
});
test("warns and falls back to auto for unsupported values", () => {
const stderrSpy = vi.spyOn(process.stderr, "write").mockReturnValue(true);
try {
expect(resolveLlamaGpuMode("rocm")).toBe("auto");
expect(stderrSpy).toHaveBeenCalled();
expect(String(stderrSpy.mock.calls[0]?.[0] || "")).toContain("QMD_LLAMA_GPU");
} finally {
stderrSpy.mockRestore();
}
});
});
describe("LlamaCpp expand context size config", () => {
const defaultExpandContextSize = 2048;
test("uses default expand context size when no config or env is set", () => {
const prev = process.env.QMD_EXPAND_CONTEXT_SIZE;
delete process.env.QMD_EXPAND_CONTEXT_SIZE;
try {
const llm = new LlamaCpp({}) as any;
expect(llm.expandContextSize).toBe(defaultExpandContextSize);
} finally {
if (prev === undefined) delete process.env.QMD_EXPAND_CONTEXT_SIZE;
else process.env.QMD_EXPAND_CONTEXT_SIZE = prev;
}
});
test("uses QMD_EXPAND_CONTEXT_SIZE when set to a positive integer", () => {
const prev = process.env.QMD_EXPAND_CONTEXT_SIZE;
process.env.QMD_EXPAND_CONTEXT_SIZE = "3072";
try {
const llm = new LlamaCpp({}) as any;
expect(llm.expandContextSize).toBe(3072);
} finally {
if (prev === undefined) delete process.env.QMD_EXPAND_CONTEXT_SIZE;
else process.env.QMD_EXPAND_CONTEXT_SIZE = prev;
}
});
test("config value overrides QMD_EXPAND_CONTEXT_SIZE", () => {
const prev = process.env.QMD_EXPAND_CONTEXT_SIZE;
process.env.QMD_EXPAND_CONTEXT_SIZE = "4096";
try {
const llm = new LlamaCpp({ expandContextSize: 1536 }) as any;
expect(llm.expandContextSize).toBe(1536);
} finally {
if (prev === undefined) delete process.env.QMD_EXPAND_CONTEXT_SIZE;
else process.env.QMD_EXPAND_CONTEXT_SIZE = prev;
}
});
test("falls back to default and warns when QMD_EXPAND_CONTEXT_SIZE is invalid", () => {
const prev = process.env.QMD_EXPAND_CONTEXT_SIZE;
process.env.QMD_EXPAND_CONTEXT_SIZE = "bad";
const stderrSpy = vi.spyOn(process.stderr, "write").mockReturnValue(true);
try {
const llm = new LlamaCpp({}) as any;
expect(llm.expandContextSize).toBe(defaultExpandContextSize);
expect(stderrSpy).toHaveBeenCalled();
expect(String(stderrSpy.mock.calls[0]?.[0] || "")).toContain("QMD_EXPAND_CONTEXT_SIZE");
} finally {
stderrSpy.mockRestore();
if (prev === undefined) delete process.env.QMD_EXPAND_CONTEXT_SIZE;
else process.env.QMD_EXPAND_CONTEXT_SIZE = prev;
}
});
test("throws when config expandContextSize is invalid", () => {
expect(() => new LlamaCpp({ expandContextSize: 0 })).toThrow(
"Invalid expandContextSize: 0. Must be a positive integer."
);
});
});
describe("LlamaCpp model resolution (config > env > default)", () => {
const HARDCODED_EMBED = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const HARDCODED_RERANK = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf";
const HARDCODED_GENERATE = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf";
test("uses hardcoded default when no config or env is set", () => {
const prev = process.env.QMD_EMBED_MODEL;
delete process.env.QMD_EMBED_MODEL;
try {
const llm = new LlamaCpp({}) as any;
expect(llm.embedModelUri).toBe(HARDCODED_EMBED);
expect(llm.rerankModelUri).toBe(HARDCODED_RERANK);
expect(llm.generateModelUri).toBe(HARDCODED_GENERATE);
} finally {
if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
else process.env.QMD_EMBED_MODEL = prev;
}
});
test("env var overrides hardcoded default", () => {
const prev = process.env.QMD_EMBED_MODEL;
process.env.QMD_EMBED_MODEL = "hf:custom/embed-model.gguf";
try {
const llm = new LlamaCpp({}) as any;
expect(llm.embedModelUri).toBe("hf:custom/embed-model.gguf");
} finally {
if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
else process.env.QMD_EMBED_MODEL = prev;
}
});
test("config overrides env var", () => {
const prev = process.env.QMD_EMBED_MODEL;
process.env.QMD_EMBED_MODEL = "hf:env/model.gguf";
try {
const llm = new LlamaCpp({ embedModel: "hf:config/model.gguf" }) as any;
expect(llm.embedModelUri).toBe("hf:config/model.gguf");
} finally {
if (prev === undefined) delete process.env.QMD_EMBED_MODEL;
else process.env.QMD_EMBED_MODEL = prev;
}
});
});
describe("LlamaCpp rerank deduping", () => {
test("deduplicates identical document texts before scoring", async () => {
const llm = new LlamaCpp({}) as any;
llm._ciMode = false; // allow unit test even in CI (mocked, no real models)
const rankAll = vi.fn(async (_query: string, docs: string[]) =>
docs.map((doc) => doc === "shared chunk" ? 0.9 : 0.2)
);
llm.touchActivity = vi.fn();
llm.ensureRerankContexts = vi.fn().mockResolvedValue([{ rankAll }]);
llm.ensureRerankModel = vi.fn().mockResolvedValue({
tokenize: (text: string) => Array.from(text),
detokenize: (tokens: string[]) => tokens.join(""),
});
const result = await llm.rerank("query", [
{ file: "a.md", text: "shared chunk" },
{ file: "b.md", text: "shared chunk" },
{ file: "c.md", text: "different chunk" },
]);
expect(rankAll).toHaveBeenCalledTimes(1);
expect(rankAll).toHaveBeenCalledWith("query", ["shared chunk", "different chunk"]);
expect(result.results).toHaveLength(3);
const scoreByFile = new Map(result.results.map((item) => [item.file, item.score]));
expect(scoreByFile.get("a.md")).toBe(0.9);
expect(scoreByFile.get("b.md")).toBe(0.9);
expect(scoreByFile.get("c.md")).toBe(0.2);
});
});
// =============================================================================
// Integration Tests (require actual models)
// =============================================================================
describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => {
// Use the singleton to avoid multiple Metal contexts
const llm = getDefaultLlamaCpp();
afterAll(async () => {
// Ensure native resources are released to avoid ggml-metal asserts on process exit.
await disposeDefaultLlamaCpp();
});
describe("embed", () => {
test("returns embedding with correct dimensions", async () => {
const result = await llm.embed("Hello world");
expect(result).not.toBeNull();
expect(result!.embedding).toBeInstanceOf(Array);
expect(result!.embedding.length).toBeGreaterThan(0);
// embeddinggemma outputs 768 dimensions
expect(result!.embedding.length).toBe(768);
});
test("returns consistent embeddings for same input", async () => {
const result1 = await llm.embed("test text");
const result2 = await llm.embed("test text");
expect(result1).not.toBeNull();
expect(result2).not.toBeNull();
// Embeddings should be identical for the same input
for (let i = 0; i < result1!.embedding.length; i++) {
expect(result1!.embedding[i]).toBeCloseTo(result2!.embedding[i]!, 5);
}
});
test("returns different embeddings for different inputs", async () => {
const result1 = await llm.embed("cats are great");
const result2 = await llm.embed("database optimization");
expect(result1).not.toBeNull();
expect(result2).not.toBeNull();
// Calculate cosine similarity - should be less than 1.0 (not identical)
let dotProduct = 0;
let norm1 = 0;
let norm2 = 0;
for (let i = 0; i < result1!.embedding.length; i++) {
const v1 = result1!.embedding[i]!;
const v2 = result2!.embedding[i]!;
dotProduct += v1 * v2;
norm1 += v1 ** 2;
norm2 += v2 ** 2;
}
const similarity = dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
expect(similarity).toBeLessThan(0.95); // Should be meaningfully different
});
});
describe("embedBatch", () => {
test("returns embeddings for multiple texts", async () => {
const texts = ["Hello world", "Test text", "Another document"];
const results = await llm.embedBatch(texts);
expect(results).toHaveLength(3);
for (const result of results) {
expect(result).not.toBeNull();
expect(result!.embedding.length).toBe(768);
}
});
test("returns same results as individual embed calls", async () => {
const texts = ["cats are great", "dogs are awesome"];
// Get batch embeddings
const batchResults = await llm.embedBatch(texts);
// Get individual embeddings
const individualResults = await Promise.all(texts.map(t => llm.embed(t)));
// Compare - should be identical
for (let i = 0; i < texts.length; i++) {
expect(batchResults[i]).not.toBeNull();
expect(individualResults[i]).not.toBeNull();
for (let j = 0; j < batchResults[i]!.embedding.length; j++) {
expect(batchResults[i]!.embedding[j]).toBeCloseTo(individualResults[i]!.embedding[j]!, 5);
}
}
});
test("handles empty array", async () => {
const results = await llm.embedBatch([]);
expect(results).toHaveLength(0);
});
test("batch is faster than sequential", async () => {
const texts = Array(10).fill(null).map((_, i) => `Document number ${i} with content`);
// Time batch
const batchStart = Date.now();
await llm.embedBatch(texts);
const batchTime = Date.now() - batchStart;
// Time sequential
const seqStart = Date.now();
for (const text of texts) {
await llm.embed(text);
}
const seqTime = Date.now() - seqStart;
console.log(`Batch: ${batchTime}ms, Sequential: ${seqTime}ms`);
// Performance is machine/load dependent. We only assert batch isn't drastically worse.
expect(batchTime).toBeLessThanOrEqual(seqTime * 3);
});
test("handles concurrent embedBatch calls on fresh instance without race condition", async () => {
// This test verifies the fix for a race condition where concurrent calls to
// ensureEmbedContext() could create multiple contexts. Without the promise guard,
// each concurrent embedBatch call sees embedContext === null and creates its own
// context, causing resource leaks and potential "Context is disposed" errors.
//
// See: https://github.com/tobi/qmd/pull/54
//
// The fix uses a promise guard to ensure only one context creation runs at a time.
// We verify this by instrumenting createEmbeddingContext to count invocations.
const freshLlm = new LlamaCpp({});
let contextCreateCount = 0;
// Instrument the model's createEmbeddingContext to count calls
const originalEnsureEmbedModel = (freshLlm as any).ensureEmbedModel.bind(freshLlm);
let modelInstrumented = false;
(freshLlm as any).ensureEmbedModel = async function() {
const model = await originalEnsureEmbedModel();
if (!modelInstrumented) {
modelInstrumented = true;
const originalCreate = model.createEmbeddingContext.bind(model);
model.createEmbeddingContext = async function(...args: any[]) {
contextCreateCount++;
return originalCreate(...args);
};
}
return model;
};
const texts = Array(10).fill(null).map((_, i) => `Document ${i}`);
// Call embedBatch 5 TIMES in parallel on fresh instance.
// Without the promise guard fix, this would create 5 contexts (one per call).
// With the fix, only 1 context should be created.
const batches = await Promise.all([
freshLlm.embedBatch(texts.slice(0, 2)),
freshLlm.embedBatch(texts.slice(2, 4)),
freshLlm.embedBatch(texts.slice(4, 6)),
freshLlm.embedBatch(texts.slice(6, 8)),
freshLlm.embedBatch(texts.slice(8, 10)),
]);
const allResults = batches.flat();
expect(allResults).toHaveLength(10);
const successCount = allResults.filter(r => r !== null).length;
expect(successCount).toBe(10);
// THE KEY ASSERTION: Contexts should be created once (by ensureEmbedContexts),
// not duplicated per concurrent embedBatch call. The exact count depends on
// available VRAM (computeParallelism), but should not be 5 (one per call).
// Without the fix, contextCreateCount would be 5× the intended count (one set per concurrent call).
// With the promise guard, contexts are created exactly once regardless of concurrent callers.
// The count depends on VRAM (computeParallelism), but should be ≤ 8 (the cap).
console.log(`Context creation count: ${contextCreateCount} (expected: ≤ 8, not 5× duplicated)`);
expect(contextCreateCount).toBeGreaterThanOrEqual(1);
expect(contextCreateCount).toBeLessThanOrEqual(8);
await freshLlm.dispose();
}, 60000);
});
describe("rerank", () => {
test("scores capital of France question correctly", async () => {
const query = "What is the capital of France?";
const documents: RerankDocument[] = [
{ file: "butterflies.txt", text: "Butterflies indeed fly through the garden." },
{ file: "france.txt", text: "The capital of France is Paris." },
{ file: "canada.txt", text: "The capital of Canada is Ottawa." },
];
const result = await llm.rerank(query, documents);
expect(result.results).toHaveLength(3);
// The France document should score highest
expect(result.results[0]!.file).toBe("france.txt");
expect(result.results[0]!.score).toBeGreaterThan(0.7);
// Canada should be somewhat relevant (also about capitals)
expect(result.results[1]!.file).toBe("canada.txt");
// Butterflies should score lowest
expect(result.results[2]!.file).toBe("butterflies.txt");
expect(result.results[2]!.score).toBeLessThan(0.6);
});
test("scores authentication query correctly", async () => {
const query = "How do I configure authentication?";
const documents: RerankDocument[] = [
{ file: "weather.md", text: "The weather today is sunny with mild temperatures." },
{ file: "auth.md", text: "Authentication can be configured by setting the AUTH_SECRET environment variable." },
{ file: "pizza.md", text: "Our restaurant serves the best pizza in town." },
{ file: "jwt.md", text: "JWT authentication requires a secret key and expiration time." },
];
const result = await llm.rerank(query, documents);
expect(result.results).toHaveLength(4);
// Auth documents should score highest
const topTwo = result.results.slice(0, 2).map((r) => r.file);
expect(topTwo).toContain("auth.md");
expect(topTwo).toContain("jwt.md");
// Irrelevant documents should score lowest
const bottomTwo = result.results.slice(2).map((r) => r.file);
expect(bottomTwo).toContain("weather.md");
expect(bottomTwo).toContain("pizza.md");
});
test("handles programming queries correctly", async () => {
const query = "How do I handle errors in JavaScript?";
const documents: RerankDocument[] = [
{ file: "cooking.md", text: "To make a good pasta, boil water and add salt." },
{ file: "errors.md", text: "Use try-catch blocks to handle JavaScript errors gracefully." },
{ file: "python.md", text: "Python uses try-except for exception handling." },
];
const result = await llm.rerank(query, documents);
// JavaScript errors doc should score highest
expect(result.results[0]!.file).toBe("errors.md");
expect(result.results[0]!.score).toBeGreaterThan(0.7);
// Python doc might be somewhat relevant (same concept, different language)
// Cooking should be least relevant
expect(result.results[2]!.file).toBe("cooking.md");
});
test("handles empty document list", async () => {
const result = await llm.rerank("test query", []);
expect(result.results).toHaveLength(0);
});
test("handles single document", async () => {
const result = await llm.rerank("test", [{ file: "doc.md", text: "content" }]);
expect(result.results).toHaveLength(1);
expect(result.results[0]!.file).toBe("doc.md");
});
test("preserves original file paths", async () => {
const documents: RerankDocument[] = [
{ file: "path/to/doc1.md", text: "content one" },
{ file: "another/path/doc2.md", text: "content two" },
];
const result = await llm.rerank("query", documents);
const files = result.results.map((r) => r.file).sort();
expect(files).toEqual(["another/path/doc2.md", "path/to/doc1.md"]);
});
test("returns scores between 0 and 1", async () => {
const documents: RerankDocument[] = [
{ file: "a.md", text: "The quick brown fox jumps over the lazy dog." },
{ file: "b.md", text: "Machine learning algorithms process data efficiently." },
{ file: "c.md", text: "React components use JSX syntax for rendering." },
];
const result = await llm.rerank("Tell me about animals", documents);
for (const doc of result.results) {
expect(doc.score).toBeGreaterThanOrEqual(0);
expect(doc.score).toBeLessThanOrEqual(1);
}
});
test("batch reranks multiple documents efficiently", async () => {
// Create 10 documents to verify batch processing works
const documents: RerankDocument[] = Array(10)
.fill(null)
.map((_, i) => ({
file: `doc${i}.md`,
text: `Document number ${i} with some content about topic ${i % 3}`,
}));
const start = Date.now();
const result = await llm.rerank("topic 1", documents);
const elapsed = Date.now() - start;
expect(result.results).toHaveLength(10);
// Verify all documents are returned with valid scores
for (const doc of result.results) {
expect(doc.score).toBeGreaterThanOrEqual(0);
expect(doc.score).toBeLessThanOrEqual(1);
}
// Log timing for monitoring batch performance
console.log(`Batch rerank of 10 docs took ${elapsed}ms`);
});
test("uses fewer active rerank contexts for small batches", async () => {
const freshLlm = new LlamaCpp({});
const calls: number[] = [];
const fakeModel = {
tokenize: (text: string) => Array.from(text),
detokenize: (tokens: string[]) => tokens.join(""),
};
const fakeContexts = Array.from({ length: 4 }, (_, idx) => ({
rankAll: async (_query: string, docs: string[]) => {
calls.push(idx);
return docs.map(() => 0.5);
},
}));
(freshLlm as any).ensureRerankModel = async () => fakeModel;
(freshLlm as any).ensureRerankContexts = async () => fakeContexts;
const documents: RerankDocument[] = Array.from({ length: 20 }, (_, i) => ({
file: `doc${i}.md`,
text: `Document number ${i}`,
}));
const result = await freshLlm.rerank("topic 1", documents);
expect(result.results).toHaveLength(20);
expect(calls).toEqual([0, 1]);
});
test("truncates and reranks document exceeding 2048 token context size", async () => {
// The reranker context is created with contextSize=2048. Documents that
// exceed the token budget (contextSize - template overhead - query tokens)
// should be silently truncated rather than crashing.
const paragraph = "The quick brown fox jumps over the lazy dog near the riverbank. " +
"Authentication tokens must be validated on every request to ensure security. " +
"Database queries should use prepared statements to prevent SQL injection attacks. " +
"The deployment pipeline includes linting, testing, building, and publishing stages. ";
// ~320 chars per paragraph, repeat 40 times = ~12800 chars ≈ 3200 tokens
const longText = paragraph.repeat(40);
const query = "How do I configure authentication?";
const documents: RerankDocument[] = [
{ file: "short-relevant.md", text: "Authentication can be configured by setting AUTH_SECRET." },
{ file: "long-doc.md", text: longText },
{ file: "short-irrelevant.md", text: "The weather is sunny today." },
];
console.log(`Long doc length: ${longText.length} chars (~${Math.round(longText.length / 4)} tokens)`);
const result = await llm.rerank(query, documents);
// Should return all 3 documents without crashing
expect(result.results).toHaveLength(3);
// All scores should be valid numbers in [0, 1]
for (const doc of result.results) {
expect(doc.score).toBeGreaterThanOrEqual(0);
expect(doc.score).toBeLessThanOrEqual(1);
expect(Number.isNaN(doc.score)).toBe(false);
}
// The short, directly relevant doc should still rank highest
console.log("Rerank results for long doc test:");
for (const doc of result.results) {
console.log(` ${doc.file}: ${doc.score.toFixed(4)}`);
}
});
});
describe("expandQuery", () => {
test("returns query expansions with correct types", async () => {
const result = await llm.expandQuery("test query");
// Result is Queryable[] containing lex, vec, and/or hyde entries
expect(result.length).toBeGreaterThanOrEqual(1);
// Each result should have a valid type
for (const q of result) {
expect(["lex", "vec", "hyde"]).toContain(q.type);
expect(q.text.length).toBeGreaterThan(0);
}
}, 30000); // 30s timeout for model loading
test("can exclude lexical queries", async () => {
const result = await llm.expandQuery("authentication setup", { includeLexical: false });
// Should not contain any 'lex' type entries
const lexEntries = result.filter(q => q.type === "lex");
expect(lexEntries).toHaveLength(0);
});
});
});
// =============================================================================
// Session Management Tests
// =============================================================================
describe.skipIf(!!process.env.CI)("LLM Session Management", () => {
describe("withLLMSession", () => {
test("session provides access to LLM operations", async () => {
const result = await withLLMSession(async (session) => {
expect(session.isValid).toBe(true);
const embedding = await session.embed("test text");
expect(embedding).not.toBeNull();
expect(embedding!.embedding.length).toBe(768);
return "success";
});
expect(result).toBe("success");
});
test("session is invalid after release", async () => {
let capturedSession: ILLMSession | null = null;
await withLLMSession(async (session) => {
capturedSession = session;
expect(session.isValid).toBe(true);
});
// Session should be invalid after withLLMSession returns
expect(capturedSession).not.toBeNull();
expect(capturedSession!.isValid).toBe(false);
});
test("session prevents idle unload during operations", async () => {
await withLLMSession(async (session) => {
// While inside a session, canUnloadLLM should return false
expect(canUnloadLLM()).toBe(false);
// Perform an operation
await session.embed("test");
// Still should not be able to unload
expect(canUnloadLLM()).toBe(false);
});
// After session ends, should be able to unload
expect(canUnloadLLM()).toBe(true);
});
test("nested sessions increment ref count", async () => {
await withLLMSession(async (outerSession) => {
expect(canUnloadLLM()).toBe(false);
await withLLMSession(async (innerSession) => {
expect(canUnloadLLM()).toBe(false);
expect(innerSession.isValid).toBe(true);
expect(outerSession.isValid).toBe(true);
});
// Inner session released, but outer still active
expect(canUnloadLLM()).toBe(false);
expect(outerSession.isValid).toBe(true);
});
// All sessions released
expect(canUnloadLLM()).toBe(true);
});
test("session embedBatch works correctly", async () => {
await withLLMSession(async (session) => {
const texts = ["Hello world", "Test text", "Another document"];
const results = await session.embedBatch(texts);
expect(results).toHaveLength(3);
for (const result of results) {
expect(result).not.toBeNull();
expect(result!.embedding.length).toBe(768);
}
});
});
test("session rerank works correctly", async () => {
await withLLMSession(async (session) => {
const documents: RerankDocument[] = [
{ file: "a.txt", text: "The capital of France is Paris." },
{ file: "b.txt", text: "Dogs are great pets." },
];
const result = await session.rerank("What is the capital of France?", documents);
expect(result.results).toHaveLength(2);
expect(result.results[0]!.file).toBe("a.txt");
expect(result.results[0]!.score).toBeGreaterThan(result.results[1]!.score);
});
});
test("max duration aborts session after timeout", async () => {
let aborted = false;
try {
await withLLMSession(async (session) => {
// Wait longer than max duration
await new Promise(resolve => setTimeout(resolve, 150));
// This operation should throw because session was aborted
await session.embed("test");
}, { maxDuration: 50 }); // 50ms max
} catch (err) {
if (err instanceof SessionReleasedError) {
aborted = true;
} else {
throw err;
}
}
expect(aborted).toBe(true);
}, 5000);
test("external abort signal propagates to session", async () => {
const abortController = new AbortController();
let sessionAborted = false;
const promise = withLLMSession(async (session) => {
// Wait a bit then check if aborted
await new Promise(resolve => setTimeout(resolve, 100));
if (!session.isValid) {
sessionAborted = true;
throw new SessionReleasedError("Session aborted");
}
return "should not reach";
}, { signal: abortController.signal });
// Abort after 20ms
setTimeout(() => abortController.abort(), 20);
try {
await promise;
} catch (err) {
// Expected
}
expect(sessionAborted).toBe(true);
}, 5000);
test("session provides abort signal for monitoring", async () => {
await withLLMSession(async (session) => {
expect(session.signal).toBeInstanceOf(AbortSignal);
expect(session.signal.aborted).toBe(false);
});
});
test("returns value from callback", async () => {
const result = await withLLMSession(async (session) => {
await session.embed("test");
return { status: "complete", count: 42 };
});
expect(result).toEqual({ status: "complete", count: 42 });
});
test("propagates errors from callback", async () => {
const customError = new Error("Custom test error");
await expect(
withLLMSession(async () => {
throw customError;
})
).rejects.toThrow("Custom test error");
});
});
});