diff --git a/finetune/reward.py b/finetune/reward.py index df1e3d8..9074a8a 100644 --- a/finetune/reward.py +++ b/finetune/reward.py @@ -72,6 +72,10 @@ GENERIC_LEX_PHRASES = frozenset({ 'what is', 'how to', 'guide to', 'help with', }) +# Words commonly injected as filler/noise into lex lines by template generators +# (e.g. "ancient overview rome timeline"). Penalized when absent from the query. +INTERIOR_FILLER_WORDS = frozenset({'overview', 'basics'}) + # Chat template tokens that indicate a broken output CHAT_TEMPLATE_TOKENS = frozenset({ '<|im_start|>', '<|im_end|>', '<|endoftext|>', @@ -142,47 +146,49 @@ def extract_named_entities(query: str) -> set: Position-0 words are also detected as entities if they are capitalized and not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob"). + + Compound chaining extends one level from a directly-detected entity: + "TDS motorsports" -> {tds, motorsports}; "TDS motorsports team" -> {tds, motorsports}. """ entities = set() words = query.split() - prev_was_entity = False + prev_was_base_entity = False for i, word in enumerate(words): clean = word.strip('.,!?:;()[]"\'') if not clean: - prev_was_entity = False + prev_was_base_entity = False continue - is_entity = False + is_base_entity = False # ALL-CAPS acronyms: TDS, API, GPU, AWS if clean.isupper() and len(clean) >= 2: entities.add(clean.lower()) - is_entity = True + is_base_entity = True # Capitalized proper nouns (any position, including first word) elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS: if i > 0: # Non-first words: always treat as entity entities.add(clean.lower()) - is_entity = True + is_base_entity = True elif clean.lower() not in QUERY_VERB_STOPWORDS: # First word: also entity if not a common query verb entities.add(clean.lower()) - is_entity = True + is_base_entity = True # Technical terms with special chars: node.js, C++, .NET elif any(c in clean for c in '.+-#@') and len(clean) >= 2: entities.add(clean.lower()) - is_entity = True + is_base_entity = True # CamelCase: JavaScript, TypeScript elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper(): entities.add(clean.lower()) - is_entity = True - # Compound names: word following an entity (TDS motorsports) - elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS: + is_base_entity = True + # Compound names: word following a BASE entity only (one level deep). + elif prev_was_base_entity and clean.lower() not in KEY_TERM_STOPWORDS: entities.add(clean.lower()) - is_entity = True - prev_was_entity = is_entity + prev_was_base_entity = is_base_entity return entities @@ -208,6 +214,13 @@ def lex_preserves_entities(line: str, entities: set) -> bool: return any(e in lower for e in entities) +def lex_has_filler(lex_line: str, query: str) -> bool: + """Does the lex line contain an INTERIOR_FILLER_WORDS word absent from the query?""" + query_words = set(query.lower().split()) + return any(w in INTERIOR_FILLER_WORDS and w not in query_words + for w in lex_line.lower().split()) + + def lex_is_generic(lex_line: str) -> bool: """Is this lex line a useless generic filler phrase?""" lower = lex_line.lower().strip() @@ -280,13 +293,14 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool # --- Diversity (0-30) --- diversity_score = 0 + div_threshold = 3 if len(base_query.split()) >= 5 else 2 if len(expected_items) >= 2: diversity_score += 15 # Check for diversity among items div_score = 15 for i, a in enumerate(expected_items): for b in expected_items[i+1:]: - if not is_diverse(a, b, 2): + if not is_diverse(a, b, div_threshold): div_score -= 5 deductions.append(f"{only_type} duplicate: {a[:20]}...") diversity_score += max(0, div_score) @@ -315,6 +329,11 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool quality_score += 5 else: deductions.append(f"{generic} generic lex phrases") + # Penalty: lex lines containing filler words absent from the query + filler_count = sum(1 for l in expected_items if lex_has_filler(l, base_query)) + if filler_count > 0: + quality_score -= filler_count * 3 + deductions.append(f"{filler_count} lex line(s) with filler words") elif only_type == "vec": # Vec should be natural language sentences @@ -444,10 +463,11 @@ def score_expansion_detailed(query: str, expansion: str) -> dict: if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5 + div_threshold = 3 if len(query.split()) >= 5 else 2 lex_div = 5 for i, a in enumerate(parsed["lex"]): for b in parsed["lex"][i+1:]: - if not is_diverse(a, b, 2): + if not is_diverse(a, b, div_threshold): lex_div -= 2 deductions.append(f"lex duplicate: {a[:20]}...") diversity_score += max(0, lex_div) @@ -455,7 +475,7 @@ def score_expansion_detailed(query: str, expansion: str) -> dict: vec_div = 5 for i, a in enumerate(parsed["vec"]): for b in parsed["vec"][i+1:]: - if not is_diverse(a, b, 3): + if not is_diverse(a, b, div_threshold): vec_div -= 2 deductions.append(f"vec duplicate: {a[:20]}...") diversity_score += max(0, vec_div) @@ -517,13 +537,18 @@ def score_expansion_detailed(query: str, expansion: str) -> dict: else: deductions.append("lex missing key terms") - # Bonus: lex uses quoted phrases for multi-word entities (+3) - if entities and parsed["lex"]: - multi_word_entities = [e for e in entities if " " in e or len(e) > 6] - if multi_word_entities: - lex_joined = " ".join(parsed["lex"]) - if '"' in lex_joined: - quality_score += 3 + # Penalty: lex lines containing filler words absent from the query + if parsed["lex"]: + filler_count = sum(1 for l in parsed["lex"] if lex_has_filler(l, query)) + if filler_count > 0: + quality_score -= filler_count * 3 + deductions.append(f"{filler_count} lex line(s) with filler words") + + # Bonus: lex uses quoted phrases for multi-word queries (+3) + if parsed["lex"] and len(query.split()) >= 2: + lex_joined = " ".join(parsed["lex"]) + if '"' in lex_joined: + quality_score += 3 # --- Entity Preservation (-45 to +20) --- entity_score = 0