Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes and improvements #1463

Merged
merged 5 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions avalanche/benchmarks/classic/core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,20 @@ def CORe50(
eval_transform=eval_transform,
)

if scenario == "nc":
n_classes_per_exp = []
classes_order = []
for exp in benchmark_obj.train_stream:
exp_dataset = exp.dataset
unique_targets = list(
sorted(set(int(x) for x in exp_dataset.targets)) # type: ignore
)
n_classes_per_exp.append(len(unique_targets))
classes_order.extend(unique_targets)
setattr(benchmark_obj, "n_classes_per_exp", n_classes_per_exp)
setattr(benchmark_obj, "classes_order", classes_order)
setattr(benchmark_obj, "n_classes", 50 if object_lvl else 10)

return benchmark_obj


Expand Down
23 changes: 16 additions & 7 deletions avalanche/evaluation/metrics/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
################################################################################

import copy
from typing import TYPE_CHECKING
import io
from typing import TYPE_CHECKING, Optional

from torch import Tensor
import torch

from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue, MetricResult
Expand Down Expand Up @@ -46,9 +48,9 @@ def __init__(self):
retrieved using the `result` method.
"""
super().__init__()
self.weights = None
self.weights: Optional[bytes] = None

def update(self, weights) -> Tensor:
def update(self, weights: bytes):
"""
Update the weight checkpoint at the current experience.

Expand All @@ -57,7 +59,7 @@ def update(self, weights) -> Tensor:
"""
self.weights = weights

def result(self) -> Tensor:
def result(self) -> Optional[bytes]:
"""
Retrieves the weight checkpoint at the current experience.

Expand All @@ -75,6 +77,9 @@ def reset(self) -> None:

def _package_result(self, strategy) -> "MetricResult":
weights = self.result()
if weights is None:
return None

metric_name = get_metric_name(
self, strategy, add_experience=True, add_task=False
)
Expand All @@ -83,9 +88,13 @@ def _package_result(self, strategy) -> "MetricResult":
]

def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult":
model_params = copy.deepcopy(strategy.model.parameters())
self.update(model_params)
return None
buff = io.BytesIO()
model_params = copy.deepcopy(strategy.model).to("cpu")
torch.save(model_params, buff)
buff.seek(0)
self.update(buff.read())

return self._package_result(strategy)

def __str__(self):
return "WeightCheckpoint"
Expand Down
6 changes: 4 additions & 2 deletions avalanche/logging/text_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# Website: avalanche.continualai.org #
################################################################################
import datetime
import os.path
import sys
import warnings
from typing import List, TYPE_CHECKING, Tuple, Type, Optional, TextIO
Expand All @@ -24,7 +23,10 @@
if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate

UNSUPPORTED_TYPES: Tuple[Type] = (TensorImage,)
UNSUPPORTED_TYPES: Tuple[Type, ...] = (
TensorImage,
bytes,
)


class TextLogger(BaseLogger, SupervisedPlugin):
Expand Down
94 changes: 65 additions & 29 deletions avalanche/logging/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
# See the accompanying LICENSE file for terms. #
# #
# Date: 25-11-2020 #
# Author(s): Diganta Misra, Andrea Cossu #
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini #
# E-mail: [email protected] #
# Website: www.continualai.org #
################################################################################
""" This module handles all the functionalities related to the logging of
Avalanche experiments using Weights & Biases. """

from typing import Union, List, TYPE_CHECKING
import re
from typing import Optional, Union, List, TYPE_CHECKING
from pathlib import Path
import os
import errno
import warnings

import numpy as np
from numpy import array
import torch
from torch import Tensor

from PIL.Image import Image
Expand All @@ -37,6 +37,12 @@
from avalanche.training.templates import SupervisedTemplate


CHECKPOINT_METRIC_NAME = re.compile(
r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
)


