Skip to content

Commit

Permalink
nces deals with unexpected numbers of examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-KOUAGOU committed Aug 10, 2023
1 parent 8f8d3c5 commit 76b49bc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
13 changes: 9 additions & 4 deletions examples/quality_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ def quality(KB, solution, pos, neg):
fn=len(pos.difference(instances))
fp=len(neg.intersection(instances))
tn=len(neg.difference(instances))
print("Accuracy: {}%".format(100*accuracy(tp, fn, fp, tn)[-1]))
print("Precision: {}%".format(100*precision(tp, fn, fp, tn)[-1]))
print("Recall: {}%".format(100*recall(tp, fn, fp, tn)[-1]))
print("F1: {}%".format(100*f1(tp, fn, fp, tn)[-1]))
acc = 100*accuracy(tp, fn, fp, tn)[-1]
prec = 100*precision(tp, fn, fp, tn)[-1]
rec = 100*recall(tp, fn, fp, tn)[-1]
f_1 = 100*f1(tp, fn, fp, tn)[-1]
print("Accuracy: {}%".format(acc))
print("Precision: {}%".format(prec))
print("Recall: {}%".format(rec))
print("F1: {}%".format(f_1))
return acc, prec, rec, f_1
2 changes: 2 additions & 0 deletions ontolearn/base_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(self, knowledge_base_path, learner_name, path_of_embeddings, batch_
vocab = atomic_concept_names + role_names + ['⊔', '⊓', '∃', '∀', '¬', '⊤', '⊥', '.', ' ', '(', ')']
vocab = sorted(vocab) + ['PAD']
self.knowledge_base_path = knowledge_base_path
self.kb = kb
self.all_individuals = set([ind.get_iri().as_str().split("/")[-1] for ind in kb.individuals()])
self.inv_vocab = np.array(vocab, dtype='object')
self.vocab = {vocab[i]:i for i in range(len(vocab))}
self.learner_name = learner_name
Expand Down
44 changes: 22 additions & 22 deletions ontolearn/concept_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,36 +1657,36 @@ def load_model(learner_name, load_pretrained):
def refresh(self):
self.model = self.get_synthesizer()

def sample_examples(self, pos, neg, all_inds=set()):
def sample_examples(self, pos, neg):
assert type(pos[0]) == type(neg[0]), "The two interables pos and neg must be of same type"
if all_inds:
assert type(list(all_inds)[0]) == type(pos[0]) == type(neg[0]), "The three iterables pos, neg, all_inds must be of same type"
num_ex = self.num_examples
resample_negs = False
oversample = False
if min(len(pos),len(neg)) >= num_ex//2:
if len(pos) > len(neg):
num_neg_ex = num_ex//2
num_pos_ex = num_ex-num_neg_ex
else:
num_pos_ex = num_ex//2
num_neg_ex = num_ex-num_pos_ex
elif len(pos) > len(neg):
elif len(pos) + len(neg) >= num_ex and len(pos) > len(neg):
num_neg_ex = len(neg)
num_pos_ex = num_ex-num_neg_ex
elif len(pos) < len(neg):
elif len(pos) + len(neg) >= num_ex and len(pos) < len(neg):
num_pos_ex = len(pos)
num_neg_ex = num_ex-num_pos_ex
elif len(pos) + len(neg) < num_ex and len(all_inds):
num_pos_ex = len(pos)
num_neg_ex = num_ex-num_pos_ex
resample_negs = True
elif len(pos) + len(neg) < num_ex:
num_pos_ex = max(num_ex//3, len(pos))
num_neg_ex = max(num_ex-num_pos_ex, len(neg))
oversample = True
else:
num_pos_ex = len(pos)
num_neg_ex = len(neg)
if resample_negs:
positive = pos
remaining = list((all_inds-set(pos))-set(neg))
negative = neg + random.sample(remaining, min(num_neg_ex, len(remaining)))
if oversample:
print("Over sampling...")
remaining = list(self.all_individuals.difference(set(pos).union(set(neg))))
positive = pos + random.sample(remaining, min(max(0,num_pos_ex-len(pos)), len(remaining)))
remaining = list(set(remaining).difference(set(positive)))
negative = neg + random.sample(remaining, min(max(0,num_neg_ex-len(neg)), len(remaining)))
else:
positive = random.sample(pos, min(num_pos_ex, len(pos)))
negative = random.sample(neg, min(num_neg_ex, len(neg)))
Expand All @@ -1705,15 +1705,15 @@ def get_prediction(models, x1, x2):
prediction = model.inv_vocab[scores.argmax(1)]
return prediction

def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWLNamedIndividual], Set[str]], all_inds=set(), shuffle_examples=False, verbose=True, **kwargs):
def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWLNamedIndividual], Set[str]], shuffle_examples=False, verbose=True, **kwargs):
pos = list(pos)
neg = list(neg)
if isinstance(pos[0], OWLNamedIndividual):
pos_str = [ind.get_iri().as_str().split("/")[-1] for ind in pos]
neg_str = [ind.get_iri().as_str().split("/")[-1] for ind in neg]
pos_str, neg_str = self.sample_examples(pos_str, neg_str, all_inds)
pos_str, neg_str = self.sample_examples(pos_str, neg_str)
elif isinstance(pos[0], str):
pos_str, neg_str = self.sample_examples(pos, neg, all_inds)
pos_str, neg_str = self.sample_examples(pos, neg)
else:
raise ValueError(f"Invalid input type, was expecting OWLNamedIndividual or str but found {type(pos[0])}")
if self.sorted_examples:
Expand All @@ -1739,15 +1739,15 @@ def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWL
print("Prediction: ", prediction_str)
return prediction_as_owl_class_expression

