fix(reward): tighten entity detection, add filler penalty, stricter diversity
- Compound entity chaining now stops one level deep. Previously "TDS
motorsports team history" would inflate the expected entity set with
"team" and "history", causing false-positive entity-preservation
penalties during GRPO. Now only {tds, motorsports} are detected.
- Add INTERIOR_FILLER_WORDS penalty (-3/line): lex lines containing
"overview" or "basics" absent from the original query are penalised.
Targets template-generator noise, e.g. "ancient overview rome timeline".
- Raise is_diverse threshold 2→3: requires 3 unique words between lex
lines before they count as diverse. Reduces reward for near-duplicate
pairs like "auth setup" / "auth configuration".
- Broaden quoted-phrase bonus: was gated on named entities existing;
now any multi-word query earns +3 for using quotes in lex lines.
Better incentivises BM25-aware syntax like "memory leak" python.
Fixes scoring noise identified while working on issue #247.
This commit is contained in:
parent
d6f3688d91
commit
4511b9bd4d
@ -72,6 +72,10 @@ GENERIC_LEX_PHRASES = frozenset({
|
||||
'what is', 'how to', 'guide to', 'help with',
|
||||
})
|
||||
|
||||
# Words commonly injected as filler/noise into lex lines by template generators
|
||||
# (e.g. "ancient overview rome timeline"). Penalized when absent from the query.
|
||||
INTERIOR_FILLER_WORDS = frozenset({'overview', 'basics'})
|
||||
|
||||
# Chat template tokens that indicate a broken output
|
||||
CHAT_TEMPLATE_TOKENS = frozenset({
|
||||
'<|im_start|>', '<|im_end|>', '<|endoftext|>',
|
||||
@ -142,47 +146,49 @@ def extract_named_entities(query: str) -> set:
|
||||
|
||||
Position-0 words are also detected as entities if they are capitalized and
|
||||
not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob").
|
||||
|
||||
Compound chaining extends one level from a directly-detected entity:
|
||||
"TDS motorsports" -> {tds, motorsports}; "TDS motorsports team" -> {tds, motorsports}.
|
||||
"""
|
||||
entities = set()
|
||||
words = query.split()
|
||||
prev_was_entity = False
|
||||
prev_was_base_entity = False
|
||||
|
||||
for i, word in enumerate(words):
|
||||
clean = word.strip('.,!?:;()[]"\'')
|
||||
if not clean:
|
||||
prev_was_entity = False
|
||||
prev_was_base_entity = False
|
||||
continue
|
||||
|
||||
is_entity = False
|
||||
is_base_entity = False
|
||||
|
||||
# ALL-CAPS acronyms: TDS, API, GPU, AWS
|
||||
if clean.isupper() and len(clean) >= 2:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
is_base_entity = True
|
||||
# Capitalized proper nouns (any position, including first word)
|
||||
elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
if i > 0:
|
||||
# Non-first words: always treat as entity
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
is_base_entity = True
|
||||
elif clean.lower() not in QUERY_VERB_STOPWORDS:
|
||||
# First word: also entity if not a common query verb
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
is_base_entity = True
|
||||
# Technical terms with special chars: node.js, C++, .NET
|
||||
elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
is_base_entity = True
|
||||
# CamelCase: JavaScript, TypeScript
|
||||
elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
# Compound names: word following an entity (TDS motorsports)
|
||||
elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
is_base_entity = True
|
||||
# Compound names: word following a BASE entity only (one level deep).
|
||||
elif prev_was_base_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
||||
entities.add(clean.lower())
|
||||
is_entity = True
|
||||
|
||||
prev_was_entity = is_entity
|
||||
prev_was_base_entity = is_base_entity
|
||||
|
||||
return entities
|
||||
|
||||
@ -208,6 +214,13 @@ def lex_preserves_entities(line: str, entities: set) -> bool:
|
||||
return any(e in lower for e in entities)
|
||||
|
||||
|
||||
def lex_has_filler(lex_line: str, query: str) -> bool:
|
||||
"""Does the lex line contain an INTERIOR_FILLER_WORDS word absent from the query?"""
|
||||
query_words = set(query.lower().split())
|
||||
return any(w in INTERIOR_FILLER_WORDS and w not in query_words
|
||||
for w in lex_line.lower().split())
|
||||
|
||||
|
||||
def lex_is_generic(lex_line: str) -> bool:
|
||||
"""Is this lex line a useless generic filler phrase?"""
|
||||
lower = lex_line.lower().strip()
|
||||
@ -280,13 +293,14 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
|
||||
|
||||
# --- Diversity (0-30) ---
|
||||
diversity_score = 0
|
||||
div_threshold = 3 if len(base_query.split()) >= 5 else 2
|
||||
if len(expected_items) >= 2:
|
||||
diversity_score += 15
|
||||
# Check for diversity among items
|
||||
div_score = 15
|
||||
for i, a in enumerate(expected_items):
|
||||
for b in expected_items[i+1:]:
|
||||
if not is_diverse(a, b, 2):
|
||||
if not is_diverse(a, b, div_threshold):
|
||||
div_score -= 5
|
||||
deductions.append(f"{only_type} duplicate: {a[:20]}...")
|
||||
diversity_score += max(0, div_score)
|
||||
@ -315,6 +329,11 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
|
||||
quality_score += 5
|
||||
else:
|
||||
deductions.append(f"{generic} generic lex phrases")
|
||||
# Penalty: lex lines containing filler words absent from the query
|
||||
filler_count = sum(1 for l in expected_items if lex_has_filler(l, base_query))
|
||||
if filler_count > 0:
|
||||
quality_score -= filler_count * 3
|
||||
deductions.append(f"{filler_count} lex line(s) with filler words")
|
||||
|
||||
elif only_type == "vec":
|
||||
# Vec should be natural language sentences
|
||||
@ -444,10 +463,11 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
|
||||
if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
|
||||
diversity_score += 5
|
||||
|
||||
div_threshold = 3 if len(query.split()) >= 5 else 2
|
||||
lex_div = 5
|
||||
for i, a in enumerate(parsed["lex"]):
|
||||
for b in parsed["lex"][i+1:]:
|
||||
if not is_diverse(a, b, 2):
|
||||
if not is_diverse(a, b, div_threshold):
|
||||
lex_div -= 2
|
||||
deductions.append(f"lex duplicate: {a[:20]}...")
|
||||
diversity_score += max(0, lex_div)
|
||||
@ -455,7 +475,7 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
|
||||
vec_div = 5
|
||||
for i, a in enumerate(parsed["vec"]):
|
||||
for b in parsed["vec"][i+1:]:
|
||||
if not is_diverse(a, b, 3):
|
||||
if not is_diverse(a, b, div_threshold):
|
||||
vec_div -= 2
|
||||
deductions.append(f"vec duplicate: {a[:20]}...")
|
||||
diversity_score += max(0, vec_div)
|
||||
@ -517,13 +537,18 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
|
||||
else:
|
||||
deductions.append("lex missing key terms")
|
||||
|
||||
# Bonus: lex uses quoted phrases for multi-word entities (+3)
|
||||
if entities and parsed["lex"]:
|
||||
multi_word_entities = [e for e in entities if " " in e or len(e) > 6]
|
||||
if multi_word_entities:
|
||||
lex_joined = " ".join(parsed["lex"])
|
||||
if '"' in lex_joined:
|
||||
quality_score += 3
|
||||
# Penalty: lex lines containing filler words absent from the query
|
||||
if parsed["lex"]:
|
||||
filler_count = sum(1 for l in parsed["lex"] if lex_has_filler(l, query))
|
||||
if filler_count > 0:
|
||||
quality_score -= filler_count * 3
|
||||
deductions.append(f"{filler_count} lex line(s) with filler words")
|
||||
|
||||
# Bonus: lex uses quoted phrases for multi-word queries (+3)
|
||||
if parsed["lex"] and len(query.split()) >= 2:
|
||||
lex_joined = " ".join(parsed["lex"])
|
||||
if '"' in lex_joined:
|
||||
quality_score += 3
|
||||
|
||||
# --- Entity Preservation (-45 to +20) ---
|
||||
entity_score = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user