From 76b49bc57fb5cbf2a4bc2c3a37042105ac4c312f Mon Sep 17 00:00:00 2001 From: Jean-KOUAGOU Date: Thu, 10 Aug 2023 16:47:12 +0200 Subject: [PATCH] nces deals with unexpected numbers of examples --- examples/quality_functions.py | 13 +++++++---- ontolearn/base_nces.py | 2 ++ ontolearn/concept_learner.py | 44 +++++++++++++++++------------------ 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/examples/quality_functions.py b/examples/quality_functions.py index 36b9777f..570fe8fa 100644 --- a/examples/quality_functions.py +++ b/examples/quality_functions.py @@ -9,7 +9,12 @@ def quality(KB, solution, pos, neg): fn=len(pos.difference(instances)) fp=len(neg.intersection(instances)) tn=len(neg.difference(instances)) - print("Accuracy: {}%".format(100*accuracy(tp, fn, fp, tn)[-1])) - print("Precision: {}%".format(100*precision(tp, fn, fp, tn)[-1])) - print("Recall: {}%".format(100*recall(tp, fn, fp, tn)[-1])) - print("F1: {}%".format(100*f1(tp, fn, fp, tn)[-1])) \ No newline at end of file + acc = 100*accuracy(tp, fn, fp, tn)[-1] + prec = 100*precision(tp, fn, fp, tn)[-1] + rec = 100*recall(tp, fn, fp, tn)[-1] + f_1 = 100*f1(tp, fn, fp, tn)[-1] + print("Accuracy: {}%".format(acc)) + print("Precision: {}%".format(prec)) + print("Recall: {}%".format(rec)) + print("F1: {}%".format(f_1)) + return acc, prec, rec, f_1 \ No newline at end of file diff --git a/ontolearn/base_nces.py b/ontolearn/base_nces.py index 4b702f58..7a3bd334 100644 --- a/ontolearn/base_nces.py +++ b/ontolearn/base_nces.py @@ -20,6 +20,8 @@ def __init__(self, knowledge_base_path, learner_name, path_of_embeddings, batch_ vocab = atomic_concept_names + role_names + ['⊔', '⊓', '∃', '∀', '¬', '⊤', '⊥', '.', ' ', '(', ')'] vocab = sorted(vocab) + ['PAD'] self.knowledge_base_path = knowledge_base_path + self.kb = kb + self.all_individuals = set([ind.get_iri().as_str().split("/")[-1] for ind in kb.individuals()]) self.inv_vocab = np.array(vocab, dtype='object') self.vocab = {vocab[i]:i for i in range(len(vocab))} self.learner_name = learner_name diff --git a/ontolearn/concept_learner.py b/ontolearn/concept_learner.py index 8db38211..721d9ae0 100644 --- a/ontolearn/concept_learner.py +++ b/ontolearn/concept_learner.py @@ -1657,12 +1657,10 @@ def load_model(learner_name, load_pretrained): def refresh(self): self.model = self.get_synthesizer() - def sample_examples(self, pos, neg, all_inds=set()): + def sample_examples(self, pos, neg): assert type(pos[0]) == type(neg[0]), "The two interables pos and neg must be of same type" - if all_inds: - assert type(list(all_inds)[0]) == type(pos[0]) == type(neg[0]), "The three iterables pos, neg, all_inds must be of same type" num_ex = self.num_examples - resample_negs = False + oversample = False if min(len(pos),len(neg)) >= num_ex//2: if len(pos) > len(neg): num_neg_ex = num_ex//2 @@ -1670,23 +1668,25 @@ def sample_examples(self, pos, neg, all_inds=set()): else: num_pos_ex = num_ex//2 num_neg_ex = num_ex-num_pos_ex - elif len(pos) > len(neg): + elif len(pos) + len(neg) >= num_ex and len(pos) > len(neg): num_neg_ex = len(neg) num_pos_ex = num_ex-num_neg_ex - elif len(pos) < len(neg): + elif len(pos) + len(neg) >= num_ex and len(pos) < len(neg): num_pos_ex = len(pos) num_neg_ex = num_ex-num_pos_ex - elif len(pos) + len(neg) < num_ex and len(all_inds): - num_pos_ex = len(pos) - num_neg_ex = num_ex-num_pos_ex - resample_negs = True + elif len(pos) + len(neg) < num_ex: + num_pos_ex = max(num_ex//3, len(pos)) + num_neg_ex = max(num_ex-num_pos_ex, len(neg)) + oversample = True else: num_pos_ex = len(pos) num_neg_ex = len(neg) - if resample_negs: - positive = pos - remaining = list((all_inds-set(pos))-set(neg)) - negative = neg + random.sample(remaining, min(num_neg_ex, len(remaining))) + if oversample: + print("Over sampling...") + remaining = list(self.all_individuals.difference(set(pos).union(set(neg)))) + positive = pos + random.sample(remaining, min(max(0,num_pos_ex-len(pos)), len(remaining))) + remaining = list(set(remaining).difference(set(positive))) + negative = neg + random.sample(remaining, min(max(0,num_neg_ex-len(neg)), len(remaining))) else: positive = random.sample(pos, min(num_pos_ex, len(pos))) negative = random.sample(neg, min(num_neg_ex, len(neg))) @@ -1705,15 +1705,15 @@ def get_prediction(models, x1, x2): prediction = model.inv_vocab[scores.argmax(1)] return prediction - def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWLNamedIndividual], Set[str]], all_inds=set(), shuffle_examples=False, verbose=True, **kwargs): + def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWLNamedIndividual], Set[str]], shuffle_examples=False, verbose=True, **kwargs): pos = list(pos) neg = list(neg) if isinstance(pos[0], OWLNamedIndividual): pos_str = [ind.get_iri().as_str().split("/")[-1] for ind in pos] neg_str = [ind.get_iri().as_str().split("/")[-1] for ind in neg] - pos_str, neg_str = self.sample_examples(pos_str, neg_str, all_inds) + pos_str, neg_str = self.sample_examples(pos_str, neg_str) elif isinstance(pos[0], str): - pos_str, neg_str = self.sample_examples(pos, neg, all_inds) + pos_str, neg_str = self.sample_examples(pos, neg) else: raise ValueError(f"Invalid input type, was expecting OWLNamedIndividual or str but found {type(pos[0])}") if self.sorted_examples: @@ -1739,15 +1739,15 @@ def fit(self, pos: Union[Set[OWLNamedIndividual], Set[str]] , neg: Union[Set[OWL print("Prediction: ", prediction_str) return prediction_as_owl_class_expression - def convert_to_list_str_from_iterable(self, data, all_inds=set()): + def convert_to_list_str_from_iterable(self, data): target_concept_str, examples = data[0], data[1:] pos = list(examples[0]); neg = list(examples[1]) if isinstance(pos[0], OWLNamedIndividual): pos_str = [ind.get_iri().as_str().split("/")[-1] for ind in pos] neg_str = [ind.get_iri().as_str().split("/")[-1] for ind in neg] - pos_str, neg_str = self.sample_examples(pos_str, neg_str, all_inds) + pos_str, neg_str = self.sample_examples(pos_str, neg_str) elif isinstance(pos[0], str): - pos_str, neg_str = self.sample_examples(list(pos), list(neg), all_inds) + pos_str, neg_str = self.sample_examples(list(pos), list(neg)) else: raise ValueError(f"Invalid input type, was expecting OWLNamedIndividual or str but found {type(pos[0])}") if self.sorted_examples: @@ -1756,12 +1756,12 @@ def convert_to_list_str_from_iterable(self, data, all_inds=set()): def fit_from_iterable(self, dataset: Union[List[Tuple[str, Set[OWLNamedIndividual], Set[OWLNamedIndividual]]], - List[Tuple[str, Set[str], Set[str]]]], all_inds=set(), shuffle_examples=False, verbose=False, **kwargs) -> List: + List[Tuple[str, Set[str], Set[str]]]], shuffle_examples=False, verbose=False, **kwargs) -> List: """ dataset is a list of tuples where the first items are strings corresponding to target concepts """ assert self.load_pretrained and self.pretrained_model_name, "No pretrained model found. Please first train NCES, see the <> method" - dataset = [self.convert_to_list_str_from_iterable(datapoint, all_inds) for datapoint in dataset] + dataset = [self.convert_to_list_str_from_iterable(datapoint) for datapoint in dataset] dataset = NCESDataLoaderInference(dataset, self.instance_embeddings, self.vocab, self.inv_vocab, shuffle_examples) dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_batch_inference, shuffle=False) simpleSolution = SimpleSolution(list(self.vocab), self.atomic_concept_names)