diff --git a/.gitignore b/.gitignore index 97a02fa..3e3395e 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ doc/build/html/index.html doc/tmp.sv env/ doc/source/savefig/ + +# Torch hub cache # +################### +src/actinet/torch_hub_cache/* \ No newline at end of file diff --git a/setup.py b/setup.py index 3e4fc40..cfbf6d9 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def main(): "scipy==1.10.*", "pandas==2.0.*", "tqdm==4.64.*", - "matplotlib==3.1.*", + "matplotlib==3.5.*", "joblib==1.2.*", "scikit-learn==1.1.1", "imbalanced-learn==0.9.1", diff --git a/src/actinet/__init__.py b/src/actinet/__init__.py index 83307d9..0565556 100644 --- a/src/actinet/__init__.py +++ b/src/actinet/__init__.py @@ -5,7 +5,7 @@ __license__ = "See LICENSE file." __model_version__ = "ssl_ukb_c24_rw_20240204" -__model_md5__ = "c13390ff024e4714eb027aa66160e2eb" +__model_md5__ = "84f3d5bb73de5c4da057918c45400da4" from . import _version diff --git a/src/actinet/actinet.py b/src/actinet/actinet.py index 618dce2..1f8e452 100644 --- a/src/actinet/actinet.py +++ b/src/actinet/actinet.py @@ -20,7 +20,7 @@ from actinet.summarisation import getActivitySummary, ACTIVITY_LABELS from actinet.utils.utils import infer_freq -BASE_URL = "https://zenodo.org/records/10616280/files/" +BASE_URL = "https://zenodo.org/records/10619096/files/" def main(): @@ -29,7 +29,7 @@ def main(): description="A tool to predict activities from accelerometer data using a self-supervised Resnet 18 model", add_help=True, ) - parser.add_argument("filepath", help="Enter file to be processed") + parser.add_argument("--filepath", "-f", help="Enter file to be processed") parser.add_argument( "--outdir", "-o", @@ -67,6 +67,9 @@ def main(): action="store_true", help="Download and cache ssl module for offline usage", ) + parser.add_argument( + "--ssl-repo-path", "-s", help="Enter repository of ssl model", default=None + ) parser.add_argument("--quiet", "-q", action="store_true", help="Suppress output") args = parser.parse_args() @@ -75,7 +78,7 @@ def main(): verbose = not args.quiet if args.cache_ssl: - model = ActivityClassifier(weights_path=True, ssl_repo=None, verbose=verbose) + model = ActivityClassifier(weights_path=None, ssl_repo=None, verbose=verbose, labels=ACTIVITY_LABELS) after = time.time() print(f"Done! ({round(after - before,2)}s)") @@ -101,7 +104,7 @@ def main(): model_path = pathlib.Path(__file__).parent / f"{__model_version__}.joblib.lzma" check_md5 = args.model_path is None model: ActivityClassifier = load_model( - args.model_path or model_path, check_md5, args.force_download + args.model_path or model_path, check_md5, args.force_download, verbose ) model.verbose = verbose @@ -237,7 +240,7 @@ def resolve_path(path): return dirname, filename, extension -def load_model(model_path, check_md5=True, force_download=False): +def load_model(model_path, ssl_repo_path=None, check_md5=True, force_download=False, verbose=True): """Load trained model. Download if not exists.""" pth = pathlib.Path(model_path) @@ -246,7 +249,8 @@ def load_model(model_path, check_md5=True, force_download=False): url = f"{BASE_URL}{__model_version__}.joblib.lzma" - print(f"Downloading {url}...") + if verbose: + print(f"Downloading {url}...") with urllib.request.urlopen(url) as f_src, open(pth, "wb") as f_dst: shutil.copyfileobj(f_src, f_dst) @@ -257,7 +261,15 @@ def load_model(model_path, check_md5=True, force_download=False): "to download the model file again." ) - return joblib.load(pth) + model: ActivityClassifier = joblib.load(pth) + + if ssl_repo_path and pathlib.Path(ssl_repo_path).exists(): + if verbose: + print(f"Loading ssl repository from {ssl_repo_path}.") + + model = model.load_ssl(ssl_repo_path) + + return model def md5(fname): diff --git a/src/actinet/models.py b/src/actinet/models.py index f0a272e..f2769c0 100644 --- a/src/actinet/models.py +++ b/src/actinet/models.py @@ -13,7 +13,7 @@ def __init__( device="cpu", batch_size=512, window_sec=30, - weights_path="state_dict.pt", + weights_path=None, labels=[], ssl_repo=None, repo_tag="v1.0.0", @@ -21,16 +21,15 @@ def __init__( verbose=False, ): self.device = device - self.weights_path = weights_path self.repo_tag = repo_tag self.batch_size = batch_size self.window_sec = window_sec - self.state_dict = None self.labels = labels self.window_len = int(np.ceil(self.window_sec * sslmodel.SAMPLE_RATE)) self.verbose = verbose - - self.model = self._load_ssl(ssl_repo, weights_path) + + self.model_weights = sslmodel.get_model_dict(weights_path, device) if weights_path else None + self.model = self.load_ssl(ssl_repo) hmm_params = hmm_params or dict() self.hmms = hmm.HMM(**hmm_params) @@ -90,12 +89,11 @@ def _predict(self, X): return y_pred - def _load_ssl(self, ssl_repo, weights): + def load_ssl(self, ssl_repo): model = sslmodel.get_sslnet( - self.device, tag=self.repo_tag, local_repo_path=ssl_repo, - pretrained=weights, + pretrained_weights = self.model_weights or True, window_sec=self.window_sec, num_labels=len(self.labels), ) diff --git a/src/actinet/sslmodel.py b/src/actinet/sslmodel.py index 7daaed4..d4c8c90 100644 --- a/src/actinet/sslmodel.py +++ b/src/actinet/sslmodel.py @@ -1,5 +1,6 @@ """ Helper classes and functions for the SSL model """ +from collections import OrderedDict import torch import torch.nn as nn import numpy as np @@ -184,10 +185,9 @@ def save_checkpoint(self, val_loss, model): def get_sslnet( - device, tag="v1.0.0", local_repo_path=None, - pretrained=False, + pretrained_weights=False, window_sec: int = 30, num_labels: int = 4, ): @@ -196,8 +196,8 @@ def get_sslnet( :param str device: PyTorch device to use :param str tag: Tag on the ssl-wearables repo to check out - :param str local_repo_path: Path to local version of the SSL repo for offline usage - :param bool/str pretrained: Initialise the model with UKB self-supervised pretrained weights + :param str local_repo_path: Path to local version of the SSL repository for offline usage + :param bool/OrderedDict pretrained_weights: Initialise the model with UKB self-supervised/specified pretrained weights :param int window_sec: The length of the window of data in seconds (limited to 5, 10 or 30) :param int num_labels: The number of labels to predict :return: pytorch SSL model @@ -218,7 +218,7 @@ def get_sslnet( f"harnet{window_sec}", source="local", class_num=num_labels, - pretrained=pretrained == True, + pretrained=pretrained_weights == True, ) else: @@ -251,16 +251,20 @@ def get_sslnet( trust_repo=True, source=source, class_num=num_labels, - pretrained=pretrained == True, + pretrained=pretrained_weights == True, verbose=verbose, ) - model_dict = torch.load(pretrained, map_location=device) - sslnet.load_state_dict(model_dict) + if isinstance(pretrained_weights, OrderedDict): + sslnet.load_state_dict(pretrained_weights) return sslnet +def get_model_dict(weights_path, device): + return torch.load(weights_path, map_location=device) + + def predict(model, dataloader, device, output_logits=False): """ Iterate over the dataloader and do prediction with a pytorch model.