Skip to content

Commit

Permalink
Work of checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 27, 2024
1 parent 453d5f5 commit 10f911e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 24 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ classifiers = [
dependencies = [
"tomli", # Only needed before 3.11
"pyyaml",
"tqdm",
]

[project.optional-dependencies]
Expand Down
108 changes: 84 additions & 24 deletions src/anemoi/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,41 @@
are zip archives containing the model weights.
"""

import fnmatch
import json
import logging
import os
import time
import zipfile
from tempfile import TemporaryDirectory

import tqdm

LOG = logging.getLogger(__name__)

DEFAULT_NAME = "anemoi-metadata.json"
DEFAULT_NAME = "ai-models.json"
DEFAULT_FOLDER = "anemoi-metadata"


def metadata_files(path: str):
"""List all JSON files in a zip archive
def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
"""Check if a checkpoint file has a metadata file
Parameters
----------
path : str
The path to the zip archive
The path to the checkpoint file
name : str, optional
The name of the metadata file in the zip archive
Returns
-------
List[str]
The list of JSON files in the archive
bool
True if the metadata file is found
"""
with zipfile.ZipFile(path, "r") as f:
return [os.path.basename(b) for b in f.namelist() if fnmatch.fnmatch(os.path.basename(b), "*.json")]
for b in f.namelist():
if os.path.basename(b) == name:
return True
return False


def load_metadata(path: str, name: str = DEFAULT_NAME):
Expand All @@ -63,19 +70,19 @@ def load_metadata(path: str, name: str = DEFAULT_NAME):
with zipfile.ZipFile(path, "r") as f:
metadata = None
for b in f.namelist():
if fnmatch.fnmatch(os.path.basename(b), name):
if os.path.basename(b) == name:
if metadata is not None:
LOG.warning(f"Found two '{name}' if {path}")
raise ValueError(f"Found two or more '{name}' in {path}.")
metadata = b

if metadata is not None:
with zipfile.ZipFile(path, "r") as f:
return json.load(f.open(metadata, "r"))
else:
raise ValueError(f"Could not find {name} in {path}")
raise ValueError(f"Could not find '{name}' in {path}.")


def save_metadata(path, metadata, name=DEFAULT_NAME):
def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER):
"""Save metadata to a checkpoint file
Parameters
Expand All @@ -88,32 +95,85 @@ def save_metadata(path, metadata, name=DEFAULT_NAME):
The name of the metadata file in the zip archive
"""
with zipfile.ZipFile(path, "a") as zipf:
base, _ = os.path.splitext(os.path.basename(path))

directories = set()

for b in zipf.namelist():
directory = os.path.dirname(b)
while os.path.dirname(directory) not in (".", ""):
directory = os.path.dirname(directory)
directories.add(directory)

if os.path.basename(b) == name:
raise ValueError(f"'{name}' already in {path}")

if len(directories) != 1:
# PyTorch checkpoints should have a single directory
# otherwise PyTorch will complain
raise ValueError(f"No or multiple directories in the checkpoint {path}, directories={directories}")

directory = list(directories)[0]

LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)

zipf.writestr(
f"{base}/{name}",
f"{directory}/{folder}/{name}",
json.dumps(metadata),
)


def replace_metadata(path, metadata, name, rename=None):
def _edit_metadata(path, name, callback):
new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"

found = False

with TemporaryDirectory() as temp_dir:
zipfile.ZipFile(path, "r").extractall(temp_dir)
total = 0
for root, dirs, files in os.walk(temp_dir):
for f in files:
total += 1
full = os.path.join(root, f)
if f == name:
with open(full, "w") as f:
json.dump(metadata, f)
if rename is not None:
os.rename(full, os.path.join(root, os.path.basename(rename)))
found = True
callback(full)

if not found:
raise ValueError(f"Could not find '{name}' in {path}")

with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(temp_dir):
for f in files:
full = os.path.join(root, f)
rel = os.path.relpath(full, temp_dir)
zipf.write(full, rel)
with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
for root, dirs, files in os.walk(temp_dir):
for f in files:
full = os.path.join(root, f)
rel = os.path.relpath(full, temp_dir)
zipf.write(full, rel)
pbar.update(1)

os.rename(new_path, path)
LOG.info("Updated metadata in %s", path)


def replace_metadata(path, metadata, name=DEFAULT_NAME):

if not isinstance(metadata, dict):
raise ValueError(f"metadata must be a dict, got {type(metadata)}")

if "version" not in metadata:
raise ValueError("metadata must have a 'version' key")

def callback(full):
with open(full, "w") as f:
json.dump(metadata, f)

_edit_metadata(path, name, callback)


def remove_metadata(path, name=DEFAULT_NAME):

LOG.info("Removing metadata '%s' from %s", name, path)

def callback(full):
os.remove(full)

_edit_metadata(path, name, callback)

0 comments on commit 10f911e

Please sign in to comment.