Skip to content

Commit

Permalink
Merge pull request #474 from allenai/052_upgrade
Browse files Browse the repository at this point in the history
Update to the latest UMLS version
  • Loading branch information
dakinggg authored Apr 29, 2023
2 parents 4f9ba09 + d1aabb3 commit 5368cc3
Show file tree
Hide file tree
Showing 17 changed files with 189 additions and 72 deletions.
11 changes: 3 additions & 8 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,16 @@
max-line-length = 115

ignore =
# these rules don't play well with black
E203 # whitespace before :
W503 # line break before binary operator
W504 # line break after binary operator
E203
W503
W504

exclude =
build/**
docs/**

per-file-ignores =
# __init__.py files are allowed to have unused imports and lines-too-long
scispacy/__init__.py:F401
scispacy/**/__init__.py:F401,E501

# scripts don't have to respect
# E501: line length
# E402: imports not at top of file (because we mess with sys.path)
scripts/**:E501,E402
97 changes: 97 additions & 0 deletions evaluation/evaluate_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import spacy
from scispacy.linking import EntityLinker
from scispacy.data_util import read_full_med_mentions
import os
from tqdm import tqdm

EVALUATION_FOLDER_PATH = os.path.dirname(os.path.abspath(__file__))


def main():
nlp = spacy.load("en_core_sci_sm")
nlp.add_pipe(
"scispacy_linker", config={"resolve_abbreviations": True, "linker_name": "umls"}
)
linker = nlp.get_pipe("scispacy_linker")

med_mentions = read_full_med_mentions(
os.path.join(EVALUATION_FOLDER_PATH, os.pardir, "data", "med_mentions"),
use_umls_ids=True,
)

test_data = med_mentions[2]

total_entities = 0
correct_at_1 = 0
correct_at_2 = 0
correct_at_10 = 0
correct_at_40 = 0
correct_at_60 = 0
correct_at_80 = 0
correct_at_100 = 0
for text_doc, entities in tqdm(test_data):
for start, end, label in entities["entities"]:
text_span = text_doc[start:end]
candidates = linker.candidate_generator([text_span], 40)[0]
sorted_candidates = sorted(
candidates, reverse=True, key=lambda x: max(x.similarities)
)
candidate_ids = [c.concept_id for c in sorted_candidates]
if label in candidate_ids[:1]:
correct_at_1 += 1
if label in candidate_ids[:2]:
correct_at_2 += 1
if label in candidate_ids[:10]:
correct_at_10 += 1
if label in candidate_ids[:40]:
correct_at_40 += 1
# if label in candidate_ids[:60]:
# correct_at_60 += 1
# if label in candidate_ids[:80]:
# correct_at_80 += 1
# if label in candidate_ids[:100]:
# correct_at_100 += 1

total_entities += 1

print("Total entities: ", total_entities)
print(
"Correct at 1: ", correct_at_1, "Recall at 1: ", correct_at_1 / total_entities
)
print(
"Correct at 2: ", correct_at_2, "Recall at 2: ", correct_at_2 / total_entities
)
print(
"Correct at 10: ",
correct_at_10,
"Recall at 10: ",
correct_at_10 / total_entities,
)
print(
"Correct at 40: ",
correct_at_40,
"Recall at 40: ",
correct_at_40 / total_entities,
)
# print(
# "Correct at 60: ",
# correct_at_60,
# "Recall at 60: ",
# correct_at_60 / total_entities,
# )
# print(
# "Correct at 80: ",
# correct_at_80,
# "Recall at 80: ",
# correct_at_80 / total_entities,
# )
# print(
# "Correct at 100: ",
# correct_at_100,
# "Recall at 100: ",
# correct_at_100 / total_entities,
# )


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ title: "scispaCy pipeline"
description: "All the steps needed in the scispaCy pipeline"

vars:
version_string: "0.5.1"
version_string: "0.5.2"
gpu_id: 0
freqs_loc_s3: "s3://ai2-s2-scispacy/data/gorc_subset.freqs"
freqs_loc_local: "assets/gorc_subset.freqs"
Expand Down
2 changes: 2 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ flake8
black
mypy
types-requests
types-setuptools
types-tabulate

# Required for releases.
twine
Expand Down
22 changes: 20 additions & 2 deletions scispacy/abbreviation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ def find_abbreviation(
return short_form_candidate, long_form_candidate[starting_index:]


def span_contains_unbalanced_parentheses(span: Span) -> bool:
stack_counter = 0
for token in span:
if token.text == "(":
stack_counter += 1
elif token.text == ")":
if stack_counter > 0:
stack_counter -= 1
else:
return True

return stack_counter != 0


def filter_matches(
matcher_output: List[Tuple[int, int, int]], doc: Doc
) -> List[Tuple[Span, Span]]:
Expand All @@ -100,6 +114,10 @@ def filter_matches(
# Take one word before.
short_form_candidate = doc[start - 2 : start - 1]
long_form_candidate = doc[start:end]

# make sure any parentheses inside long form are balanced
if span_contains_unbalanced_parentheses(long_form_candidate):
continue
else:
# Normal case.
# Short form is inside the parens.
Expand Down Expand Up @@ -190,7 +208,7 @@ def __call__(self, doc: Doc) -> Doc:
filtered = filter_matches(matches_no_brackets, doc)
occurences = self.find_matches_for(filtered, doc)

for (long_form, short_forms) in occurences:
for long_form, short_forms in occurences:
for short in short_forms:
short._.long_form = long_form
doc._.abbreviations.append(short)
Expand All @@ -209,7 +227,7 @@ def find_matches_for(
all_occurences: Dict[Span, Set[Span]] = defaultdict(set)
already_seen_long: Set[str] = set()
already_seen_short: Set[str] = set()
for (long_candidate, short_candidate) in filtered:
for long_candidate, short_candidate in filtered:
short, long = find_abbreviation(long_candidate, short_candidate)
# We need the long and short form definitions to be unique, because we need
# to store them so we can look them up later. This is a bit of a
Expand Down
55 changes: 27 additions & 28 deletions scispacy/candidate_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Tuple, NamedTuple, Type
from typing import Optional, List, Dict, Tuple, NamedTuple, Type
import json
import datetime
from collections import defaultdict
Expand Down Expand Up @@ -41,38 +41,38 @@ class LinkerPaths(NamedTuple):


UmlsLinkerPaths = LinkerPaths(
ann_index="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/nmslib_index.bin", # noqa
tfidf_vectorizer="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/concept_aliases.json", # noqa
ann_index="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/nmslib_index.bin", # noqa
tfidf_vectorizer="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2023-04-23/umls/concept_aliases.json", # noqa
)

MeshLinkerPaths = LinkerPaths(
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/mesh/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/mesh/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/mesh/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/mesh/concept_aliases.json", # noqa
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/mesh/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/mesh/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/mesh/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/mesh/concept_aliases.json", # noqa
)

GeneOntologyLinkerPaths = LinkerPaths(
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/go/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/go/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/go/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/go/concept_aliases.json", # noqa
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/go/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/go/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/go/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/go/concept_aliases.json", # noqa
)

HumanPhenotypeOntologyLinkerPaths = LinkerPaths(
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/hpo/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/hpo/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/hpo/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/hpo/concept_aliases.json", # noqa
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/hpo/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/hpo/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/hpo/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/hpo/concept_aliases.json", # noqa
)

RxNormLinkerPaths = LinkerPaths(
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/rxnorm/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/rxnorm/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/rxnorm/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2020-10-09/rxnorm/concept_aliases.json", # noqa
ann_index="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/rxnorm/nmslib_index.bin", # noqa
tfidf_vectorizer="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/rxnorm/tfidf_vectorizer.joblib", # noqa
tfidf_vectors="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/rxnorm/tfidf_vectors_sparse.npz", # noqa
concept_aliases_list="https://ai2-s2-scispacy.s3-us-west-2.amazonaws.com/data/linkers/2023-04-23/rxnorm/concept_aliases.json", # noqa
)


Expand Down Expand Up @@ -196,15 +196,14 @@ class CandidateGenerator:

def __init__(
self,
ann_index: FloatIndex = None,
tfidf_vectorizer: TfidfVectorizer = None,
ann_concept_aliases_list: List[str] = None,
kb: KnowledgeBase = None,
ann_index: Optional[FloatIndex] = None,
tfidf_vectorizer: Optional[TfidfVectorizer] = None,
ann_concept_aliases_list: Optional[List[str]] = None,
kb: Optional[KnowledgeBase] = None,
verbose: bool = False,
ef_search: int = 200,
name: str = None,
name: Optional[str] = None,
) -> None:

if name is not None and any(
[ann_index, tfidf_vectorizer, ann_concept_aliases_list, kb]
):
Expand Down Expand Up @@ -363,7 +362,7 @@ def __call__(


def create_tfidf_ann_index(
out_path: str, kb: KnowledgeBase = None
out_path: str, kb: Optional[KnowledgeBase] = None
) -> Tuple[List[str], TfidfVectorizer, FloatIndex]:
"""
Build tfidf vectorizer and ann index.
Expand Down
12 changes: 9 additions & 3 deletions scispacy/data_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple, List, Iterator, Dict, Tuple
from typing import Optional, NamedTuple, List, Iterator, Dict, Tuple
import tarfile
import atexit
import os
Expand Down Expand Up @@ -148,9 +148,10 @@ def remove_overlapping_entities(

def read_full_med_mentions(
directory_path: str,
label_mapping: Dict[str, str] = None,
label_mapping: Optional[Dict[str, str]] = None,
span_only: bool = False,
spacy_format: bool = True,
use_umls_ids: bool = False,
):
def _cleanup_dir(dir_path: str):
if os.path.exists(dir_path):
Expand Down Expand Up @@ -209,7 +210,12 @@ def label_function(label):

for example in examples:
spacy_format_entities = [
(x.start, x.end, label_function(x.mention_type)) for x in example.entities
(
x.start,
x.end,
label_function(x.mention_type) if not use_umls_ids else x.umls_id,
)
for x in example.entities
]
spacy_format_entities = remove_overlapping_entities(
sorted(spacy_format_entities, key=lambda x: x[0])
Expand Down
12 changes: 7 additions & 5 deletions scispacy/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Tuple, Union, IO
from typing import Optional, Tuple, Union, IO
from hashlib import sha256

import requests
Expand All @@ -17,7 +17,9 @@
DATASET_CACHE = str(CACHE_ROOT / "datasets")


def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
def cached_path(
url_or_filename: Union[str, Path], cache_dir: Optional[str] = None
) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
Expand Down Expand Up @@ -47,7 +49,7 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
)


def url_to_filename(url: str, etag: str = None) -> str:
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
Expand All @@ -68,7 +70,7 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename


def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
def filename_to_url(filename: str, cache_dir: Optional[str] = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
Expand Down Expand Up @@ -99,7 +101,7 @@ def http_get(url: str, temp_file: IO) -> None:
temp_file.write(chunk)


def get_from_cache(url: str, cache_dir: str = None) -> str:
def get_from_cache(url: str, cache_dir: Optional[str] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
Expand Down
3 changes: 0 additions & 3 deletions scispacy/hyponym_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class HyponymDetector:
def __init__(
self, nlp: Language, name: str = "hyponym_detector", extended: bool = False
):

self.nlp = nlp

self.patterns = BASE_PATTERNS
Expand Down Expand Up @@ -91,7 +90,6 @@ def expand_to_noun_compound(self, token: Token, doc: Doc):
return doc[start:end]

def find_noun_compound_head(self, token: Token):

while token.head.pos_ in {"PROPN", "NOUN", "PRON"} and token.dep_ == "compound":
token = token.head
return token
Expand Down Expand Up @@ -135,7 +133,6 @@ def __call__(self, doc: Doc):
)

for token in hyponym.conjuncts:

token_extended = self.expand_to_noun_compound(token, doc)
if token != hypernym and token is not None:
doc._.hearst_patterns.append(
Expand Down
Loading

0 comments on commit 5368cc3

Please sign in to comment.