diff --git a/finetune/Justfile b/finetune/Justfile index 1563f87..6f12785 100644 --- a/finetune/Justfile +++ b/finetune/Justfile @@ -29,3 +29,15 @@ train-local: grpo-local: CUDA_VISIBLE_DEVICES=1,2,3 HF_TOKEN=${HF_TOKEN} uv run torchrun --standalone --nproc_per_node 3 \ train.py grpo --config configs/grpo.yaml |& tee /tmp/qmd-grpo-train.log + +gepa-local: + UV_CACHE_DIR=/tmp/uv-cache LITELLM_CACHE_DIR=/tmp/litellm-cache OLLAMA_API_BASE=http://localhost:11434 \ + uv run python gepa/dspy_gepa.py \ + --input data/qmd_expansion_v2.jsonl \ + --model ollama/glm-4.7-flash:Q8_0 \ + --reflection-model ollama/glm-4.7-flash:Q8_0 \ + --max-metric-calls 100 --limit 20 \ + --valset data/qmd_expansion_handcrafted.jsonl --val-limit 20 \ + --max-tokens 512 --reflection-max-tokens 512 \ + --emit gepa/gepa_outputs_glm.jsonl \ + --save-prompt gepa/best_prompt_glm.txt diff --git a/finetune/gepa/best_prompt_glm.txt b/finetune/gepa/best_prompt_glm.txt new file mode 100644 index 0000000..ee84bf0 --- /dev/null +++ b/finetune/gepa/best_prompt_glm.txt @@ -0,0 +1 @@ +Expand a search query into lex/vec/hyde lines. diff --git a/finetune/gepa/dspy_gepa.py b/finetune/gepa/dspy_gepa.py index de7f8ba..db8f944 100644 --- a/finetune/gepa/dspy_gepa.py +++ b/finetune/gepa/dspy_gepa.py @@ -27,7 +27,7 @@ repo_root = Path(__file__).parent.parent if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) -from dataset.schema import parse_output_text +from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text from reward import score_expansion_detailed @@ -35,11 +35,12 @@ class ExpandSignature(dspy.Signature): """Expand a search query into lex/vec/hyde lines.""" query = dspy.InputField(desc="User search query") - expansion = dspy.OutputField( + output = dspy.OutputField( desc=( - "Multi-line text with prefixes: 2-3 lex:, 2-3 vec:, optional 0-1 hyde:. " - "Lex lines are short keywords and must not echo the query. " - "Vec lines are natural language search phrases. " + "JSON array of [kind, text] pairs. kind is lex|vec|hyde. " + "Return 2-3 lex, 2-3 vec, optional 0-1 hyde. " + "Lex items are short keywords and must not echo the query. " + "Vec items are natural language search phrases. " "Hyde is 50-200 chars, single line." ) ) @@ -55,7 +56,7 @@ class Expander(dspy.Module): def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None): - expansion = getattr(pred, "expansion", "") or "" + expansion = output_items_to_text(_coerce_output_items(pred)) detail = score_expansion_detailed(gold.query, expansion) score = detail["percentage"] / 100.0 feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}" @@ -80,10 +81,31 @@ def to_examples(queries: list[str]) -> list[dspy.Example]: return [dspy.Example(query=q).with_inputs("query") for q in queries] -def write_jsonl(path: Path, queries: list[str], outputs: list[str]) -> None: +def _coerce_output_items(pred) -> list[list[str]]: + raw_output = getattr(pred, "output", None) + if isinstance(raw_output, (list, tuple)): + return normalize_output_items(raw_output) + + raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip() + if not raw_text: + return [] + + if raw_text[0] in ("[", "{"): + try: + obj = json.loads(raw_text) + if isinstance(obj, dict) and "output" in obj: + obj = obj["output"] + if isinstance(obj, (list, tuple)): + return normalize_output_items(obj) + except Exception: + pass + + return parse_output_text(raw_text) + + +def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None: with path.open("w", encoding="utf-8") as f: - for query, output_text in zip(queries, outputs, strict=True): - output = parse_output_text(output_text) + for query, output in zip(queries, outputs, strict=True): f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n") @@ -102,6 +124,8 @@ def main() -> int: default="grok-4-1-fast-reasoning", help="LM string in provider/model format (e.g., openai/gpt-4o)", ) + parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM") + parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM") parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"]) parser.add_argument("--max-full-evals", type=int, default=None) parser.add_argument("--max-metric-calls", type=int, default=None) @@ -134,8 +158,8 @@ def main() -> int: val_queries = val_queries[: args.val_limit] valset = to_examples(val_queries) - lm = dspy.LM(model=args.model) - reflection_lm = dspy.LM(model=args.reflection_model) + lm = dspy.LM(model=args.model, max_tokens=args.max_tokens) + reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens) student = Expander() student.set_lm(lm) @@ -163,7 +187,8 @@ def main() -> int: outputs = [] for q in queries: pred = optimized(query=q) - outputs.append(getattr(pred, "expansion", "") or "") + items = _coerce_output_items(pred) + outputs.append(items) write_jsonl(Path(args.emit), queries, outputs) print(f"Wrote {args.emit}") diff --git a/finetune/gepa/gepa_outputs_glm.jsonl b/finetune/gepa/gepa_outputs_glm.jsonl new file mode 100644 index 0000000..4533a08 --- /dev/null +++ b/finetune/gepa/gepa_outputs_glm.jsonl @@ -0,0 +1,20 @@ +{"query": "how tourism affects local cultures", "output": []} +{"query": "how to ferment foods at home", "output": []} +{"query": "how to mix modern and vintage decor", "output": []} +{"query": "how to perform a scientific experiment", "output": []} +{"query": "web mail", "output": []} +{"query": "what does the quran cover", "output": []} +{"query": "web config", "output": []} +{"query": "how to choose farm equipment", "output": []} +{"query": "how do thought experiments aid philosophical reasoning", "output": []} +{"query": "what is the significance of logic in philosophy", "output": []} +{"query": "how to train for a 5k run", "output": []} +{"query": "how to engage with political dialogues", "output": []} +{"query": "what is competitive analysis", "output": []} +{"query": "how does the united nations operate", "output": []} +{"query": "what are the crusades?", "output": []} +{"query": "what is a literary theme?", "output": []} +{"query": "what is the ethical significance of consent", "output": []} +{"query": "paint mix", "output": []} +{"query": "how to conserve energy in the office?", "output": []} +{"query": "how to test soil ph?", "output": []} diff --git a/src/llm.ts b/src/llm.ts index 4f80fa1..ab39c86 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -742,13 +742,17 @@ export class LlamaCpp implements LLM { const session = new LlamaChatSession({ contextSequence: sequence }); const maxTokens = options.maxTokens ?? 150; - const temperature = options.temperature ?? 0; + // Qwen3 recommends temp=0.7, topP=0.8, topK=20 for non-thinking mode + // DO NOT use greedy decoding (temp=0) - causes repetition loops + const temperature = options.temperature ?? 0.7; let result = ""; try { await session.prompt(prompt, { maxTokens, temperature, + topK: 20, + topP: 0.8, onTextChunk: (text) => { result += text; }, @@ -811,10 +815,19 @@ export class LlamaCpp implements LLM { const session = new LlamaChatSession({ contextSequence: sequence }); try { + // Qwen3 recommended settings for non-thinking mode: + // temp=0.7, topP=0.8, topK=20, presence_penalty for repetition + // DO NOT use greedy decoding (temp=0) - causes infinite loops const result = await session.prompt(prompt, { grammar, maxTokens: 600, - temperature: 0.1, + temperature: 0.7, + topK: 20, + topP: 0.8, + repeatPenalty: { + lastTokens: 64, + presencePenalty: 0.5, + }, }); const lines = result.trim().split("\n");