Skip to content

Commit

Permalink
Bug fixes for offline model
Browse files Browse the repository at this point in the history
  • Loading branch information
“Aidan committed Feb 5, 2024
1 parent 6f31338 commit 5857a41
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 25 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ doc/build/html/index.html
doc/tmp.sv
env/
doc/source/savefig/

# Torch hub cache #
###################
src/actinet/torch_hub_cache/*
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main():
"scipy==1.10.*",
"pandas==2.0.*",
"tqdm==4.64.*",
"matplotlib==3.1.*",
"matplotlib==3.5.*",
"joblib==1.2.*",
"scikit-learn==1.1.1",
"imbalanced-learn==0.9.1",
Expand Down
2 changes: 1 addition & 1 deletion src/actinet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__license__ = "See LICENSE file."

__model_version__ = "ssl_ukb_c24_rw_20240204"
__model_md5__ = "c13390ff024e4714eb027aa66160e2eb"
__model_md5__ = "84f3d5bb73de5c4da057918c45400da4"

from . import _version

Expand Down
26 changes: 19 additions & 7 deletions src/actinet/actinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from actinet.summarisation import getActivitySummary, ACTIVITY_LABELS
from actinet.utils.utils import infer_freq

BASE_URL = "https://zenodo.org/records/10616280/files/"
BASE_URL = "https://zenodo.org/records/10619096/files/"


def main():
Expand All @@ -29,7 +29,7 @@ def main():
description="A tool to predict activities from accelerometer data using a self-supervised Resnet 18 model",
add_help=True,
)
parser.add_argument("filepath", help="Enter file to be processed")
parser.add_argument("--filepath", "-f", help="Enter file to be processed")
parser.add_argument(
"--outdir",
"-o",
Expand Down Expand Up @@ -67,6 +67,9 @@ def main():
action="store_true",
help="Download and cache ssl module for offline usage",
)
parser.add_argument(
"--ssl-repo-path", "-s", help="Enter repository of ssl model", default=None
)
parser.add_argument("--quiet", "-q", action="store_true", help="Suppress output")
args = parser.parse_args()

Expand All @@ -75,7 +78,7 @@ def main():
verbose = not args.quiet

if args.cache_ssl:
model = ActivityClassifier(weights_path=True, ssl_repo=None, verbose=verbose)
model = ActivityClassifier(weights_path=None, ssl_repo=None, verbose=verbose, labels=ACTIVITY_LABELS)

after = time.time()
print(f"Done! ({round(after - before,2)}s)")
Expand All @@ -101,7 +104,7 @@ def main():
model_path = pathlib.Path(__file__).parent / f"{__model_version__}.joblib.lzma"
check_md5 = args.model_path is None
model: ActivityClassifier = load_model(
args.model_path or model_path, check_md5, args.force_download
args.model_path or model_path, check_md5, args.force_download, verbose
)

model.verbose = verbose
Expand Down Expand Up @@ -237,7 +240,7 @@ def resolve_path(path):
return dirname, filename, extension


def load_model(model_path, check_md5=True, force_download=False):
def load_model(model_path, ssl_repo_path=None, check_md5=True, force_download=False, verbose=True):
"""Load trained model. Download if not exists."""

pth = pathlib.Path(model_path)
Expand All @@ -246,7 +249,8 @@ def load_model(model_path, check_md5=True, force_download=False):

url = f"{BASE_URL}{__model_version__}.joblib.lzma"

print(f"Downloading {url}...")
if verbose:
print(f"Downloading {url}...")

with urllib.request.urlopen(url) as f_src, open(pth, "wb") as f_dst:
shutil.copyfileobj(f_src, f_dst)
Expand All @@ -257,7 +261,15 @@ def load_model(model_path, check_md5=True, force_download=False):
"to download the model file again."
)

return joblib.load(pth)
model: ActivityClassifier = joblib.load(pth)

if ssl_repo_path and pathlib.Path(ssl_repo_path).exists():
if verbose:
print(f"Loading ssl repository from {ssl_repo_path}.")

model = model.load_ssl(ssl_repo_path)

return model


def md5(fname):
Expand Down
14 changes: 6 additions & 8 deletions src/actinet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@ def __init__(
device="cpu",
batch_size=512,
window_sec=30,
weights_path="state_dict.pt",
weights_path=None,
labels=[],
ssl_repo=None,
repo_tag="v1.0.0",
hmm_params=None,
verbose=False,
):
self.device = device
self.weights_path = weights_path
self.repo_tag = repo_tag
self.batch_size = batch_size
self.window_sec = window_sec
self.state_dict = None
self.labels = labels
self.window_len = int(np.ceil(self.window_sec * sslmodel.SAMPLE_RATE))
self.verbose = verbose

self.model = self._load_ssl(ssl_repo, weights_path)

self.model_weights = sslmodel.get_model_dict(weights_path, device) if weights_path else None
self.model = self.load_ssl(ssl_repo)

hmm_params = hmm_params or dict()
self.hmms = hmm.HMM(**hmm_params)
Expand Down Expand Up @@ -90,12 +89,11 @@ def _predict(self, X):

return y_pred

def _load_ssl(self, ssl_repo, weights):
def load_ssl(self, ssl_repo):
model = sslmodel.get_sslnet(
self.device,
tag=self.repo_tag,
local_repo_path=ssl_repo,
pretrained=weights,
pretrained_weights = self.model_weights or True,
window_sec=self.window_sec,
num_labels=len(self.labels),
)
Expand Down
20 changes: 12 additions & 8 deletions src/actinet/sslmodel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Helper classes and functions for the SSL model """

from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
Expand Down Expand Up @@ -184,10 +185,9 @@ def save_checkpoint(self, val_loss, model):


def get_sslnet(
device,
tag="v1.0.0",
local_repo_path=None,
pretrained=False,
pretrained_weights=False,
window_sec: int = 30,
num_labels: int = 4,
):
Expand All @@ -196,8 +196,8 @@ def get_sslnet(
:param str device: PyTorch device to use
:param str tag: Tag on the ssl-wearables repo to check out
:param str local_repo_path: Path to local version of the SSL repo for offline usage
:param bool/str pretrained: Initialise the model with UKB self-supervised pretrained weights
:param str local_repo_path: Path to local version of the SSL repository for offline usage
:param bool/OrderedDict pretrained_weights: Initialise the model with UKB self-supervised/specified pretrained weights
:param int window_sec: The length of the window of data in seconds (limited to 5, 10 or 30)
:param int num_labels: The number of labels to predict
:return: pytorch SSL model
Expand All @@ -218,7 +218,7 @@ def get_sslnet(
f"harnet{window_sec}",
source="local",
class_num=num_labels,
pretrained=pretrained == True,
pretrained=pretrained_weights == True,
)

else:
Expand Down Expand Up @@ -251,16 +251,20 @@ def get_sslnet(
trust_repo=True,
source=source,
class_num=num_labels,
pretrained=pretrained == True,
pretrained=pretrained_weights == True,
verbose=verbose,
)

model_dict = torch.load(pretrained, map_location=device)
sslnet.load_state_dict(model_dict)
if isinstance(pretrained_weights, OrderedDict):
sslnet.load_state_dict(pretrained_weights)

return sslnet


def get_model_dict(weights_path, device):
return torch.load(weights_path, map_location=device)


def predict(model, dataloader, device, output_logits=False):
"""
Iterate over the dataloader and do prediction with a pytorch model.
Expand Down

0 comments on commit 5857a41

Please sign in to comment.