fix: use Qwen3 recommended sampling params to prevent repetition loops
- Changed temperature from 0/0.1 to 0.7 (Qwen3 non-thinking mode default) - Added topK=20, topP=0.8 per Qwen3 docs - Added repeatPenalty with presencePenalty=0.5 for query expansion - Fixes infinite loop on acronyms like DHH, BFCM Qwen3 docs explicitly warn: 'DO NOT use greedy decoding, as it can lead to performance degradation and endless repetitions'
This commit is contained in:
parent
479b68bbf1
commit
102ff861d3
@ -29,3 +29,15 @@ train-local:
|
||||
grpo-local:
|
||||
CUDA_VISIBLE_DEVICES=1,2,3 HF_TOKEN=${HF_TOKEN} uv run torchrun --standalone --nproc_per_node 3 \
|
||||
train.py grpo --config configs/grpo.yaml |& tee /tmp/qmd-grpo-train.log
|
||||
|
||||
gepa-local:
|
||||
UV_CACHE_DIR=/tmp/uv-cache LITELLM_CACHE_DIR=/tmp/litellm-cache OLLAMA_API_BASE=http://localhost:11434 \
|
||||
uv run python gepa/dspy_gepa.py \
|
||||
--input data/qmd_expansion_v2.jsonl \
|
||||
--model ollama/glm-4.7-flash:Q8_0 \
|
||||
--reflection-model ollama/glm-4.7-flash:Q8_0 \
|
||||
--max-metric-calls 100 --limit 20 \
|
||||
--valset data/qmd_expansion_handcrafted.jsonl --val-limit 20 \
|
||||
--max-tokens 512 --reflection-max-tokens 512 \
|
||||
--emit gepa/gepa_outputs_glm.jsonl \
|
||||
--save-prompt gepa/best_prompt_glm.txt
|
||||
|
||||
1
finetune/gepa/best_prompt_glm.txt
Normal file
1
finetune/gepa/best_prompt_glm.txt
Normal file
@ -0,0 +1 @@
|
||||
Expand a search query into lex/vec/hyde lines.
|
||||
@ -27,7 +27,7 @@ repo_root = Path(__file__).parent.parent
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
|
||||
from dataset.schema import parse_output_text
|
||||
from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
|
||||
from reward import score_expansion_detailed
|
||||
|
||||
|
||||
@ -35,11 +35,12 @@ class ExpandSignature(dspy.Signature):
|
||||
"""Expand a search query into lex/vec/hyde lines."""
|
||||
|
||||
query = dspy.InputField(desc="User search query")
|
||||
expansion = dspy.OutputField(
|
||||
output = dspy.OutputField(
|
||||
desc=(
|
||||
"Multi-line text with prefixes: 2-3 lex:, 2-3 vec:, optional 0-1 hyde:. "
|
||||
"Lex lines are short keywords and must not echo the query. "
|
||||
"Vec lines are natural language search phrases. "
|
||||
"JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
|
||||
"Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
|
||||
"Lex items are short keywords and must not echo the query. "
|
||||
"Vec items are natural language search phrases. "
|
||||
"Hyde is 50-200 chars, single line."
|
||||
)
|
||||
)
|
||||
@ -55,7 +56,7 @@ class Expander(dspy.Module):
|
||||
|
||||
|
||||
def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
|
||||
expansion = getattr(pred, "expansion", "") or ""
|
||||
expansion = output_items_to_text(_coerce_output_items(pred))
|
||||
detail = score_expansion_detailed(gold.query, expansion)
|
||||
score = detail["percentage"] / 100.0
|
||||
feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
|
||||
@ -80,10 +81,31 @@ def to_examples(queries: list[str]) -> list[dspy.Example]:
|
||||
return [dspy.Example(query=q).with_inputs("query") for q in queries]
|
||||
|
||||
|
||||
def write_jsonl(path: Path, queries: list[str], outputs: list[str]) -> None:
|
||||
def _coerce_output_items(pred) -> list[list[str]]:
|
||||
raw_output = getattr(pred, "output", None)
|
||||
if isinstance(raw_output, (list, tuple)):
|
||||
return normalize_output_items(raw_output)
|
||||
|
||||
raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
|
||||
if not raw_text:
|
||||
return []
|
||||
|
||||
if raw_text[0] in ("[", "{"):
|
||||
try:
|
||||
obj = json.loads(raw_text)
|
||||
if isinstance(obj, dict) and "output" in obj:
|
||||
obj = obj["output"]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return normalize_output_items(obj)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return parse_output_text(raw_text)
|
||||
|
||||
|
||||
def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
for query, output_text in zip(queries, outputs, strict=True):
|
||||
output = parse_output_text(output_text)
|
||||
for query, output in zip(queries, outputs, strict=True):
|
||||
f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
@ -102,6 +124,8 @@ def main() -> int:
|
||||
default="grok-4-1-fast-reasoning",
|
||||
help="LM string in provider/model format (e.g., openai/gpt-4o)",
|
||||
)
|
||||
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
|
||||
parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
|
||||
parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
|
||||
parser.add_argument("--max-full-evals", type=int, default=None)
|
||||
parser.add_argument("--max-metric-calls", type=int, default=None)
|
||||
@ -134,8 +158,8 @@ def main() -> int:
|
||||
val_queries = val_queries[: args.val_limit]
|
||||
valset = to_examples(val_queries)
|
||||
|
||||
lm = dspy.LM(model=args.model)
|
||||
reflection_lm = dspy.LM(model=args.reflection_model)
|
||||
lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
|
||||
reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
|
||||
|
||||
student = Expander()
|
||||
student.set_lm(lm)
|
||||
@ -163,7 +187,8 @@ def main() -> int:
|
||||
outputs = []
|
||||
for q in queries:
|
||||
pred = optimized(query=q)
|
||||
outputs.append(getattr(pred, "expansion", "") or "")
|
||||
items = _coerce_output_items(pred)
|
||||
outputs.append(items)
|
||||
write_jsonl(Path(args.emit), queries, outputs)
|
||||
print(f"Wrote {args.emit}")
|
||||
|
||||
|
||||
20
finetune/gepa/gepa_outputs_glm.jsonl
Normal file
20
finetune/gepa/gepa_outputs_glm.jsonl
Normal file
@ -0,0 +1,20 @@
|
||||
{"query": "how tourism affects local cultures", "output": []}
|
||||
{"query": "how to ferment foods at home", "output": []}
|
||||
{"query": "how to mix modern and vintage decor", "output": []}
|
||||
{"query": "how to perform a scientific experiment", "output": []}
|
||||
{"query": "web mail", "output": []}
|
||||
{"query": "what does the quran cover", "output": []}
|
||||
{"query": "web config", "output": []}
|
||||
{"query": "how to choose farm equipment", "output": []}
|
||||
{"query": "how do thought experiments aid philosophical reasoning", "output": []}
|
||||
{"query": "what is the significance of logic in philosophy", "output": []}
|
||||
{"query": "how to train for a 5k run", "output": []}
|
||||
{"query": "how to engage with political dialogues", "output": []}
|
||||
{"query": "what is competitive analysis", "output": []}
|
||||
{"query": "how does the united nations operate", "output": []}
|
||||
{"query": "what are the crusades?", "output": []}
|
||||
{"query": "what is a literary theme?", "output": []}
|
||||
{"query": "what is the ethical significance of consent", "output": []}
|
||||
{"query": "paint mix", "output": []}
|
||||
{"query": "how to conserve energy in the office?", "output": []}
|
||||
{"query": "how to test soil ph?", "output": []}
|
||||
17
src/llm.ts
17
src/llm.ts
@ -742,13 +742,17 @@ export class LlamaCpp implements LLM {
|
||||
const session = new LlamaChatSession({ contextSequence: sequence });
|
||||
|
||||
const maxTokens = options.maxTokens ?? 150;
|
||||
const temperature = options.temperature ?? 0;
|
||||
// Qwen3 recommends temp=0.7, topP=0.8, topK=20 for non-thinking mode
|
||||
// DO NOT use greedy decoding (temp=0) - causes repetition loops
|
||||
const temperature = options.temperature ?? 0.7;
|
||||
|
||||
let result = "";
|
||||
try {
|
||||
await session.prompt(prompt, {
|
||||
maxTokens,
|
||||
temperature,
|
||||
topK: 20,
|
||||
topP: 0.8,
|
||||
onTextChunk: (text) => {
|
||||
result += text;
|
||||
},
|
||||
@ -811,10 +815,19 @@ export class LlamaCpp implements LLM {
|
||||
const session = new LlamaChatSession({ contextSequence: sequence });
|
||||
|
||||
try {
|
||||
// Qwen3 recommended settings for non-thinking mode:
|
||||
// temp=0.7, topP=0.8, topK=20, presence_penalty for repetition
|
||||
// DO NOT use greedy decoding (temp=0) - causes infinite loops
|
||||
const result = await session.prompt(prompt, {
|
||||
grammar,
|
||||
maxTokens: 600,
|
||||
temperature: 0.1,
|
||||
temperature: 0.7,
|
||||
topK: 20,
|
||||
topP: 0.8,
|
||||
repeatPenalty: {
|
||||
lastTokens: 64,
|
||||
presencePenalty: 0.5,
|
||||
},
|
||||
});
|
||||
|
||||
const lines = result.trim().split("\n");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user