Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute scores for probes #33

Merged
merged 4 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
"module": "scripts.make_activation_dataset",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Script Probes Sanity Checks",
"type": "debugpy",
"request": "launch",
"module": "scripts.analysis.probes_sanity_checks",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
152 changes: 152 additions & 0 deletions scripts/analysis/probes_sanity_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Script analyse FGSM adversarial images.

Run with:
```
poetry run python -m scripts.analysis.probes_sanity_checks
```
"""

import argparse
from typing import List
import os

import einops
import torch
from datasets import load_dataset
from huggingface_hub import HfApi
from loguru import logger

from sklearn.metrics import f1_score, recall_score, precision_score

from mulsi.adversarial import LRClfLoss
from scripts.constants import HF_TOKEN, ASSETS_FOLDER, LABELED_CLASSES, CLASSES
from mulsi import analysis

LAYER_NAMES = ["layers.0", "layers.6", "layers.11"]
CONCEPTS = ["yellow", "red", "sphere", "ovaloid"]

hf_api = HfApi(token=HF_TOKEN)


def probe_single_eval(y_true, y_pred):
metrics = {}
metrics["precision"] = precision_score(y_true, y_pred)
metrics["recall"] = recall_score(y_true, y_pred)
metrics["f1"] = f1_score(y_true, y_pred)
return metrics


def eval_probe(
probe: LRClfLoss,
activation: torch.Tensor,
labels: torch.Tensor,
indices: torch.Tensor,
classes: List[str],
selected_classes,
):
predictions = probe(activation) > 0
global_metrics = probe_single_eval(labels, predictions)

per_pixel_metrics = {}
for i in range(50):
bool_index = indices == i
per_pixel_metrics[i] = probe_single_eval(labels[bool_index], predictions[bool_index])

per_class_metrics = {}
# for class_name in selected_classes:
# bool_index = torch.tensor([c == class_name for c in classes])
# per_class_metrics[class_name] = probe_single_eval(labels[bool_index], predictions[bool_index])

return {"global": global_metrics, "per_pixel": per_pixel_metrics, "per_class": per_class_metrics}


def map_fn(s_batched):
b, p, h = s_batched["activation"].shape
new_s_batched = {}
new_s_batched["pixel_activation"] = einops.rearrange(s_batched["activation"], "b p h -> (b p) h")
new_s_batched["pixel_label"] = einops.repeat(s_batched["label"], "b -> (b p)", p=p)
new_s_batched["pixel_class"] = [s_batched["class"][i] for i in range(b) for _ in range(p)]
new_s_batched["pixel_index"] = einops.repeat(torch.arange(p), "p -> (b p)", b=b)
return new_s_batched


def main(args: argparse.Namespace):
dataset_name = args.dataset_name

# Download probes dataset
hf_api.snapshot_download(
repo_id=dataset_name.replace("concepts", "probes"),
repo_type="model",
local_dir=ASSETS_FOLDER / dataset_name.replace("concepts", "probes"),
revision=args.probe_ref,
)

subfolder = "only_labeled" if args.only_labeled else "all"
os.makedirs(ASSETS_FOLDER / "figures" / "sanity_checks" / subfolder, exist_ok=True)
metrics = {}
for layer_name in LAYER_NAMES:
metrics[layer_name] = {}

# Download activations dataset
ds_activations = load_dataset(
args.dataset_name.replace("concepts", "activations"), split="test", name=layer_name
)
selected_classes = LABELED_CLASSES if args.only_labeled else CLASSES
init_ds = ds_activations.filter(lambda s: s["class"] in selected_classes)

for concept in CONCEPTS:
filtered_ds = init_ds.filter(lambda s: s[concept] is not None)
labeled_ds = filtered_ds.rename_column(concept, "label")
labeled_ds = labeled_ds.class_encode_column("label")
torch_ds = labeled_ds.select_columns(["activation", "label", "class"]).with_format("torch")
pred_dataset = torch_ds.map(map_fn, remove_columns=["activation", "label", "class"], batched=True)

with open(
ASSETS_FOLDER / f"{dataset_name.replace('concepts', 'probes')}/data/{layer_name}/{concept}/clf.pt",
"rb",
) as f:
probe = torch.load(f)

metrics[layer_name][concept] = eval_probe(
probe,
pred_dataset["pixel_activation"],
pred_dataset["pixel_label"],
pred_dataset["pixel_index"],
pred_dataset["pixel_class"],
selected_classes,
)
logger.info(
f"Layer: {layer_name}, Concept: {concept}, Global metrics: {metrics[layer_name][concept]['global']}"
)
analysis.plot_metric_boxes(
metrics[layer_name][concept]["per_pixel"],
title=f"{layer_name}/{concept}",
save_to=ASSETS_FOLDER
/ "figures"
/ "sanity_checks"
/ subfolder
/ f"{layer_name}_{concept}_pixel_boxes.png",
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("fgsm-probing")
parser.add_argument("--mode", type=str, default="torch_clf")
parser.add_argument(
"--dataset_name",
type=str,
default="mulsi/fruit-vegetable-concepts",
)
parser.add_argument("--probe_ref", type=str, default=None)
parser.add_argument("--only_labeled", action=argparse.BooleanOptionalAction, default=False)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
main(args)

# Probe refs
# 29c94861ed9922843d4821f23e7e44fbb30f2de4 -> 3 CLF pre-labeling
# ? -> 12 CLF all post-labeling
# ? -> 12 CLF only_labeled post-labeling
4 changes: 4 additions & 0 deletions scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@
"cabbage",
"bell pepper",
"carrot",
"turnip",
"mango",
"capsicum",
"sweetpotato",
]

assert len(CLASSES) == len(CLASS_CONCEPTS_VALUES.keys())
18 changes: 18 additions & 0 deletions src/mulsi/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,21 @@ def plot_mean_proba_through_layers(
plt.close()
else:
plt.show()


def plot_metric_boxes(
data,
title=None,
save_to=None,
):
labels = next(iter(data.values())).keys()
boxed_data = list(zip(*[m.values() for m in data.values()]))
plt.boxplot(boxed_data, notch=True, vert=True, patch_artist=True, labels=labels)
plt.legend()
plt.ylabel("Metric value")
plt.title(title)
if save_to is not None:
plt.savefig(save_to)
plt.close()
else:
plt.show()
Loading