Skip to content

Commit

Permalink
#30: add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
madaan committed Dec 16, 2021
1 parent 9435858 commit 3307119
Show file tree
Hide file tree
Showing 17 changed files with 93 additions and 62 deletions.
9 changes: 6 additions & 3 deletions gem_metrics/bertscore.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#!/usr/bin/env python3

from .texts import Predictions, References
from .metric import ReferencedMetric

from typing import Dict, List
from datasets import load_metric
import numpy as np



class BERTScore(ReferencedMetric):
Expand All @@ -16,11 +19,11 @@ def __init__(self):
def _initialize(self):
self.metric = load_metric("bertscore", batch_size=64)

def _make_serializable(self, score_entry):
def _make_serializable(self, score_entry) -> List[float]:
"""Convert from tensor object to list of floats."""
return [float(score) for score in score_entry]

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
"""Run BERTScore."""
self.metric.add_batch(
predictions=predictions.untokenized, references=references.untokenized
Expand Down
5 changes: 4 additions & 1 deletion gem_metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#!/usr/bin/env python3

from .metric import ReferencedMetric
from .texts import Predictions, References

from typing import Dict
import sacrebleu
from itertools import zip_longest

Expand All @@ -12,7 +15,7 @@ def support_caching(self):
# BLEU is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
ref_streams = list(zip_longest(*references.untokenized))
bleu = sacrebleu.corpus_bleu(
predictions.untokenized, ref_streams, lowercase=True
Expand Down
6 changes: 4 additions & 2 deletions gem_metrics/bleurt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3
from .data import ensure_download
from .metric import ReferencedMetric
from .texts import Predictions, References

from bleurt import score
import numpy as np
import tensorflow as tf

from typing import Dict

class BLEURT(ReferencedMetric):
"""BLEURT uses the base checkpoint for efficient runtime."""
Expand All @@ -22,7 +24,7 @@ def _initialize(self):
)
self.metric = score.BleurtScorer(ckpt_path)

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
"""Compute the BLEURT score. Multi-ref will be averaged."""
# Use untokenized here since the module uses its own tokenizer.
if isinstance(references.untokenized[0], list):
Expand Down
5 changes: 3 additions & 2 deletions gem_metrics/chrf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3

from .metric import ReferencedMetric
from .texts import Predictions, References

from itertools import zip_longest

import sacrebleu
from typing import Dict

class CHRF(ReferencedMetric):
"""
Expand Down Expand Up @@ -40,7 +41,7 @@ def support_caching(self):
# corpus-level, so individual examples won't be aggregated.
return False

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
ref_streams = list(zip_longest(*references.untokenized))
scores = {}

Expand Down
6 changes: 3 additions & 3 deletions gem_metrics/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
nltk.data.path.insert(0, _NLTK_DATA_PATH)


def nltk_ensure_download(package):
def nltk_ensure_download(package: str):
"""Check if the given package is available, download if needed."""
try:
nltk.data.find(package)
Expand All @@ -30,7 +30,7 @@ def nltk_ensure_download(package):
nltk.download(package_id, download_dir=_NLTK_DATA_PATH)


def _urlretrieve_reporthook(count, block_size, total_size, start_time):
def _urlretrieve_reporthook(count: int, block_size: int, total_size: int, start_time: int):
"""Helper function -- progress indicator."""
# adapted from https://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python
duration = time.time() - start_time
Expand All @@ -44,7 +44,7 @@ def _urlretrieve_reporthook(count, block_size, total_size, start_time):
sys.stderr.flush()


def ensure_download(subdir, target_file, url):
def ensure_download(subdir: str, target_file: str, url: str) -> str:
"""Check if the given data file is available, download if needed."""

target_dir = os.path.join(_BASE_DIR, subdir)
Expand Down
16 changes: 10 additions & 6 deletions gem_metrics/local_recall.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

from .metric import ReferencedMetric
from .texts import Predictions, References

from typing import Any, Dict, List, Set, Tuple, Union
from collections import Counter, defaultdict


Expand Down Expand Up @@ -26,7 +30,7 @@ def support_caching(self):
# LocalRecall is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
results = LocalRecall.local_recall_scores(
predictions.list_tokenized_lower_nopunct,
references.list_tokenized_lower_nopunct,
Expand All @@ -35,7 +39,7 @@ def compute(self, cache, predictions, references):
return {"local_recall": results}

@staticmethod
def build_reference_index(refs):
def build_reference_index(refs: List[List[str]]) -> Dict[int, Set]:
"""
Build reference index for a given item.
Input: list of lists (list of sentences, where each sentence is a list of string tokens).
Expand All @@ -50,7 +54,7 @@ def build_reference_index(refs):
return importance_index

@staticmethod
def check_item(prediction, refs):
def check_item(prediction: Union[List[str], Set[str]], refs: List[List[str]]) -> Dict:
"""
Check whether the predictions capture words that are frequently mentioned.
Expand All @@ -76,14 +80,14 @@ def check_item(prediction, refs):
return results

@staticmethod
def replace(a_list, to_replace, replacement):
def replace(a_list: List[Any], to_replace: Any, replacement: Any) -> List[Any]:
"""
Returns a_list with all occurrences of to_replace replaced with replacement.
"""
return [replacement if x == to_replace else x for x in a_list]

@staticmethod
def aggregate_score(outcomes):
def aggregate_score(outcomes: List[Tuple[int, int]]) -> float:
"""
Produce an aggregate score based on a list of tuples: [(size_overlap, size_refs)]
"""
Expand All @@ -93,7 +97,7 @@ def aggregate_score(outcomes):
return score

@staticmethod
def local_recall_scores(predictions, full_references):
def local_recall_scores(predictions: List[List[str]], full_references: List[List[List[str]]]) -> Dict:
"""
Compute local recall scores.
"""
Expand Down
5 changes: 4 additions & 1 deletion gem_metrics/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from .metric import ReferencedMetric
from .impl.meteor import PyMeteorWrapper
from .texts import Predictions, References

from typing import Dict
from logzero import logger


Expand All @@ -14,7 +17,7 @@ def support_caching(self):
# While individual scores can be computed, the overall score is different.
return False

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
try:
m = PyMeteorWrapper(predictions.language.alpha_2)
except Exception as e:
Expand Down
12 changes: 7 additions & 5 deletions gem_metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3
from .texts import Predictions, References, Sources

from copy import copy
import numpy as np
from typing import List
from typing import List, Dict
from logzero import logger


Expand Down Expand Up @@ -47,7 +49,7 @@ def _aggregate_scores(self, score_list: List):
"Please add to this function an aggregator for your data format."
)

def compute_cached(self, cache, predictions, *args):
def compute_cached(self, cache, predictions: Predictions, *args):
"""Loops through the predictions to check for cache hits before computing."""
original_order = copy(predictions.ids)

Expand Down Expand Up @@ -100,19 +102,19 @@ def compute_cached(self, cache, predictions, *args):
class ReferencelessMetric(AbstractMetric):
"""Base class for all referenceless metrics."""

def compute(self, cache, predictions):
def compute(self, cache, predictions: Predictions) -> Dict:
raise NotImplementedError


class ReferencedMetric(AbstractMetric):
"""Base class for all referenced metrics."""

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
raise NotImplementedError


class SourceAndReferencedMetric(AbstractMetric):
"""Base class for all metrics that require source and reference sentences."""

def compute(self, cache, predictions, references, sources):
def compute(self, cache, predictions: Predictions, references: References, sources: Sources) -> Dict:
raise NotImplementedError
21 changes: 10 additions & 11 deletions gem_metrics/msttr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#!/usr/bin/env python3
import itertools
from .metric import ReferencelessMetric
from .texts import Predictions

from string import punctuation
from nltk import ngrams
import itertools
import random

from .metric import ReferencelessMetric
from typing import Dict, List


class MSTTR(ReferencelessMetric):
Expand All @@ -17,7 +16,7 @@ class MSTTR(ReferencelessMetric):
https://github.com/evanmiltenburg/NLG-diversity/blob/main/diversity.py
"""

def __init__(self, window_size=100):
def __init__(self, window_size: int = 100):
# use MSTTR-100 by default.
self.rnd = random.Random(1234)
self.window_size = window_size
Expand All @@ -26,7 +25,7 @@ def support_caching(self):
# MSTTR is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions):
def compute(self, cache, predictions: Predictions) -> Dict:
return {
f"msttr-{self.window_size}": round(
self._MSTTR(predictions.list_tokenized_lower, self.window_size)[
Expand All @@ -42,11 +41,11 @@ def compute(self, cache, predictions):
),
}

def _TTR(self, list_of_words):
def _TTR(self, list_of_words: List[str]) -> float:
"Compute type-token ratio."
return len(set(list_of_words)) / len(list_of_words)

def _MSTTR(self, tokenized_data, window_size):
def _MSTTR(self, tokenized_data: List[List[str]], window_size: int) -> Dict:
"""
Computes Mean-Segmental Type-Token Ratio (MSTTR; Johnson, 1944)
by dividing the concatenated texts into non-overlapping segments of equal
Expand All @@ -56,7 +55,7 @@ def _MSTTR(self, tokenized_data, window_size):
"""
ttrs = []
concatenated = list(itertools.chain.from_iterable(tokenized_data))

for i in range(0, len(concatenated), window_size):
window = concatenated[i: i+window_size]
# removes the last segment from the computation
Expand All @@ -71,7 +70,7 @@ def _MSTTR(self, tokenized_data, window_size):
}
return results

def _repeated_MSTTR(self, tokenized_data, window_size, repeats=5):
def _repeated_MSTTR(self, tokenized_data: List[List[str]], window_size: int, repeats: int = 5) -> float:
"Repeated MSTTR to obtain a more robust MSTTR value."
msttrs = []
for i in range(repeats):
Expand Down
14 changes: 8 additions & 6 deletions gem_metrics/ngrams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python3

from typing import Dict, List, Tuple
from .metric import ReferencelessMetric
from .texts import Predictions

import numpy as np
from nltk import ngrams

from .metric import ReferencelessMetric


class NGramStats(ReferencelessMetric):
"""Ngram basic statistics and entropy, working with tokenized & lowercased data (+ variant excluding punctuation):
Expand All @@ -28,7 +30,7 @@ def support_caching(self):
# NGramStats is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions):
def compute(self, cache, predictions: Predictions) -> Dict:

results = {}
for data_id, data in [
Expand Down Expand Up @@ -64,7 +66,7 @@ def compute(self, cache, predictions):

return results

def _ngram_stats(self, data, N):
def _ngram_stats(self, data: List[List[str]], N: int) -> Tuple[Dict, int, int]:
"""Return basic ngram statistics, as well as a dict of all ngrams and their freqsuencies."""
ngram_freqs = {} # ngrams with frequencies
ngram_len = 0 # total number of ngrams
Expand All @@ -76,7 +78,7 @@ def _ngram_stats(self, data, N):
uniq_ngrams = len([val for val in ngram_freqs.values() if val == 1])
return ngram_freqs, uniq_ngrams, ngram_len

def _entropy(self, ngram_freqs):
def _entropy(self, ngram_freqs: Dict) -> float:
"""Shannon entropy over ngram frequencies"""
total_freq = sum(ngram_freqs.values())
return -sum(
Expand All @@ -86,7 +88,7 @@ def _entropy(self, ngram_freqs):
]
)

def _cond_entropy(self, joint, ctx):
def _cond_entropy(self, joint: Dict, ctx: Dict) -> float:
"""Conditional/next-word entropy (language model style), using ngrams (joint) and n-1-grams (ctx)."""
total_joint = sum(joint.values())
total_ctx = sum(ctx.values())
Expand Down
4 changes: 3 additions & 1 deletion gem_metrics/nist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3

from .texts import Predictions, References
from .metric import ReferencedMetric
from .impl.pymteval import NISTScore

from typing import Dict

class NIST(ReferencedMetric):
"""NIST from e2e-metrics."""
Expand All @@ -11,7 +13,7 @@ def support_caching(self):
# NIST is corpus-level, so individual examples can't be aggregated.
return False

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
nist = NISTScore()
for pred, refs in zip(predictions.untokenized, references.untokenized):
nist.append(pred, refs)
Expand Down
8 changes: 5 additions & 3 deletions gem_metrics/nubia.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from nubia_score import Nubia
import numpy as np
from .metric import ReferencedMetric
from .texts import Predictions, References

from nubia_score import Nubia
import numpy as np
from typing import Dict

class NUBIA(ReferencedMetric):
def __init__(self):
Expand All @@ -15,7 +17,7 @@ def _initialize(self):
self.metric.roberta_MNLI.to("cuda")
self.metric.gpt_model.to("cuda")

def compute(self, cache, predictions, references):
def compute(self, cache, predictions: Predictions, references: References) -> Dict:
"""Run Nubia"""
scores = {}
for ref, pred, pred_id in zip(
Expand Down
Loading

0 comments on commit 3307119

Please sign in to comment.