From 960d013544a5f594ff994250b2727cc9de9e5abb Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Wed, 19 Jul 2023 15:51:07 +0200 Subject: [PATCH 1/5] Fix issue #774 --- avalanche/benchmarks/classic/core50.py | 14 ++++++++++++++ tests/test_core50.py | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/avalanche/benchmarks/classic/core50.py b/avalanche/benchmarks/classic/core50.py index 70809cd9b..7a58df9dc 100644 --- a/avalanche/benchmarks/classic/core50.py +++ b/avalanche/benchmarks/classic/core50.py @@ -156,6 +156,20 @@ def CORe50( eval_transform=eval_transform, ) + if scenario == "nc": + n_classes_per_exp = [] + classes_order = [] + for exp in benchmark_obj.train_stream: + exp_dataset = exp.dataset + unique_targets = list( + sorted(set(int(x) for x in exp_dataset.targets)) # type: ignore + ) + n_classes_per_exp.append(len(unique_targets)) + classes_order.extend(unique_targets) + setattr(benchmark_obj, "n_classes_per_exp", n_classes_per_exp) + setattr(benchmark_obj, "classes_order", classes_order) + setattr(benchmark_obj, "n_classes", 50 if object_lvl else 10) + return benchmark_obj diff --git a/tests/test_core50.py b/tests/test_core50.py index fe4e4c736..13f91a2c9 100644 --- a/tests/test_core50.py +++ b/tests/test_core50.py @@ -38,6 +38,11 @@ def test_core50_nc_benchmark(self): classes_in_test = benchmark_instance.classes_in_experience["test"][0] self.assertSetEqual(set(range(50)), set(classes_in_test)) + # Regression tests for issue #774 + self.assertSequenceEqual([10] + ([5] * 8), benchmark_instance.n_classes_per_exp) + self.assertSetEqual(set(range(50)), set(benchmark_instance.classes_order)) + self.assertEqual(50, len(benchmark_instance.classes_order)) + if __name__ == "__main__": unittest.main() From 1e7e13858b37d748a8b615c5a7e9ce824cd39754 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Wed, 19 Jul 2023 15:51:44 +0200 Subject: [PATCH 2/5] Add support for weights artifact in W&B --- avalanche/evaluation/metrics/checkpoint.py | 23 ++++-- avalanche/logging/text_logging.py | 6 +- avalanche/logging/wandb_logger.py | 94 +++++++++++++++------- examples/wandb_logger.py | 23 +++++- 4 files changed, 106 insertions(+), 40 deletions(-) diff --git a/avalanche/evaluation/metrics/checkpoint.py b/avalanche/evaluation/metrics/checkpoint.py index b06926ccb..f0a0cc447 100644 --- a/avalanche/evaluation/metrics/checkpoint.py +++ b/avalanche/evaluation/metrics/checkpoint.py @@ -10,9 +10,11 @@ ################################################################################ import copy -from typing import TYPE_CHECKING +import io +from typing import TYPE_CHECKING, Optional from torch import Tensor +import torch from avalanche.evaluation import PluginMetric from avalanche.evaluation.metric_results import MetricValue, MetricResult @@ -46,9 +48,9 @@ def __init__(self): retrieved using the `result` method. """ super().__init__() - self.weights = None + self.weights: Optional[bytes] = None - def update(self, weights) -> Tensor: + def update(self, weights: bytes): """ Update the weight checkpoint at the current experience. @@ -57,7 +59,7 @@ def update(self, weights) -> Tensor: """ self.weights = weights - def result(self) -> Tensor: + def result(self) -> Optional[bytes]: """ Retrieves the weight checkpoint at the current experience. @@ -75,6 +77,9 @@ def reset(self) -> None: def _package_result(self, strategy) -> "MetricResult": weights = self.result() + if weights is None: + return None + metric_name = get_metric_name( self, strategy, add_experience=True, add_task=False ) @@ -83,9 +88,13 @@ def _package_result(self, strategy) -> "MetricResult": ] def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult": - model_params = copy.deepcopy(strategy.model.parameters()) - self.update(model_params) - return None + buff = io.BytesIO() + model_params = copy.deepcopy(strategy.model).to("cpu") + torch.save(model_params, buff) + buff.seek(0) + self.update(buff.read()) + + return self._package_result(strategy) def __str__(self): return "WeightCheckpoint" diff --git a/avalanche/logging/text_logging.py b/avalanche/logging/text_logging.py index 670c20f0d..7222410bc 100644 --- a/avalanche/logging/text_logging.py +++ b/avalanche/logging/text_logging.py @@ -9,7 +9,6 @@ # Website: avalanche.continualai.org # ################################################################################ import datetime -import os.path import sys import warnings from typing import List, TYPE_CHECKING, Tuple, Type, Optional, TextIO @@ -24,7 +23,10 @@ if TYPE_CHECKING: from avalanche.training.templates import SupervisedTemplate -UNSUPPORTED_TYPES: Tuple[Type] = (TensorImage,) +UNSUPPORTED_TYPES: Tuple[Type, ...] = ( + TensorImage, + bytes, +) class TextLogger(BaseLogger, SupervisedPlugin): diff --git a/avalanche/logging/wandb_logger.py b/avalanche/logging/wandb_logger.py index 00120dbdc..1bcbd1b04 100644 --- a/avalanche/logging/wandb_logger.py +++ b/avalanche/logging/wandb_logger.py @@ -4,21 +4,21 @@ # See the accompanying LICENSE file for terms. # # # # Date: 25-11-2020 # -# Author(s): Diganta Misra, Andrea Cossu # +# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini # # E-mail: contact@continualai.org # # Website: www.continualai.org # ################################################################################ """ This module handles all the functionalities related to the logging of Avalanche experiments using Weights & Biases. """ -from typing import Union, List, TYPE_CHECKING +import re +from typing import Optional, Union, List, TYPE_CHECKING from pathlib import Path import os -import errno +import warnings import numpy as np from numpy import array -import torch from torch import Tensor from PIL.Image import Image @@ -37,6 +37,12 @@ from avalanche.training.templates import SupervisedTemplate +CHECKPOINT_METRIC_NAME = re.compile( + r"^WeightCheckpoint\/(?P\S+)_phase\/(?P\S+)_" + r"stream(\/Task(?P\d+))?\/Exp(?P\d+)$" +) + + class WandBLogger(BaseLogger, SupervisedPlugin): """Weights and Biases logger. @@ -60,18 +66,21 @@ def __init__( run_name: str = "Test", log_artifacts: bool = False, path: Union[str, Path] = "Checkpoints", - uri: str = None, + uri: Optional[str] = None, sync_tfboard: bool = False, save_code: bool = True, - config: object = None, - dir: Union[str, Path] = None, - params: dict = None, + config: Optional[object] = None, + dir: Optional[Union[str, Path]] = None, + params: Optional[dict] = None, ): """Creates an instance of the `WandBLogger`. :param project_name: Name of the W&B project. :param run_name: Name of the W&B run. :param log_artifacts: Option to log model weights as W&B Artifacts. + Note that, in order for model weights to be logged, the + :class:`WeightCheckpoint` metric must be added to the + evaluation plugin. :param path: Path to locally save the model checkpoints. :param uri: URI identifier for external storage buckets (GCS, S3). :param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI. @@ -102,6 +111,8 @@ def __init__( def import_wandb(self): try: import wandb + + assert hasattr(wandb, "__version__") except ImportError: raise ImportError('Please run "pip install wandb" to install wandb') self.wandb = wandb @@ -140,7 +151,7 @@ def after_training_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], - **kwargs + **kwargs, ): for val in metric_values: self.log_metrics([val]) @@ -151,6 +162,11 @@ def after_training_exp( def log_single_metric(self, name, value, x_plot): self.step = x_plot + if name.startswith("WeightCheckpoint"): + if self.log_artifacts: + self._log_checkpoint(name, value, x_plot) + return + if isinstance(value, AlternativeValues): value = value.best_supported_value( Image, @@ -192,26 +208,46 @@ def log_single_metric(self, name, value, x_plot): elif isinstance(value, TensorImage): self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step) - elif name.startswith("WeightCheckpoint"): - if self.log_artifacts: - cwd = os.getcwd() - ckpt = os.path.join(cwd, self.path) - try: - os.makedirs(ckpt) - except OSError as e: - if e.errno != errno.EEXIST: - raise - suffix = ".pth" - dir_name = os.path.join(ckpt, name + suffix) - artifact_name = os.path.join("Models", name + suffix) - if isinstance(value, Tensor): - torch.save(value, dir_name) - name = os.path.splittext(self.checkpoint) - artifact = self.wandb.Artifact(name, type="model") - artifact.add_file(dir_name, name=artifact_name) - self.wandb.run.log_artifact(artifact) - if self.uri is not None: - artifact.add_reference(self.uri, name=artifact_name) + def _log_checkpoint(self, name, value, x_plot): + assert self.wandb is not None + + # Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000' + name_match = CHECKPOINT_METRIC_NAME.match(name) + if name_match is None: + warnings.warn(f"Checkpoint metric has unsupported name {name}.") + return + # phase_name: str = name_match['phase_name'] + # stream_name: str = name_match['stream_name'] + task_id: Optional[int] = ( + int(name_match["task_id"]) if name_match["task_id"] is not None else None + ) + experience_id: int = int(name_match["experience_id"]) + assert experience_id >= 0 + + cwd = Path.cwd() + checkpoint_directory = cwd / self.path + checkpoint_directory.mkdir(parents=True, exist_ok=True) + + checkpoint_name = "Model_{}".format(experience_id) + checkpoint_file_name = checkpoint_name + ".pth" + checkpoint_path = checkpoint_directory / checkpoint_file_name + artifact_name = "Models/" + checkpoint_file_name + + # Write the checkpoint blob + with open(checkpoint_path, "wb") as f: + f.write(value) + + metadata = { + "experience": experience_id, + "x_step": x_plot, + **({"task_id": task_id} if task_id is not None else {}), + } + + artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata) + artifact.add_file(str(checkpoint_path), name=artifact_name) + self.wandb.run.log_artifact(artifact) + if self.uri is not None: + artifact.add_reference(self.uri, name=artifact_name) def __getstate__(self): state = self.__dict__.copy() diff --git a/examples/wandb_logger.py b/examples/wandb_logger.py index 5ba9b94df..5404ea6bf 100644 --- a/examples/wandb_logger.py +++ b/examples/wandb_logger.py @@ -24,6 +24,7 @@ from avalanche.benchmarks import nc_benchmark from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location +from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint from avalanche.logging import InteractiveLogger, WandBLogger from avalanche.training.plugins import EvaluationPlugin from avalanche.evaluation.metrics import ( @@ -83,7 +84,11 @@ def main(args): interactive_logger = InteractiveLogger() wandb_logger = WandBLogger( - project_name=args.project, run_name=args.run, config=vars(args) + project_name=args.project, + run_name=args.run, + log_artifacts=args.artifacts, + path=args.path if args.path else None, + config=vars(args), ) eval_plugin = EvaluationPlugin( @@ -120,6 +125,7 @@ def main(args): ), disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True), MAC_metrics(minibatch=True, epoch=True, experience=True), + WeightCheckpoint(), loggers=[interactive_logger, wandb_logger], ) @@ -157,9 +163,22 @@ def main(args): default=0, help="Select zero-indexed cuda device. -1 to use CPU.", ) - parser.add_argument("--run", type=str, help="Provide a run name for WandB") parser.add_argument( "--project", type=str, help="Define the name of the WandB project" ) + parser.add_argument("--run", type=str, help="Provide a run name for WandB") + parser.add_argument( + "--artifacts", + default=False, + action="store_true", + help="Log Model Checkpoints as W&B Artifacts", + ) + parser.add_argument( + "--path", + type=str, + default="Checkpoint", + help="Local path to save the model checkpoints", + ) + args = parser.parse_args() main(args) From 160cf3af2c1b11715d59e512d4227e51f9008f15 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Wed, 19 Jul 2023 16:31:34 +0200 Subject: [PATCH 3/5] Prevent test file creation in main folder --- tests/test_models.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5103d39bc..1d7f8af57 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ import sys import os import copy +import tempfile import unittest @@ -650,10 +651,13 @@ def test_ncm_save_load(self): ), } ) - torch.save(classifier.state_dict(), "ncm.pt") - del classifier - classifier = NCMClassifier() - check = torch.load("ncm.pt") + + with tempfile.TemporaryFile() as tmpfile: + torch.save(classifier.state_dict(), tmpfile) + del classifier + classifier = NCMClassifier() + tmpfile.seek(0) + check = torch.load(tmpfile) classifier.load_state_dict(check) assert classifier.class_means.shape == (3, 5) assert (classifier.class_means[0] == 0).all() From 6c2ec8157c4084d585fbfcf9961c4c7a3275719d Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Wed, 19 Jul 2023 16:39:05 +0200 Subject: [PATCH 4/5] Enable num workers in EWC --- avalanche/training/plugins/ewc.py | 5 +++-- examples/multihead.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index f3c2f05a5..02bea2391 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -121,6 +121,7 @@ def after_training_exp(self, strategy, **kwargs): strategy.experience.dataset, strategy.device, strategy.train_mb_size, + num_workers=kwargs.get('num_workers', 0) ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) @@ -129,7 +130,7 @@ def after_training_exp(self, strategy, **kwargs): del self.saved_params[exp_counter - 1] def compute_importances( - self, model, criterion, optimizer, dataset, device, batch_size + self, model, criterion, optimizer, dataset, device, batch_size, num_workers=0 ) -> Dict[str, ParamData]: """ Compute EWC importance matrix for each parameter @@ -153,7 +154,7 @@ def compute_importances( # list of list importances = zerolike_params_dict(model) collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers) for i, batch in enumerate(dataloader): # get only input, target and task_id from the batch x, y, task_labels = batch[0], batch[1], batch[-1] diff --git a/examples/multihead.py b/examples/multihead.py index c4facb554..c0177418e 100644 --- a/examples/multihead.py +++ b/examples/multihead.py @@ -69,8 +69,8 @@ def main(args): # train and test loop for train_task in train_stream: - strategy.train(train_task) - strategy.eval(test_stream) + strategy.train(train_task, num_workers=4) + strategy.eval(test_stream, num_workers=4) if __name__ == "__main__": From abde4c21e1369eea0f0ea177a0e28e65eff15262 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Wed, 19 Jul 2023 16:46:26 +0200 Subject: [PATCH 5/5] Solve linter issue --- avalanche/training/plugins/ewc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index 02bea2391..6c829cf1d 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -121,7 +121,7 @@ def after_training_exp(self, strategy, **kwargs): strategy.experience.dataset, strategy.device, strategy.train_mb_size, - num_workers=kwargs.get('num_workers', 0) + num_workers=kwargs.get("num_workers", 0), ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) @@ -154,7 +154,12 @@ def compute_importances( # list of list importances = zerolike_params_dict(model) collate_fn = dataset.collate_fn if hasattr(dataset, "collate_fn") else None - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_workers, + ) for i, batch in enumerate(dataloader): # get only input, target and task_id from the batch x, y, task_labels = batch[0], batch[1], batch[-1]