From 3307119173fc24b90bd49a1a30b9ea901e8d31ef Mon Sep 17 00:00:00 2001 From: Aman Madaan Date: Thu, 16 Dec 2021 18:39:59 -0500 Subject: [PATCH] #30: add type hints --- gem_metrics/bertscore.py | 9 ++++++--- gem_metrics/bleu.py | 5 ++++- gem_metrics/bleurt.py | 6 ++++-- gem_metrics/chrf.py | 5 +++-- gem_metrics/data.py | 6 +++--- gem_metrics/local_recall.py | 16 ++++++++++------ gem_metrics/meteor.py | 5 ++++- gem_metrics/metric.py | 12 +++++++----- gem_metrics/msttr.py | 21 ++++++++++----------- gem_metrics/ngrams.py | 14 ++++++++------ gem_metrics/nist.py | 4 +++- gem_metrics/nubia.py | 8 +++++--- gem_metrics/questeval.py | 6 ++++-- gem_metrics/rouge.py | 6 ++++-- gem_metrics/sari.py | 16 +++++++++------- gem_metrics/texts.py | 9 +++++---- gem_metrics/tokenize.py | 7 ++++--- 17 files changed, 93 insertions(+), 62 deletions(-) diff --git a/gem_metrics/bertscore.py b/gem_metrics/bertscore.py index ab585e2..2de0ca4 100644 --- a/gem_metrics/bertscore.py +++ b/gem_metrics/bertscore.py @@ -1,8 +1,11 @@ #!/usr/bin/env python3 +from .texts import Predictions, References from .metric import ReferencedMetric + +from typing import Dict, List from datasets import load_metric -import numpy as np + class BERTScore(ReferencedMetric): @@ -16,11 +19,11 @@ def __init__(self): def _initialize(self): self.metric = load_metric("bertscore", batch_size=64) - def _make_serializable(self, score_entry): + def _make_serializable(self, score_entry) -> List[float]: """Convert from tensor object to list of floats.""" return [float(score) for score in score_entry] - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: """Run BERTScore.""" self.metric.add_batch( predictions=predictions.untokenized, references=references.untokenized diff --git a/gem_metrics/bleu.py b/gem_metrics/bleu.py index 0a9a903..5f13db7 100644 --- a/gem_metrics/bleu.py +++ b/gem_metrics/bleu.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 from .metric import ReferencedMetric +from .texts import Predictions, References + +from typing import Dict import sacrebleu from itertools import zip_longest @@ -12,7 +15,7 @@ def support_caching(self): # BLEU is corpus-level, so individual examples can't be aggregated. return False - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: ref_streams = list(zip_longest(*references.untokenized)) bleu = sacrebleu.corpus_bleu( predictions.untokenized, ref_streams, lowercase=True diff --git a/gem_metrics/bleurt.py b/gem_metrics/bleurt.py index 6886576..2ace13d 100644 --- a/gem_metrics/bleurt.py +++ b/gem_metrics/bleurt.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 from .data import ensure_download from .metric import ReferencedMetric +from .texts import Predictions, References + from bleurt import score import numpy as np import tensorflow as tf - +from typing import Dict class BLEURT(ReferencedMetric): """BLEURT uses the base checkpoint for efficient runtime.""" @@ -22,7 +24,7 @@ def _initialize(self): ) self.metric = score.BleurtScorer(ckpt_path) - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: """Compute the BLEURT score. Multi-ref will be averaged.""" # Use untokenized here since the module uses its own tokenizer. if isinstance(references.untokenized[0], list): diff --git a/gem_metrics/chrf.py b/gem_metrics/chrf.py index b744e77..4851396 100644 --- a/gem_metrics/chrf.py +++ b/gem_metrics/chrf.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 from .metric import ReferencedMetric +from .texts import Predictions, References from itertools import zip_longest - import sacrebleu +from typing import Dict class CHRF(ReferencedMetric): """ @@ -40,7 +41,7 @@ def support_caching(self): # corpus-level, so individual examples won't be aggregated. return False - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: ref_streams = list(zip_longest(*references.untokenized)) scores = {} diff --git a/gem_metrics/data.py b/gem_metrics/data.py index d646c3a..9b1ae0d 100644 --- a/gem_metrics/data.py +++ b/gem_metrics/data.py @@ -21,7 +21,7 @@ nltk.data.path.insert(0, _NLTK_DATA_PATH) -def nltk_ensure_download(package): +def nltk_ensure_download(package: str): """Check if the given package is available, download if needed.""" try: nltk.data.find(package) @@ -30,7 +30,7 @@ def nltk_ensure_download(package): nltk.download(package_id, download_dir=_NLTK_DATA_PATH) -def _urlretrieve_reporthook(count, block_size, total_size, start_time): +def _urlretrieve_reporthook(count: int, block_size: int, total_size: int, start_time: int): """Helper function -- progress indicator.""" # adapted from https://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python duration = time.time() - start_time @@ -44,7 +44,7 @@ def _urlretrieve_reporthook(count, block_size, total_size, start_time): sys.stderr.flush() -def ensure_download(subdir, target_file, url): +def ensure_download(subdir: str, target_file: str, url: str) -> str: """Check if the given data file is available, download if needed.""" target_dir = os.path.join(_BASE_DIR, subdir) diff --git a/gem_metrics/local_recall.py b/gem_metrics/local_recall.py index 2cbd5e9..2b28ead 100644 --- a/gem_metrics/local_recall.py +++ b/gem_metrics/local_recall.py @@ -1,4 +1,8 @@ + from .metric import ReferencedMetric +from .texts import Predictions, References + +from typing import Any, Dict, List, Set, Tuple, Union from collections import Counter, defaultdict @@ -26,7 +30,7 @@ def support_caching(self): # LocalRecall is corpus-level, so individual examples can't be aggregated. return False - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: results = LocalRecall.local_recall_scores( predictions.list_tokenized_lower_nopunct, references.list_tokenized_lower_nopunct, @@ -35,7 +39,7 @@ def compute(self, cache, predictions, references): return {"local_recall": results} @staticmethod - def build_reference_index(refs): + def build_reference_index(refs: List[List[str]]) -> Dict[int, Set]: """ Build reference index for a given item. Input: list of lists (list of sentences, where each sentence is a list of string tokens). @@ -50,7 +54,7 @@ def build_reference_index(refs): return importance_index @staticmethod - def check_item(prediction, refs): + def check_item(prediction: Union[List[str], Set[str]], refs: List[List[str]]) -> Dict: """ Check whether the predictions capture words that are frequently mentioned. @@ -76,14 +80,14 @@ def check_item(prediction, refs): return results @staticmethod - def replace(a_list, to_replace, replacement): + def replace(a_list: List[Any], to_replace: Any, replacement: Any) -> List[Any]: """ Returns a_list with all occurrences of to_replace replaced with replacement. """ return [replacement if x == to_replace else x for x in a_list] @staticmethod - def aggregate_score(outcomes): + def aggregate_score(outcomes: List[Tuple[int, int]]) -> float: """ Produce an aggregate score based on a list of tuples: [(size_overlap, size_refs)] """ @@ -93,7 +97,7 @@ def aggregate_score(outcomes): return score @staticmethod - def local_recall_scores(predictions, full_references): + def local_recall_scores(predictions: List[List[str]], full_references: List[List[List[str]]]) -> Dict: """ Compute local recall scores. """ diff --git a/gem_metrics/meteor.py b/gem_metrics/meteor.py index eccd232..5a1ae80 100644 --- a/gem_metrics/meteor.py +++ b/gem_metrics/meteor.py @@ -2,6 +2,9 @@ from .metric import ReferencedMetric from .impl.meteor import PyMeteorWrapper +from .texts import Predictions, References + +from typing import Dict from logzero import logger @@ -14,7 +17,7 @@ def support_caching(self): # While individual scores can be computed, the overall score is different. return False - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: try: m = PyMeteorWrapper(predictions.language.alpha_2) except Exception as e: diff --git a/gem_metrics/metric.py b/gem_metrics/metric.py index cb4ae1f..36231e8 100644 --- a/gem_metrics/metric.py +++ b/gem_metrics/metric.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 +from .texts import Predictions, References, Sources + from copy import copy import numpy as np -from typing import List +from typing import List, Dict from logzero import logger @@ -47,7 +49,7 @@ def _aggregate_scores(self, score_list: List): "Please add to this function an aggregator for your data format." ) - def compute_cached(self, cache, predictions, *args): + def compute_cached(self, cache, predictions: Predictions, *args): """Loops through the predictions to check for cache hits before computing.""" original_order = copy(predictions.ids) @@ -100,19 +102,19 @@ def compute_cached(self, cache, predictions, *args): class ReferencelessMetric(AbstractMetric): """Base class for all referenceless metrics.""" - def compute(self, cache, predictions): + def compute(self, cache, predictions: Predictions) -> Dict: raise NotImplementedError class ReferencedMetric(AbstractMetric): """Base class for all referenced metrics.""" - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: raise NotImplementedError class SourceAndReferencedMetric(AbstractMetric): """Base class for all metrics that require source and reference sentences.""" - def compute(self, cache, predictions, references, sources): + def compute(self, cache, predictions: Predictions, references: References, sources: Sources) -> Dict: raise NotImplementedError diff --git a/gem_metrics/msttr.py b/gem_metrics/msttr.py index 040baa8..2b71bf9 100644 --- a/gem_metrics/msttr.py +++ b/gem_metrics/msttr.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 -import itertools +from .metric import ReferencelessMetric +from .texts import Predictions -from string import punctuation -from nltk import ngrams +import itertools import random - -from .metric import ReferencelessMetric +from typing import Dict, List class MSTTR(ReferencelessMetric): @@ -17,7 +16,7 @@ class MSTTR(ReferencelessMetric): https://github.com/evanmiltenburg/NLG-diversity/blob/main/diversity.py """ - def __init__(self, window_size=100): + def __init__(self, window_size: int = 100): # use MSTTR-100 by default. self.rnd = random.Random(1234) self.window_size = window_size @@ -26,7 +25,7 @@ def support_caching(self): # MSTTR is corpus-level, so individual examples can't be aggregated. return False - def compute(self, cache, predictions): + def compute(self, cache, predictions: Predictions) -> Dict: return { f"msttr-{self.window_size}": round( self._MSTTR(predictions.list_tokenized_lower, self.window_size)[ @@ -42,11 +41,11 @@ def compute(self, cache, predictions): ), } - def _TTR(self, list_of_words): + def _TTR(self, list_of_words: List[str]) -> float: "Compute type-token ratio." return len(set(list_of_words)) / len(list_of_words) - def _MSTTR(self, tokenized_data, window_size): + def _MSTTR(self, tokenized_data: List[List[str]], window_size: int) -> Dict: """ Computes Mean-Segmental Type-Token Ratio (MSTTR; Johnson, 1944) by dividing the concatenated texts into non-overlapping segments of equal @@ -56,7 +55,7 @@ def _MSTTR(self, tokenized_data, window_size): """ ttrs = [] concatenated = list(itertools.chain.from_iterable(tokenized_data)) - + for i in range(0, len(concatenated), window_size): window = concatenated[i: i+window_size] # removes the last segment from the computation @@ -71,7 +70,7 @@ def _MSTTR(self, tokenized_data, window_size): } return results - def _repeated_MSTTR(self, tokenized_data, window_size, repeats=5): + def _repeated_MSTTR(self, tokenized_data: List[List[str]], window_size: int, repeats: int = 5) -> float: "Repeated MSTTR to obtain a more robust MSTTR value." msttrs = [] for i in range(repeats): diff --git a/gem_metrics/ngrams.py b/gem_metrics/ngrams.py index bd024e6..2c609a6 100644 --- a/gem_metrics/ngrams.py +++ b/gem_metrics/ngrams.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 +from typing import Dict, List, Tuple +from .metric import ReferencelessMetric +from .texts import Predictions + import numpy as np from nltk import ngrams -from .metric import ReferencelessMetric - class NGramStats(ReferencelessMetric): """Ngram basic statistics and entropy, working with tokenized & lowercased data (+ variant excluding punctuation): @@ -28,7 +30,7 @@ def support_caching(self): # NGramStats is corpus-level, so individual examples can't be aggregated. return False - def compute(self, cache, predictions): + def compute(self, cache, predictions: Predictions) -> Dict: results = {} for data_id, data in [ @@ -64,7 +66,7 @@ def compute(self, cache, predictions): return results - def _ngram_stats(self, data, N): + def _ngram_stats(self, data: List[List[str]], N: int) -> Tuple[Dict, int, int]: """Return basic ngram statistics, as well as a dict of all ngrams and their freqsuencies.""" ngram_freqs = {} # ngrams with frequencies ngram_len = 0 # total number of ngrams @@ -76,7 +78,7 @@ def _ngram_stats(self, data, N): uniq_ngrams = len([val for val in ngram_freqs.values() if val == 1]) return ngram_freqs, uniq_ngrams, ngram_len - def _entropy(self, ngram_freqs): + def _entropy(self, ngram_freqs: Dict) -> float: """Shannon entropy over ngram frequencies""" total_freq = sum(ngram_freqs.values()) return -sum( @@ -86,7 +88,7 @@ def _entropy(self, ngram_freqs): ] ) - def _cond_entropy(self, joint, ctx): + def _cond_entropy(self, joint: Dict, ctx: Dict) -> float: """Conditional/next-word entropy (language model style), using ngrams (joint) and n-1-grams (ctx).""" total_joint = sum(joint.values()) total_ctx = sum(ctx.values()) diff --git a/gem_metrics/nist.py b/gem_metrics/nist.py index 8afd57a..b2779b1 100644 --- a/gem_metrics/nist.py +++ b/gem_metrics/nist.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 +from .texts import Predictions, References from .metric import ReferencedMetric from .impl.pymteval import NISTScore +from typing import Dict class NIST(ReferencedMetric): """NIST from e2e-metrics.""" @@ -11,7 +13,7 @@ def support_caching(self): # NIST is corpus-level, so individual examples can't be aggregated. return False - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: nist = NISTScore() for pred, refs in zip(predictions.untokenized, references.untokenized): nist.append(pred, refs) diff --git a/gem_metrics/nubia.py b/gem_metrics/nubia.py index 771e06b..7583ba6 100644 --- a/gem_metrics/nubia.py +++ b/gem_metrics/nubia.py @@ -1,7 +1,9 @@ -from nubia_score import Nubia -import numpy as np from .metric import ReferencedMetric +from .texts import Predictions, References +from nubia_score import Nubia +import numpy as np +from typing import Dict class NUBIA(ReferencedMetric): def __init__(self): @@ -15,7 +17,7 @@ def _initialize(self): self.metric.roberta_MNLI.to("cuda") self.metric.gpt_model.to("cuda") - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: """Run Nubia""" scores = {} for ref, pred, pred_id in zip( diff --git a/gem_metrics/questeval.py b/gem_metrics/questeval.py index bbf975e..cb0790b 100644 --- a/gem_metrics/questeval.py +++ b/gem_metrics/questeval.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 - from .metric import SourceAndReferencedMetric +from .texts import Predictions, References, Sources + +from typing import Dict from questeval.questeval_metric import QuestEval as QuestEvalMetric from logzero import logger @@ -21,7 +23,7 @@ def support_caching(self): # We are using corpus-level QuestEval which is aggregated. return True - def compute(self, cache, predictions, references, sources): + def compute(self, cache, predictions: Predictions, references: References, sources: Sources) -> Dict: # If task or language is different, we must change QA / QG models for questeval if predictions.task != self.task or predictions.language.alpha_2 != self.language: self.task = predictions.task diff --git a/gem_metrics/rouge.py b/gem_metrics/rouge.py index b37b89f..c6f564c 100644 --- a/gem_metrics/rouge.py +++ b/gem_metrics/rouge.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 - +from .texts import Predictions, References from .metric import ReferencedMetric + +from typing import Dict import numpy as np from rouge_score import rouge_scorer, scoring @@ -12,7 +14,7 @@ class ROUGE(ReferencedMetric): the jackknifing follows the description of the ROUGE paper. """ - def compute(self, cache, predictions, references): + def compute(self, cache, predictions: Predictions, references: References) -> Dict: rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] rouge = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) scores = {} diff --git a/gem_metrics/sari.py b/gem_metrics/sari.py index cd1b8ab..a216c97 100644 --- a/gem_metrics/sari.py +++ b/gem_metrics/sari.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 - +from .texts import Predictions, References, Sources from .metric import SourceAndReferencedMetric + +from typing import Dict, Tuple, List from collections import Counter import sacrebleu import sacremoses @@ -24,7 +26,7 @@ class SARI(SourceAndReferencedMetric): [3] https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py """ - def compute(self, cache, predictions, references, sources): + def compute(self, cache, predictions: Predictions, references: References, sources: Sources) -> Dict: srcs = [self.normalize(sent) for sent in sources.untokenized] preds = [self.normalize(sent) for sent in predictions.untokenized] @@ -48,7 +50,7 @@ def compute(self, cache, predictions, references, sources): return sari_scores - def SARIngram(self, sgrams, cgrams, rgramslist, numref): + def SARIngram(self, sgrams: List[str], cgrams: List[str], rgramslist: List[List[str]], numref: int) -> Tuple[float, float, float]: rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] rgramcounter = Counter(rgramsall) @@ -153,7 +155,7 @@ def SARIngram(self, sgrams, cgrams, rgramslist, numref): return (keepscore, delscore_precision, addscore) - def SARIsent(self, ssent, csent, rsents): + def SARIsent(self, ssent: str, csent: str, rsents: List[str]) -> float: numref = len(rsents) s1grams = ssent.split(" ") @@ -255,11 +257,11 @@ def SARIsent(self, ssent, csent, rsents): def normalize( self, - sentence, + sentence: str, lowercase: bool = True, tokenizer: str = "13a", return_str: bool = True, - ): + ) -> List[str]: # Normalization is requried for the ASSET dataset to allow using space # to split the sentence. Even though Wiki-Auto and TURK datasets, @@ -277,7 +279,7 @@ def normalize( return sentence - def tokenize(self, sentence, tokenizer): + def tokenize(self, sentence: str, tokenizer: str) -> List[str]: if tokenizer in ["intl", "13a"]: sentence = sacrebleu.metrics.bleu._get_tokenizer(tokenizer)()(sentence) elif tokenizer == "moses": diff --git a/gem_metrics/texts.py b/gem_metrics/texts.py index 5567c21..8507294 100644 --- a/gem_metrics/texts.py +++ b/gem_metrics/texts.py @@ -1,13 +1,14 @@ #!/usr/bin/env python3 +from .tokenize import default_tokenize_func +from gem_metrics.config import get_language_for_dataset, get_task_type_for_dataset import functools -from gem_metrics.config import get_language_for_dataset, get_task_type_for_dataset -from typing import List, Optional +from typing import List, Optional, Union, Dict import json import string from pycountry import languages from logzero import logger -from .tokenize import default_tokenize_func + class Texts: @@ -15,7 +16,7 @@ class Texts: PUNCTUATION = set(string.punctuation) - def __init__(self, data_key, data, language="en"): + def __init__(self, data_key: str, data: Union[str, Dict], language="en"): self.data_key = data_key # TODO allow other data formats. if not isinstance(data, dict): diff --git a/gem_metrics/tokenize.py b/gem_metrics/tokenize.py index 2e59f7d..547ac5c 100644 --- a/gem_metrics/tokenize.py +++ b/gem_metrics/tokenize.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 +from .data import nltk_ensure_download +from typing import List import re from functools import partial import nltk -from .data import nltk_ensure_download -def default_tokenize_func(lang): +def default_tokenize_func(lang: str): """Return the default tokenizer function for a given language (Punkt, backoff to dumb_tokenize). @param lang: pycountry.db.Language object representing the language (result of pycountry.languages.get) @return tokenizer function, taking one parameter (text) and returning list of tokens. @@ -24,7 +25,7 @@ def default_tokenize_func(lang): return func -def dumb_tokenize(text): +def dumb_tokenize(text: str) -> List[str]: """Tokenize text (separate tokens by spaces), language-agnostic failsafe version. @param text: String to be tokenized @return list of tokens