diff --git a/finetune/train_1.7B_grpo.py b/finetune/train_1.7B_grpo.py index 7ea02a1..bbf3f2f 100644 --- a/finetune/train_1.7B_grpo.py +++ b/finetune/train_1.7B_grpo.py @@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float: """Score expansion. Returns 0.0-1.0 for RL reward.""" text = expansion.strip() + # Strip end token if present + text = text.replace('<|im_end|>', '').strip() + + # Check for ... blocks - strip and mark as not skipped + skipped_think = 20 # Bonus for not using thinking mode + if '' in text and '' in text: + text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + skipped_think = 0 # No bonus if thinking was used + # HARD FAIL: Chat template artifacts - if any(token in text for token in ['<|im_start|>', '<|im_end|>', '', '', - '\nassistant\n', '\nuser\n', '<|endoftext|>']): + if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']): return 0.0 # HARD FAIL: EVERY line must start with lex:, vec:, or hyde: @@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float: elif not entities: entity_score = 10 - total = format_score + diversity_score + hyde_score + quality_score + entity_score - max_possible = 120 if parsed["hyde"] else 100 + total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think + max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus return max(0.0, min(1.0, total / max_possible)) diff --git a/finetune/train_4B_grpo.py b/finetune/train_4B_grpo.py index c50aab4..bc9aeb6 100644 --- a/finetune/train_4B_grpo.py +++ b/finetune/train_4B_grpo.py @@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float: """Score expansion. Returns 0.0-1.0 for RL reward.""" text = expansion.strip() + # Strip end token if present + text = text.replace('<|im_end|>', '').strip() + + # Check for ... blocks - strip and mark as not skipped + skipped_think = 20 # Bonus for not using thinking mode + if '' in text and '' in text: + text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + skipped_think = 0 # No bonus if thinking was used + # HARD FAIL: Chat template artifacts - if any(token in text for token in ['<|im_start|>', '<|im_end|>', '', '', - '\nassistant\n', '\nuser\n', '<|endoftext|>']): + if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']): return 0.0 # HARD FAIL: EVERY line must start with lex:, vec:, or hyde: @@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float: elif not entities: entity_score = 10 - total = format_score + diversity_score + hyde_score + quality_score + entity_score - max_possible = 120 if parsed["hyde"] else 100 + total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think + max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus return max(0.0, min(1.0, total / max_possible))