Skip to content

Commit

Permalink
Merge pull request #51 from Giskard-AI/perturbation-detectors
Browse files Browse the repository at this point in the history
working on perturbation detectors
  • Loading branch information
rabah-khalek authored Aug 13, 2024
2 parents 677cf94 + ffbb425 commit b9ab6b8
Show file tree
Hide file tree
Showing 33 changed files with 403 additions and 179 deletions.
15 changes: 1 addition & 14 deletions giskard_vision/core/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,10 @@
get_image_channel_number,
get_image_size,
)
from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import AttributesIssueMeta

from ..types import TypesBase

EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)


class DataIteratorBase(ABC):
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.
Expand Down
3 changes: 2 additions & 1 deletion giskard_vision/core/dataloaders/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
from giskard_vision.core.dataloaders.base import DataIteratorBase
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
from giskard_vision.core.issues import AttributesIssueMeta
from giskard_vision.utils.errors import GiskardError, GiskardImportError


Expand Down
2 changes: 1 addition & 1 deletion giskard_vision/core/dataloaders/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import IssueGroup


class MetaData:
Expand Down
72 changes: 72 additions & 0 deletions giskard_vision/core/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,78 @@ def get_image(self, idx: int) -> np.ndarray:
return cv2.GaussianBlur(image, self._kernel_size, *self._sigma)


class NoisyDataLoader(DataLoaderWrapper):
"""Wrapper class for a DataIteratorBase, providing noisy images.
Args:
dataloader (DataIteratorBase): The data loader to be wrapped.
sigma (float): Standard deviation of the Gaussian noise.
Returns:
NoisyDataLoader: Noisy data loader instance.
"""

def __init__(
self,
dataloader: DataIteratorBase,
sigma: float = 0.1,
) -> None:
"""
Initializes the BlurredDataLoader.
Args:
dataloader (DataIteratorBase): The data loader to be wrapped.
sigma (float): Standard deviation of the Gaussian noise.
"""
super().__init__(dataloader)
self._sigma = sigma

@property
def name(self):
"""
Gets the name of the blurred data loader.
Returns:
str: The name of the blurred data loader.
"""
return "noisy"

def get_image(self, idx: int) -> np.ndarray:
"""
Gets a blurred image using Gaussian blur.
Args:
idx (int): Index of the data.
Returns:
np.ndarray: Blurred image data.
"""
image = super().get_image(idx)
return self.add_gaussian_noise(image, self._sigma * 255)

def add_gaussian_noise(self, image, std_dev):
"""
Add Gaussian noise to the image
Args:
image (np.ndarray): Image
std_dev (float): Standard deviation of the Gaussian noise.
Returns:
np.ndarray: Noisy image
"""
# Generate Gaussian noise
noise = np.random.normal(0, std_dev, image.shape).astype(np.float32)

# Add the noise to the image
noisy_image = cv2.add(image.astype(np.float32), noise)

# Clip the values to stay within valid range (0-255 for uint8)
noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)

return noisy_image


class ColoredDataLoader(DataLoaderWrapper):
"""Wrapper class for a DataIteratorBase, providing color-altered images using OpenCV color conversion.
Expand Down
51 changes: 39 additions & 12 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple

from giskard_vision.core.issues import IssueGroup
from giskard_vision.utils.errors import GiskardImportError


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str
from .specs import DetectorSpecsBase


@dataclass
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
}


class DetectorVisionBase:
class DetectorVisionBase(DetectorSpecsBase):
"""
Abstract class for Vision Detectors
Expand All @@ -67,12 +64,6 @@ class DetectorVisionBase:
evaluation results for the scan.
"""

issue_group: IssueGroup
warning_messages: dict
issue_level_threshold: float = 0.2
deviation_threshold: float = 0.05
num_images: int = 0

def run(
self,
model: Any,
Expand Down Expand Up @@ -139,6 +130,42 @@ def get_issues(

return issues

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)

@abstractmethod
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
"""Returns a list of ScanResult
Expand Down
38 changes: 1 addition & 37 deletions giskard_vision/core/detectors/metadata_scan_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import pandas as pd

from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import PerformanceIssueMeta
from giskard_vision.core.tests.base import MetricBase
from giskard_vision.utils.errors import GiskardImportError

Expand Down Expand Up @@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
pass

return pd.DataFrame(df)

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)
87 changes: 87 additions & 0 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase


class PerturbationBaseDetector(DetectorVisionBase):
"""
Abstract class for Landmark Detection Detectors
Methods:
get_dataloaders(dataset: Any) -> Sequence[Any]:
Abstract method that returns a list of dataloaders corresponding to
slices or transformations
get_results(model: Any, dataset: Any) -> Sequence[ScanResult]:
Returns a list of ScanResult containing the evaluation results
get_scan_result(self, test_result) -> ScanResult:
Convert TestResult to ScanResult
"""

issue_group = Robustness

def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
# Only set attributes that are not part of Python's special attributes (those starting with __)
for attr_name, attr_value in vars(DetectorSpecs).items():
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model_type}")

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
self.set_specs_from_model_type(model.model_type)
dataloaders = self.get_dataloaders(dataset)

results = []
for dl in dataloaders:
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
model=model,
dataloader=dl,
dataloader_ref=dataset,
)

# Save example images from dataloader and dataset
current_path = str(Path())
os.makedirs(f"{current_path}/examples_images", exist_ok=True)
filename_examples = []

index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0]

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
filename_examples.append(filename_example_dataloader)
results.append(
self.get_scan_result(
test_result.metric_value_test,
test_result.metric_value_ref,
test_result.metric_name,
filename_examples,
dl.name,
len(dl),
issue_group=self.issue_group,
)
)

return results
13 changes: 13 additions & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from giskard_vision.core.issues import IssueGroup
from giskard_vision.image_classification.tests.performance import MetricBase


class DetectorSpecsBase:
issue_group: IssueGroup
warning_messages: dict
metric: MetricBase = None
metric_type: str = None
metric_direction: str = None
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import BlurredDataLoader

from ...core.detectors.decorator import maybe_detector
from .base import LandmarkDetectionBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("blurring_landmark", tags=["vision", "face", "landmark", "transformed", "blurred"])
class TransformationBlurringDetectorLandmark(LandmarkDetectionBaseDetector):
@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationBlurringDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on blurred images
"""

issue_group = Robustness

def __init__(self, kernel_size=(11, 11), sigma=(3, 3)):
self.kernel_size = kernel_size
self.sigma = sigma
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader

from ...core.detectors.decorator import maybe_detector
from .base import LandmarkDetectionBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("color_landmark", tags=["vision", "face", "landmark", "filtered", "colored"])
class TransformationColorDetectorLandmark(LandmarkDetectionBaseDetector):
@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationColorDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance depending on images in grayscale
"""

issue_group = Robustness

def get_dataloaders(self, dataset):
dl = ColoredDataLoader(dataset)

Expand Down
Loading

0 comments on commit b9ab6b8

Please sign in to comment.