Fix GRPO reward function to handle think blocks and end tokens

- Strip <|im_end|> token from completions (model output includes it)
- Change think_penalty to skipped_think bonus (+20 for not using think)
- Adjust max_possible to account for bonus (120/140)
- Fix typo in chat template artifact check

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Tobi Lutke 2026-01-25 16:32:13 -05:00
parent 66bb8ed963
commit 891f3262cf
No known key found for this signature in database
2 changed files with 24 additions and 8 deletions

View File

@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
"""Score expansion. Returns 0.0-1.0 for RL reward."""
text = expansion.strip()
# Strip end token if present
text = text.replace('<|im_end|>', '').strip()
# Check for <think>...</think> blocks - strip and mark as not skipped
skipped_think = 20 # Bonus for not using thinking mode
if '<think>' in text and '</think>' in text:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
skipped_think = 0 # No bonus if thinking was used
# HARD FAIL: Chat template artifacts
if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
'\nassistant\n', '\nuser\n', '<|endoftext|>']):
if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
return 0.0
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
elif not entities:
entity_score = 10
total = format_score + diversity_score + hyde_score + quality_score + entity_score
max_possible = 120 if parsed["hyde"] else 100
total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
return max(0.0, min(1.0, total / max_possible))

View File

@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
"""Score expansion. Returns 0.0-1.0 for RL reward."""
text = expansion.strip()
# Strip end token if present
text = text.replace('<|im_end|>', '').strip()
# Check for <think>...</think> blocks - strip and mark as not skipped
skipped_think = 20 # Bonus for not using thinking mode
if '<think>' in text and '</think>' in text:
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
skipped_think = 0 # No bonus if thinking was used
# HARD FAIL: Chat template artifacts
if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
'\nassistant\n', '\nuser\n', '<|endoftext|>']):
if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
return 0.0
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
elif not entities:
entity_score = 10
total = format_score + diversity_score + hyde_score + quality_score + entity_score
max_possible = 120 if parsed["hyde"] else 100
total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
return max(0.0, min(1.0, total / max_possible))