def convert_to_list_str_from_iterable(self, data, all_inds=set()):
def convert_to_list_str_from_iterable(self, data):
target_concept_str, examples = data[0], data[1:]
pos = list(examples[0]); neg = list(examples[1])
if isinstance(pos[0], OWLNamedIndividual):
pos_str = [ind.get_iri().as_str().split("/")[-1] for ind in pos]
neg_str = [ind.get_iri().as_str().split("/")[-1] for ind in neg]
pos_str, neg_str = self.sample_examples(pos_str, neg_str, all_inds)
pos_str, neg_str = self.sample_examples(pos_str, neg_str)
elif isinstance(pos[0], str):
pos_str, neg_str = self.sample_examples(list(pos), list(neg), all_inds)
pos_str, neg_str = self.sample_examples(list(pos), list(neg))
else:
raise ValueError(f"Invalid input type, was expecting OWLNamedIndividual or str but found {type(pos[0])}")
if self.sorted_examples:
Expand All @@ -1756,12 +1756,12 @@ def convert_to_list_str_from_iterable(self, data, all_inds=set()):


def fit_from_iterable(self, dataset: Union[List[Tuple[str, Set[OWLNamedIndividual], Set[OWLNamedIndividual]]],
List[Tuple[str, Set[str], Set[str]]]], all_inds=set(), shuffle_examples=False, verbose=False, **kwargs) -> List:
List[Tuple[str, Set[str], Set[str]]]], shuffle_examples=False, verbose=False, **kwargs) -> List:
"""
dataset is a list of tuples where the first items are strings corresponding to target concepts
"""
assert self.load_pretrained and self.pretrained_model_name, "No pretrained model found. Please first train NCES, see the <<train>> method"
dataset = [self.convert_to_list_str_from_iterable(datapoint, all_inds) for datapoint in dataset]
dataset = [self.convert_to_list_str_from_iterable(datapoint) for datapoint in dataset]
dataset = NCESDataLoaderInference(dataset, self.instance_embeddings, self.vocab, self.inv_vocab, shuffle_examples)
dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_batch_inference, shuffle=False)
simpleSolution = SimpleSolution(list(self.vocab), self.atomic_concept_names)
Expand Down

0 comments on commit 76b49bc

Please sign in to comment.