From 14343aaae46ebb15551aa116221b1e91c9cf4b96 Mon Sep 17 00:00:00 2001 From: Aidan Date: Sun, 4 Feb 2024 20:30:21 +0000 Subject: [PATCH] Adapt code to .lzma file --- src/actinet/__init__.py | 2 +- src/actinet/actinet.py | 8 ++-- src/actinet/hmm.py | 16 +++++-- src/actinet/models.py | 50 +++++++++++++------- src/actinet/sslmodel.py | 91 ++++++++++++++++++++++-------------- src/actinet/summarisation.py | 4 +- 6 files changed, 107 insertions(+), 64 deletions(-) diff --git a/src/actinet/__init__.py b/src/actinet/__init__.py index fe2146a..7469a36 100644 --- a/src/actinet/__init__.py +++ b/src/actinet/__init__.py @@ -4,7 +4,7 @@ __maintainer_email__ = "shing.chan@ndph.ox.ac.uk" __license__ = "See LICENSE file." -__model_version__ = "ssl-ukb-c24-rw" +__model_version__ = "ssl_ukb_c24_rw_20240204" __model_md5__ = "" from . import _version diff --git a/src/actinet/actinet.py b/src/actinet/actinet.py index 97c7702..3b77ddb 100644 --- a/src/actinet/actinet.py +++ b/src/actinet/actinet.py @@ -20,6 +20,8 @@ from actinet.summarisation import getActivitySummary, ACTIVITY_LABELS from actinet.utils.utils import infer_freq +BASE_URL = "https://zenodo.org/records/10616280/files/" + def main(): @@ -86,7 +88,7 @@ def main(): model_path = pathlib.Path(__file__).parent / f"{__model_version__}.joblib.lzma" check_md5 = args.model_path is None model: ActivityClassifier = load_model( - args.model_path or model_path, args.model_type, check_md5, args.force_download + args.model_path or model_path, check_md5, args.force_download ) model.verbose = verbose @@ -222,14 +224,14 @@ def resolve_path(path): return dirname, filename, extension -def load_model(model_path, model_type, check_md5=True, force_download=False): +def load_model(model_path, check_md5=True, force_download=False): """Load trained model. Download if not exists.""" pth = pathlib.Path(model_path) if force_download or not pth.exists(): - url = f"https://wearables-files.ndph.ox.ac.uk/files/models/stepcount/{__model_version__}.joblib.lzma" + url = f"{BASE_URL}{__model_version__}.joblib.lzma" print(f"Downloading {url}...") diff --git a/src/actinet/hmm.py b/src/actinet/hmm.py index 7456a79..7af26f6 100644 --- a/src/actinet/hmm.py +++ b/src/actinet/hmm.py @@ -8,15 +8,23 @@ class HMM: Implement a basic HMM model with parameter saving/loading. """ - def __init__(self, labels=None, uniform_prior=True): - self.prior = None - self.emission = None - self.transition = None + def __init__( + self, + prior=None, + emission=None, + transition=None, + labels=None, + uniform_prior=True, + ): + self.prior = prior + self.emission = emission + self.transition = transition self.labels = labels self.uniform_prior = uniform_prior def __str__(self): return ( + "Hidden Markov Model\n" "prior: {prior}\n" "emission: {emission}\n" "transition: {transition}\n" diff --git a/src/actinet/models.py b/src/actinet/models.py index 4c0a640..f0a272e 100644 --- a/src/actinet/models.py +++ b/src/actinet/models.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -from sklearn.preprocessing import LabelEncoder from tqdm.auto import tqdm from torch.utils.data import DataLoader @@ -16,6 +15,7 @@ def __init__( window_sec=30, weights_path="state_dict.pt", labels=[], + ssl_repo=None, repo_tag="v1.0.0", hmm_params=None, verbose=False, @@ -26,14 +26,26 @@ def __init__( self.batch_size = batch_size self.window_sec = window_sec self.state_dict = None - self.label_encoder = LabelEncoder().fit(labels) + self.labels = labels self.window_len = int(np.ceil(self.window_sec * sslmodel.SAMPLE_RATE)) - self.verbose = verbose + self.model = self._load_ssl(ssl_repo, weights_path) + hmm_params = hmm_params or dict() self.hmms = hmm.HMM(**hmm_params) + def __str__(self): + return ( + "Activity Classifier\n" + "class_labels: {self.labels}\n" + "window_length: {self.window_sec}\n" + "batch_size: {self.batch_size}\n" + "device: {self.device}\n" + "hmm: {self.hmms}\n" + "model: {self.model}".format(self=self) + ) + def predict_from_frame(self, data): def fn(chunk): @@ -55,16 +67,14 @@ def fn(chunk): data, self.window_sec, fn=fn, return_index=True, verbose=self.verbose ) - Y_labels = self.label_encoder.inverse_transform(self._predict(X)) - - Y = raw_to_df(X, Y_labels, T, self.label_encoder.classes_, reindex=False) + Y = raw_to_df(X, self._predict(X), T, self.labels, reindex=False) return Y - def _predict(self, X, groups=None): + def _predict(self, X): sslmodel.verbose = self.verbose - dataset = sslmodel.NormalDataset(X, name="prediction") + dataset = sslmodel.NormalDataset(X) dataloader = DataLoader( dataset, batch_size=self.batch_size, @@ -72,22 +82,26 @@ def _predict(self, X, groups=None): num_workers=0, ) + _, y_pred, _ = sslmodel.predict( + self.model, dataloader, self.device, output_logits=False + ) + + y_pred = self.hmms.predict(y_pred) + + return y_pred + + def _load_ssl(self, ssl_repo, weights): model = sslmodel.get_sslnet( + self.device, tag=self.repo_tag, - pretrained=False, + local_repo_path=ssl_repo, + pretrained=weights, window_sec=self.window_sec, - num_labels=len(self.label_encoder.classes_), + num_labels=len(self.labels), ) - model.load_state_dict(self.state_dict) model.to(self.device) - _, y_pred, _ = sslmodel.predict( - model, dataloader, self.device, output_logits=False - ) - - y_pred = self.hmms.predict(y_pred, groups=groups) - - return y_pred + return model def make_windows(data, window_sec, fn=None, return_index=False, verbose=True): diff --git a/src/actinet/sslmodel.py b/src/actinet/sslmodel.py index e130550..7daaed4 100644 --- a/src/actinet/sslmodel.py +++ b/src/actinet/sslmodel.py @@ -184,42 +184,26 @@ def save_checkpoint(self, val_loss, model): def get_sslnet( - tag="v1.0.0", pretrained=False, window_sec: int = 30, num_labels: int = 4 + device, + tag="v1.0.0", + local_repo_path=None, + pretrained=False, + window_sec: int = 30, + num_labels: int = 4, ): """ - Load and return the Self Supervised Learning (SSL) model from pytorch hub. + Load and return the Self Supervised Learning (SSL) model from pytorch hub or local storage. + :param str device: PyTorch device to use :param str tag: Tag on the ssl-wearables repo to check out - :param bool pretrained: Initialise the model with UKB self-supervised pretrained weights + :param str local_repo_path: Path to local version of the SSL repo for offline usage + :param bool/str pretrained: Initialise the model with UKB self-supervised pretrained weights :param int window_sec: The length of the window of data in seconds (limited to 5, 10 or 30) :param int num_labels: The number of labels to predict :return: pytorch SSL model :rtype: nn.Module """ - repo_name = "ssl-wearables" - repo = f"OxWearables/{repo_name}:{tag}" - - if not torch_cache_path.exists(): - Path.mkdir(torch_cache_path, parents=True, exist_ok=True) - - torch.hub.set_dir(str(torch_cache_path)) - - # find repo cache dir that matches repo name and tag - cache_dirs = [f for f in torch_cache_path.iterdir() if f.is_dir()] - repo_path = next( - (f for f in cache_dirs if repo_name in f.name and tag in f.name), None - ) - - if repo_path is None: - repo_path = repo - source = "github" - else: - repo_path = str(repo_path) - source = "local" - if verbose: - print(f"Using local {repo_path}") - if window_sec not in [5, 10, 30]: raise ValueError( "Length of window in seconds must be either 5, 10 or 30 seconds" @@ -228,15 +212,52 @@ def get_sslnet( if num_labels < 1: raise ValueError("Numer of class labels should be > 0") - sslnet: nn.Module = torch.hub.load( - repo_path, - f"harnet{window_sec}", - trust_repo=True, - source=source, - class_num=num_labels, - pretrained=pretrained, - verbose=verbose, - ) + if local_repo_path is not None: + sslnet: nn.Module = torch.hub.load( + local_repo_path, + f"harnet{window_sec}", + source="local", + class_num=num_labels, + pretrained=pretrained == True, + ) + + else: + repo_name = "ssl-wearables" + repo = f"OxWearables/{repo_name}:{tag}" + + if not torch_cache_path.exists(): + Path.mkdir(torch_cache_path, parents=True, exist_ok=True) + + torch.hub.set_dir(str(torch_cache_path)) + + # find repo cache dir that matches repo name and tag + cache_dirs = [f for f in torch_cache_path.iterdir() if f.is_dir()] + repo_path = next( + (f for f in cache_dirs if repo_name in f.name and tag in f.name), None + ) + + if repo_path is None: + repo_path = repo + source = "github" + else: + repo_path = str(repo_path) + source = "local" + if verbose: + print(f"Using local {repo_path}") + + sslnet: nn.Module = torch.hub.load( + repo_path, + f"harnet{window_sec}", + trust_repo=True, + source=source, + class_num=num_labels, + pretrained=pretrained == True, + verbose=verbose, + ) + + model_dict = torch.load(pretrained, map_location=device) + sslnet.load_state_dict(model_dict) + return sslnet diff --git a/src/actinet/summarisation.py b/src/actinet/summarisation.py index dad4b47..c81e4c4 100644 --- a/src/actinet/summarisation.py +++ b/src/actinet/summarisation.py @@ -7,7 +7,7 @@ from actinet.utils.utils import date_parser, toScreen from actinet import circadian -ACTIVITY_LABELS = ["light", "MVPA", "sedentary", "sleep"] +ACTIVITY_LABELS = ["light", "moderate-vigorous", "sedentary", "sleep"] def getActivitySummary( @@ -98,8 +98,6 @@ def _summarise( for col in cols: summary[f"day{i}-recorded-{col}(hrs)"] = row.loc[col] - summary["day_avg"] - # Calculate empirical cumulative distribution function of vector magnitudes if intensityDistribution: summary = calculateECDF(data["acc"], summary)