Fix GRPO reward function to handle think blocks and end tokens
- Strip <|im_end|> token from completions (model output includes it) - Change think_penalty to skipped_think bonus (+20 for not using think) - Adjust max_possible to account for bonus (120/140) - Fix typo in chat template artifact check Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
66bb8ed963
commit
891f3262cf
@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
|
||||
"""Score expansion. Returns 0.0-1.0 for RL reward."""
|
||||
text = expansion.strip()
|
||||
|
||||
# Strip end token if present
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
|
||||
# Check for <think>...</think> blocks - strip and mark as not skipped
|
||||
skipped_think = 20 # Bonus for not using thinking mode
|
||||
if '<think>' in text and '</think>' in text:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
skipped_think = 0 # No bonus if thinking was used
|
||||
|
||||
# HARD FAIL: Chat template artifacts
|
||||
if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
|
||||
'\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
||||
if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
||||
return 0.0
|
||||
|
||||
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
|
||||
@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
|
||||
elif not entities:
|
||||
entity_score = 10
|
||||
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score
|
||||
max_possible = 120 if parsed["hyde"] else 100
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
|
||||
max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
|
||||
return max(0.0, min(1.0, total / max_possible))
|
||||
|
||||
|
||||
|
||||
@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
|
||||
"""Score expansion. Returns 0.0-1.0 for RL reward."""
|
||||
text = expansion.strip()
|
||||
|
||||
# Strip end token if present
|
||||
text = text.replace('<|im_end|>', '').strip()
|
||||
|
||||
# Check for <think>...</think> blocks - strip and mark as not skipped
|
||||
skipped_think = 20 # Bonus for not using thinking mode
|
||||
if '<think>' in text and '</think>' in text:
|
||||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
|
||||
skipped_think = 0 # No bonus if thinking was used
|
||||
|
||||
# HARD FAIL: Chat template artifacts
|
||||
if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
|
||||
'\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
||||
if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
|
||||
return 0.0
|
||||
|
||||
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
|
||||
@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
|
||||
elif not entities:
|
||||
entity_score = 10
|
||||
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score
|
||||
max_possible = 120 if parsed["hyde"] else 100
|
||||
total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
|
||||
max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
|
||||
return max(0.0, min(1.0, total / max_possible))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user