From ccefa83a8f5a7ebd9689939ef982e46fbfa0c338 Mon Sep 17 00:00:00 2001 From: Imene Kerboua Date: Sat, 4 May 2024 13:59:36 +0200 Subject: [PATCH 1/4] compute scores for probes --- .vscode/launch.json | 8 ++ scripts/analysis/probes_analysis.py | 121 ++++++++++++++++++++++++++++ scripts/constants.py | 3 + 3 files changed, 132 insertions(+) create mode 100644 scripts/analysis/probes_analysis.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 3e30ac2..3b8aafa 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -41,6 +41,14 @@ "module": "scripts.make_activation_dataset", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Script Probes Analysis", + "type": "debugpy", + "request": "launch", + "module": "scripts.analysis.probes_analysis", + "console": "integratedTerminal", + "justMyCode": false } ] } diff --git a/scripts/analysis/probes_analysis.py b/scripts/analysis/probes_analysis.py new file mode 100644 index 0000000..d72a30c --- /dev/null +++ b/scripts/analysis/probes_analysis.py @@ -0,0 +1,121 @@ +"""Script analyse FGSM adversarial images. + +Run with: +``` +poetry run python -m scripts.analysis.probes_analysis +``` +""" + +import argparse + +import einops +import torch +from datasets import load_dataset +from huggingface_hub import HfApi +from loguru import logger + +from sklearn.metrics import f1_score, recall_score, precision_score + +from mulsi.adversarial import LRClfLoss +from scripts.constants import HF_TOKEN, ASSETS_FOLDER + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +LAYER_NAMES = ["layers.0", "layers.6", "layers.11"] +CONCEPTS = ["yellow", "red", "sphere", "ovaloid"] +GOOD_INDICES = { + "banana": [], # None for all + "lemon": [0, 6, 8], + "tomato": [], +} + +hf_api = HfApi(token=HF_TOKEN) + + +def eval_probe(probe: LRClfLoss, inputs: torch.Tensor, targets: list[str]): + # TODO: impelment this for each pixel activation + predictions = probe(inputs) > 0 + precision = precision_score(targets, predictions) + recall = recall_score(targets, predictions) + f1 = f1_score(targets, predictions) + return precision, recall, f1 + + +def map_fn(s_batched): + b, p, h = s_batched["activation"].shape + new_s_batched = {} + new_s_batched["pixel_activation"] = einops.rearrange(s_batched["activation"], "b p h -> (b p) h") + new_s_batched["pixel_label"] = einops.repeat(s_batched["label"], "b -> (b p)", p=p) + new_s_batched["pixel_index"] = einops.repeat(torch.arange(p), "p -> (b p)", b=b) + return new_s_batched + + +def main(args: argparse.Namespace): + logger.info(f"Running on {DEVICE}") + dataset_name = args.dataset_name + + # Download probes dataset + hf_api.snapshot_download( + repo_id=dataset_name.replace("concepts", "probes"), + repo_type="model", + local_dir=ASSETS_FOLDER / dataset_name.replace("concepts", "probes"), + revision=args.probe_ref, + ) + + probes, metrics = {}, {} + for layer_name in LAYER_NAMES: + probes[layer_name] = {} + metrics[layer_name] = {} + + # Download activations dataset + ds_activations = load_dataset( + args.dataset_name.replace("concepts", "activations"), split="test", name=layer_name + ) + + for concept in CONCEPTS: + filtered_ds = ds_activations.filter(lambda s: s[concept] is not None) + labeled_ds = filtered_ds.rename_column(concept, "label") + labeled_ds = labeled_ds.class_encode_column("label") + torch_ds = labeled_ds.select_columns(["activation", "label"]).with_format("torch") + pre_dataset = torch_ds.map(map_fn, remove_columns=["activation", "label"], batched=True) + + with open( + ASSETS_FOLDER / f"{dataset_name.replace('concepts', 'probes')}/data/{layer_name}/{concept}/clf.pt", + "rb", + ) as f: + probes[layer_name][concept] = torch.load(f) + + precision, recall, f1 = eval_probe( + probes[layer_name][concept], + pre_dataset["pixel_activation"], + pre_dataset["pixel_label"], + ) + metrics[layer_name][concept] = { + metric_name: value + for metric_name, value in zip(["precision", "recall", "f1"], [precision, recall, f1]) + } + logger.info(f"Layer: {layer_name}, Concept: {concept}, Metrics: {metrics[layer_name][concept]}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("fgsm-probing") + parser.add_argument("--mode", type=str, default="torch_clf") + parser.add_argument( + "--dataset_name", + type=str, + default="mulsi/fruit-vegetable-concepts", + ) + parser.add_argument("--probe_ref", type=str, default=None) + parser.add_argument("--epsilon", type=int, default=3) + parser.add_argument("--n_iter", type=int, default=10) + parser.add_argument("--concept", type=str, default="yellow") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) + + # Probe refs + # 29c94861ed9922843d4821f23e7e44fbb30f2de4 -> 3 CLF pre-labeling + # ? -> 12 CLF all post-labeling + # ? -> 12 CLF only_labeled post-labeling diff --git a/scripts/constants.py b/scripts/constants.py index 256f96c..a0fb81d 100644 --- a/scripts/constants.py +++ b/scripts/constants.py @@ -123,6 +123,9 @@ "cabbage", "bell pepper", "carrot", + "turnip", + "mango", + "capsicum" "sweetpotato", ] assert len(CLASSES) == len(CLASS_CONCEPTS_VALUES.keys()) From 77de61dba8703fcfd15b3b8136f5af98d914520a Mon Sep 17 00:00:00 2001 From: Xmaster6y <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 May 2024 18:25:02 +0200 Subject: [PATCH 2/4] box plots --- .vscode/launch.json | 4 +- scripts/analysis/probes_analysis.py | 121 ------------------- scripts/analysis/probes_sanity_checks.py | 147 +++++++++++++++++++++++ scripts/constants.py | 3 +- src/mulsi/analysis.py | 18 +++ 5 files changed, 169 insertions(+), 124 deletions(-) delete mode 100644 scripts/analysis/probes_analysis.py create mode 100644 scripts/analysis/probes_sanity_checks.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 3b8aafa..a89c6f4 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -43,10 +43,10 @@ "justMyCode": false }, { - "name": "Script Probes Analysis", + "name": "Script Probes Sanity Checks", "type": "debugpy", "request": "launch", - "module": "scripts.analysis.probes_analysis", + "module": "scripts.analysis.probes_sanity_checks", "console": "integratedTerminal", "justMyCode": false } diff --git a/scripts/analysis/probes_analysis.py b/scripts/analysis/probes_analysis.py deleted file mode 100644 index d72a30c..0000000 --- a/scripts/analysis/probes_analysis.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Script analyse FGSM adversarial images. - -Run with: -``` -poetry run python -m scripts.analysis.probes_analysis -``` -""" - -import argparse - -import einops -import torch -from datasets import load_dataset -from huggingface_hub import HfApi -from loguru import logger - -from sklearn.metrics import f1_score, recall_score, precision_score - -from mulsi.adversarial import LRClfLoss -from scripts.constants import HF_TOKEN, ASSETS_FOLDER - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -LAYER_NAMES = ["layers.0", "layers.6", "layers.11"] -CONCEPTS = ["yellow", "red", "sphere", "ovaloid"] -GOOD_INDICES = { - "banana": [], # None for all - "lemon": [0, 6, 8], - "tomato": [], -} - -hf_api = HfApi(token=HF_TOKEN) - - -def eval_probe(probe: LRClfLoss, inputs: torch.Tensor, targets: list[str]): - # TODO: impelment this for each pixel activation - predictions = probe(inputs) > 0 - precision = precision_score(targets, predictions) - recall = recall_score(targets, predictions) - f1 = f1_score(targets, predictions) - return precision, recall, f1 - - -def map_fn(s_batched): - b, p, h = s_batched["activation"].shape - new_s_batched = {} - new_s_batched["pixel_activation"] = einops.rearrange(s_batched["activation"], "b p h -> (b p) h") - new_s_batched["pixel_label"] = einops.repeat(s_batched["label"], "b -> (b p)", p=p) - new_s_batched["pixel_index"] = einops.repeat(torch.arange(p), "p -> (b p)", b=b) - return new_s_batched - - -def main(args: argparse.Namespace): - logger.info(f"Running on {DEVICE}") - dataset_name = args.dataset_name - - # Download probes dataset - hf_api.snapshot_download( - repo_id=dataset_name.replace("concepts", "probes"), - repo_type="model", - local_dir=ASSETS_FOLDER / dataset_name.replace("concepts", "probes"), - revision=args.probe_ref, - ) - - probes, metrics = {}, {} - for layer_name in LAYER_NAMES: - probes[layer_name] = {} - metrics[layer_name] = {} - - # Download activations dataset - ds_activations = load_dataset( - args.dataset_name.replace("concepts", "activations"), split="test", name=layer_name - ) - - for concept in CONCEPTS: - filtered_ds = ds_activations.filter(lambda s: s[concept] is not None) - labeled_ds = filtered_ds.rename_column(concept, "label") - labeled_ds = labeled_ds.class_encode_column("label") - torch_ds = labeled_ds.select_columns(["activation", "label"]).with_format("torch") - pre_dataset = torch_ds.map(map_fn, remove_columns=["activation", "label"], batched=True) - - with open( - ASSETS_FOLDER / f"{dataset_name.replace('concepts', 'probes')}/data/{layer_name}/{concept}/clf.pt", - "rb", - ) as f: - probes[layer_name][concept] = torch.load(f) - - precision, recall, f1 = eval_probe( - probes[layer_name][concept], - pre_dataset["pixel_activation"], - pre_dataset["pixel_label"], - ) - metrics[layer_name][concept] = { - metric_name: value - for metric_name, value in zip(["precision", "recall", "f1"], [precision, recall, f1]) - } - logger.info(f"Layer: {layer_name}, Concept: {concept}, Metrics: {metrics[layer_name][concept]}") - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser("fgsm-probing") - parser.add_argument("--mode", type=str, default="torch_clf") - parser.add_argument( - "--dataset_name", - type=str, - default="mulsi/fruit-vegetable-concepts", - ) - parser.add_argument("--probe_ref", type=str, default=None) - parser.add_argument("--epsilon", type=int, default=3) - parser.add_argument("--n_iter", type=int, default=10) - parser.add_argument("--concept", type=str, default="yellow") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - main(args) - - # Probe refs - # 29c94861ed9922843d4821f23e7e44fbb30f2de4 -> 3 CLF pre-labeling - # ? -> 12 CLF all post-labeling - # ? -> 12 CLF only_labeled post-labeling diff --git a/scripts/analysis/probes_sanity_checks.py b/scripts/analysis/probes_sanity_checks.py new file mode 100644 index 0000000..55e3489 --- /dev/null +++ b/scripts/analysis/probes_sanity_checks.py @@ -0,0 +1,147 @@ +"""Script analyse FGSM adversarial images. + +Run with: +``` +poetry run python -m scripts.analysis.probes_sanity_checks +``` +""" + +import argparse +from typing import List +import os + +import einops +import torch +from datasets import load_dataset +from huggingface_hub import HfApi +from loguru import logger + +from sklearn.metrics import f1_score, recall_score, precision_score + +from mulsi.adversarial import LRClfLoss +from scripts.constants import HF_TOKEN, ASSETS_FOLDER, LABELED_CLASSES, CLASSES +from mulsi import analysis + +LAYER_NAMES = ["layers.0", "layers.6", "layers.11"] +CONCEPTS = ["yellow", "red", "sphere", "ovaloid"] + +hf_api = HfApi(token=HF_TOKEN) + + +def probe_single_eval(y_true, y_pred): + metrics = {} + metrics["precision"] = precision_score(y_true, y_pred) + metrics["recall"] = recall_score(y_true, y_pred) + metrics["f1"] = f1_score(y_true, y_pred) + return metrics + + +def eval_probe( + probe: LRClfLoss, + activation: torch.Tensor, + labels: torch.Tensor, + indices: torch.Tensor, + classes: List[str], + selected_classes, +): + predictions = probe(activation) > 0 + global_metrics = probe_single_eval(labels, predictions) + + per_pixel_metrics = {} + for i in range(50): + bool_index = indices == i + per_pixel_metrics[i] = probe_single_eval(labels[bool_index], predictions[bool_index]) + + per_class_metrics = {} + # for class_name in selected_classes: + # bool_index = torch.tensor([c == class_name for c in classes]) + # per_class_metrics[class_name] = probe_single_eval(labels[bool_index], predictions[bool_index]) + + return {"global": global_metrics, "per_pixel": per_pixel_metrics, "per_class": per_class_metrics} + + +def map_fn(s_batched): + b, p, h = s_batched["activation"].shape + new_s_batched = {} + new_s_batched["pixel_activation"] = einops.rearrange(s_batched["activation"], "b p h -> (b p) h") + new_s_batched["pixel_label"] = einops.repeat(s_batched["label"], "b -> (b p)", p=p) + new_s_batched["pixel_class"] = [s_batched["class"][i] for i in range(b) for _ in range(p)] + new_s_batched["pixel_index"] = einops.repeat(torch.arange(p), "p -> (b p)", b=b) + return new_s_batched + + +def main(args: argparse.Namespace): + dataset_name = args.dataset_name + + # Download probes dataset + hf_api.snapshot_download( + repo_id=dataset_name.replace("concepts", "probes"), + repo_type="model", + local_dir=ASSETS_FOLDER / dataset_name.replace("concepts", "probes"), + revision=args.probe_ref, + ) + + os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks") + metrics = {} + for layer_name in LAYER_NAMES: + metrics[layer_name] = {} + + # Download activations dataset + ds_activations = load_dataset( + args.dataset_name.replace("concepts", "activations"), split="test", name=layer_name + ) + selected_classes = LABELED_CLASSES if args.only_labeled else CLASSES + init_ds = ds_activations.filter(lambda s: s["class"] in selected_classes) + + for concept in CONCEPTS: + filtered_ds = init_ds.filter(lambda s: s[concept] is not None) + labeled_ds = filtered_ds.rename_column(concept, "label") + labeled_ds = labeled_ds.class_encode_column("label") + torch_ds = labeled_ds.select_columns(["activation", "label", "class"]).with_format("torch") + pred_dataset = torch_ds.map(map_fn, remove_columns=["activation", "label", "class"], batched=True) + + with open( + ASSETS_FOLDER / f"{dataset_name.replace('concepts', 'probes')}/data/{layer_name}/{concept}/clf.pt", + "rb", + ) as f: + probe = torch.load(f) + + metrics[layer_name][concept] = eval_probe( + probe, + pred_dataset["pixel_activation"], + pred_dataset["pixel_label"], + pred_dataset["pixel_index"], + pred_dataset["pixel_class"], + selected_classes, + ) + logger.info( + f"Layer: {layer_name}, Concept: {concept}, Global metrics: {metrics[layer_name][concept]['global']}" + ) + analysis.plot_metric_boxes( + metrics[layer_name][concept]["per_pixel"], + title=f"{layer_name}/{concept}", + save_to=ASSETS_FOLDER / "figures" / "sanity_checks" / f"{layer_name}_{concept}_pixel_boxes.png", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("fgsm-probing") + parser.add_argument("--mode", type=str, default="torch_clf") + parser.add_argument( + "--dataset_name", + type=str, + default="mulsi/fruit-vegetable-concepts", + ) + parser.add_argument("--probe_ref", type=str, default=None) + parser.add_argument("--only_labeled", action=argparse.BooleanOptionalAction, default=False) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) + + # Probe refs + # 29c94861ed9922843d4821f23e7e44fbb30f2de4 -> 3 CLF pre-labeling + # ? -> 12 CLF all post-labeling + # ? -> 12 CLF only_labeled post-labeling diff --git a/scripts/constants.py b/scripts/constants.py index a0fb81d..ea71aff 100644 --- a/scripts/constants.py +++ b/scripts/constants.py @@ -125,7 +125,8 @@ "carrot", "turnip", "mango", - "capsicum" "sweetpotato", + "capsicum", + "sweetpotato", ] assert len(CLASSES) == len(CLASS_CONCEPTS_VALUES.keys()) diff --git a/src/mulsi/analysis.py b/src/mulsi/analysis.py index 03a7739..127b527 100644 --- a/src/mulsi/analysis.py +++ b/src/mulsi/analysis.py @@ -238,3 +238,21 @@ def plot_mean_proba_through_layers( plt.close() else: plt.show() + + +def plot_metric_boxes( + data, + title=None, + save_to=None, +): + labels = next(iter(data.values())).keys() + boxed_data = list(zip(*[m.values() for m in data.values()])) + plt.boxplot(boxed_data, notch=True, vert=True, patch_artist=True, labels=labels) + plt.legend() + plt.ylabel("Metric value") + plt.title(title) + if save_to is not None: + plt.savefig(save_to) + plt.close() + else: + plt.show() From c84d3bc75a9efe64467825760a143fc64b151ec2 Mon Sep 17 00:00:00 2001 From: Xmaster6y <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 May 2024 18:26:51 +0200 Subject: [PATCH 3/4] only_labeled save --- scripts/analysis/probes_sanity_checks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/analysis/probes_sanity_checks.py b/scripts/analysis/probes_sanity_checks.py index 55e3489..d54a973 100644 --- a/scripts/analysis/probes_sanity_checks.py +++ b/scripts/analysis/probes_sanity_checks.py @@ -81,7 +81,8 @@ def main(args: argparse.Namespace): revision=args.probe_ref, ) - os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks") + subfolder = "only_labeled" if args.only_labeled else "all" + os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks" / subfolder) metrics = {} for layer_name in LAYER_NAMES: metrics[layer_name] = {} @@ -120,7 +121,11 @@ def main(args: argparse.Namespace): analysis.plot_metric_boxes( metrics[layer_name][concept]["per_pixel"], title=f"{layer_name}/{concept}", - save_to=ASSETS_FOLDER / "figures" / "sanity_checks" / f"{layer_name}_{concept}_pixel_boxes.png", + save_to=ASSETS_FOLDER + / "figures" + / "sanity_checks" + / subfolder + / f"{layer_name}_{concept}_pixel_boxes.png", ) From 8563203c3ae50d208c8ea392d053ec10889ce40d Mon Sep 17 00:00:00 2001 From: Xmaster6y <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 May 2024 18:32:02 +0200 Subject: [PATCH 4/4] exist_ok --- scripts/analysis/probes_sanity_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/analysis/probes_sanity_checks.py b/scripts/analysis/probes_sanity_checks.py index d54a973..5950ea9 100644 --- a/scripts/analysis/probes_sanity_checks.py +++ b/scripts/analysis/probes_sanity_checks.py @@ -82,7 +82,7 @@ def main(args: argparse.Namespace): ) subfolder = "only_labeled" if args.only_labeled else "all" - os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks" / subfolder) + os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks" / subfolder, exist_ok=True) metrics = {} for layer_name in LAYER_NAMES: metrics[layer_name] = {}