Skip to content

Commit

Permalink
Use importlib.resources to get paths (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtritt authored Jun 14, 2023
1 parent 9e988c6 commit 30957e9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/gtnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import hashlib
from importlib.resources import files
import json
import logging
import os
Expand All @@ -9,7 +10,6 @@
import zipfile

import numpy as np
from pkg_resources import resource_filename
import torch
import torch.nn as nn

Expand Down Expand Up @@ -40,15 +40,15 @@ class DeployPkg:

@classmethod
def check_pkg(cls):
deploy_dir = resource_filename(__name__, 'deploy_pkg')
deploy_dir = files(__package__).joinpath('deploy_pkg')
total = 0
for path in glob.glob(f"{deploy_dir}/*"):
total += os.path.getsize(path)
if total == 0:
msg = ("Downloading GTNet deployment package. This will only happen on the first invocation "
"of gtnet predict or gtnet classify")
warnings.warn(msg)
zip_path = resource_filename(__name__, 'deploy_pkg.zip')
zip_path = files(__package__).joinpath('deploy_pkg.zip')
urllib.request.urlretrieve(cls._deploy_pkg_url, zip_path)
dl_checksum = hashlib.md5(open(zip_path,'rb').read()).hexdigest()
if dl_checksum != cls._checksum:
Expand Down

0 comments on commit 30957e9

Please sign in to comment.