fix(reward): tighten entity detection, add filler penalty, stricter diversity

- Compound entity chaining now stops one level deep. Previously "TDS
  motorsports team history" would inflate the expected entity set with
  "team" and "history", causing false-positive entity-preservation
  penalties during GRPO. Now only {tds, motorsports} are detected.

- Add INTERIOR_FILLER_WORDS penalty (-3/line): lex lines containing
  "overview" or "basics" absent from the original query are penalised.
  Targets template-generator noise, e.g. "ancient overview rome timeline".

- Raise is_diverse threshold 2→3: requires 3 unique words between lex
  lines before they count as diverse. Reduces reward for near-duplicate
  pairs like "auth setup" / "auth configuration".

- Broaden quoted-phrase bonus: was gated on named entities existing;
  now any multi-word query earns +3 for using quotes in lex lines.
  Better incentivises BM25-aware syntax like "memory leak" python.

Fixes scoring noise identified while working on issue #247.
This commit is contained in:
rkbadhan 2026-02-24 19:15:03 +05:30
parent d6f3688d91
commit 4511b9bd4d

View File

@ -72,6 +72,10 @@ GENERIC_LEX_PHRASES = frozenset({
'what is', 'how to', 'guide to', 'help with',
})
# Words commonly injected as filler/noise into lex lines by template generators
# (e.g. "ancient overview rome timeline"). Penalized when absent from the query.
INTERIOR_FILLER_WORDS = frozenset({'overview', 'basics'})
# Chat template tokens that indicate a broken output
CHAT_TEMPLATE_TOKENS = frozenset({
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
@ -142,47 +146,49 @@ def extract_named_entities(query: str) -> set:
Position-0 words are also detected as entities if they are capitalized and
not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob").
Compound chaining extends one level from a directly-detected entity:
"TDS motorsports" -> {tds, motorsports}; "TDS motorsports team" -> {tds, motorsports}.
"""
entities = set()
words = query.split()
prev_was_entity = False
prev_was_base_entity = False
for i, word in enumerate(words):
clean = word.strip('.,!?:;()[]"\'')
if not clean:
prev_was_entity = False
prev_was_base_entity = False
continue
is_entity = False
is_base_entity = False
# ALL-CAPS acronyms: TDS, API, GPU, AWS
if clean.isupper() and len(clean) >= 2:
entities.add(clean.lower())
is_entity = True
is_base_entity = True
# Capitalized proper nouns (any position, including first word)
elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
if i > 0:
# Non-first words: always treat as entity
entities.add(clean.lower())
is_entity = True
is_base_entity = True
elif clean.lower() not in QUERY_VERB_STOPWORDS:
# First word: also entity if not a common query verb
entities.add(clean.lower())
is_entity = True
is_base_entity = True
# Technical terms with special chars: node.js, C++, .NET
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
entities.add(clean.lower())
is_entity = True
is_base_entity = True
# CamelCase: JavaScript, TypeScript
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
entities.add(clean.lower())
is_entity = True
# Compound names: word following an entity (TDS motorsports)
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
is_base_entity = True
# Compound names: word following a BASE entity only (one level deep).
elif prev_was_base_entity and clean.lower() not in KEY_TERM_STOPWORDS:
entities.add(clean.lower())
is_entity = True
prev_was_entity = is_entity
prev_was_base_entity = is_base_entity
return entities
@ -208,6 +214,13 @@ def lex_preserves_entities(line: str, entities: set) -> bool:
return any(e in lower for e in entities)
def lex_has_filler(lex_line: str, query: str) -> bool:
"""Does the lex line contain an INTERIOR_FILLER_WORDS word absent from the query?"""
query_words = set(query.lower().split())
return any(w in INTERIOR_FILLER_WORDS and w not in query_words
for w in lex_line.lower().split())
def lex_is_generic(lex_line: str) -> bool:
"""Is this lex line a useless generic filler phrase?"""
lower = lex_line.lower().strip()
@ -280,13 +293,14 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
# --- Diversity (0-30) ---
diversity_score = 0
div_threshold = 3 if len(base_query.split()) >= 5 else 2
if len(expected_items) >= 2:
diversity_score += 15
# Check for diversity among items
div_score = 15
for i, a in enumerate(expected_items):
for b in expected_items[i+1:]:
if not is_diverse(a, b, 2):
if not is_diverse(a, b, div_threshold):
div_score -= 5
deductions.append(f"{only_type} duplicate: {a[:20]}...")
diversity_score += max(0, div_score)
@ -315,6 +329,11 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
quality_score += 5
else:
deductions.append(f"{generic} generic lex phrases")
# Penalty: lex lines containing filler words absent from the query
filler_count = sum(1 for l in expected_items if lex_has_filler(l, base_query))
if filler_count > 0:
quality_score -= filler_count * 3
deductions.append(f"{filler_count} lex line(s) with filler words")
elif only_type == "vec":
# Vec should be natural language sentences
@ -444,10 +463,11 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
diversity_score += 5
div_threshold = 3 if len(query.split()) >= 5 else 2
lex_div = 5
for i, a in enumerate(parsed["lex"]):
for b in parsed["lex"][i+1:]:
if not is_diverse(a, b, 2):
if not is_diverse(a, b, div_threshold):
lex_div -= 2
deductions.append(f"lex duplicate: {a[:20]}...")
diversity_score += max(0, lex_div)
@ -455,7 +475,7 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
vec_div = 5
for i, a in enumerate(parsed["vec"]):
for b in parsed["vec"][i+1:]:
if not is_diverse(a, b, 3):
if not is_diverse(a, b, div_threshold):
vec_div -= 2
deductions.append(f"vec duplicate: {a[:20]}...")
diversity_score += max(0, vec_div)
@ -517,13 +537,18 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
else:
deductions.append("lex missing key terms")
# Bonus: lex uses quoted phrases for multi-word entities (+3)
if entities and parsed["lex"]:
multi_word_entities = [e for e in entities if " " in e or len(e) > 6]
if multi_word_entities:
lex_joined = " ".join(parsed["lex"])
if '"' in lex_joined:
quality_score += 3
# Penalty: lex lines containing filler words absent from the query
if parsed["lex"]:
filler_count = sum(1 for l in parsed["lex"] if lex_has_filler(l, query))
if filler_count > 0:
quality_score -= filler_count * 3
deductions.append(f"{filler_count} lex line(s) with filler words")
# Bonus: lex uses quoted phrases for multi-word queries (+3)
if parsed["lex"] and len(query.split()) >= 2:
lex_joined = " ".join(parsed["lex"])
if '"' in lex_joined:
quality_score += 3
# --- Entity Preservation (-45 to +20) ---
entity_score = 0