Skip to content

Commit

Permalink
Features
Browse files Browse the repository at this point in the history
Features
 * FrameGetter supports len()
 * Per-Class Accuracy metric
  • Loading branch information
michael-camilleri committed Mar 2, 2022
1 parent 6bbd112 commit 49ea181
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
14 changes: 14 additions & 0 deletions mpctools/extensions/cvext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Author: Michael P. J. Camilleri
"""
import glob

from numba import jit, uint8, uint16, double
from queue import Queue, Empty, Full
Expand Down Expand Up @@ -1202,6 +1203,19 @@ def __getitem__(self, item):
assert os.path.exists(_pth), f"Image {item} does not exist at {_pth}."
return cv2.imread(_pth)

def __len__(self):
"""
Retrieves the number of Frames
This is 'approximate', by counting the number of files matching the extension.
:return: Length
"""
if self.Fmt.lower() == "video":
return int(cv2.VideoCapture(self.Path).get(cv2.CAP_PROP_FRAME_COUNT))
else:
return len(glob.glob(os.path.join(self.Path, f'*{os.path.splitext(self.Fmt)[1]}')))


class SwCLAHE:
"""
Expand Down
24 changes: 23 additions & 1 deletion mpctools/extensions/skext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,32 @@
"""
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from scipy.spatial.distance import squareform
from mpctools.extensions import npext
from mpctools.extensions import npext, utils
from sklearn import metrics as skmetrics
import numpy as np


def class_accuracy(y_true, y_pred, labels=None, normalize=True):
"""
Computes per-class, one-v-rest accuracy
:param y_true: True labels (N)
:param y_pred: Predicted labels (N)
:param labels: If not None, specifies labels to consider: otherwise any label that appears in
y_true or y_pred is considered
:param normalize: If True, normalise relative to all samples: else report number of samples.
:return: Accuracy-score per-class
"""
# Define Labels
labels = utils.default(labels, np.union1d(np.unique(y_pred), np.unique(y_true)))

# compute per-class accuracy
accuracy = np.empty(len(labels))
for i, lbl in enumerate(labels):
accuracy[i] = skmetrics.accuracy_score(y_true == lbl, y_pred == lbl, normalize=normalize)
return accuracy


def hierarchical_log_loss(y_true, y_prob, mapping, eps=1e-15):
"""
Compute the Log-Loss, when y_true contains over-arching labels which are not predictable in
Expand Down

0 comments on commit 49ea181

Please sign in to comment.