diff options
Diffstat (limited to 'data/style_features.py')
| -rw-r--r-- | data/style_features.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/data/style_features.py b/data/style_features.py new file mode 100644 index 0000000..d674419 --- /dev/null +++ b/data/style_features.py @@ -0,0 +1,138 @@ +"""Style feature extraction for SFD (Style Feature Distance) metric.""" + +import nltk +from nltk.tokenize import sent_tokenize, word_tokenize + +FEATURE_NAMES = [ + 'length', 'avg_sent_len', 'TTR', 'newline_rate', 'exclaim_rate', + 'first_person_rate', 'adj_adv_rate', 'sentiment_score' +] + + +def extract_style_features(text: str) -> list: + """Extract style feature vector from text. + + Returns: + [length, avg_sent_len, TTR, newline_rate, exclaim_rate, + first_person_rate, adj_adv_rate, sentiment_score] + """ + if not text or not text.strip(): + return [0.0] * 8 + + words = word_tokenize(text) + num_words = max(len(words), 1) + + # 1. Length (word count) + length = float(num_words) + + # 2. Average sentence length + sentences = sent_tokenize(text) + num_sents = max(len(sentences), 1) + avg_sent_len = num_words / num_sents + + # 3. Type-Token Ratio + unique_words = set(w.lower() for w in words if w.isalpha()) + alpha_words = [w for w in words if w.isalpha()] + ttr = len(unique_words) / max(len(alpha_words), 1) + + # 4. Newline rate + newline_count = text.count('\n') + newline_rate = newline_count / num_sents + + # 5. Exclamation rate + exclaim_count = text.count('!') + exclaim_rate = exclaim_count / num_sents + + # 6. First-person rate + first_person = {'i', 'me', 'my', 'mine', 'myself', 'we', 'us', 'our', 'ours', 'ourselves'} + fp_count = sum(1 for w in words if w.lower() in first_person) + first_person_rate = fp_count / num_words + + # 7. Adjective/Adverb rate + try: + tagged = nltk.pos_tag(words) + adj_adv_tags = {'JJ', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'} + adj_adv_count = sum(1 for _, tag in tagged if tag in adj_adv_tags) + adj_adv_rate = adj_adv_count / num_words + except Exception: + adj_adv_rate = 0.0 + + # 8. Simple sentiment score (positive - negative word ratio) + positive_words = { + 'good', 'great', 'excellent', 'amazing', 'wonderful', 'fantastic', + 'love', 'loved', 'best', 'perfect', 'awesome', 'beautiful', + 'enjoy', 'enjoyed', 'happy', 'glad', 'nice', 'brilliant', + 'outstanding', 'superb', 'delightful', 'pleasant', 'favorite', + 'recommend', 'recommended', 'impressive', 'incredible', + } + negative_words = { + 'bad', 'terrible', 'awful', 'horrible', 'worst', 'poor', + 'hate', 'hated', 'boring', 'disappointing', 'disappointed', + 'ugly', 'annoying', 'waste', 'useless', 'mediocre', 'dull', + 'pathetic', 'garbage', 'rubbish', 'disgusting', 'dreadful', + } + pos_count = sum(1 for w in words if w.lower() in positive_words) + neg_count = sum(1 for w in words if w.lower() in negative_words) + sentiment_score = (pos_count - neg_count) / num_words + + return [length, avg_sent_len, ttr, newline_rate, exclaim_rate, + first_person_rate, adj_adv_rate, sentiment_score] + + +def compute_sfd(generated_text: str, support_texts: list, exclude_length: bool = False) -> float: + """Compute Style Feature Distance. + + Args: + generated_text: The model's generated output. + support_texts: List of the user's support set output texts. + exclude_length: If True, exclude length feature (index 0) from SFD. + + Returns: + L1 distance between generated style and user style prototype. + """ + gen_features = extract_style_features(generated_text) + support_features_list = [extract_style_features(t) for t in support_texts] + num_support = len(support_features_list) + if num_support == 0: + return 0.0 + + prototype = [0.0] * len(gen_features) + for sf in support_features_list: + for i in range(len(prototype)): + prototype[i] += sf[i] + prototype = [p / num_support for p in prototype] + + start_idx = 1 if exclude_length else 0 + sfd = 0.0 + for i in range(start_idx, len(gen_features)): + g, p = gen_features[i], prototype[i] + scale = max(abs(p), 1.0) + sfd += abs(g - p) / scale + return sfd + + +def compute_feature_deltas(generated_text: str, support_texts: list) -> dict: + """Compute per-feature deltas between generated text and user style prototype. + + Returns dict mapping feature_name -> (gen_value, proto_value, delta). + """ + gen_features = extract_style_features(generated_text) + support_features_list = [extract_style_features(t) for t in support_texts] + num_support = len(support_features_list) + if num_support == 0: + return {} + + prototype = [0.0] * len(gen_features) + for sf in support_features_list: + for i in range(len(prototype)): + prototype[i] += sf[i] + prototype = [p / num_support for p in prototype] + + deltas = {} + for i, name in enumerate(FEATURE_NAMES): + deltas[name] = { + 'gen': gen_features[i], + 'proto': prototype[i], + 'delta': gen_features[i] - prototype[i], + } + return deltas |