class WandBLogger(BaseLogger, SupervisedPlugin):
"""Weights and Biases logger.

Expand All @@ -60,18 +66,21 @@ def __init__(
run_name: str = "Test",
log_artifacts: bool = False,
path: Union[str, Path] = "Checkpoints",
uri: str = None,
uri: Optional[str] = None,
sync_tfboard: bool = False,
save_code: bool = True,
config: object = None,
dir: Union[str, Path] = None,
params: dict = None,
config: Optional[object] = None,
dir: Optional[Union[str, Path]] = None,
params: Optional[dict] = None,
):
"""Creates an instance of the `WandBLogger`.

:param project_name: Name of the W&B project.
:param run_name: Name of the W&B run.
:param log_artifacts: Option to log model weights as W&B Artifacts.
Note that, in order for model weights to be logged, the
:class:`WeightCheckpoint` metric must be added to the
evaluation plugin.
:param path: Path to locally save the model checkpoints.
:param uri: URI identifier for external storage buckets (GCS, S3).
:param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI.
Expand Down Expand Up @@ -102,6 +111,8 @@ def __init__(
def import_wandb(self):
try:
import wandb

assert hasattr(wandb, "__version__")
except ImportError:
raise ImportError('Please run "pip install wandb" to install wandb')
self.wandb = wandb
Expand Down Expand Up @@ -140,7 +151,7 @@ def after_training_exp(
self,
strategy: "SupervisedTemplate",
metric_values: List["MetricValue"],
**kwargs
**kwargs,
):
for val in metric_values:
self.log_metrics([val])
Expand All @@ -151,6 +162,11 @@ def after_training_exp(
def log_single_metric(self, name, value, x_plot):
self.step = x_plot

if name.startswith("WeightCheckpoint"):
if self.log_artifacts:
self._log_checkpoint(name, value, x_plot)
return

if isinstance(value, AlternativeValues):
value = value.best_supported_value(
Image,
Expand Down Expand Up @@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot):
elif isinstance(value, TensorImage):
self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step)

elif name.startswith("WeightCheckpoint"):
if self.log_artifacts:
cwd = os.getcwd()
ckpt = os.path.join(cwd, self.path)
try:
os.makedirs(ckpt)
except OSError as e:
if e.errno != errno.EEXIST:
raise
suffix = ".pth"
dir_name = os.path.join(ckpt, name + suffix)
artifact_name = os.path.join("Models", name + suffix)
if isinstance(value, Tensor):
torch.save(value, dir_name)
name = os.path.splittext(self.checkpoint)
artifact = self.wandb.Artifact(name, type="model")
artifact.add_file(dir_name, name=artifact_name)
self.wandb.run.log_artifact(artifact)
if self.uri is not None:
artifact.add_reference(self.uri, name=artifact_name)
def _log_checkpoint(self, name, value, x_plot):
assert self.wandb is not None

# Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000'
name_match = CHECKPOINT_METRIC_NAME.match(name)
if name_match is None:
warnings.warn(f"Checkpoint metric has unsupported name {name}.")
return
# phase_name: str = name_match['phase_name']
# stream_name: str = name_match['stream_name']
task_id: Optional[int] = (
int(name_match["task_id"]) if name_match["task_id"] is not None else None
)
experience_id: int = int(name_match["experience_id"])
assert experience_id >= 0

cwd = Path.cwd()
checkpoint_directory = cwd / self.path
checkpoint_directory.mkdir(parents=True, exist_ok=True)

checkpoint_name = "Model_{}".format(experience_id)
checkpoint_file_name = checkpoint_name + ".pth"
checkpoint_path = checkpoint_directory / checkpoint_file_name
artifact_name = "Models/" + checkpoint_file_name

# Write the checkpoint blob
with open(checkpoint_path, "wb") as f:
f.write(value)

metadata = {
"experience": experience_id,
"x_step": x_plot,
**({"task_id": task_id} if task_id is not None else {}),
}

artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata)
artifact.add_file(str(checkpoint_path), name=artifact_name)
self.wandb.run.log_artifact(artifact)
if self.uri is not None:
artifact.add_reference(self.uri, name=artifact_name)

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
10 changes: 8 additions & 2 deletions avalanche/training/plugins/ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def after_training_exp(self, strategy, **kwargs):
strategy.experience.dataset,
strategy.device,
strategy.train_mb_size,
num_workers=kwargs.get("num_workers", 0),
)
self.update_importances(importances, exp_counter)
self.saved_params[exp_counter] = copy_params_dict(strategy.model)
Expand All @@ -129,7 +130,7 @@ def after_training_exp(self, strategy, **kwargs):
del self.saved_params[exp_counter - 1]

