diff --git a/finetune/train_1.7B_grpo.py b/finetune/train_1.7B_grpo.py
index 7ea02a1..bbf3f2f 100644
--- a/finetune/train_1.7B_grpo.py
+++ b/finetune/train_1.7B_grpo.py
@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
"""Score expansion. Returns 0.0-1.0 for RL reward."""
text = expansion.strip()
+ # Strip end token if present
+ text = text.replace('<|im_end|>', '').strip()
+
+ # Check for ... blocks - strip and mark as not skipped
+ skipped_think = 20 # Bonus for not using thinking mode
+ if '' in text and '' in text:
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip()
+ skipped_think = 0 # No bonus if thinking was used
+
# HARD FAIL: Chat template artifacts
- if any(token in text for token in ['<|im_start|>', '<|im_end|>', '', '',
- '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+ if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
return 0.0
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
elif not entities:
entity_score = 10
- total = format_score + diversity_score + hyde_score + quality_score + entity_score
- max_possible = 120 if parsed["hyde"] else 100
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
+ max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
return max(0.0, min(1.0, total / max_possible))
diff --git a/finetune/train_4B_grpo.py b/finetune/train_4B_grpo.py
index c50aab4..bc9aeb6 100644
--- a/finetune/train_4B_grpo.py
+++ b/finetune/train_4B_grpo.py
@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
"""Score expansion. Returns 0.0-1.0 for RL reward."""
text = expansion.strip()
+ # Strip end token if present
+ text = text.replace('<|im_end|>', '').strip()
+
+ # Check for ... blocks - strip and mark as not skipped
+ skipped_think = 20 # Bonus for not using thinking mode
+ if '' in text and '' in text:
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip()
+ skipped_think = 0 # No bonus if thinking was used
+
# HARD FAIL: Chat template artifacts
- if any(token in text for token in ['<|im_start|>', '<|im_end|>', '', '',
- '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+ if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
return 0.0
# HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
elif not entities:
entity_score = 10
- total = format_score + diversity_score + hyde_score + quality_score + entity_score
- max_possible = 120 if parsed["hyde"] else 100
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
+ max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
return max(0.0, min(1.0, total / max_possible))