def compute_importances(
self, model, criterion, optimizer, dataset, device, batch_size
self, model, criterion, optimizer, dataset, device, batch_size, num_workers=0
) -> Dict[str, ParamData]:
"""
Compute EWC importance matrix for each parameter
Expand All @@ -153,7 +154,12 @@ def compute_importances(
# list of list
importances = zerolike_params_dict(model)
collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
)
for i, batch in enumerate(dataloader):
# get only input, target and task_id from the batch
x, y, task_labels = batch[0], batch[1], batch[-1]
Expand Down
4 changes: 2 additions & 2 deletions examples/multihead.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def main(args):

# train and test loop
for train_task in train_stream:
strategy.train(train_task)
strategy.eval(test_stream)
strategy.train(train_task, num_workers=4)
strategy.eval(test_stream, num_workers=4)


if __name__ == "__main__":
Expand Down
23 changes: 21 additions & 2 deletions examples/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from avalanche.benchmarks import nc_benchmark
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint
from avalanche.logging import InteractiveLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
Expand Down Expand Up @@ -83,7 +84,11 @@ def main(args):

interactive_logger = InteractiveLogger()
wandb_logger = WandBLogger(
project_name=args.project, run_name=args.run, config=vars(args)
project_name=args.project,
run_name=args.run,
log_artifacts=args.artifacts,
path=args.path if args.path else None,
config=vars(args),
)

eval_plugin = EvaluationPlugin(
Expand Down Expand Up @@ -120,6 +125,7 @@ def main(args):
),
disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
MAC_metrics(minibatch=True, epoch=True, experience=True),
WeightCheckpoint(),
loggers=[interactive_logger, wandb_logger],
)

Expand Down Expand Up @@ -157,9 +163,22 @@ def main(args):
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
parser.add_argument(
"--project", type=str, help="Define the name of the WandB project"
)
parser.add_argument("--run", type=str, help="Provide a run name for WandB")
parser.add_argument(
"--artifacts",
default=False,
action="store_true",
help="Log Model Checkpoints as W&B Artifacts",
)
parser.add_argument(
"--path",
type=str,
default="Checkpoint",
help="Local path to save the model checkpoints",
)

args = parser.parse_args()
main(args)
5 changes: 5 additions & 0 deletions tests/test_core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_core50_nc_benchmark(self):
classes_in_test = benchmark_instance.classes_in_experience["test"][0]
self.assertSetEqual(set(range(50)), set(classes_in_test))

# Regression tests for issue #774
self.assertSequenceEqual([10] + ([5] * 8), benchmark_instance.n_classes_per_exp)
self.assertSetEqual(set(range(50)), set(benchmark_instance.classes_order))
self.assertEqual(50, len(benchmark_instance.classes_order))


if __name__ == "__main__":
unittest.main()
12 changes: 8 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import os
import copy
import tempfile

import unittest

Expand Down Expand Up @@ -650,10 +651,13 @@ def test_ncm_save_load(self):
),
}
)
torch.save(classifier.state_dict(), "ncm.pt")
del classifier
classifier = NCMClassifier()
check = torch.load("ncm.pt")

with tempfile.TemporaryFile() as tmpfile:
torch.save(classifier.state_dict(), tmpfile)
del classifier
classifier = NCMClassifier()
tmpfile.seek(0)
check = torch.load(tmpfile)
classifier.load_state_dict(check)
assert classifier.class_means.shape == (3, 5)
assert (classifier.class_means[0] == 0).all()
Expand Down