diff --git a/.github/workflows/environment-update.yml b/.github/workflows/environment-update.yml index 1b926ee23..247533337 100644 --- a/.github/workflows/environment-update.yml +++ b/.github/workflows/environment-update.yml @@ -56,7 +56,11 @@ jobs: id: unittest shell: bash -l -c "conda run -n avalanche-env --no-capture-output bash {0}" run: | - python -m unittest discover tests + python -m unittest discover tests && + echo "Running checkpointing tests..." && + bash ./tests/checkpointing/test_checkpointing.sh && + echo "Running distributed training tests..." && + python ./tests/run_dist_tests.py && - name: checkout avalanche-docker repo if: always() uses: actions/checkout@v3 diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 7eee2e3ff..a2baa3717 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -55,6 +55,8 @@ jobs: python -m unittest discover tests && echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && + echo "Running distributed training tests..." && + python ./tests/run_dist_tests.py && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data diff --git a/avalanche/benchmarks/scenarios/classification_scenario.py b/avalanche/benchmarks/scenarios/classification_scenario.py index c99aaca44..3d875b64a 100644 --- a/avalanche/benchmarks/scenarios/classification_scenario.py +++ b/avalanche/benchmarks/scenarios/classification_scenario.py @@ -1,5 +1,6 @@ import copy import re +import warnings from abc import ABC from typing import ( Generic, @@ -18,10 +19,8 @@ Mapping, ) -from typing_extensions import Protocol - -import warnings from torch.utils.data.dataset import Dataset +from typing_extensions import Protocol from avalanche.benchmarks.scenarios.generic_definitions import ( TCLExperience, @@ -32,10 +31,8 @@ from avalanche.benchmarks.scenarios.lazy_dataset_sequence import ( LazyDatasetSequence, ) -from avalanche.benchmarks.utils import make_classification_dataset -from avalanche.benchmarks.utils.classification_dataset import ( - ClassificationDataset, -) +from avalanche.benchmarks.utils import \ + make_classification_dataset, AvalancheDataset from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing TGenericCLClassificationScenario = TypeVar( @@ -494,7 +491,7 @@ def _check_and_adapt_user_stream_def( # exp_data[0] must contain the generator stream_length = exp_data[1] is_lazy = True - elif isinstance(exp_data, ClassificationDataset): + elif isinstance(exp_data, AvalancheDataset): # Single element exp_data = [exp_data] is_lazy = False @@ -506,7 +503,7 @@ def _check_and_adapt_user_stream_def( if not is_lazy: for i, dataset in enumerate(exp_data): - if not isinstance(dataset, ClassificationDataset): + if not isinstance(dataset, AvalancheDataset): raise ValueError( "All experience datasets must be subclasses of" " AvalancheDataset" diff --git a/avalanche/benchmarks/scenarios/detection_scenario.py b/avalanche/benchmarks/scenarios/detection_scenario.py index db78d9c76..18cef4d08 100644 --- a/avalanche/benchmarks/scenarios/detection_scenario.py +++ b/avalanche/benchmarks/scenarios/detection_scenario.py @@ -8,26 +8,34 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ - -from typing import TypeVar, List, Callable +import copy +import warnings +from abc import abstractmethod, ABC +from typing import TypeVar, List, Callable, Protocol, runtime_checkable, \ + Union, Iterable, Generic, Sequence, Optional, Mapping, Set from avalanche.benchmarks import ( - GenericClassificationExperience, - ClassificationExperience, TCLScenario, TCLStream, GenericCLScenario, - TStreamsUserDict, - ClassificationStream, -) -from avalanche.benchmarks.utils import make_classification_dataset + TStreamsUserDict, TCLExperience, ) +from avalanche.benchmarks.scenarios.classification_scenario import \ + _get_slice_ids +from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing +from avalanche.benchmarks.utils.detection_dataset import DetectionDataset -TDetectionExperience = TypeVar( - "TDetectionExperience", bound=GenericClassificationExperience +TGenericCLDetectionScenario = TypeVar( + "TGenericCLDetectionScenario", bound="DetectionCLScenario" +) +TGenericDetectionExperience = TypeVar( + "TGenericDetectionExperience", bound="GenericDetectionExperience" +) +TGenericScenarioStream = TypeVar( + "TGenericScenarioStream", bound="DetectionStream" ) -class DetectionCLScenario(GenericCLScenario[TDetectionExperience]): +class DetectionCLScenario(GenericCLScenario[TCLExperience]): """ Base implementation of a Continual Learning object detection benchmark. @@ -43,7 +51,7 @@ def __init__( n_classes: int = None, complete_test_set_only: bool = False, experience_factory: Callable[ - ["ClassificationStream", int], TDetectionExperience + ["DetectionStream", int], TCLExperience ] = None, ): """ @@ -66,7 +74,7 @@ def __init__( """ if experience_factory is None: - experience_factory = DetectionExperience + experience_factory = GenericDetectionExperience super(DetectionCLScenario, self).__init__( stream_definitions=stream_definitions, @@ -79,50 +87,419 @@ def __init__( The number of classes in the scenario. """ + @GenericCLScenario.classes_in_experience.getter + def classes_in_experience( + self, + ) -> Mapping[str, Sequence[Optional[Set[int]]]]: + """ + A dictionary mapping each stream (by name) to a list. + + Each element of the list is a set describing the classes included in + that experience (identified by its index). + + In previous releases this field contained the list of sets for the + training stream (that is, there was no way to obtain the list for other + streams). That behavior is deprecated and support for that usage way + will be removed in the future. + """ + + return _LazyStreamClassesInDetectionExps(self) + + +class _LazyStreamClassesInDetectionExps(Mapping[str, + Sequence[Optional[Set[int]]]]): + def __init__(self, benchmark: GenericCLScenario): + self._benchmark = benchmark + self._default_lcie = _LazyClassesInDetectionExps( + benchmark, stream="train") + + def __len__(self): + return len(self._benchmark.stream_definitions) + + def __getitem__(self, stream_name_or_exp_id): + if isinstance(stream_name_or_exp_id, str): + return _LazyClassesInDetectionExps( + self._benchmark, stream=stream_name_or_exp_id + ) + + warnings.warn( + "Using classes_in_experience[exp_id] is deprecated. " + "Consider using classes_in_experience[stream_name][exp_id]" + "instead.", + stacklevel=2, + ) + return self._default_lcie[stream_name_or_exp_id] + + def __iter__(self): + yield from self._benchmark.stream_definitions.keys() + + +class _LazyClassesInDetectionExps(Sequence[Optional[Set[int]]]): + def __init__(self, benchmark: GenericCLScenario, stream: str = "train"): + self._benchmark = benchmark + self._stream = stream + + def __len__(self): + return len(self._benchmark.streams[self._stream]) + + def __getitem__(self, exp_id) -> Set[int]: + return manage_advanced_indexing( + exp_id, + self._get_single_exp_classes, + len(self), + _LazyClassesInDetectionExps._slice_collate, + ) -class DetectionExperience(ClassificationExperience[TCLScenario, TCLStream]): + def __str__(self): + return ( + "[" + ", ".join([str(self[idx]) for idx in range(len(self))]) + "]" + ) + + def _get_single_exp_classes(self, exp_id): + b = self._benchmark.stream_definitions[self._stream] + if not b.is_lazy and exp_id not in b.exps_data.targets_field_sequence: + raise IndexError + targets = b.exps_data.targets_field_sequence[exp_id] + if targets is None: + return None + + classes_in_exp = set() + for target in targets: + for label in target['labels']: + classes_in_exp.add(int(label)) + return classes_in_exp + + @staticmethod + def _slice_collate(*classes_in_exps: Optional[Set[int]]): + if any(x is None for x in classes_in_exps): + return None + + return [list(x) for x in classes_in_exps] + + +class DetectionScenarioStream(Protocol[TCLScenario, TCLExperience]): """ - Definition of a learning experience based on a :class:`DetectionScenario` - instance. + A scenario stream describes a sequence of incremental experiences. + Experiences are described as :class:`IExperience` instances. They contain a + set of patterns which has become available at a particular time instant + along with any optional, scenario specific, metadata. - This experience implementation uses the generic experience-patterns - assignment defined in the :class:`DetectionScenario` instance. Instances of - this class are usually obtained from an object detection benchmark stream. + Most scenario expose two different streams: the training stream and the test + stream. + """ + + name: str + """ + The name of the stream. + """ + + benchmark: TCLScenario + """ + A reference to the scenario this stream belongs to. + """ + + @property + def scenario(self) -> TCLScenario: + """This property is DEPRECATED, use self.benchmark instead.""" + warnings.warn( + "Using self.scenario is deprecated ScenarioStream. " + "Consider using self.benchmark instead.", + stacklevel=2, + ) + return self.benchmark + + def __getitem__( + self: TCLStream, experience_idx: Union[int, slice, Iterable[int]] + ) -> Union[TCLExperience, TCLStream]: + """ + Gets an experience given its experience index (or a stream slice given + the experience order). + + :param experience_idx: An int describing the experience index or an + iterable/slice object describing a slice of this stream. + :return: The Experience instance associated to the given experience + index or a sliced stream instance. + """ + ... + + def __len__(self) -> int: + """ + Used to get the length of this stream (the amount of experiences). + + :return: The amount of experiences in this stream. + """ + ... + + +class DetectionStream( + Generic[TCLExperience, TGenericCLDetectionScenario], + DetectionScenarioStream[ + TGenericCLDetectionScenario, TCLExperience + ], + Sequence[TCLExperience], +): + def __init__( + self: TGenericScenarioStream, + name: str, + benchmark: TGenericCLDetectionScenario, + *, + slice_ids: List[int] = None, + ): + super(DetectionStream, self).__init__() + self.slice_ids: Optional[List[int]] = slice_ids + """ + Describes which experiences are contained in the current stream slice. + Can be None, which means that this object is the original stream. """ + + self.name: str = name + """ + The name of the stream (for instance: "train", "test", "valid", ...). + """ + + self.benchmark = benchmark + """ + A reference to the benchmark. + """ + + def __len__(self) -> int: + """ + Gets the number of experiences this stream it's made of. + + :return: The number of experiences in this stream. + """ + if self.slice_ids is None: + return len(self.benchmark.stream_definitions[self.name].exps_data) + else: + return len(self.slice_ids) + + def __getitem__( + self, exp_idx: Union[int, slice, Iterable[int]] + ) -> Union[TCLExperience, TCLStream]: + """ + Gets an experience given its experience index (or a stream slice given + the experience order). + + :param exp_idx: An int describing the experience index or an + iterable/slice object describing a slice of this stream. + + :return: The experience instance associated to the given experience + index or a sliced stream instance. + """ + if isinstance(exp_idx, int): + if exp_idx < len(self): + if self.slice_ids is None: + return self.benchmark.experience_factory(self, exp_idx) + else: + return self.benchmark.experience_factory( + self, self.slice_ids[exp_idx] + ) + raise IndexError( + "Experience index out of bounds" + str(int(exp_idx)) + ) + else: + return self._create_slice(exp_idx) + + def _create_slice( + self: TGenericScenarioStream, + exps_slice: Union[int, slice, Iterable[int]], + ) -> TCLStream: + """ + Creates a sliced version of this stream. + + In its base version, a shallow copy of this stream is created and + then its ``slice_ids`` field is adapted. + + :param exps_slice: The slice to use. + :return: A sliced version of this stream. + """ + stream_copy = copy.copy(self) + slice_exps = _get_slice_ids(exps_slice, len(self)) + + if self.slice_ids is None: + stream_copy.slice_ids = slice_exps + else: + stream_copy.slice_ids = [self.slice_ids[x] for x in slice_exps] + return stream_copy + + def drop_previous_experiences(self, to_exp: int) -> None: + """ + Drop the reference to experiences up to a certain experience ID + (inclusive). + + This means that any reference to experiences with ID [0, from_exp] will + be released. By dropping the reference to previous experiences, the + memory associated with them can be freed, especially the one occupied by + the dataset. However, if external references to the experience or the + dataset still exist, dropping previous experiences at the stream level + will have little to no impact on the memory usage. + + To make sure that the underlying dataset can be freed, make sure that: + - No reference to previous datasets or experiences are kept in you code; + - The replay implementation doesn't keep a reference to previous + datasets (in which case, is better to store a copy of the raw + tensors instead); + - The benchmark is being generated using a lazy initializer. + + By dropping previous experiences, those experiences will no longer be + available in the stream. Trying to access them will result in an + exception. + + :param to_exp: The ID of the last exp to drop (inclusive). Can be a + negative number, in which case this method doesn't have any effect. + Can be greater or equal to the stream length, in which case all + currently loaded experiences will be dropped. + :return: None + """ + self.benchmark.stream_definitions[ + self.name + ].exps_data.drop_previous_experiences(to_exp) + + +@runtime_checkable +class DetectionExperience(Protocol[TCLScenario, TCLStream]): + """Definition of a detection experience. + + A classification detection contains a set of patterns + which has become available at a particular time instant. The content and + size of an Experience is defined by the specific benchmark that creates the + IExperience instance. + + Experiences of Single Incremental Task (a.k.a. task-free) scenarios are + usually called "batches" while in Multi Task scenarios an Experience is + usually associated to a "task". Finally, in a Multi Incremental Task + scenario the Experience may be composed by patterns from different tasks. + """ + + origin_stream: TCLStream + """ + A reference to the original stream from which this experience was obtained. + """ + + benchmark: TCLScenario + """ + A reference to the benchmark. + """ + + current_experience: int + """ + This is an incremental, 0-indexed, value used to keep track of the position + of current experience in the original stream. + + Beware that this value only describes the experience position in the + original stream and may be unrelated to the order in which the strategy will + encounter experiences. + """ + + dataset: DetectionDataset + """ + The dataset containing the patterns available in this experience. + """ + + @property + @abstractmethod + def task_labels(self) -> List[int]: + """ + This list will contain the unique task labels of the patterns contained + in this experience. In the most common scenarios this will be a list + with a single value. Note: for scenarios that don't produce task labels, + a placeholder task label value like 0 is usually set to each pattern + (see the description of the originating scenario for details). + """ + ... + + @property + @abstractmethod + def task_label(self) -> int: + """ + The task label. This value will never have value "None". However, + for scenarios that don't produce task labels a placeholder value like 0 + is usually set. Beware that this field is meant as a shortcut to obtain + a unique task label: it assumes that only patterns labeled with a + single task label are present. If this experience contains patterns from + multiple tasks, accessing this property will result in an exception. + """ + ... + + @property + def scenario(self) -> TCLScenario: + """This property is DEPRECATED, use self.benchmark instead.""" + warnings.warn( + "Using self.scenario is deprecated in Experience. " + "Consider using self.benchmark instead.", + stacklevel=2, + ) + return self.benchmark + + +class AbstractDetectionExperience( + DetectionExperience[TGenericCLDetectionScenario, TCLStream], ABC +): + """ + Definition of a learning experience. A learning experience contains a set of + patterns which has become available at a particular time instant. The + content and size of an Experience is defined by the specific benchmark that + creates the experience. + + For instance, an experience of a New Classes scenario will contain all + patterns belonging to a subset of classes of the original training set. An + experience of a New Instance scenario will contain patterns from previously + seen classes. """ def __init__( - self: TDetectionExperience, + self, origin_stream: TCLStream, current_experience: int, + classes_in_this_exp: Sequence[int], + previous_classes: Sequence[int], + classes_seen_so_far: Sequence[int], + future_classes: Optional[Sequence[int]], ): """ - Creates an instance of an experience given the stream from this - experience was taken and the current experience ID. + Creates an instance of the abstract experience given the benchmark + stream, the current experience ID and data about the classes timeline. :param origin_stream: The stream from which this experience was obtained. :param current_experience: The current experience ID, as an integer. + :param classes_in_this_exp: The list of classes in this experience. + :param previous_classes: The list of classes in previous experiences. + :param classes_seen_so_far: List of classes of current and previous + experiences. + :param future_classes: The list of classes of next experiences. """ + self.origin_stream: TCLStream = origin_stream + + # benchmark keeps a reference to the base benchmark self.benchmark: TCLScenario = origin_stream.benchmark + + # current_experience is usually an incremental, 0-indexed, value used to + # keep track of the current batch/task. self.current_experience: int = current_experience - self.dataset: make_classification_dataset = ( - origin_stream.benchmark.stream_definitions[ - origin_stream.name - ].exps_data[current_experience] - ) + self.classes_in_this_experience: Sequence[int] = classes_in_this_exp + """ The list of classes in this experience """ - def _get_stream_def(self): - return self.benchmark.stream_definitions[self.origin_stream.name] + self.previous_classes: Sequence[int] = previous_classes + """ The list of classes in previous experiences """ - @property - def task_labels(self) -> List[int]: - stream_def = self._get_stream_def() - return list(stream_def.exps_task_labels[self.current_experience]) + self.classes_seen_so_far: Sequence[int] = classes_seen_so_far + """ List of classes of current and previous experiences """ + + self.future_classes: Optional[Sequence[int]] = future_classes + """ The list of classes of next experiences """ @property def task_label(self) -> int: + """ + The task label. This value will never have value "None". However, + for scenarios that don't produce task labels a placeholder value like 0 + is usually set. Beware that this field is meant as a shortcut to obtain + a unique task label: it assumes that only patterns labeled with a + single task label are present. If this experience contains patterns from + multiple tasks, accessing this property will result in an exception. + """ if len(self.task_labels) != 1: raise ValueError( "The task_label property can only be accessed " @@ -132,4 +509,77 @@ def task_label(self) -> int: return self.task_labels[0] -__all__ = ["TDetectionExperience", "DetectionCLScenario", "DetectionExperience"] +class GenericDetectionExperience( + AbstractDetectionExperience[ + TGenericCLDetectionScenario, + DetectionStream[ + TGenericDetectionExperience, TGenericCLDetectionScenario + ], + ] +): + """ + Definition of a learning experience based on a :class:`GenericCLScenario` + instance. + + This experience implementation uses the generic experience-patterns + assignment defined in the :class:`GenericCLScenario` instance. Instances of + this class are usually obtained from a benchmark stream. + """ + + def __init__( + self: TGenericDetectionExperience, + origin_stream: DetectionStream[ + TGenericDetectionExperience, TGenericCLDetectionScenario + ], + current_experience: int, + ): + """ + Creates an instance of a generic experience given the stream from this + experience was taken and the current experience ID. + + :param origin_stream: The stream from which this experience was + obtained. + :param current_experience: The current experience ID, as an integer. + """ + self.dataset: DetectionDataset = ( + origin_stream.benchmark.stream_definitions[ + origin_stream.name + ].exps_data[current_experience] + ) + + ( + classes_in_this_exp, + previous_classes, + classes_seen_so_far, + future_classes, + ) = origin_stream.benchmark.get_classes_timeline( + current_experience, stream=origin_stream.name + ) + + super().__init__( + origin_stream, + current_experience, + classes_in_this_exp, + previous_classes, + classes_seen_so_far, + future_classes, + ) + + def _get_stream_def(self): + return self.benchmark.stream_definitions[self.origin_stream.name] + + @property + def task_labels(self) -> List[int]: + stream_def = self._get_stream_def() + return list(stream_def.exps_task_labels[self.current_experience]) + + +__all__ = [ + 'TGenericCLDetectionScenario', + 'TGenericDetectionExperience', + 'TGenericScenarioStream', + 'DetectionCLScenario', + 'DetectionStream', + 'AbstractDetectionExperience', + 'GenericDetectionExperience' +] diff --git a/avalanche/benchmarks/scenarios/lazy_dataset_sequence.py b/avalanche/benchmarks/scenarios/lazy_dataset_sequence.py index 6460199c7..46ce1b8c5 100644 --- a/avalanche/benchmarks/scenarios/lazy_dataset_sequence.py +++ b/avalanche/benchmarks/scenarios/lazy_dataset_sequence.py @@ -12,10 +12,8 @@ from collections import defaultdict from typing import Sequence, Iterable, Dict, Optional, Iterator -from avalanche.benchmarks.utils import make_classification_dataset -from avalanche.benchmarks.utils.classification_dataset import ( - ClassificationDataset, -) +from avalanche.benchmarks.utils import \ + make_classification_dataset, AvalancheDataset class LazyDatasetSequence(Sequence[make_classification_dataset]): @@ -212,7 +210,7 @@ def load_all_experiences(self, to_exp: int = None) -> None: f"while generating experience {exp_id}." ) - if not isinstance(generated_exp, ClassificationDataset): + if not isinstance(generated_exp, AvalancheDataset): raise ValueError( "All experience datasets must be subclasses of" " AvalancheDataset" diff --git a/avalanche/benchmarks/utils/__init__.py b/avalanche/benchmarks/utils/__init__.py index 773520ac8..c3e29b407 100644 --- a/avalanche/benchmarks/utils/__init__.py +++ b/avalanche/benchmarks/utils/__init__.py @@ -1,5 +1,6 @@ from .transforms import * from .classification_dataset import * +from .detection_dataset import * from .datasets_from_filelists import * from .torchvision_wrapper import * from .data import * diff --git a/avalanche/benchmarks/utils/classification_dataset.py b/avalanche/benchmarks/utils/classification_dataset.py index 20c4c3f4b..46eb068de 100644 --- a/avalanche/benchmarks/utils/classification_dataset.py +++ b/avalanche/benchmarks/utils/classification_dataset.py @@ -23,6 +23,7 @@ from torch.utils.data import Dataset from torch.utils.data.dataset import Subset, ConcatDataset, TensorDataset +from .collate_functions import ClassificationCollate from .data import make_avalanche_dataset, AvalancheDataset from .transform_groups import TransformGroups, DefaultTransformGroups from .data_attribute import DataAttribute @@ -216,6 +217,9 @@ def make_classification_dataset( if len(das) == 0: das = None + if collate_fn is None: + collate_fn = getattr(dataset, 'collate_fn', ClassificationCollate()) + data = ClassificationDataset( [dataset], data_attributes=das, diff --git a/avalanche/benchmarks/utils/collate_functions.py b/avalanche/benchmarks/utils/collate_functions.py index 342504691..2088423af 100644 --- a/avalanche/benchmarks/utils/collate_functions.py +++ b/avalanche/benchmarks/utils/collate_functions.py @@ -10,9 +10,17 @@ ################################################################################ import itertools +from abc import ABC, abstractmethod from collections import defaultdict +from typing import List, TypeVar, Generic, Sequence, Tuple, Dict import torch +from torch import Tensor +from torch.utils.data import default_collate + +BatchT = TypeVar("BatchT") +ExampleT = TypeVar("ExampleT") +FeatureT = TypeVar("FeatureT") def classification_collate_mbatches_fn(mbatches): @@ -26,17 +34,22 @@ def classification_collate_mbatches_fn(mbatches): """ batch = [] for i in range(len(mbatches[0])): - t = classification_single_values_collate_fn( + t = classification_single_values_collate_mbatches_fn( [el[i] for el in mbatches], i ) batch.append(t) return batch -def classification_single_values_collate_fn(values_list, index): +def classification_single_values_collate_mbatches_fn(values_list, index): """ Collate function used to merge the single elements (x or y or t, - etcetera) of a minibatch of data from a classification dataset. + etcetera) of multiple minibatches of data from a classification dataset. + + Beware that this function expects a list of already batched values, + which means that it accepts a list of [mb_size, X, Y, Z, ...] tensors. + This is different from :func:`classification_single_values_collate_fn`, + which expects a flat list of tensors [X, Y, Z, ...] to be collated. This function assumes that all values are tensors of the same shape (excluding the first dimension). @@ -49,6 +62,26 @@ def classification_single_values_collate_fn(values_list, index): return torch.cat(values_list, dim=0) +def classification_single_values_collate_fn(values_list, index): + """ + Collate function used to merge the single elements (x or y or t, + etcetera) of a minibatch of data from a classification dataset. + + This function expects a flat list of tensors [X, Y, Z, ...] to be collated. + For a version of the functions that can collate pre-batched values + [mb_size, X, Y, Z, ...], refer to + :func:`classification_single_values_collate_mbatches_fn`. + + This function assumes that all values are tensors of the same shape. + + :param values_list: The list of values to merge. + :param index: The index of the element. 0 for x values, 1 for y values, + etcetera. In this implementation, this parameter is ignored. + :return: The merged values. + """ + return torch.stack(values_list) + + def detection_collate_fn(batch): """ Collate function used when loading detection datasets using a DataLoader. @@ -83,9 +116,141 @@ def detection_collate_mbatches_fn(mbatches): return lists +class Collate(ABC, Generic[ExampleT, BatchT]): + + @abstractmethod + def collate_fn(self, batch: Sequence[ExampleT]) -> BatchT: + """ + + Merge multiple examples to create a batch. + + This function expects a list of elements as obtained from + the dataset. + + PyTorch official documentation described the default_collate_fn as: + "Function that takes in a batch of data and puts the elements within + the batch into a tensor with an additional + outer dimension - batch size." + + :param batch: The list of examples. + :return: The batch. + """ + pass + + @abstractmethod + def collate_single_value_fn( + self, + feature_batch: Sequence[FeatureT], + feature_idx: int) -> Sequence[FeatureT]: + """ + Merge a specific feature to create a single-feature batch. + + This function expects a list of features. + + :param feature_batch: The list of features to be batched. + :param feature_idx: The index of the feature being batched. + This may be useful to customize how features are merged. + + :return: The batched features. + """ + pass + + @abstractmethod + def collate_batches_fn(self, batches: Sequence[BatchT]) -> BatchT: + """ + Merge multiple batches. + + This function expects a list of pre-collated batches + (as collated through :meth:`collate_fn`.) + + :param batches: A list of batches to be merged together. + :return: A batch made by collating the input batches. + """ + pass + + @abstractmethod + def collate_single_value_batches_fn( + self, + feature_batches: Sequence[Sequence[FeatureT]], + feature_idx: int) -> FeatureT: + """ + Merge a specific feature of examples contained in multiple batches. + + This function expects a list of pre-batched features. + + :param feature_batches: A list of batched features to be merged + together. + :param feature_idx: The index of the feature being batched. + This may be useful to customize how features are merged. + :return: A batch of features made by collating the input batched + features. + """ + pass + + def __call__(self, batch: List[ExampleT]) -> BatchT: + """ + Merges multiple examples to create a batch. + + In practice, this will call :meth:`collate_fn`. + """ + return self.collate_fn(batch) + + +class ClassificationCollate(Collate[Tuple[Tensor, ...], Tuple[Tensor, ...]]): + + def collate_fn(self, batch): + return default_collate(batch) + + def collate_single_value_fn( + self, + feature_batch: Sequence[Tensor], + feature_idx): + return torch.stack(feature_batch) + + def collate_batches_fn(self, batches): + batch = [] + for i in range(len(batches[0])): + t = self.collate_single_value_batches_fn( + [el[i] for el in batches], i + ) + batch.append(t) + return batch + + def collate_single_value_batches_fn( + self, + feature_batch: Sequence[Tensor], + feature_idx) -> Tensor: + return torch.cat(feature_batch, dim=0) + + +class DetectionCollate(Collate[Tuple[Tensor, Dict, int], + Tuple[Tuple[Tensor], Tuple[Dict], Tuple[int]]]): + + def collate_fn(self, batch): + return detection_collate_fn(batch) + + def collate_single_value_fn(self, feature_batch, feature_idx): + return tuple(feature_batch) + + def collate_batches_fn(self, batches): + return detection_collate_mbatches_fn(batches) + + def collate_single_value_batches_fn( + self, + feature_batch: Sequence[Sequence[FeatureT]], + feature_idx) -> Sequence[FeatureT]: + flattened_features = [] + for batch in feature_batch: + flattened_features.extend(batch) + return tuple(flattened_features) + + __all__ = [ "classification_collate_mbatches_fn", - "classification_single_values_collate_fn", + "classification_single_values_collate_mbatches_fn", "detection_collate_fn", "detection_collate_mbatches_fn", + "Collate", + "ClassificationCollate", + "DetectionCollate" ] diff --git a/avalanche/benchmarks/utils/data.py b/avalanche/benchmarks/utils/data.py index 533d79eed..3801a09eb 100644 --- a/avalanche/benchmarks/utils/data.py +++ b/avalanche/benchmarks/utils/data.py @@ -34,7 +34,7 @@ TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset") -class AvalancheDataset(FlatData): +class AvalancheDataset(FlatData[T_co]): """Avalanche Dataset. Avlanche dataset are pytorch-compatible Datasets with some additional @@ -255,7 +255,7 @@ def _getitem_recursive_call(self, idx, group_name): element = self._transform_groups(element, group_name=group_name) return element - def __getitem__(self, idx) -> Union[T_co, Sequence[T_co]]: + def __getitem__(self, idx) -> T_co: elem = self._getitem_recursive_call( idx, self._transform_groups.current_group ) diff --git a/avalanche/benchmarks/utils/data_attribute.py b/avalanche/benchmarks/utils/data_attribute.py index 0505160ae..6780264d0 100644 --- a/avalanche/benchmarks/utils/data_attribute.py +++ b/avalanche/benchmarks/utils/data_attribute.py @@ -15,6 +15,7 @@ concatenation and subsampling operations and are automatically managed by AvalancheDatasets. """ +from typing import TypeVar, Generic, Sequence, Set, Dict, Optional import torch @@ -22,7 +23,10 @@ from .flat_data import ConstantSequence, FlatData -class DataAttribute: +DataT = TypeVar("DataT") + + +class DataAttribute(Generic[DataT]): """Data attributes manage sample-wise information such as task or class labels. @@ -32,7 +36,11 @@ class labels. Data attributes can be efficiently concatenated and subsampled. """ - def __init__(self, data: IDataset, name: str = None, use_in_getitem=False): + def __init__( + self, + data: IDataset[DataT], + name: str = None, + use_in_getitem: bool = False): """Data Attribute. :param data: a sequence of values, one for each sample. @@ -42,16 +50,16 @@ def __init__(self, data: IDataset, name: str = None, use_in_getitem=False): :param use_in_getitem: If True, `AvalancheDataset` will add the value at the end of each sample. """ - self.name = name - self.use_in_getitem = use_in_getitem + self.name: str = name + self.use_in_getitem: bool = use_in_getitem - self._data = self._normalize_sequence(data) + self._data: FlatData = self._normalize_sequence(data) - self._uniques = None # set() - self._val_to_idx = None # dict() - self._count = None # dict() + self._uniques: Optional[Set[DataT]] = None + self._val_to_idx: Optional[Dict[DataT, Sequence[int]]] = None + self._count: Optional[Dict[DataT, int]] = None - def __getitem__(self, item): + def __getitem__(self, item) -> DataT: return self.data[item] def __len__(self): @@ -64,26 +72,18 @@ def __str__(self): return str(self.data[:]) @property - def data(self): + def data(self) -> FlatData[DataT]: return self._data @property - def uniques(self): + def uniques(self) -> Set[DataT]: """Set of unique values in the attribute.""" if self._uniques is None: - self._uniques = set() - # init. uniques with fast paths for special cases - if isinstance(self.data, ConstantSequence): - self.uniques.add(self.data[0]) - elif isinstance(self.data, DataAttribute): - self.uniques.update(self.data.uniques) - else: - for el in self.data: - self.uniques.add(el) + self._uniques = set(self.data) return self._uniques @property - def count(self): + def count(self) -> Dict[DataT, int]: """Dictionary of value -> count.""" if self._count is None: self._count = {} @@ -94,7 +94,7 @@ def count(self): return self._count @property - def val_to_idx(self): + def val_to_idx(self) -> Dict[DataT, Sequence[int]]: """Dictionary mapping unique values to indices.""" if self._val_to_idx is None: # init. val-to-idx @@ -108,7 +108,7 @@ def val_to_idx(self): self._val_to_idx[x].append(i) return self._val_to_idx - def subset(self, indices): + def subset(self, indices) -> "DataAttribute[DataT]": """Subset operation. Return a new `DataAttribute` by keeping only the elements in `indices`. @@ -122,14 +122,14 @@ def subset(self, indices): use_in_getitem=self.use_in_getitem, ) - def concat(self, other: "DataAttribute"): + def concat(self, other: "DataAttribute[DataT]") -> "DataAttribute[DataT]": """Concatenation operation. :param other: the other `DataAttribute` :return: the new concatenated `DataAttribute` """ assert self.name == other.name, ( - "Cannot concatenate DataAttributes" + "with different names." + "Cannot concatenate DataAttributes with different names." ) return DataAttribute( self.data.concat(other.data), @@ -155,4 +155,7 @@ def __init__(self, task_labels): super().__init__(task_labels, "task_labels", use_in_getitem=True) -__all__ = ["DataAttribute", "TaskLabels"] +__all__ = [ + "DataAttribute", + "TaskLabels" +] diff --git a/avalanche/benchmarks/utils/data_loader.py b/avalanche/benchmarks/utils/data_loader.py index 0276d2e08..7c0b582fa 100644 --- a/avalanche/benchmarks/utils/data_loader.py +++ b/avalanche/benchmarks/utils/data_loader.py @@ -15,10 +15,10 @@ between the current data and the replay memory. """ from itertools import chain -from typing import Dict, Sequence, Union +from typing import Dict, Sequence, Union, Any import torch -from torch.utils.data import RandomSampler, DistributedSampler +from torch.utils.data import RandomSampler, DistributedSampler, Dataset from torch.utils.data.dataloader import DataLoader from avalanche.benchmarks.utils import make_classification_dataset @@ -31,6 +31,7 @@ from avalanche.benchmarks.utils.collate_functions import ( detection_collate_mbatches_fn as _detection_collate_mbatches_fn, ) +from avalanche.distributed import DistributedHelper _default_collate_mbatches_fn = classification_collate_mbatches_fn @@ -233,6 +234,7 @@ def __iter__(self): removed_dataloaders_idxs.append(tid) continue mb_curr.extend(batch) + yield self.collate_fn(mb_curr) # clear empty data-loaders @@ -274,14 +276,14 @@ def __init__( self.collate_mbatches = collate_mbatches for data in self.datasets: - if _DistributedHelper.is_distributed and distributed_sampling: + if DistributedHelper.is_distributed and distributed_sampling: seed = torch.randint( 0, - 2 ** 32 - 1 - _DistributedHelper.world_size, + 2 ** 32 - 1 - DistributedHelper.world_size, (1,), dtype=torch.int64, ) - seed += _DistributedHelper.rank + seed += DistributedHelper.rank generator = torch.Generator() generator.manual_seed(int(seed)) else: @@ -307,6 +309,7 @@ def __iter__(self): for tid, t_loader in enumerate(iter_dataloaders): batch = next(t_loader) mb_curr.append(batch) + yield self.collate_mbatches(mb_curr) def __len__(self): @@ -597,11 +600,11 @@ def _get_batch_sizes( def _make_data_loader( - dataset, - distributed_sampling, - data_loader_args, - batch_size, - force_no_workers=False, + dataset: Dataset, + distributed_sampling: bool, + data_loader_args: Dict[str, Any], + batch_size: int, + force_no_workers=False ): data_loader_args = data_loader_args.copy() @@ -612,14 +615,23 @@ def _make_data_loader( if 'persistent_workers' in data_loader_args: data_loader_args['persistent_workers'] = False - if _DistributedHelper.is_distributed and distributed_sampling: + if DistributedHelper.is_distributed and distributed_sampling: + # Note: shuffle only goes in the sampler, while + # drop_last must be passed to both the sampler + # and the DataLoader + drop_last = data_loader_args.pop("drop_last", False) sampler = DistributedSampler( dataset, - shuffle=data_loader_args.pop("shuffle", False), - drop_last=data_loader_args.pop("drop_last", False), + shuffle=data_loader_args.pop("shuffle", True), + drop_last=drop_last, ) + data_loader = DataLoader( - dataset, sampler=sampler, batch_size=batch_size, **data_loader_args + dataset, + sampler=sampler, + batch_size=batch_size, + drop_last=drop_last, + **data_loader_args ) else: sampler = None @@ -630,15 +642,6 @@ def _make_data_loader( return data_loader, sampler -class __DistributedHelperPlaceholder: - is_distributed = False - world_size = 1 - rank = 0 - - -_DistributedHelper = __DistributedHelperPlaceholder() - - __all__ = [ "detection_collate_fn", "detection_collate_mbatches_fn", diff --git a/avalanche/benchmarks/utils/detection_dataset.py b/avalanche/benchmarks/utils/detection_dataset.py new file mode 100644 index 000000000..5652e0a86 --- /dev/null +++ b/avalanche/benchmarks/utils/detection_dataset.py @@ -0,0 +1,857 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 12-05-2020 # +# Author(s): Lorenzo Pellegrini, Antonio Carta # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This module contains the implementation of the ``DetectionDataset``, +which is the dataset used for supervised continual learning benchmarks. +DetectionDatasets are ``AvalancheDatasets`` that manage targets and task +labels automatically. Concatenation and subsampling operations are optimized +to be used frequently, as is common in replay strategies. +""" +import warnings +from collections import defaultdict, deque +from functools import partial +from typing import ( + List, + Any, + Sequence, + Union, + Optional, + TypeVar, + Callable, + Dict, + Tuple, + Mapping, ) + +import torch +from torch import Tensor +from torch.utils.data import Dataset +from torch.utils.data.dataset import Subset, ConcatDataset +from typing_extensions import Protocol + +from .collate_functions import DetectionCollate +from .data import AvalancheDataset +from .data_attribute import DataAttribute +from .dataset_definitions import ( + IDatasetWithTargets, ) +from .dataset_utils import ( + SubSequence, + find_list_from_index, +) +from .flat_data import ConstantSequence +from .transform_groups import TransformGroups, DefaultTransformGroups + +T_co = TypeVar("T_co", covariant=True) +TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset") +TTargetType = Dict[str, Tensor] + + +# Info: https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols +class XComposedTransformDef(Protocol): + def __call__(self, *input_values: Any) -> Any: + pass + + +class XTransformDef(Protocol): + def __call__(self, input_value: Any) -> Any: + pass + + +class YTransformDef(Protocol): + def __call__(self, input_value: Any) -> Any: + pass + + +XTransform = Optional[Union[XTransformDef, XComposedTransformDef]] +YTransform = Optional[YTransformDef] +TransformGroupDef = Union[None, XTransform, Tuple[XTransform, YTransform]] + + +SupportedDetectionDataset = Union[ + IDatasetWithTargets, + Subset, + ConcatDataset, +] + +# Image (tensor), target dict, task label +DetectionExampleT = Tuple[Tensor, TTargetType, int] + + +class DetectionDataset(AvalancheDataset, + IDatasetWithTargets[DetectionExampleT, TTargetType]): + def __init__(self, *args, **kwargs): + # Here defined only to provide type hinting + self.targets_task_labels: DataAttribute[int] = DataAttribute( + [], + name='targets_task_labels', + use_in_getitem=True + ) + self.targets: DataAttribute[Dict[str, Tensor]] = DataAttribute( + [], + name='targets', + use_in_getitem=False + ) + + del self.targets_task_labels + del self.targets + + super().__init__(*args, **kwargs) + + assert hasattr(self, 'targets_task_labels') + assert hasattr(self, 'targets') + + def subset(self, indices): + data = super().subset(indices) + return data.with_transforms(self._transform_groups.current_group) + + def concat(self, other): + data = super().concat(other) + return data.with_transforms(self._transform_groups.current_group) + + @property + def task_pattern_indices(self): + """A dictionary mapping task ids to their sample indices.""" + return self.targets_task_labels.val_to_idx + + @property + def task_set(self): + """Returns the dataset's ``TaskSet``, which is a mapping .""" + return DetectionTaskSet(self) + + +def make_detection_dataset( + dataset: SupportedDetectionDataset, + *, + transform: XTransform = None, + target_transform: YTransform = None, + transform_groups: Dict[str, TransformGroupDef] = None, + initial_transform_group: str = None, + task_labels: Union[int, Sequence[int]] = None, + targets: Sequence[TTargetType] = None, + collate_fn: Callable[[List], Any] = None +): + """Avalanche Detection Dataset. + + Supervised continual learning benchmarks in Avalanche return instances of + this dataset, but it can also be used in a completely standalone manner. + + This dataset applies input/target transformations, it supports + slicing and advanced indexing and it also contains useful fields as + `targets`, which contains the pattern dictionaries, and + `targets_task_labels`, which contains the pattern task labels. + The `task_set` field can be used to obtain a the subset of patterns + labeled with a given task label. + + This dataset can also be used to apply several advanced operations involving + transformations. For instance, it allows the user to add and replace + transformations, freeze them so that they can't be changed, etc. + + This dataset also allows the user to keep distinct transformations groups. + Simply put, a transformation group is a pair of transform+target_transform + (exactly as in torchvision datasets). This dataset natively supports keeping + two transformation groups: the first, 'train', contains transformations + applied to training patterns. Those transformations usually involve some + kind of data augmentation. The second one is 'eval', that will contain + transformations applied to test patterns. Having both groups can be + useful when, for instance, in need to test on the training data (as this + process usually involves removing data augmentation operations). Switching + between transformations can be easily achieved by using the + :func:`train` and :func:`eval` methods. + + Moreover, arbitrary transformation groups can be added and used. For more + info see the constructor and the :func:`with_transforms` method. + + This dataset will try to inherit the task labels from the input + dataset. If none are available and none are given via the `task_labels` + parameter, each pattern will be assigned a default task label 0. + + Creates a ``AvalancheDataset`` instance. + + :param dataset: The dataset to decorate. Beware that + AvalancheDataset will not overwrite transformations already + applied by this dataset. + :param transform: A function/transform that takes the X value of a + pattern from the original dataset and returns a transformed version. + :param target_transform: A function/transform that takes in the target + and transforms it. + :param transform_groups: A dictionary containing the transform groups. + Transform groups are used to quickly switch between training and + eval (test) transformations. This becomes useful when in need to + test on the training dataset as test transformations usually don't + contain random augmentations. ``AvalancheDataset`` natively supports + the 'train' and 'eval' groups by calling the ``train()`` and + ``eval()`` methods. When using custom groups one can use the + ``with_transforms(group_name)`` method instead. Defaults to None, + which means that the current transforms will be used to + handle both 'train' and 'eval' groups (just like in standard + ``torchvision`` datasets). + :param initial_transform_group: The name of the initial transform group + to be used. Defaults to None, which means that the current group of + the input dataset will be used (if an AvalancheDataset). If the + input dataset is not an AvalancheDataset, then 'train' will be + used. + :param task_labels: The task label of each instance. Must be a sequence + of ints, one for each instance in the dataset. Alternatively can be + a single int value, in which case that value will be used as the + task label for all the instances. Defaults to None, which means that + the dataset will try to obtain the task labels from the original + dataset. If no task labels could be found, a default task label + 0 will be applied to all instances. + :param targets: The dictionary of detection boxes of each pattern. + Defaults to None, which means that the targets will be retrieved from + the dataset (if possible). + :param collate_fn: The function to use when slicing to merge single + patterns. This function is the function used in the data loading + process, too. If None, the constructor will check if a + `collate_fn` field exists in the dataset. If no such field exists, + the default collate function for detection will be used. + """ + transform_gs = _init_transform_groups( + transform_groups, + transform, + target_transform, + initial_transform_group, + dataset, + ) + targets = _init_targets(dataset, targets) + task_labels = _init_task_labels(dataset, task_labels) + + das = [] + if targets is not None: + das.append(targets) + if task_labels is not None: + das.append(task_labels) + if len(das) == 0: + das = None + + if collate_fn is None: + collate_fn = getattr(dataset, 'collate_fn', DetectionCollate()) + + data = DetectionDataset( + [dataset], + data_attributes=das, + transform_groups=transform_gs, + collate_fn=collate_fn, + ) + if initial_transform_group is not None: + return data.with_transforms(initial_transform_group) + else: + return data + + +def _init_transform_groups( + transform_groups, + transform, + target_transform, + initial_transform_group, + dataset, +): + if transform_groups is not None and ( + transform is not None or target_transform is not None + ): + raise ValueError( + "transform_groups can't be used with transform" + "and target_transform values" + ) + + if transform_groups is not None: + _check_groups_dict_format(transform_groups) + + if initial_transform_group is None: + # Detect from the input dataset. If not an AvalancheDataset then + # use 'train' as the initial transform group + if ( + isinstance(dataset, DetectionDataset) + and dataset._transform_groups is not None + ): + initial_transform_group = dataset._transform_groups.current_group + else: + initial_transform_group = "train" + + if transform_groups is None: + if target_transform is None and transform is None: + tgs = None + else: + tgs = TransformGroups( + { + "train": (transform, target_transform), + "eval": (transform, target_transform), + }, + current_group=initial_transform_group, + ) + else: + tgs = TransformGroups( + transform_groups, current_group=initial_transform_group + ) + return tgs + + +def _check_groups_dict_format(groups_dict): + # The original groups_dict must be convertible to native Python dict + groups_dict = dict(groups_dict) + + # Check if the format of the groups is correct + for map_key in groups_dict: + if not isinstance(map_key, str): + raise ValueError( + "Every group must be identified by a string." + 'Wrong key was: "' + str(map_key) + '"' + ) + + if "test" in groups_dict: + warnings.warn( + 'A transformation group named "test" has been found. Beware ' + "that by default AvalancheDataset supports test transformations" + ' through the "eval" group. Consider using that one!' + ) + + +def _init_targets(dataset, targets, check_shape=True): + if targets is not None: + # User defined targets always take precedence + if len(targets) != len(dataset) and check_shape: + raise ValueError( + "Invalid amount of target labels. It must be equal to the " + "number of patterns in the dataset. Got {}, expected " + "{}!".format(len(targets), len(dataset)) + ) + return DataAttribute(targets, "targets") + + if isinstance(dataset, DetectionDataset): + return None # targets are initialized automatically + else: + targets = _traverse_supported_dataset(dataset, _select_targets) + + if targets is None: + return None + return DataAttribute(targets, "targets") + + +def _init_task_labels(dataset, task_labels, check_shape=True): + """A task label for each pattern in the dataset.""" + if task_labels is not None: + # task_labels has priority over the dataset fields + if isinstance(task_labels, int): + task_labels = ConstantSequence(task_labels, len(dataset)) + elif len(task_labels) != len(dataset) and check_shape: + raise ValueError( + "Invalid amount of task labels. It must be equal to the " + "number of patterns in the dataset. Got {}, expected " + "{}!".format(len(task_labels), len(dataset)) + ) + tls = SubSequence(task_labels, converter=int) + else: + if isinstance(dataset, DetectionDataset): + tls = None + else: + task_labels = _traverse_supported_dataset( + dataset, _select_task_labels + ) + tls = SubSequence(task_labels, converter=int) + + if tls is None: + return None + return DataAttribute(tls, "targets_task_labels", use_in_getitem=True) + + +def _detection_class_mapping_transform(class_mapping, example_target_dict): + example_target_dict = dict(example_target_dict) + + # example_target_dict["labels"] is a tensor containing one label + # for each bounding box in the image. We need to remap each of them + example_target_labels = example_target_dict["labels"] + example_mapped_labels = [class_mapping[int(el)] for el + in example_target_labels] + + if isinstance(example_target_labels, Tensor): + example_mapped_labels = torch.as_tensor(example_mapped_labels) + + example_target_dict["labels"] = example_mapped_labels + + return example_target_dict + + +def detection_subset( + dataset: SupportedDetectionDataset, + indices: Sequence[int] = None, + *, + class_mapping: Sequence[int] = None, + transform: Callable[[Any], Any] = None, + target_transform: Callable[[int], int] = None, + transform_groups: Dict[str, Tuple[XTransform, YTransform]] = None, + initial_transform_group: str = None, + task_labels: Union[int, Sequence[int]] = None, + targets: Sequence[TTargetType] = None, + collate_fn: Callable[[List], Any] = None +): + """Creates an ``AvalancheSubset`` instance. + + For simple subset operations you should use the method + `dataset.subset(indices)`. + Use this constructor only if you need to redefine transformation or + class/task labels. + + A Dataset that behaves like a PyTorch :class:`torch.utils.data.Subset`. + This Dataset also supports transformations, slicing, advanced indexing, + the targets field, class mapping and all the other goodies listed in + :class:`AvalancheDataset`. + + :param dataset: The whole dataset. + :param indices: Indices in the whole set selected for subset. Can + be None, which means that the whole dataset will be returned. + :param class_mapping: A list that, for each possible class label value, + contains its corresponding remapped value. Can be None. + :param transform: A function/transform that takes the X value of a + pattern from the original dataset and returns a transformed version. + :param target_transform: A function/transform that takes in the target + and transforms it. + :param transform_groups: A dictionary containing the transform groups. + Transform groups are used to quickly switch between training and + eval (test) transformations. This becomes useful when in need to + test on the training dataset as test transformations usually don't + contain random augmentations. ``AvalancheDataset`` natively supports + the 'train' and 'eval' groups by calling the ``train()`` and + ``eval()`` methods. When using custom groups one can use the + ``with_transforms(group_name)`` method instead. Defaults to None, + which means that the current transforms will be used to + handle both 'train' and 'eval' groups (just like in standard + ``torchvision`` datasets). + :param initial_transform_group: The name of the initial transform group + to be used. Defaults to None, which means that the current group of + the input dataset will be used (if an AvalancheDataset). If the + input dataset is not an AvalancheDataset, then 'train' will be + used. + :param task_labels: The task label for each instance. Must be a sequence + of ints, one for each instance in the dataset. This can either be a + list of task labels for the original dataset or the list of task + labels for the instances of the subset (an automatic detection will + be made). In the unfortunate case in which the original dataset and + the subset contain the same amount of instances, then this parameter + is considered to contain the task labels of the subset. + Alternatively can be a single int value, in which case + that value will be used as the task label for all the instances. + Defaults to None, which means that the dataset will try to + obtain the task labels from the original dataset. If no task labels + could be found, a default task label 0 will be applied to all + instances. + :param targets: The target dictionary of each pattern. Defaults to None, + which means that the targets will be retrieved from the dataset (if + possible). This can either be a list of target dictionaries for the + original dataset or the list of target dictionaries for the instances + of the subset (an automatic detection will be made). In the + unfortunate case in which the original dataset and the subset contain + the same amount of instances, then this parameter is considered to + contain the target dictionaries of the subset. + :param collate_fn: The function to use when slicing to merge single + patterns. This function is the function used in the data loading + process, too. If None, the constructor will check if a + `collate_fn` field exists in the dataset. If no such field exists, + the default collate function for detection will be used + """ + if isinstance(dataset, DetectionDataset): + if ( + class_mapping is None + and transform is None + and target_transform is None + and transform_groups is None + and initial_transform_group is None + and task_labels is None + and targets is None + and collate_fn is None + ): + return dataset.subset(indices) + + targets = _init_targets(dataset, targets, check_shape=False) + task_labels = _init_task_labels(dataset, task_labels, check_shape=False) + transform_gs = _init_transform_groups( + transform_groups, + transform, + target_transform, + initial_transform_group, + dataset, + ) + + if initial_transform_group is not None and isinstance( + dataset, AvalancheDataset + ): + dataset = dataset.with_transforms(initial_transform_group) + + if class_mapping is not None: # update targets + + if targets is None: + targets = dataset.targets + + tgs = [ + _detection_class_mapping_transform( + class_mapping, example_target_dict) + for example_target_dict in targets] + + targets = DataAttribute(tgs, "targets") + + if class_mapping is not None: + mapping_fn = partial(_detection_class_mapping_transform, class_mapping) + frozen_transform_groups = DefaultTransformGroups( + (None, mapping_fn) + ) + else: + frozen_transform_groups = None + + das = [] + if targets is not None: + das.append(targets) + if task_labels is not None: + das.append(task_labels) + if len(das) == 0: + das = None + + if collate_fn is None: + collate_fn = DetectionCollate() + + return DetectionDataset( + [dataset], + indices=indices, + data_attributes=das, + transform_groups=transform_gs, + frozen_transform_groups=frozen_transform_groups, + collate_fn=collate_fn, + ) + + +def concat_detection_datasets( + datasets: List[SupportedDetectionDataset], + *, + transform: Callable[[Any], Any] = None, + target_transform: Callable[[int], int] = None, + transform_groups: Dict[str, Tuple[XTransform, YTransform]] = None, + initial_transform_group: str = None, + task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]] = None, + targets: Union[ + Sequence[TTargetType], Sequence[Sequence[TTargetType]] + ] = None, + collate_fn: Callable[[List], Any] = None +): + """Creates a ``AvalancheConcatDataset`` instance. + + For simple subset operations you should use the method + `dataset.concat(other)` or + `concat_datasets` from `avalanche.benchmarks.utils.utils`. + Use this constructor only if you need to redefine transformation or + class/task labels. + + A Dataset that behaves like a PyTorch + :class:`torch.utils.data.ConcatDataset`. However, this Dataset also supports + transformations, slicing, advanced indexing and the targets field and all + the other goodies listed in :class:`AvalancheDataset`. + + This dataset guarantees that the operations involving the transformations + and transformations groups are consistent across the concatenated dataset + (if they are subclasses of :class:`AvalancheDataset`). + + :param datasets: A collection of datasets. + :param transform: A function/transform that takes the X value of a + pattern from the original dataset and returns a transformed version. + :param target_transform: A function/transform that takes in the target + and transforms it. + :param transform_groups: A dictionary containing the transform groups. + Transform groups are used to quickly switch between training and + eval (test) transformations. This becomes useful when in need to + test on the training dataset as test transformations usually don't + contain random augmentations. ``AvalancheDataset`` natively supports + the 'train' and 'eval' groups by calling the ``train()`` and + ``eval()`` methods. When using custom groups one can use the + ``with_transforms(group_name)`` method instead. Defaults to None, + which means that the current transforms will be used to + handle both 'train' and 'eval' groups (just like in standard + ``torchvision`` datasets). + :param initial_transform_group: The name of the initial transform group + to be used. Defaults to None, which means that if all + AvalancheDatasets in the input datasets list agree on a common + group (the "current group" is the same for all datasets), then that + group will be used as the initial one. If the list of input datasets + does not contain an AvalancheDataset or if the AvalancheDatasets + do not agree on a common group, then 'train' will be used. + :param targets: The label of each pattern. Can either be a sequence of + labels or, alternatively, a sequence containing sequences of labels + (one for each dataset to be concatenated). Defaults to None, which + means that the targets will be retrieved from the datasets (if + possible). + :param task_labels: The task labels for each pattern. Must be a sequence + of ints, one for each pattern in the dataset. Alternatively, task + labels can be expressed as a sequence containing sequences of ints + (one for each dataset to be concatenated) or even a single int, + in which case that value will be used as the task label for all + instances. Defaults to None, which means that the dataset will try + to obtain the task labels from the original datasets. If no task + labels could be found for a dataset, a default task label 0 will + be applied to all patterns of that dataset. + :param collate_fn: The function to use when slicing to merge single + patterns. This function is the function used in the data loading + process, too. If None, the constructor will check if a `collate_fn` + field exists in the first dataset. If no such field exists, the + default collate function for detection will be used. + Beware that the chosen collate function will be applied to all + the concatenated datasets even if a different collate is defined + in different datasets. + """ + dds = [] + for dd in datasets: + if not isinstance(dd, AvalancheDataset): + dd = make_detection_dataset( + dd, + transform=transform, + target_transform=target_transform, + transform_groups=transform_groups, + initial_transform_group=initial_transform_group, + task_labels=task_labels, + targets=targets, + collate_fn=collate_fn, + ) + dds.append(dd) + if ( + transform is None + and target_transform is None + and transform_groups is None + and initial_transform_group is None + and task_labels is None + and targets is None + and collate_fn is None + and len(datasets) > 0 + ): + d0 = datasets[0] + if isinstance(d0, DetectionDataset): + for d1 in datasets[1:]: + d0 = d0.concat(d1) + return d0 + + das = [] + if len(dds) > 0: + ####################################### + # TRANSFORMATION GROUPS + ####################################### + transform_groups = _init_transform_groups( + transform_groups, + transform, + target_transform, + initial_transform_group, + dds[0], + ) + + if initial_transform_group is None: + uniform_group = None + for d_set in datasets: + if isinstance(d_set, AvalancheDataset): + if uniform_group is None: + uniform_group = d_set._transform_groups.current_group + else: + if ( + uniform_group + != d_set._transform_groups.current_group + ): + uniform_group = None + break + + if uniform_group is None: + initial_transform_group = "train" + else: + initial_transform_group = uniform_group + + ####################################### + # DATA ATTRIBUTES + ####################################### + + totlen = sum([len(d) for d in datasets]) + if ( + task_labels is not None + ): # User defined targets always take precedence + if isinstance(task_labels, int): + task_labels = ConstantSequence(task_labels, totlen) + elif len(task_labels) != totlen: + raise ValueError( + "Invalid amount of target labels. It must be equal to the " + "number of patterns in the dataset. Got {}, expected " + "{}!".format(len(task_labels), totlen) + ) + das.append( + DataAttribute( + task_labels, "targets_task_labels", use_in_getitem=True + ) + ) + + if targets is not None: # User defined targets always take precedence + if len(targets) != totlen: + raise ValueError( + "Invalid amount of target dictionaries. It must be " + "equal to the number of patterns in the dataset. " + "Got {}, expected {}!".format(len(targets), totlen) + ) + das.append(DataAttribute(targets, "targets")) + if len(das) == 0: + das = None + data = DetectionDataset( + dds, transform_groups=transform_groups, data_attributes=das + ) + return data.with_transforms(initial_transform_group) + + +def _select_targets(dataset, indices): + if hasattr(dataset, "targets"): + # Standard supported dataset + found_targets = dataset.targets + else: + raise ValueError( + "Unsupported dataset: must have a valid targets field" + ) + + if indices is not None: + found_targets = SubSequence(found_targets, indices=indices) + + return found_targets + + +def _select_task_labels(dataset, indices): + found_task_labels = None + if hasattr(dataset, "targets_task_labels"): + found_task_labels = dataset.targets_task_labels + + if found_task_labels is None: + if isinstance(dataset, (Subset, ConcatDataset)): + return None # Continue traversing + + if found_task_labels is None: + if indices is None: + return ConstantSequence(0, len(dataset)) + return ConstantSequence(0, len(indices)) + + if indices is not None: + found_task_labels = SubSequence(found_task_labels, indices=indices) + + return found_task_labels + + +def _traverse_supported_dataset( + dataset, values_selector: Callable[[Dataset, List[int]], List], indices=None +) -> List: + initial_error = None + try: + result = values_selector(dataset, indices) + if result is not None: + return result + except BaseException as e: + initial_error = e + + if isinstance(dataset, Subset): + if indices is None: + indices = range(len(dataset)) + indices = [dataset.indices[x] for x in indices] + return list( + _traverse_supported_dataset( + dataset.dataset, values_selector, indices + ) + ) + + if isinstance(dataset, ConcatDataset): + result = [] + if indices is None: + for c_dataset in dataset.datasets: + result += list( + _traverse_supported_dataset( + c_dataset, values_selector, indices + ) + ) + return result + + datasets_to_indexes = defaultdict(list) + indexes_to_dataset = [] + datasets_len = [] + recursion_result = [] + + all_size = 0 + for c_dataset in dataset.datasets: + len_dataset = len(c_dataset) + datasets_len.append(len_dataset) + all_size += len_dataset + + for subset_idx in indices: + dataset_idx, pattern_idx = find_list_from_index( + subset_idx, datasets_len, all_size + ) + datasets_to_indexes[dataset_idx].append(pattern_idx) + indexes_to_dataset.append(dataset_idx) + + for dataset_idx, c_dataset in enumerate(dataset.datasets): + recursion_result.append( + deque( + _traverse_supported_dataset( + c_dataset, + values_selector, + datasets_to_indexes[dataset_idx], + ) + ) + ) + + result = [] + for idx in range(len(indices)): + dataset_idx = indexes_to_dataset[idx] + result.append(recursion_result[dataset_idx].popleft()) + + return result + + if initial_error is not None: + raise initial_error + + raise ValueError("Error: can't find the needed data in the given dataset") + + +class DetectionTaskSet(Mapping): + """A lazy mapping for task dataset>. + + Given a `DetectionDataset`, this class provides an + iterator that splits the data into task subsets, returning tuples + ``. + + Usage: + + .. code-block:: python + + tset = DetectionTaskSet(data) + for tid, tdata in tset: + print(f"task {tid} has {len(tdata)} examples.") + + """ + + def __init__(self, data: DetectionDataset): + """Constructor. + + :param data: original data + """ + super().__init__() + self.data = data + + def __iter__(self): + return iter(self.data.targets_task_labels.uniques) + + def __getitem__(self, task_label): + tl_idx = self.data.targets_task_labels.val_to_idx[task_label] + return detection_subset(self.data, tl_idx) + + def __len__(self): + return len(self.data.targets_task_labels.uniques) + + +__all__ = [ + "SupportedDetectionDataset", + "DetectionDataset", + "make_detection_dataset", + "detection_subset", + "concat_detection_datasets", + "DetectionTaskSet", +] diff --git a/avalanche/benchmarks/utils/flat_data.py b/avalanche/benchmarks/utils/flat_data.py index 2efb8030f..eaa3214cb 100644 --- a/avalanche/benchmarks/utils/flat_data.py +++ b/avalanche/benchmarks/utils/flat_data.py @@ -12,14 +12,17 @@ Datasets with optimized concat/subset operations. """ import bisect -from typing import List +from typing import List, TypeVar, Optional from torch.utils.data import ConcatDataset from avalanche.benchmarks.utils.dataset_definitions import IDataset +FlatDataImplT = TypeVar('FlatDataImplT', bound='FlatData') +DataT = TypeVar("DataT") -class FlatData(IDataset): + +class FlatData(IDataset[DataT]): """FlatData is a dataset optimized for efficient repeated concatenation and subset operations. @@ -42,9 +45,9 @@ class FlatData(IDataset): def __init__( self, - datasets: List[IDataset], + datasets: List[IDataset[DataT]], indices: List[int] = None, - can_flatten=True, + can_flatten: bool = True, ): """Constructor @@ -69,13 +72,14 @@ def _get_indices(self): else: return list(range(len(self))) - def subset(self, indices: List[int]) -> "FlatData": + def subset(self: FlatDataImplT, indices: Optional[List[int]]) \ + -> FlatDataImplT: """Subsampling operation. :param indices: indices of the new samples :return: """ - if self._can_flatten: + if self._can_flatten and indices is not None: if self._indices is None: new_indices = indices else: @@ -84,7 +88,7 @@ def subset(self, indices: List[int]) -> "FlatData": return self.__class__(datasets=self._datasets, indices=new_indices) return self.__class__(datasets=[self], indices=indices) - def concat(self, other: "FlatData") -> "FlatData": + def concat(self: FlatDataImplT, other: "FlatData") -> FlatDataImplT: """Concatenation operation. :param other: other dataset. @@ -172,7 +176,7 @@ def _get_idx(self, idx): idx = idx - self._cumulative_sizes[dataset_idx - 1] return dataset_idx, int(idx) - def __getitem__(self, idx): + def __getitem__(self, idx) -> DataT: dataset_idx, idx = self._get_idx(idx) return self._datasets[dataset_idx][idx] @@ -183,10 +187,10 @@ def __len__(self): return len(self._indices) return self._cumulative_sizes[-1] - def __add__(self, other: "FlatData") -> "FlatData": + def __add__(self, other: FlatDataImplT) -> FlatDataImplT: return self.concat(other) - def __radd__(self, other: "FlatData") -> "FlatData": + def __radd__(self, other: FlatDataImplT) -> FlatDataImplT: return other.concat(self) @@ -240,7 +244,8 @@ def __str__(self): ) -def _flatten_dataset_list(datasets: List[FlatData]): +def _flatten_dataset_list(datasets: List[IDataset[DataT]]) \ + -> List[IDataset[DataT]]: """Flatten dataset tree if possible.""" # Concat -> Concat branch # Flattens by borrowing the list of concatenated datasets @@ -259,7 +264,7 @@ def _flatten_dataset_list(datasets: List[FlatData]): flattened_list.append(dataset) # merge consecutive Subsets if compatible - new_data_list = [] + new_data_list: List[IDataset[DataT]] = [] for dataset in flattened_list: if ( isinstance(dataset, FlatData) diff --git a/avalanche/core.py b/avalanche/core.py index ac13aac9f..1441c0754 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -27,6 +27,13 @@ class BasePlugin(Generic[Template], ABC): and loggers. """ + supports_distributed = False + """ + A class-level attribute that indicates whether the plugin is supported + in distributed training. If False, Avalanche will warn when the plugin + is used in distributed training. + """ + def __init__(self): pass diff --git a/avalanche/distributed/__init__.py b/avalanche/distributed/__init__.py new file mode 100644 index 000000000..af11a110e --- /dev/null +++ b/avalanche/distributed/__init__.py @@ -0,0 +1,5 @@ +from .distributed_helper import * +from .distributed_value import * +from .distributed_batch import * +from .distributed_model import * +from .distributed_commons import * diff --git a/avalanche/distributed/distributed_batch.py b/avalanche/distributed/distributed_batch.py new file mode 100644 index 000000000..0fd3ed858 --- /dev/null +++ b/avalanche/distributed/distributed_batch.py @@ -0,0 +1,173 @@ +from abc import abstractmethod, ABC +from typing import TypeVar, List, Optional, Callable, Any, Iterable + +import torch +from torch import Tensor + +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_value import SwitchableDistributedValue + +LocalT = TypeVar('LocalT') +DistributedT = TypeVar('DistributedT') + + +class DistributedObject(SwitchableDistributedValue[LocalT, DistributedT], ABC): + """ + An intermediate abstract class in charge of synchronizing objects. + + The merge procedure must be implemented in child classes. + """ + def _synchronize(self) -> DistributedT: + objects = self._synchronize_objects() + return self._merge_objects(objects) + + def _synchronize_objects(self) -> List[LocalT]: + return DistributedHelper.gather_all_objects( + self._local_value + ) + + @abstractmethod + def _merge_objects(self, objects: List[LocalT]) -> DistributedT: + pass + + +class OnlyTupleSynchronizationSupported(BaseException): + pass + + +class DistributedBatch(DistributedObject[LocalT, LocalT], ABC): + """ + An intermediate abstract class in charge of synchronizing data batches. + + This class can handle batches as either tuples of elements (as usual) or + even single values. + + The merge procedure of tuples and single elements must be implemented in + child classes. By default, the tuples will be merged value by value. + + NOTE: In the future, this class may be replaced with a version in which only + the accessed tuple elements are synchronized, instead of the whole batch. + The current design, in which child classes have to implement + `_merge_single_values`, allows for this change to happen without affecting + child classes. + """ + + def __init__(self, name: str, initial_local_value: LocalT): + super().__init__(name, initial_local_value) + self._value_is_tuple = False + + def _synchronize(self) -> LocalT: + if self._local_value is None: + return None + else: + return super()._synchronize() + + def _set_local(self, new_local_value): + self._value_is_tuple = isinstance(new_local_value, (tuple, list)) + super()._set_local(new_local_value) + + def _merge_objects(self, objects: List[LocalT]) -> LocalT: + if not self._value_is_tuple: + try: + return self._merge_single_values(objects, 0) + except OnlyTupleSynchronizationSupported: + pass + + return self._merge_tuples(objects) + + def _merge_tuples(self, tuples: List[LocalT]): + try: + merged_elements = [] + # Note: _local_value is usually a tuple (mb_x, mb_y, ...) + # which means that n_elements is usually == 2 or 3 + + n_elements = len(self._local_value) + for element_idx in range(n_elements): + to_merge_elements = [] + for tp in tuples: + to_merge_elements.append(tp[element_idx]) + + merged_elements.append( + self._merge_single_values(to_merge_elements, element_idx) + ) + + return tuple(merged_elements) + except OnlyTupleSynchronizationSupported: + raise RuntimeError( + '[DistributedBatch] No proper collate function set.') + + @abstractmethod + def _merge_single_values(self, values: List, value_index: int): + pass + + +class CollateDistributedBatch(DistributedBatch[LocalT]): + """ + An implementation of :class:`DistributedBatch` in which the + `_merge_tuples` mechanism is given as a callable function. + + This assumes that local batches are locally pre-collated and + will thus unroll them before calling the given function. + """ + + def __init__(self, name: str, initial_local_value: LocalT, + tuples_collate_fn: Optional[Callable[[List], LocalT]], + single_values_collate_fn: Optional[Callable[[Any, int], Any]]): + super().__init__(name, initial_local_value) + self.tuples_collate_fn = tuples_collate_fn + self.single_values_collate_fn = single_values_collate_fn + + def _unroll_minibatch(self, tuples: List[LocalT]) -> List[LocalT]: + unrolled_elements = [] + for local_tuple in tuples: + n_elements = len(local_tuple) + mb_size = len(local_tuple[0]) + + for mb_element_idx in range(mb_size): + mb_element = [] + for tuple_element_idx in range(n_elements): + mb_element.append( + local_tuple[tuple_element_idx][mb_element_idx]) + unrolled_elements.append(tuple(mb_element)) + return unrolled_elements + + def _unroll_value(self, collated_values: List[Iterable[Any]]) -> Any: + unrolled_values = [] + for val_batch in collated_values: + unrolled_values.extend(val_batch) + + return unrolled_values + + def _merge_tuples(self, tuples: List[LocalT]): + if self.tuples_collate_fn is not None: + unrolled_elements = self._unroll_minibatch(tuples) + + return self.tuples_collate_fn(unrolled_elements) + + return super()._merge_tuples(tuples) + + def _merge_single_values(self, values: List, value_index: int): + if self.single_values_collate_fn is None: + raise OnlyTupleSynchronizationSupported() + + unrolled_elements = self._unroll_value(values) + return self.single_values_collate_fn(unrolled_elements, value_index) + + +def make_classification_distributed_batch(name: str) -> \ + CollateDistributedBatch[Optional[Tensor]]: + """ + Return a :class:`CollateDistributedBatch` that assumes that all values + are Tensors. Values are obtained by concatenating these tensors. + """ + return CollateDistributedBatch( + name, None, None, lambda x, y: torch.stack(x) + ) + + +__all__ = [ + 'DistributedObject', + 'DistributedBatch', + 'CollateDistributedBatch', + 'make_classification_distributed_batch' +] diff --git a/avalanche/distributed/distributed_commons.py b/avalanche/distributed/distributed_commons.py new file mode 100644 index 000000000..7a43654b1 --- /dev/null +++ b/avalanche/distributed/distributed_commons.py @@ -0,0 +1,25 @@ +import torch + +from avalanche.distributed.distributed_tensor import DistributedMeanTensor + + +class DistributedLoss(DistributedMeanTensor): + """ + A distributed value in charge of obtaining the mean loss. + + The mean loss is computed as the mean of losses from all processes, without + weighting using the mini batch sizes in each process. + + This is current mostly an alias for :class:`DistributedMeanTensor`. However, + in the future this class may be extended to add loss-specific features. + """ + def __init__(self, name: str = 'loss'): + super(DistributedLoss, self).__init__(name, torch.zeros((1,))) + + def _merge(self, tensors): + return super(DistributedLoss, self)._merge(tensors) + + +__all__ = [ + 'DistributedLoss' +] diff --git a/avalanche/distributed/distributed_consistency_verification.py b/avalanche/distributed/distributed_consistency_verification.py new file mode 100644 index 000000000..71c0e8602 --- /dev/null +++ b/avalanche/distributed/distributed_consistency_verification.py @@ -0,0 +1,102 @@ +import hashlib +import io + +from typing import Tuple, TYPE_CHECKING + +import torch +from torch import Tensor +from torch.nn import Module +from torch.utils.data import Dataset, DataLoader + +if TYPE_CHECKING: + from avalanche.benchmarks import GenericCLScenario + + +def hash_benchmark(benchmark: 'GenericCLScenario', *, + hash_engine=None, num_workers=0) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for stream_name in sorted(benchmark.streams.keys()): + stream = benchmark.streams[stream_name] + hash_engine.update(stream_name.encode()) + for experience in stream: + exp_dataset = experience.dataset + hash_dataset(exp_dataset, + hash_engine=hash_engine, + num_workers=num_workers) + return hash_engine.hexdigest() + + +def hash_dataset(dataset: 'Dataset', *, hash_engine=None, num_workers=0) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + data_loader = DataLoader( + dataset, + collate_fn=lambda batch: tuple(zip(*batch)), + num_workers=num_workers + ) + for loaded_elem in data_loader: + example = tuple(tuple_element[0] for tuple_element in loaded_elem) + + # https://stackoverflow.com/a/63880190 + buff = io.BytesIO() + torch.save(example, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_minibatch(minibatch: Tuple[Tensor], *, hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for tuple_elem in minibatch: + buff = io.BytesIO() + torch.save(tuple_elem, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_tensor(tensor: Tensor, *, hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + buff = io.BytesIO() + torch.save(tensor, buff) + buff.seek(0) + hash_engine.update(buff.read()) + return hash_engine.hexdigest() + + +def hash_model(model: Module, include_buffers=True, *, hash_engine=None) -> str: + if hash_engine is None: + hash_engine = hashlib.sha256() + + for name, param in model.named_parameters(): + hash_engine.update(name.encode()) + buff = io.BytesIO() + torch.save(param.detach().cpu(), buff) + buff.seek(0) + hash_engine.update(buff.read()) + + if include_buffers: + for name, model_buffer in model.named_buffers(): + hash_engine.update(name.encode()) + buff = io.BytesIO() + torch.save(model_buffer.detach().cpu(), buff) + buff.seek(0) + hash_engine.update(buff.read()) + + return hash_engine.hexdigest() + + +__all__ = [ + 'hash_benchmark', + 'hash_dataset', + 'hash_minibatch', + 'hash_tensor', + 'hash_model' +] diff --git a/avalanche/distributed/distributed_helper.py b/avalanche/distributed/distributed_helper.py new file mode 100644 index 000000000..ef04e19bd --- /dev/null +++ b/avalanche/distributed/distributed_helper.py @@ -0,0 +1,576 @@ +import os +import pickle +import warnings +from io import BytesIO +from typing import Optional, List, Any, Iterable, Dict, TypeVar + +import torch +from torch import Tensor +from torch.nn.modules import Module +from torch.nn.parallel import DistributedDataParallel +from typing_extensions import Literal +from torch.distributed import ( + init_process_group, + broadcast_object_list +) + + +BroadcastT = TypeVar('BroadcastT') + + +from avalanche.distributed.distributed_consistency_verification import \ + hash_tensor + + +class _Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(_Singleton, cls).__call__( + *args, **kwargs) + return cls._instances[cls] + + +class RollingSeedContext(object): + """ + Implement seed alignment by storing random number generators state. + + Doesn't require a distributed communication (even broadcast), which makes + this the best choices when wrapping sections that (may) both: + - behave differently depending on the rank + - change the global state of random number generators + """ + def __init__(self): + self.rng_manager_state = None + + def save_generators_state(self): + from avalanche.training.determinism.rng_manager import RNGManager + self.rng_manager_state = RNGManager.__getstate__() + + def load_generators_state(self): + from avalanche.training.determinism.rng_manager import RNGManager + self.rng_manager_state = RNGManager.__setstate__(self.rng_manager_state) + + def step_random_generators(self): + from avalanche.training.determinism.rng_manager import RNGManager + RNGManager.step_generators() + + def __enter__(self): + self.save_generators_state() + + def __exit__(self, *_): + self.load_generators_state() + self.step_random_generators() + + +class BroadcastSeedContext(object): + """ + Implement seed alignment by broadcasting a new seed from the main process. + + This is usually slower than using :class:`RollingSeedContext`. + """ + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *_): + DistributedHelper.align_seeds() + + +class _MainProcessFirstContext(object): + """ + A context in which the main process must enter and exit the section before + other processes. + + For instance, can be used to wrap the dataset download procedure. + """ + + def __init__( + self, + seed_alignment: Literal["rolling", "broadcast"] = 'rolling', + final_barrier: bool = False): + if seed_alignment == 'rolling': + self._seed_aligner = RollingSeedContext() + else: + self._seed_aligner = BroadcastSeedContext() + + self._final_barrier = final_barrier + + def __enter__(self): + self._seed_aligner.__enter__() + + if not DistributedHelper.is_main_process: + # Wait for the main process + DistributedHelper.barrier() + + def __exit__(self, exc_type, exc_val, exc_tb): + if DistributedHelper.is_main_process: + # Let other process enter the section + DistributedHelper.barrier() + + self._seed_aligner.__exit__() + if self._final_barrier: + DistributedHelper.barrier() + + +class _DistributedHelperCls(object): + __metaclass__ = _Singleton + + def __init__(self): + self.use_cuda = False + self._dev_map = _DistributedHelperCls._make_map('cpu') + + def init_distributed(self, random_seed, backend=None, use_cuda=True): + if self.is_distributed: + raise RuntimeError('Distributed API already initialized') + + use_cuda = use_cuda and torch.cuda.is_available() + + if backend is None: + if use_cuda: + backend = 'nccl' + else: + backend = 'gloo' + + if backend == 'nccl' and not use_cuda: + warnings.warn( + 'Bad configuration: using NCCL, but you set use_cuda=False!') + + could_initialize_distributed = False + if os.environ.get('LOCAL_RANK', None) is None: + warnings.warn( + 'Torch distributed could not be initialized ' + '(missing environment configuration)') + else: + init_process_group(backend=backend) + could_initialize_distributed = True + + self.set_random_seeds(random_seed) + self.use_cuda = use_cuda + + if use_cuda or backend == 'nccl': # TODO: remove in final release + # https://github.com/pytorch/pytorch/issues/6351 + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Force-init the default CUDA device (if any) + reference_device = self.make_device(set_cuda_device=True) + + # Create map for device placement of unpickled tensors + self._dev_map = _DistributedHelperCls._make_map(reference_device) + + return could_initialize_distributed + + def get_device_id(self): + if self.is_distributed: + device_id = self.rank + else: + device_id = 0 + + if self.use_cuda: + return device_id + + return -1 + + def make_device(self, set_cuda_device=False): + if self.is_distributed: + device_id = self.rank + else: + device_id = 0 + + if self.use_cuda and device_id >= 0: + ref_device = torch.device(f'cuda:{device_id}') + if set_cuda_device: + torch.cuda.set_device(ref_device) + else: + ref_device = torch.device('cpu') + return ref_device + + def wrap_model(self, model: Module) -> Module: + # Note: find_unused_parameters is needed for multi task models. + if self.is_distributed: + if self.forced_cuda_comm or self.use_cuda: + # forced_cuda_comm is True if using NCCL; use_cuda may be true + # even when not using NCCL. + # User already warned if using NCCL with use_cuda==False. + # device_ids must be a single device id + # (an int, a device object or a str) + # If not set, output_device defaults to device_ids[0] + return DistributedDataParallel( + model, device_ids=[self.make_device()], + find_unused_parameters=True) + else: + return DistributedDataParallel( + model, + find_unused_parameters=True) + else: + return model + + def unwrap_model(self, model: Module) -> Module: + if isinstance(model, DistributedDataParallel): + return model.module + + return model + + def set_random_seeds(self, random_seed): + from avalanche.training.determinism.rng_manager import RNGManager + RNGManager.set_random_seeds(random_seed) + + def align_seeds(self): + if not self.is_distributed: + return + + if self.is_main_process: + reference_seed = torch.randint(0, 2**32-1, (1,), dtype=torch.int64) + else: + reference_seed = torch.empty((1,), dtype=torch.int64) + + self.broadcast(reference_seed) + seed = int(reference_seed) + self.set_random_seeds(seed) + + def main_process_first(self): + return _MainProcessFirstContext() + + def barrier(self): + if self.is_distributed: + torch.distributed.barrier() + + def broadcast(self, tensor: Tensor, src=0): + if not self.is_distributed: + return tensor + + tensor_distrib, orig_data = self._prepare_for_distributed_comm(tensor) + torch.distributed.broadcast(tensor_distrib, src=src) + tensor = self._revert_to_original_device(tensor_distrib, orig_data) + + return tensor + + def broadcast_object(self, obj: BroadcastT, src=0) -> BroadcastT: + if not self.is_distributed: + return obj + + io_list = [obj] + + broadcast_object_list(io_list, src=src) + return io_list[0] + + def cat_all(self, tensor: Tensor): + # TODO: use all_gather_into_tensor (if available and + # if NCCL and tensor.device == 'default device') + + if not self.is_distributed: + return tensor + + gathered_tensors = self.gather_all(tensor) + for i, t in enumerate(gathered_tensors): + if len(t.shape) == 0: + # Tensor with 0-length shape + gathered_tensors[i] = torch.reshape(t, (1,)) + + return torch.cat(gathered_tensors) + + def gather_tensor_shapes(self, tensor: Tensor, max_shape_len=10) \ + -> List[List[int]]: + """ + Gathers the shapes of all the tensors. + """ + # Tensor differ by whole shape + tensor_size = torch.zeros(max_shape_len, dtype=torch.int64) + for i in range(len(tensor.shape)): + tensor_size[i] = tensor.shape[i] + all_tensors_shape = [ + self._prepare_for_distributed_comm( + torch.zeros_like(tensor_size))[0] + for _ in range(self.world_size)] + tensor_size, _ = self._prepare_for_distributed_comm(tensor_size) + + torch.distributed.all_gather(all_tensors_shape, tensor_size) + + all_tensors_shape = [t.cpu() for t in all_tensors_shape] + + # Trim shape + for i, t in enumerate(all_tensors_shape): + for x in range(len(t)): + if t[x] == 0: + if x == 0: + # Tensor with 0-length shape + all_tensors_shape[i] = t[:x+1] + else: + all_tensors_shape[i] = t[:x] + + break + + return [t_shape.tolist() for t_shape in all_tensors_shape] + + def gather_all( + self, + tensor: Tensor, + same_shape: bool = False, + shapes: Optional[List[List[int]]] = None): + """ + Gather all for tensors only. + + Note: differently from the original Pytorch function, which requires + that input tensor is to be moved to the default device (forced to + CUDA if using NCCL), this function also manages input tensors + residing on a different devics. The resulting list of tensors will + be moved to the same device of the input tensor. + + This will also manage tensors of different shapes. If you + are sure that the tensors will be of the same shape, consider + passing same_shape to speed up the communication. + + Beware that, if you are in need of concatenating multiple tensors, + method `cat_all` may be more suitable. + """ + if not self.is_distributed: + return [tensor] + + # Based on: + # https://discuss.pytorch.org/t/how-to-concatenate-different-size-tensors-from-distributed-processes/44819/4 + + if same_shape: + # Same size for all tensors + if len(tensor.shape) > 0: + tensor_size = list(tensor.shape) + else: + tensor_size = [0] + all_tensors_shape = \ + [tensor_size for _ in range(self.world_size)] + elif shapes is not None: + # Shapes given by the user + # make sure it is a list of lists + all_tensors_shape = [list(s) for s in shapes] + else: + # Tensor differ by whole shape + all_tensors_shape = self.gather_tensor_shapes(tensor) + + same_shape = all(all_tensors_shape[0] == x for x in all_tensors_shape) + orig_device = tensor.device + + if same_shape: + # Same shape: create identical tensors and proceed with all_gather + out_tensors = [torch.empty_like(tensor) for _ in all_tensors_shape] + else: + # Different shapes: create a tensors of the size of the bigger one + all_tensors_numel = [] + dtype = tensor.dtype + for t_shape in all_tensors_shape: + if t_shape[0] == 0 and len(t_shape) == 1: + # Tensor with 0-length shape + curr_size = 1 + else: + curr_size = 1 + for t_s in t_shape: + curr_size *= t_s + all_tensors_numel.append(curr_size) + + max_numel = max(all_tensors_numel) + out_tensors = [torch.empty((max_numel,), dtype=dtype) + for _ in all_tensors_shape] + + tensor = tensor.flatten() + n_padding = max_numel - tensor.numel() + if n_padding > 0: + padding = torch.zeros((n_padding,), + dtype=tensor.dtype, + device=orig_device) + tensor = torch.cat((tensor, padding), dim=0) + + tensor, _ = self._prepare_for_distributed_comm(tensor) + out_tensors = [self._prepare_for_distributed_comm(t)[0] + for t in out_tensors] + + torch.distributed.all_gather(out_tensors, tensor) + + if not same_shape: + # The tensors are flat and of the wrong dimension: re-shape them + for tensor_idx, (tensor_sz, tensor_numel, out_t) in \ + enumerate(zip(all_tensors_shape, + all_tensors_numel, + out_tensors)): + if tensor_sz[0] == 0: + # Tensor with 0-length shape + out_tensors[tensor_idx] = \ + out_t[:tensor_numel].reshape(tuple()) + else: + out_tensors[tensor_idx] = \ + out_t[:tensor_numel].reshape(tensor_sz) + + out_tensors = [t.to(orig_device) for t in out_tensors] + return out_tensors + + def gather_all_objects(self, obj: BroadcastT) -> List[BroadcastT]: + """ + Gather all for objects. This will also take care of moving cuda tensors + (even the ones nested inside objects) to the correct default device. + """ + out_list = [None for _ in range(self.world_size)] + torch.distributed.all_gather_object(out_list, obj) + return out_list + + def check_equal_tensors(self, tensor: Tensor): + if not DistributedHelper.is_distributed: + return + + all_tensors = self.gather_all(tensor) + + tensors_hashes = [hash_tensor(t) for t in all_tensors] + + if len(set(tensors_hashes)) != 1: + # Equal tensors + raise ValueError('Different tensors. Got hashes: {}'.format( + tensors_hashes)) + + def check_equal_objects(self, obj): + if not DistributedHelper.is_distributed: + return + + output: List[Any] = [None for _ in range(self.world_size)] + torch.distributed.all_gather_object(output, obj) + + obj_bt = base_typed(obj) + + for i, o in enumerate(output): + o_bt = base_typed(o) + if obj_bt != o_bt: + raise ValueError( + 'Different objects (ranks this={}, remote={}). ' + 'Got this={}, remote={}'.format( + self.rank, i, obj, o)) + + def _prepare_for_distributed_comm(self, tensor: Tensor): + original_device = tensor.device + copy_back = self.forced_cuda_comm and not tensor.is_cuda + if self.forced_cuda_comm: + tensor_distributed = tensor.cuda() + else: + tensor_distributed = tensor + + return tensor_distributed, (original_device, copy_back, tensor) + + def _revert_to_original_device(self, tensor_distributed, orig_data): + original_device, copy_back, tensor = orig_data + if copy_back: + if tensor is None: + tensor = tensor_distributed.to(original_device) + else: + tensor[:] = tensor_distributed + + return tensor + + @property + def rank(self) -> int: + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + @property + def world_size(self) -> int: + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + return 1 + + @property + def is_distributed(self) -> bool: + return torch.distributed.is_initialized() + + @property + def is_main_process(self) -> bool: + return self.rank == 0 + + @property + def backend(self) -> str: + return torch.distributed.get_backend() + + @property + def forced_cuda_comm(self) -> bool: + return self.backend == 'nccl' + + @property + def device_map(self) -> Dict[str, str]: + return self._dev_map + + @staticmethod + def _make_map(device_or_map) -> Dict[str, str]: + # TODO: borrowed from checkpointing plugins + # it would be better to have a single function in a shared utils + if not isinstance(device_or_map, (torch.device, str)): + return device_or_map + + device = torch.device(device_or_map) + map_location = dict() + + map_location['cpu'] = 'cpu' + for cuda_idx in range(100): + map_location[f'cuda:{cuda_idx}'] = str(device) + return map_location + + +BASE_TYPES = [str, int, float, bool, type(None)] + + +def base_typed(obj): + """ + Improved version of https://stackoverflow.com/a/62420097 + """ + T = type(obj) + from_numpy = T.__module__ == 'numpy' + from_pytorch = T.__module__ == 'torch' + + if from_numpy or from_pytorch: + return obj.tolist() + + if T in BASE_TYPES or callable(obj) or ((from_numpy or from_pytorch) + and not isinstance(T, Iterable)): + return obj + + if isinstance(obj, Dict): + return {base_typed(k): base_typed(v) for k, v in obj.items()} + elif isinstance(obj, Iterable): + base_items = [base_typed(item) for item in obj] + return base_items if (from_numpy or from_pytorch) else T(base_items) + + d = obj if T is dict else obj.__dict__ + + return {k: base_typed(v) for k, v in d.items()} + + +DistributedHelper = _DistributedHelperCls() + + +def fix(): + return lambda b: torch.load(BytesIO(b), + map_location=DistributedHelper.device_map) + + +class MappedUnpickler(pickle.Unpickler): + # Based on: + # https://github.com/pytorch/pytorch/issues/16797#issuecomment-777059657 + + # In turn based on: + # https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def find_class(self, module, name): + if module == 'torch.storage' and name == '_load_from_bytes': + return fix() + else: + return super().find_class(module, name) + + +torch.distributed.distributed_c10d._unpickler = MappedUnpickler + + +__all__ = [ + 'RollingSeedContext', + 'BroadcastSeedContext', + 'DistributedHelper', + '_DistributedHelperCls' +] diff --git a/avalanche/distributed/distributed_model.py b/avalanche/distributed/distributed_model.py new file mode 100644 index 000000000..56afeb683 --- /dev/null +++ b/avalanche/distributed/distributed_model.py @@ -0,0 +1,175 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 1/12/2021 # +# Author(s): Lorenzo Pellegrini # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ +from typing import Optional, Union, Tuple + +from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel +from typing_extensions import Type + +from avalanche.distributed import OptionalDistributedValue +from avalanche.distributed.distributed_value import DistributedT, \ + DistributedValue, SettableDistributedValue, SwitchableDistributedValue + + +class DistributedModel(OptionalDistributedValue[Optional[Module]]): + """ + Contains the model used in the :class:`BaseTemplate` strategy template. + + Instances of this class can also carry the distributed (that is, wrapped + in a PyTorch `DistributedDataParallel`) version of a local model. If no + distributed model is set, then the model returned by the + `distributed_model` field will be the local one. + + By setting the `distributed_model` field, the model stored in the + `local_model` field will be discarded (from that moment, retrieving the + `local_model` will be the same as obtaining the `distributed_model.module` + field). Setting the `local_model` will discard the current + `distributed_model`. + + Beware that the setter of this class behaves a bit differently + from superclasses. When setting the `value`, the class of the new value + us checked against a list of distributed model classes (by default, + only :class:`DistributedDataParallel` is considered). If the model + is an instance of these classes, then the distributed value is set + instead of the local value. + """ + + def __init__( + self, + *, + name: str = 'model', + initial_model: Module = None, + distributed_model_class: Union[Type, Tuple[Type]] = + DistributedDataParallel,): + """ + Creates a `ModelInstance`. + + :param name: The name of this value. Defaults to 'model'. + :param initial_model: The initial model to use. Defaults to None. + :param distributed_model_class: The type(s) of the distributed model. + Defaults to `DistributedDataParallel`. + """ + super().__init__(name, initial_local_value=initial_model) + self.distributed_model_class = distributed_model_class + + @SwitchableDistributedValue.value.setter + def value(self, new_value: Module): + """ + Sets the local or distributed model, depending on if the model is a + subclass of DistributedDataParallel. + + This will discard the current distributed value. + """ + + if isinstance(new_value, self.distributed_model_class): + self.distributed_value = new_value + else: + self.local_value = new_value + + @DistributedValue.local_value.getter + def local_value(self) -> Module: + if self._distributed_value is not None: + return self._distributed_value.module + return self._local_value + + @SettableDistributedValue.distributed_value.setter + def distributed_value(self, new_distributed_value: Module): + if new_distributed_value is None: + self.reset_distributed_value() + else: + self._distributed_value = new_distributed_value + self._distributed_value_set = True + + # Prevent alignment and memory issues. + # The local model will be retrieved from the distributed model. + self._local_value = None + + def reset_distributed_value(self): + if self._distributed_value_set: + if self._distributed_value is not None: + # Unwrap the DistributedDataParallel to obtain the local value. + self._local_value = self._distributed_value.module + self._distributed_value = None + self._distributed_value_set = False + + def reset_distributed_model(self): + """ + Discards the distributed model. + + If the distributed model was not set, nothing happens. + """ + return self.reset_distributed_value() + + def _synchronize(self) -> DistributedT: + raise RuntimeError( + 'The distributed model needs to be wrapped and set by using the ' + f'following class(es): {self.distributed_model_class}') + + # BEGIN ALIASES for "(local|distributed)value" + @property + def model(self): + """ + The current model. + """ + return self.value + + @model.setter + def model(self, new_model: Module): + """ + Sets the current model. + """ + self.value = new_model + + @property + def local_model(self) -> Module: + """ + The current (local) model. + + If a `distributed_model` was set, then the value of the + `distributed_model.module` field will be returned. + """ + return self.local_value + + @local_model.setter + def local_model(self, new_local_value): + """ + Sets the local model. + + This will discard the current distributed model. + """ + self.local_value = new_local_value + + @property + def distributed_model(self): + """ + The current (distributed) model. + + If not set (not running a distributed training, or if the wrapped + model has not been created yet), this is the same as `local_model`. + """ + return self.distributed_value + + @distributed_model.setter + def distributed_model(self, new_distributed_value): + """ + Sets the model wrapped by PyTorch `DistributedDataParallel`. + + Setting this field will release the reference to the current local + model. In that case, the `local_model` field will return + `distributed_model.module` instead. + """ + self.distributed_value = new_distributed_value + # END ALIASES for "(local|distributed)value" + + +__all__ = [ + 'DistributedModel' +] diff --git a/avalanche/distributed/distributed_tensor.py b/avalanche/distributed/distributed_tensor.py new file mode 100644 index 000000000..cfb3d2fec --- /dev/null +++ b/avalanche/distributed/distributed_tensor.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from typing import List + +import torch +from torch import Tensor + +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_value import SwitchableDistributedValue + + +class DistributedTensor(SwitchableDistributedValue[Tensor, Tensor], ABC): + """ + A distributed Tensor wrapper. + + This abstract class is in charge of synchronizing Tensors across processes. + + Child classes must override `_merge` to define how those tensors + should be merged. + """ + def _synchronize(self) -> Tensor: + return self._merge( + DistributedHelper.gather_all(self.local_value)) + + @abstractmethod + def _merge(self, tensors: List[Tensor]) -> Tensor: + """ + Merge all tensors into one. + + :param tensors: The list of tensors obtained from all processes, in the + order defined by the rank. + :return: The merged tensor. + """ + pass + + +class ConcatDistributedTensor(DistributedTensor): + """ + A distributed tensor obtained by concatenating tensors from all processes + (in the order defined by the rank). + + This also correctly manages tensors with 0-length shapes (like losses). + """ + def _merge(self, tensors: List[Tensor]) -> Tensor: + # Manage tensors without shape (0-length shape) + for i, t in enumerate(tensors): + if len(t.shape) == 0: + # Tensor with 0-length shape + tensors[i] = torch.reshape(t, (1,)) + + return torch.cat(tensors) + + +class DistributedMeanTensor(ConcatDistributedTensor): + """ + A distributed 1-item tensor obtained by computing the mean of tensors + from all processes. + """ + def _merge(self, tensors: List[Tensor]) -> Tensor: + concat_tensor = super()._merge(tensors) + return torch.mean(concat_tensor) + + +__all__ = [ + 'DistributedTensor', + 'ConcatDistributedTensor', + 'DistributedMeanTensor' +] diff --git a/avalanche/distributed/distributed_value.py b/avalanche/distributed/distributed_value.py new file mode 100644 index 000000000..8d4e869cf --- /dev/null +++ b/avalanche/distributed/distributed_value.py @@ -0,0 +1,297 @@ +from contextlib import contextmanager +from typing import TypeVar, Generic, Optional, Union, Generator, List, \ + Tuple +from abc import ABC, abstractmethod + +from avalanche.distributed import DistributedHelper + +LocalT = TypeVar('LocalT') +DistributedT = TypeVar('DistributedT') +SwitchableT = TypeVar('SwitchableT', bound='SwitchableDistributedValue') + + +class DistributedValue(Generic[LocalT, DistributedT], ABC): + """ + Class used to generically implement values that may need + a lazy synchronization when running a distributed training. + + When not running a distributed training, this class will act as a + no-op wrapper. + + This class considers setting the 'value' and 'local_value' as the + same operation (setting the local value). However, retrieving 'value' will + trigger the synchronization procedure. + + This class exposes methods that can be customized to define how different + values should be gathered (and merged) from all processes. For instance, + loss values should be averaged together, minibatch outputs should be + concatenated, etcetera. + + Beware that the purpose of this class is to only manage the + local and distributed values. When implementing the subclass, please do not + transform the value and/or type of the local and global values. This + would make it difficult to understand what is going on. + + Also, consider having the same type for the local and distributed value. + That is, if the local value is a Tensor, the distributed value should be + a Tensor as well, not a List[Tensor]. This is because local and distributed + values will be transparently used by users without considering the possibly + distributed nature of the value. + + Feel free to implement, in subclasses, properties with more readable names. + For instance 'mb_output', 'local_mb_output', 'loss', 'local_loss', ... + instead of the default 'value' and 'local_value' already implemented by + this class. + """ + + def __init__(self, name: str, initial_local_value: LocalT): + """ + Creates an instance of a distributed value. + + :param name: The name of the value. Also used when obtaining a string + representation. + :param initial_local_value: The initial local value. + """ + self.name: str = name + self._local_value: LocalT = initial_local_value + self._distributed_value: Optional[DistributedT] = None + self._distributed_value_set: bool = False + + @property + def value(self) -> DistributedT: + """ + The current value. + + When running a distributed training, this will be the value obtained + by gathering and merging values coming from all processes. + """ + return self._get_distributed_value() + + @value.setter + def value(self, new_value: LocalT): + """ + Sets the (local) value. + + This will discard the current distributed value. + """ + self._set_local(new_value) + + @property + def local_value(self) -> LocalT: + """ + The current (local) value. + + Even when running a distributed training, this property will always + contain the local value only. + """ + return self._local_value + + @local_value.setter + def local_value(self, new_value: LocalT): + """ + Sets the (local) value. + + This will discard the current distributed value. + """ + self._set_local(new_value) + + def _set_local(self, new_local_value: LocalT): + self._local_value = new_local_value + self._distributed_value = None + self._distributed_value_set = False + + def _get_distributed_value(self) -> DistributedT: + if not DistributedHelper.is_distributed: + return self._local_value + + if not self._distributed_value_set: + self._distributed_value = self._synchronize() + self._distributed_value_set = True + + return self._distributed_value + + @abstractmethod + def _synchronize(self) -> DistributedT: + pass + + def __str__(self): + base_str = f'DistributedObject_{self.name} = {self.local_value}' + if self._distributed_value_set: + return base_str + \ + f' (distributed value = {self.value})' + else: + return base_str + \ + f' (distributed value not synchronized yet)' + + +class SettableDistributedValue(DistributedValue[LocalT, DistributedT], ABC): + """ + A version of :class:`DistributedValue` in which the distributed value can be + set (and reset) externally instead of being synchronized. + + If this class should only allow for distributed values to be set + externally (that is, synchronization should be disabled), please + override `_synchronize` to raise an appropriate error. + In that case, this means this class is mainly used as a switch between a + local and a distributed value based on whether the distributed value has + been set or not. + """ + + def __init__(self, name: str, initial_local_value: LocalT): + super(SettableDistributedValue, self).__init__( + name, initial_local_value + ) + + @property + def distributed_value(self) -> DistributedT: + """ + The current value. + + When running a distributed training, this will be the value obtained + by gathering and merging values coming from all processes. + """ + return self._get_distributed_value() + + @distributed_value.setter + def distributed_value(self, new_distributed_value: DistributedT): + """ + Set the distributed value. + """ + self._distributed_value = new_distributed_value + self._distributed_value_set = True + + def reset_distributed_value(self): + """ + Discards the distributed value (if set). + + If the distributed value was not set, nothing happens. + """ + self._distributed_value = None + self._distributed_value_set = False + + def __str__(self): + base_str = super(SettableDistributedValue, self).__str__() + return f'(Settable){base_str}' + + +class SwitchableDistributedValue(SettableDistributedValue[LocalT, DistributedT], + ABC): + """ + A version of :class:`SettableDistributedValue` in which the behaviour of + the `value` property can be switched so that it returns the local value + instead of the distributed one. The setter behaviour can be customized as + well. + + Useful for situations in which one has to force components interacting with + this value to use the local value.Properties whose name feature an explicit + `local` or `distributed` part are not affected. + """ + + def __init__(self, name: str, initial_local_value: LocalT): + """ + Creates an instance of a distributed value. + + :param name: The name of the value. Also used when obtaining a string + representation. + :param initial_local_value: The initial local value. + """ + super().__init__(name, initial_local_value) + + self._behaviour_stack: List[Tuple[bool, bool]] = list() + """ + If greater than 0, the `value` property will return the local value. + """ + + @contextmanager + def use_local_value(self: SwitchableT, getter=True, setter=True) -> \ + Generator[SwitchableT, None, None]: + """ + A context manager used to set the behaviour of the value property. + + Please note that in a plain code section (not wrapped by this + context manager), the default behaviour is that the getter returns the + distributed value while the setter sets the local value. + + :param getter: If True, the local value will be returned by the getter. + Defaults to True, which means that the getter behaviour will be + changed. + :param setter: If True, the local value will be set by the setter. + Defaults to True, which means that the setter will behave as usual. + :return: This object (self). + """ + self._behaviour_stack.append((getter, setter)) + try: + yield self + finally: + self._behaviour_stack.pop() + + @property + def value(self) -> Union[LocalT, DistributedT]: + if self._use_local_getter(): + return self.local_value + else: + return self.distributed_value + + @value.setter + def value(self, new_value): + if self._use_local_setter(): + self.local_value = new_value + else: + self.distributed_value = new_value + + def _use_local_getter(self): + if len(self._behaviour_stack) == 0: + return False + + return self._behaviour_stack[-1][0] + + def _use_local_setter(self): + if len(self._behaviour_stack) == 0: + return True + + return self._behaviour_stack[-1][1] + + def __str__(self): + base_str = super(SettableDistributedValue, self).__str__() + + current_get_behaviour = 'local' if self._use_local_getter() \ + else 'distributed' + current_set_behaviour = 'local' if self._use_local_setter() \ + else 'distributed' + + return f'(fget={current_get_behaviour},' \ + f'fset={current_set_behaviour}){base_str}' + + +class OptionalDistributedValue(SwitchableDistributedValue[LocalT, LocalT], ABC): + """ + A version of :class:`SettableDistributedValue` in which the + 'value' property returns the local value if no distributed value has + been set yet (without attempting a synchronization). Accessing the + 'distributed_value' property will still force a synchronization. + + Beware that, when using this class, the generic types for the local and + distributed values is enforced to be the same. + + This class is mainly used for managing models wrapped using + `DistributedDataParallel`. + """ + + def __init__(self, name, initial_local_value): + super().__init__(name, initial_local_value) + + def _get_distributed_value(self) -> DistributedT: + if not self._distributed_value_set: + return self._local_value + + return self._distributed_value + + +__all__ = [ + 'DistributedValue', + 'SettableDistributedValue', + 'SwitchableDistributedValue', + 'OptionalDistributedValue', + 'LocalT', + 'DistributedT' +] diff --git a/avalanche/distributed/strategies/__init__.py b/avalanche/distributed/strategies/__init__.py new file mode 100644 index 000000000..9205b85d7 --- /dev/null +++ b/avalanche/distributed/strategies/__init__.py @@ -0,0 +1,4 @@ +from .distributed_strategy_support import * +from .distributed_model_strategy import * +from .distributed_mbatch_strategy import * +from .distributed_loss_strategy import * diff --git a/avalanche/distributed/strategies/distributed_loss_strategy.py b/avalanche/distributed/strategies/distributed_loss_strategy.py new file mode 100644 index 000000000..61a9bfd68 --- /dev/null +++ b/avalanche/distributed/strategies/distributed_loss_strategy.py @@ -0,0 +1,50 @@ +from torch import Tensor + +from avalanche.distributed import DistributedLoss +from avalanche.distributed.strategies import DistributedStrategySupport + + +class DistributedLossStrategySupport(DistributedStrategySupport): + + def __init__(self): + super().__init__() + self._loss = DistributedLoss() + self._use_local_contexts.append(self.use_local_loss) + + @property + def loss(self) -> Tensor: + """ The loss tensor. """ + return self._loss.value + + @loss.setter + def loss(self, value): + """ Sets the loss. """ + self._loss.value = value + + @property + def local_loss(self): + return self._loss.local_value + + @local_loss.setter + def local_loss(self, value): + self._loss.local_value = value + + @property + def distributed_loss(self): + return self._loss.distributed_value + + @distributed_loss.setter + def distributed_loss(self, value): + self._loss.distributed_value = value + + def reset_distributed_loss(self): + """ Resets the distributed value of the loss. """ + self._loss.reset_distributed_value() + + def use_local_loss(self, *args, **kwargs): + return self._loss.use_local_value(*args, **kwargs) + + +__all__ = [ + 'DistributedLossStrategySupport' +] diff --git a/avalanche/distributed/strategies/distributed_mbatch_strategy.py b/avalanche/distributed/strategies/distributed_mbatch_strategy.py new file mode 100644 index 000000000..3f02a80de --- /dev/null +++ b/avalanche/distributed/strategies/distributed_mbatch_strategy.py @@ -0,0 +1,201 @@ +from typing import Callable, List, Any, Optional, Union + +from avalanche.benchmarks.utils import AvalancheDataset +from avalanche.benchmarks.utils.collate_functions import \ + Collate, ClassificationCollate +from avalanche.distributed import CollateDistributedBatch +from avalanche.distributed.strategies import DistributedStrategySupport + + +class DistributedMiniBatchStrategySupport(DistributedStrategySupport): + + def __init__(self): + super().__init__() + + default_collate_impl = ClassificationCollate() + self._mbatch = CollateDistributedBatch( + 'mbatch', + None, + default_collate_impl.collate_fn, + default_collate_impl.collate_single_value_fn + ) + + self._mb_output = CollateDistributedBatch( + 'mb_output', + None, + default_collate_impl.collate_fn, + default_collate_impl.collate_single_value_fn + ) + + self._adapted_dataset: Optional[AvalancheDataset] = None + self._collate_fn: Optional[Union[Collate, Callable]] = None + + self._use_local_contexts.append(self.use_local_input_batch) + self._use_local_contexts.append(self.use_local_output_batch) + + # --- START INPUT MINIBATCH PROPERTY --- + @property + def mbatch(self): + """ Current mini-batch. """ + return self._mbatch.value + + @mbatch.setter + def mbatch(self, value): + """ Sets the current mini-batch. """ + self._mbatch.value = value + + @property + def local_mbatch(self): + """ The current local mini-batch. """ + return self._mbatch.local_value + + @local_mbatch.setter + def local_mbatch(self, value): + """ Sets the current local mini-batch. """ + self._mbatch.local_value = value + + @property + def distributed_mbatch(self): + """ The current distributed mini-batch. """ + return self._mbatch.distributed_value + + @distributed_mbatch.setter + def distributed_mbatch(self, value): + """ Sets the current distributed mini-batch. """ + self._mbatch.distributed_value = value + + def reset_distributed_mbatch(self): + """ Resets the distributed value of the mini-batch. """ + self._mbatch.reset_distributed_value() + # --- END INPUT MINIBATCH PROPERTY --- + + # --- START OUTPUT MINIBATCH PROPERTY --- + @property + def mb_output(self): + """ Model's output computed on the current mini-batch. """ + return self._mb_output.value + + @mb_output.setter + def mb_output(self, value): + """ Sets the model's output computed on the current mini-batch. """ + self._mb_output.value = value + + @property + def local_mb_output(self): + """ The current local output. """ + return self._mb_output.local_value + + @local_mb_output.setter + def local_mb_output(self, value): + """ Sets the current local output. """ + self._mb_output.local_value = value + + @property + def distributed_mb_output(self): + """ The current distributed output. """ + return self._mb_output.local_value + + @distributed_mb_output.setter + def distributed_mb_output(self, value): + """ Sets the current distributed output. """ + self._mb_output.distributed_value = value + + def reset_distributed_mb_output(self): + """ Resets the distributed value of the output. """ + self._mb_output.reset_distributed_value() + # --- END OUTPUT MINIBATCH PROPERTY --- + + # --- START COLLATE FUNCTIONS (INPUT MB) --- + @property + def input_batch_collate_fn(self): + return self._mbatch.tuples_collate_fn + + @input_batch_collate_fn.setter + def input_batch_collate_fn(self, batch_collate_fn: Callable[[List], Any]): + self._mbatch.tuples_collate_fn = batch_collate_fn + + @property + def input_batch_single_values_collate_fn(self): + return self._mbatch.single_values_collate_fn + + @input_batch_single_values_collate_fn.setter + def input_batch_single_values_collate_fn( + self, single_values_collate_fn: Callable[[List], Any]): + self._mbatch.single_values_collate_fn = single_values_collate_fn + # --- END COLLATE FUNCTIONS (INPUT MB) --- + + # --- START COLLATE FUNCTIONS (OUTPUT MB) --- + @property + def output_batch_collate_fn(self): + return self._mb_output.tuples_collate_fn + + @output_batch_collate_fn.setter + def output_batch_collate_fn(self, batch_collate_fn: Callable[[List], Any]): + self._mb_output.tuples_collate_fn = batch_collate_fn + + @property + def output_batch_single_values_collate_fn(self): + return self._mb_output.single_values_collate_fn + + @output_batch_single_values_collate_fn.setter + def output_batch_single_values_collate_fn( + self, single_values_collate_fn: Callable[[List], Any]): + self._mb_output.single_values_collate_fn = single_values_collate_fn + # --- END COLLATE FUNCTIONS (OUTPUT MB) --- + + # --- START LOCAL CONTEXT MANAGERS --- + def use_local_input_batch(self, *args, **kwargs): + return self._mbatch.use_local_value(*args, **kwargs) + + def use_local_output_batch(self, *args, **kwargs): + return self._mb_output.use_local_value(*args, **kwargs) + # --- END LOCAL CONTEXT MANAGERS --- + + # --- START - GET COLLATE FUNCTIONS FROM DATASET --- + @property + def collate_fn(self): + """ + The collate function used to merge the values obtained from the + dataset into a minibatch. + + This value is obtained from the adapted dataset directly. + """ + return self._collate_fn + + @collate_fn.setter + def collate_fn(self, new_collate): + self._collate_fn = new_collate + + if isinstance(new_collate, Collate): + self.input_batch_collate_fn = new_collate.collate_fn + self.input_batch_single_values_collate_fn = \ + new_collate.collate_single_value_fn + else: + self.input_batch_collate_fn = new_collate + self.input_batch_single_values_collate_fn = None + + @property + def adapted_dataset(self): + return self._adapted_dataset + + @adapted_dataset.setter + def adapted_dataset(self, dataset: Optional[AvalancheDataset]): + # Every time a new dataset is set, the related collate + # function is retrieved and set for sync-ing distributed + # input/output minibatch fields. + self._adapted_dataset = dataset + if self._adapted_dataset is None: + return + + new_collate = self._adapted_dataset.collate_fn + if new_collate is None: + return + + self.collate_fn = new_collate + + # --- END - GET COLLATE FUNCTIONS FROM DATASET --- + + +__all__ = [ + 'DistributedMiniBatchStrategySupport' +] diff --git a/avalanche/distributed/strategies/distributed_model_strategy.py b/avalanche/distributed/strategies/distributed_model_strategy.py new file mode 100644 index 000000000..6a31244db --- /dev/null +++ b/avalanche/distributed/strategies/distributed_model_strategy.py @@ -0,0 +1,47 @@ +from torch.nn import Module + +from avalanche.distributed import DistributedModel +from avalanche.distributed.strategies import DistributedStrategySupport + + +class DistributedModelStrategySupport(DistributedStrategySupport): + + def __init__(self): + super().__init__() + self._model = DistributedModel() + self._use_local_contexts.append(self.use_local_model) + + @property + def model(self) -> Module: + """ PyTorch model. """ + # This will return the local model if training locally + return self._model.value + + @model.setter + def model(self, value): + """ Sets the PyTorch model. """ + self._model.value = value + + @property + def local_model(self): + return self._model.local_model + + @local_model.setter + def local_model(self, value): + self._model.local_model = value + + @property + def distributed_model(self): + return self._model.distributed_model + + @distributed_model.setter + def distributed_model(self, value): + self._model.distributed_model = value + + def use_local_model(self, *args, **kwargs): + return self._model.use_local_value(*args, **kwargs) + + +__all__ = [ + 'DistributedModelStrategySupport' +] diff --git a/avalanche/distributed/strategies/distributed_strategy_support.py b/avalanche/distributed/strategies/distributed_strategy_support.py new file mode 100644 index 000000000..a595aa6ca --- /dev/null +++ b/avalanche/distributed/strategies/distributed_strategy_support.py @@ -0,0 +1,48 @@ +from contextlib import contextmanager, ExitStack + + +class DistributedStrategySupport: + + def __init__(self): + """ + Implements the basic elements needed to support distributed training + in Avalanche strategies. + """ + super().__init__() + self._use_local_contexts = [] + """ + A list of context manager factories to be used in `use_local`. + """ + + @contextmanager + def use_local(self, *args, **kwargs): + """ + A context manager used to change the behavior of some property getters. + + When running code in this context, the property getter implementation + of some distributed-critical fields will return the local value instead + of the distributed (synchronized) one. + + Examples of distributed-critical fields are `model`, `mbatch`, + `mb_output`, `loss`. + + Beware that this method will modify the behavior of getters of ALL + such properties. This may not be desirable. Use the field-specific + `use_local_*` context managers to control the behavior of these + fields in a finer way. + + :param args: Passed to all field-specific `use_local_*` context + managers. + :param kwargs: Passed to all field-specific `use_local_*` context + managers. + :return: The context manager to be used through the `with` syntax. + """ + with ExitStack() as stack: + for lcm in self._use_local_contexts: + stack.enter_context(lcm(*args, **kwargs)) + yield + + +__all__ = [ + 'DistributedStrategySupport' +] diff --git a/avalanche/logging/base_logger.py b/avalanche/logging/base_logger.py index 77b86864e..8598b219b 100644 --- a/avalanche/logging/base_logger.py +++ b/avalanche/logging/base_logger.py @@ -2,9 +2,11 @@ from typing import TYPE_CHECKING, List +from avalanche.distributed import DistributedHelper + + if TYPE_CHECKING: from avalanche.evaluation.metric_results import MetricValue - from avalanche.training.templates import SupervisedTemplate class BaseLogger(ABC): @@ -28,6 +30,31 @@ class BaseLogger(ABC): def __init__(self): super().__init__() + if not DistributedHelper.is_main_process: + + raise RuntimeError( + 'You are creating a logger in a non-main process during a ' + 'distributed training session. ' + 'Jump to this error for an example on how to fix this.') + + # You have to create the loggers in the main process only. Otherwise, + # metrics will end up duplicated in your log files and consistency + # errors may arise, too. When creating the EvaluationPlugin in a + # non-main process, just pass loggers=None. + # + # Recommended way: + # if not DistributedHelper.is_main_process + # # Define the loggers + # loggers = [...] + # else: + # loggers = None + # + # # Instantiate the evaluation plugin + # eval_plugin = EvaluationPlugin(metricA, metricB, ..., loggers=loggers) + # + # # Instantiate the strategy + # strategy = MyStrategy(..., evaluator=eval_plugin) + def log_single_metric(self, name, value, x_plot): """Log a metric value. diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index dbac376d5..f00ed393d 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -14,7 +14,6 @@ """ import torch from torch.nn import Module -import numpy as np from avalanche.benchmarks.utils.flat_data import ConstantSequence from avalanche.benchmarks.scenarios import CLExperience @@ -74,6 +73,11 @@ def eval_adaptation(self, experience: CLExperience): """ pass + @property + def model_device(self): + """Returns the device of the model.""" + return next(self.parameters()).device + class MultiTaskModule(DynamicModule): """Base pytorch Module with support for task labels. @@ -216,7 +220,7 @@ def __init__( self.mask_value = mask_value self.classifier = torch.nn.Linear(in_features, initial_out_features) - au_init = torch.zeros(initial_out_features, dtype=torch.bool) + au_init = torch.zeros(initial_out_features, dtype=torch.int8) self.register_buffer("active_units", au_init) @torch.no_grad() @@ -226,6 +230,7 @@ def adaptation(self, experience: CLExperience): :param experience: data from the current experience. :return: """ + device = self.model_device in_features = self.classifier.in_features old_nclasses = self.classifier.out_features curr_classes = experience.classes_in_this_experience @@ -235,7 +240,11 @@ def adaptation(self, experience: CLExperience): if self.masking: if old_nclasses != new_nclasses: # expand active_units mask old_act_units = self.active_units - self.active_units = torch.zeros(new_nclasses, dtype=torch.bool) + self.active_units = torch.zeros( + new_nclasses, + dtype=torch.int8, + device=device) + self.active_units[: old_act_units.shape[0]] = old_act_units # update with new active classes if self.training: @@ -245,7 +254,7 @@ def adaptation(self, experience: CLExperience): if old_nclasses == new_nclasses: return old_w, old_b = self.classifier.weight, self.classifier.bias - self.classifier = torch.nn.Linear(in_features, new_nclasses) + self.classifier = torch.nn.Linear(in_features, new_nclasses).to(device) self.classifier.weight[:old_nclasses] = old_w self.classifier.bias[:old_nclasses] = old_b @@ -318,14 +327,14 @@ def __init__( self.classifiers["0"] = first_head self.max_class_label = max(self.max_class_label, initial_out_features) - au_init = torch.zeros(initial_out_features, dtype=torch.bool) + au_init = torch.zeros(initial_out_features, dtype=torch.int8) self.register_buffer("active_units_T0", au_init) @property def active_units(self): res = {} for tid in self.known_train_tasks_labels: - mask = getattr(self, f"active_units_T{tid}") + mask = getattr(self, f"active_units_T{tid}").to(torch.bool) au = torch.arange(0, mask.shape[0])[mask].tolist() res[tid] = au return res @@ -334,7 +343,7 @@ def active_units(self): def task_masks(self): res = {} for tid in self.known_train_tasks_labels: - res[tid] = getattr(self, f"active_units_T{tid}") + res[tid] = getattr(self, f"active_units_T{tid}").to(torch.bool) return res def adaptation(self, experience: CLExperience): @@ -344,6 +353,7 @@ def adaptation(self, experience: CLExperience): :return: """ super().adaptation(experience) + device = self.model_device curr_classes = experience.classes_in_this_experience task_labels = experience.task_labels if isinstance(task_labels, ConstantSequence): @@ -355,12 +365,14 @@ def adaptation(self, experience: CLExperience): # head adaptation if tid not in self.classifiers: # create new head new_head = IncrementalClassifier( - self.in_features, self.starting_out_features - ) + self.in_features, self.starting_out_features, masking=False + ).to(device) self.classifiers[tid] = new_head au_init = torch.zeros( - self.starting_out_features, dtype=torch.bool + self.starting_out_features, + dtype=torch.int8, + device=device ) self.register_buffer(f"active_units_T{tid}", au_init) @@ -388,7 +400,9 @@ def adaptation(self, experience: CLExperience): if old_nunits != new_nclasses: # expand active_units mask old_act_units = self._buffers[au_name] self._buffers[au_name] = torch.zeros( - new_nclasses, dtype=torch.bool + new_nclasses, + dtype=torch.int8, + device=device ) self._buffers[au_name][ : old_act_units.shape[0] @@ -405,6 +419,7 @@ def forward_single_task(self, x, task_label): :param task_label: :return: """ + device = self.model_device task_label = str(task_label) out = self.classifiers[task_label](x) if self.masking: @@ -413,7 +428,10 @@ def forward_single_task(self, x, task_label): nunits, oldsize = out.shape[-1], curr_au.shape[0] if oldsize < nunits: # we have to update the mask old_mask = self._buffers[au_name] - self._buffers[au_name] = torch.zeros(nunits, dtype=torch.bool) + self._buffers[au_name] = torch.zeros( + nunits, + dtype=torch.int8, + device=device) self._buffers[au_name][:oldsize] = old_mask curr_au = self._buffers[au_name] out[..., torch.logical_not(curr_au)] = self.mask_value diff --git a/avalanche/models/utils.py b/avalanche/models/utils.py index 5a1ef3153..b40f88191 100644 --- a/avalanche/models/utils.py +++ b/avalanche/models/utils.py @@ -1,19 +1,29 @@ from avalanche.benchmarks.utils import make_classification_dataset from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel from collections import OrderedDict from avalanche.benchmarks.scenarios import CLExperience +def is_multi_task_module(model: nn.Module): + return isinstance(model, MultiTaskModule) or \ + (isinstance(model, DistributedDataParallel) and + isinstance(model.module, MultiTaskModule)) + + def avalanche_forward(model, x, task_labels): - if isinstance(model, MultiTaskModule): + if is_multi_task_module(model): return model(x, task_labels) else: # no task labels return model(x) def avalanche_model_adaptation(model: nn.Module, experience: CLExperience): + if isinstance(model, DistributedDataParallel): + raise RuntimeError('The model is wrapped in DistributedDataParallel. ' + 'Please unwrap it before calling this method.') for module in model.modules(): if isinstance(module, DynamicModule): module.adaptation(experience) diff --git a/avalanche/training/determinism/rng_manager.py b/avalanche/training/determinism/rng_manager.py index 5052cacdc..9b7b0208a 100644 --- a/avalanche/training/determinism/rng_manager.py +++ b/avalanche/training/determinism/rng_manager.py @@ -1,4 +1,3 @@ -import hashlib import random from collections import OrderedDict diff --git a/avalanche/training/plugins/clock.py b/avalanche/training/plugins/clock.py index 535ef3f72..1718beaf3 100644 --- a/avalanche/training/plugins/clock.py +++ b/avalanche/training/plugins/clock.py @@ -18,6 +18,8 @@ class Clock(SupervisedPlugin): wrong for plugins called after it. """ + supports_distributed = True + def __init__(self): """Init.""" super().__init__() diff --git a/avalanche/training/plugins/cwr_star.py b/avalanche/training/plugins/cwr_star.py index 6bd88c681..2495eae85 100644 --- a/avalanche/training/plugins/cwr_star.py +++ b/avalanche/training/plugins/cwr_star.py @@ -22,6 +22,8 @@ class CWRStarPlugin(SupervisedPlugin): This plugin does not use task identities. """ + supports_distributed = True + def __init__(self, model, cwr_layer_name=None, freeze_remaining_model=True): """ :param model: the model. @@ -47,23 +49,26 @@ def __init__(self, model, cwr_layer_name=None, freeze_remaining_model=True): self.cur_class = None def after_training_exp(self, strategy, **kwargs): - self.consolidate_weights() - self.set_consolidate_weights() + with strategy.use_local_model(): + self.consolidate_weights() + self.set_consolidate_weights() def before_training_exp(self, strategy, **kwargs): - if self.freeze_remaining_model and strategy.clock.train_exp_counter > 0: - self.freeze_other_layers() - - # Count current classes and number of samples for each of them. - data = strategy.experience.dataset - self.model.cur_j = examples_per_class(data.targets) - self.cur_class = [ - cls - for cls in set(self.model.cur_j.keys()) - if self.model.cur_j[cls] > 0 - ] - - self.reset_weights(self.cur_class) + with strategy.use_local_model(): + if self.freeze_remaining_model and \ + strategy.clock.train_exp_counter > 0: + self.freeze_other_layers() + + # Count current classes and number of samples for each of them. + data = strategy.experience.dataset + self.model.cur_j = examples_per_class(data.targets) + self.cur_class = [ + cls + for cls in set(self.model.cur_j.keys()) + if self.model.cur_j[cls] > 0 + ] + + self.reset_weights(self.cur_class) def consolidate_weights(self): """Mean-shift for the target layer weights""" diff --git a/avalanche/training/plugins/evaluation.py b/avalanche/training/plugins/evaluation.py index 22a7dfda6..1606613a8 100644 --- a/avalanche/training/plugins/evaluation.py +++ b/avalanche/training/plugins/evaluation.py @@ -2,7 +2,9 @@ from copy import copy from collections import defaultdict from typing import Union, Sequence, TYPE_CHECKING +from typing_extensions import Literal +from avalanche.distributed import DistributedHelper from avalanche.evaluation.metric_results import MetricValue from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics from avalanche.logging import InteractiveLogger @@ -28,10 +30,14 @@ class EvaluationPlugin: This plugin also logs metrics using the provided loggers. """ + supports_distributed = True + def __init__( self, *metrics: Union["PluginMetric", Sequence["PluginMetric"]], - loggers: Union["BaseLogger", Sequence["BaseLogger"]] = None, + loggers: Union["BaseLogger", + Sequence["BaseLogger"], + Literal['default']] = 'default', collect_all=True, strict_checks=False ): @@ -57,14 +63,16 @@ def __init__( flat_metrics_list.append(metric) self.metrics = flat_metrics_list - if loggers is None: + if loggers == 'default': + loggers = make_default_loggers() + elif loggers is None: loggers = [] elif not isinstance(loggers, Sequence): loggers = [loggers] self.loggers: Sequence["BaseLogger"] = loggers - if len(self.loggers) == 0: + if len(self.loggers) == 0 and DistributedHelper.is_main_process: warnings.warn("No loggers specified, metrics will not be logged") if self.collect_all: @@ -200,12 +208,19 @@ def before_eval(self, strategy: "SupervisedTemplate", **kwargs): def default_evaluator(): return EvaluationPlugin( - accuracy_metrics( - minibatch=False, epoch=True, experience=True, stream=True - ), - loss_metrics(minibatch=False, epoch=True, experience=True, stream=True), - loggers=[InteractiveLogger()], + accuracy_metrics(minibatch=False, epoch=True, + experience=True, stream=True), + loss_metrics(minibatch=False, epoch=True, + experience=True, stream=True), + loggers='default' ) +def make_default_loggers(): + if DistributedHelper.is_main_process: + return [InteractiveLogger()] + else: + return [] + + __all__ = ["EvaluationPlugin", "default_evaluator"] diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index 255e686ec..fcbf0f6f1 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -23,6 +23,19 @@ class EWCPlugin(SupervisedPlugin): training set. This plugin does not use task identities. """ + supports_distributed = False + """ + EwC does not support distributed training. + + This is because the plugin needs to compute an additional component of the + loss function that involves model parameters. It is not possible, in + distributed training, to use model parameters to compute grad elements + outside the forward function. + This is a limitation of PyTorch DistributedDataParallel. + + Setting parameters like `find_unused_parameters` do not solve this problem. + """ + def __init__( self, ewc_lambda, diff --git a/avalanche/training/plugins/gdumb.py b/avalanche/training/plugins/gdumb.py index be44c8cdc..0c95224c7 100644 --- a/avalanche/training/plugins/gdumb.py +++ b/avalanche/training/plugins/gdumb.py @@ -21,6 +21,8 @@ class GDumbPlugin(SupervisedPlugin): https://www.robots.ox.ac.uk/~tvg/publications/2020/gdumb.pdf """ + supports_distributed = True + def __init__(self, mem_size: int = 200): super().__init__() self.mem_size = mem_size @@ -39,7 +41,7 @@ def before_train_dataset_adaptation( if self.init_model is None: self.init_model = copy.deepcopy(strategy.model) else: - strategy.model = copy.deepcopy(self.init_model) + strategy.model = copy.deepcopy(self.init_model) strategy.model_adaptation(self.init_model) def before_eval_dataset_adaptation( diff --git a/avalanche/training/plugins/lwf.py b/avalanche/training/plugins/lwf.py index ed5c5b8be..d63afbb9e 100644 --- a/avalanche/training/plugins/lwf.py +++ b/avalanche/training/plugins/lwf.py @@ -10,6 +10,8 @@ class LwFPlugin(SupervisedPlugin): When used with multi-headed models, all heads are distilled. """ + supports_distributed = True + def __init__(self, alpha=1, temperature=2): """ :param alpha: distillation hyperparameter. It can be either a float @@ -24,13 +26,16 @@ def before_backward(self, strategy, **kwargs): Add distillation loss """ - strategy.loss += self.lwf( - strategy.mb_x, strategy.mb_output, strategy.model - ) + with strategy.use_local_loss(): + with strategy.use_local_input_batch(): + with strategy.use_local_output_batch(): + strategy.loss += self.lwf( + strategy.mb_x, strategy.mb_output, strategy.model + ) def after_training_exp(self, strategy, **kwargs): """ Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. """ - self.lwf.update(strategy.experience, strategy.model) + self.lwf.update(strategy.experience, strategy.local_model) diff --git a/avalanche/training/plugins/replay.py b/avalanche/training/plugins/replay.py index f653a1834..22bca224b 100644 --- a/avalanche/training/plugins/replay.py +++ b/avalanche/training/plugins/replay.py @@ -44,6 +44,8 @@ class ReplayPlugin(SupervisedPlugin): in memory """ + supports_distributed = True + def __init__( self, mem_size: int = 200, diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 882aa97f2..18cd631b3 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -60,7 +60,7 @@ def __init__( eval_mb_size: int = 128, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, ): """ @@ -261,18 +261,24 @@ def make_train_dataloader(self, num_workers=0, shuffle=True, **kwargs): if hasattr(self.adapted_dataset, "collate_fn") else None ) + + other_dataloader_args = self._obtain_common_dataloader_parameters( + batch_size=current_batch_mb_size, + num_workers=num_workers, + shuffle=shuffle, + **kwargs + ) + # AR1 only supports SIT scenarios (no task labels). self.dataloader = DataLoader( self.adapted_dataset, - num_workers=num_workers, - batch_size=current_batch_mb_size, - shuffle=shuffle, collate_fn=collate_fn, + **other_dataloader_args ) def training_epoch(self, **kwargs): for mb_it, self.mbatch in enumerate(self.dataloader): - self._unpack_minibatch() + self.unpack_minibatch() self._before_training_iteration(**kwargs) self.optimizer.zero_grad() diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index f2ae3981b..612d35fee 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -2,12 +2,10 @@ from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import ConcatDataset -from avalanche.benchmarks.utils import concat_classification_datasets from avalanche.benchmarks.utils.utils import concat_datasets +from avalanche.training.plugins import SupervisedPlugin from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.templates import SupervisedTemplate @@ -28,7 +26,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, ): """Init. diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 676652636..36384667d 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -37,7 +37,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator=default_evaluator, eval_every=-1, ): """Init function for the SLDA model. @@ -100,16 +100,17 @@ def __init__( def forward(self, return_features=False): """Compute the model's output given the current mini-batch.""" - self.model.eval() - if isinstance(self.model, MultiTaskModule): - feat = self.model(self.mb_x, self.mb_task_id) - else: # no task labels - feat = self.model(self.mb_x) - out = self.predict(feat) - if return_features: - return out, feat - else: - return out + with self.use_local_input_batch(): + self.model.eval() + if isinstance(self.model, MultiTaskModule): + feat = self.model(self.mb_x, self.mb_task_id) + else: # no task labels + feat = self.model(self.mb_x) + out = self.predict(feat) + if return_features: + return out, feat + else: + return out def training_epoch(self, **kwargs): """ @@ -118,7 +119,7 @@ def training_epoch(self, **kwargs): :return: """ for _, self.mbatch in enumerate(self.dataloader): - self._unpack_minibatch() + self.unpack_minibatch() self._before_training_iteration(**kwargs) self.loss = 0 @@ -130,7 +131,7 @@ def training_epoch(self, **kwargs): self._after_forward(**kwargs) # Loss & Backward - self.loss += self.criterion() + self.loss = self.criterion() # Optimization step self._before_update(**kwargs) diff --git a/avalanche/training/supervised/icarl.py b/avalanche/training/supervised/icarl.py index 91125afa2..05d7d04f1 100644 --- a/avalanche/training/supervised/icarl.py +++ b/avalanche/training/supervised/icarl.py @@ -42,7 +42,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, ): """Init. diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 71fd0fbf2..68bb49c5f 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -54,7 +54,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator=default_evaluator, eval_every=-1, ): """Init. @@ -160,7 +160,7 @@ def train_dataset_adaptation(self, **kwargs): self.adapted_dataset = cat_data self.adapted_dataset = self.adapted_dataset.train() - def model_adaptation(self, model=None): + def _model_adaptation(self, model=None): """Adapts strategy's model for all experiences.""" if model is None: model = self.model diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index 7da505094..41f229192 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -39,7 +39,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, peval_mode="epoch", ): diff --git a/avalanche/training/supervised/naive_object_detection.py b/avalanche/training/supervised/naive_object_detection.py index 1575aa725..c549af229 100644 --- a/avalanche/training/supervised/naive_object_detection.py +++ b/avalanche/training/supervised/naive_object_detection.py @@ -56,7 +56,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, peval_mode="epoch", scaler=None, @@ -127,7 +127,7 @@ def make_train_dataloader( self, num_workers=0, shuffle=True, - pin_memory=True, + pin_memory=None, persistent_workers=False, **kwargs ): @@ -139,46 +139,71 @@ def make_train_dataloader( :param num_workers: number of thread workers for the data loading. :param shuffle: True if the data should be shuffled, False otherwise. :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. + pinned memory before returning them. Defaults to None, which means + that the value will be determined by looking at the strategy + `device` field. :param persistent_workers: If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. Used only if `PyTorch >= 1.7.0`. """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers + other_dataloader_args = self._obtain_common_dataloader_parameters( + batch_size=self.train_mb_size, + num_workers=num_workers, + shuffle=shuffle, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + **kwargs + ) self.dataloader = TaskBalancedDataLoader( self.adapted_dataset, oversample_small_groups=True, - num_workers=num_workers, - batch_size=self.train_mb_size, - shuffle=shuffle, - pin_memory=pin_memory, - collate_mbatches=detection_collate_mbatches_fn, collate_fn=detection_collate_fn, **other_dataloader_args ) - def make_eval_dataloader(self, num_workers=0, pin_memory=True, **kwargs): + def make_eval_dataloader( + self, + num_workers=0, + shuffle=False, + pin_memory=None, + persistent_workers=False, + drop_last=False, + **kwargs + ): + """ - Initializes the eval data loader. :param num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0). + :param shuffle: True if the data should be shuffled, False otherwise. :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - :param kwargs: - :return: + pinned memory before returning them. Defaults to None, which means + that the value will be determined by looking at the strategy + `device` field. + :param persistent_workers: If True, the data loader will not shut down + the worker processes after a dataset has been consumed once. + Please refer to PyTorch `DataLoader` class for more details. + :param drop_last: If True, the last batch will be skipped if not of size + equal to the eval minibatch size. + :param kwargs: Other dataloader parameters. """ - self.dataloader = DataLoader( - self.adapted_dataset, - num_workers=num_workers, + + other_dataloader_args = self._obtain_common_dataloader_parameters( batch_size=self.eval_mb_size, + num_workers=num_workers, + shuffle=shuffle, pin_memory=pin_memory, + persistent_workers=persistent_workers, + drop_last=drop_last, + **kwargs + ) + + self.dataloader = DataLoader( + self.adapted_dataset, collate_fn=detection_collate_fn, + **other_dataloader_args ) def criterion(self): @@ -192,19 +217,21 @@ def criterion(self): Beware that the loss can only be obtained for the training phase as no loss dictionary is returned when evaluating. """ - if self.is_training: - return sum(loss for loss in self.detection_loss_dict.values()) - else: - # eval does not compute the loss directly. - # Metrics will use self.mb_output and self.detection_predictions - # to compute AP, AR, ... - self.detection_predictions = { - target["image_id"].item(): output - for target, output in zip(self.mb_y, self.mb_output) - } - return torch.zeros((1,)) - - def forward(self): + with self.use_local_output_batch(): + with self.use_local_input_batch(): + if self.is_training: + return sum( + loss for loss in self.detection_loss_dict.values()) + else: + # eval does not compute the loss directly. + # Metrics will use self.mb_output and + # self.detection_predictions to compute AP, AR, ... + self.detection_predictions = \ + {target["image_id"].item(): output + for target, output in zip(self.mb_y, self.mb_output)} + return torch.zeros((1,)) + + def _forward(self): """ Compute the model's output given the current mini-batch. @@ -228,10 +255,9 @@ def _unpack_minibatch(self): targets = [ {k: v.to(self.device) for k, v in t.items()} for t in self.mbatch[1] ] - self.mbatch[0] = images - self.mbatch[1] = targets + self.mbatch = (images, targets, *self.mbatch[2:]) - def backward(self): + def _backward(self): if self.scaler is not None: self.scaler.scale(self.loss).backward() else: diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index c463d0540..183a4bc9c 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -61,7 +61,7 @@ def __init__( eval_mb_size: Optional[int] = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -117,7 +117,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -172,7 +172,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -236,7 +236,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -311,7 +311,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, generator_strategy: BaseTemplate = None, replay_size: int = None, @@ -435,7 +435,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = get_default_vae_logger(), + evaluator=get_default_vae_logger, eval_every=-1, **base_kwargs ): @@ -460,11 +460,12 @@ def __init__( :param \*\*base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ + self._vae_criterion = criterion super().__init__( model, optimizer, - criterion, + self._vae_criterion_adapter, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -475,10 +476,10 @@ def __init__( **base_kwargs ) - def criterion(self): + def _vae_criterion_adapter(self, *ignored): """Adapt input to criterion as needed to compute reconstruction loss and KL divergence. See default criterion VAELoss.""" - return self._criterion(self.mb_x, self.mb_output) + return self._vae_criterion(self.mb_x, self.mb_output) class GSS_greedy(SupervisedTemplate): @@ -501,7 +502,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -567,7 +568,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -632,7 +633,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -700,7 +701,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -768,7 +769,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -838,7 +839,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -924,7 +925,7 @@ def __init__( eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, - evaluator=default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -999,7 +1000,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -1072,7 +1073,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): @@ -1139,7 +1140,7 @@ def __init__( eval_mb_size: int = 1, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator=default_evaluator, eval_every=-1, **base_kwargs ): diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index d757e2401..45b3b473f 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -8,7 +8,7 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ -from typing import Optional, Sequence, List, Union +from typing import Optional, Sequence, List, Union, Callable from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer @@ -42,7 +42,7 @@ def __init__( eval_mb_size: int = None, device=None, plugins: Optional[List[SupervisedPlugin]] = None, - evaluator: EvaluationPlugin = default_evaluator, + evaluator=default_evaluator, eval_every=-1, ): """ diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index c08998069..b33cca462 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -1,3 +1,4 @@ +import sys import warnings from typing import Iterable, Sequence, Optional, Union, List @@ -6,12 +7,14 @@ from avalanche.benchmarks import CLExperience, CLStream from avalanche.core import BasePlugin +from avalanche.distributed.distributed_helper import DistributedHelper +from avalanche.distributed.strategies import DistributedModelStrategySupport from avalanche.training.utils import trigger_plugins ExpSequence = Iterable[CLExperience] -class BaseTemplate: +class BaseTemplate(DistributedModelStrategySupport): """Base class for continual learning skeletons. **Training loop** @@ -39,6 +42,8 @@ def __init__( ): """Init.""" + super(BaseTemplate, self).__init__() + self.model: Module = model """ PyTorch model. """ @@ -66,6 +71,12 @@ def __init__( self.current_eval_stream: Optional[ExpSequence] = None """ Current evaluation stream. """ + self._distributed_check: bool = False + """ + Internal flag used to verify the support for distributed + training only once. + """ + @property def is_eval(self): """True if the strategy is in evaluation mode.""" @@ -91,6 +102,12 @@ def train( If None: use training experiences for evaluation. Use [] if you do not want to evaluate during training. """ + if not self._distributed_check: + # Checks if the strategy elements are compatible with + # distributed training + self._check_distributed_training_compatibility() + self._distributed_check = True + self.is_training = True self._stop_training = False @@ -131,6 +148,12 @@ def eval( :return: dictionary containing last recorded value for each metric name """ + if not self._distributed_check: + # Checks if the strategy elements are compatible with + # distributed training + self._check_distributed_training_compatibility() + self._distributed_check = True + # eval can be called inside the train method. # Save the shared state here to restore before returning. prev_train_state = self._save_train_state() @@ -216,6 +239,29 @@ def is_callback(x): f"callbacks: {cb_p - cb_supported}", ) return + + def _check_distributed_training_compatibility(self): + """ + Check if strategy elements (plugins, ...) are compatible with + distributed training. + + This check does nothing if not training in distributed mode. + """ + if not DistributedHelper.is_distributed: + return True + + unsupported_plugins = [] + for plugin in self.plugins: + if not getattr(plugin, "supports_distributed", False): + unsupported_plugins.append(plugin) + + if len(unsupported_plugins) > 0: + warnings.warn('You are using plugins that are not compatible' + 'with distributed training:') + for plugin in unsupported_plugins: + print(type(plugin), file=sys.stderr) + + return len(unsupported_plugins) == 0 ######################################################### # Plugin Triggers # diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index dc0ba9d38..ddbfed5ff 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,24 +1,27 @@ -from typing import Iterable, Sequence, Optional, Union, List -from pkg_resources import parse_version +from typing import Iterable, Sequence, Optional, Union, List, final, Callable import torch +from pkg_resources import parse_version from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler from avalanche.benchmarks import CLExperience, CLStream +from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, \ + collate_from_data_or_kwargs from avalanche.core import BaseSGDPlugin +from avalanche.distributed import DistributedHelper +from avalanche.distributed.strategies import \ + DistributedMiniBatchStrategySupport, DistributedLossStrategySupport from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.plugins.clock import Clock from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.templates.base import BaseTemplate, ExpSequence -from avalanche.models.utils import avalanche_model_adaptation -from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, \ - collate_from_data_or_kwargs from avalanche.training.utils import trigger_plugins -class BaseSGDTemplate(BaseTemplate): +class BaseSGDTemplate(BaseTemplate, DistributedMiniBatchStrategySupport, + DistributedLossStrategySupport): """Base SGD class for continual learning skeletons. **Training loop** @@ -47,7 +50,8 @@ def __init__( eval_mb_size: Optional[int] = 1, device="cpu", plugins: Optional[List["SupervisedPlugin"]] = None, - evaluator: EvaluationPlugin = default_evaluator(), + evaluator: Union[EvaluationPlugin, + Callable[[], EvaluationPlugin]] = default_evaluator, eval_every=-1, peval_mode="epoch", ): @@ -91,8 +95,10 @@ def __init__( if evaluator is None: evaluator = EvaluationPlugin() + elif isinstance(evaluator, Callable): + evaluator = evaluator() self.plugins.append(evaluator) - self.evaluator = evaluator + self.evaluator: EvaluationPlugin = evaluator """ EvaluationPlugin used for logging and metric computations. """ # Configure periodic evaluation. @@ -122,6 +128,14 @@ def __init__( use :attr:`.BaseTemplate.experience`. """ + self.collate_fn = None + """ + The collate function used to merge the values obtained from the + dataset into a minibatch. + + This value is obtained from the adapted dataset directly. + """ + self.dataloader = None """ Dataloader. """ @@ -161,12 +175,6 @@ def eval(self, exp_list: Union[CLExperience, CLStream], **kwargs): super().eval(exp_list, **kwargs) return self.evaluator.get_last_metrics() - def _train_exp( - self, experience: CLExperience, eval_streams, **kwargs - ): - # Should be implemented in Observation Type - raise NotImplementedError() - def _eval_exp(self, **kwargs): self.eval_epoch(**kwargs) @@ -195,8 +203,19 @@ def training_epoch(self, **kwargs): # Should be implemented in Update Type raise NotADirectoryError() + @final def backward(self): - """Run the backward pass.""" + """ + Run the backward pass. + This method should not be overridden by child classes. + Consider overriding :meth:`_backward` instead. + """ + with self.use_local_loss(): + self._backward() + self.reset_distributed_loss() + + def _backward(self): + """ Implementation of the backward pass. """ self.loss.backward() def optimizer_step(self): @@ -206,7 +225,7 @@ def optimizer_step(self): def eval_epoch(self, **kwargs): """Evaluation loop over the current `self.dataloader`.""" for self.mbatch in self.dataloader: - self._unpack_minibatch() + self.unpack_minibatch() self._before_eval_iteration(**kwargs) self._before_eval_forward(**kwargs) @@ -218,6 +237,12 @@ def eval_epoch(self, **kwargs): # ==================================================================> NEW + def wrap_distributed_model(self, model): + """ + Prepare a model for distributed training/eval. + """ + return DistributedHelper.wrap_model(model) + def check_model_and_optimizer(self): # Should be implemented in observation type raise NotImplementedError() @@ -302,11 +327,43 @@ def _before_eval_exp(self, **kwargs): super()._before_eval_exp(**kwargs) + def _obtain_common_dataloader_parameters(self, **kwargs): + """ + Utility function that returns the dictionary of parameters to be passed + to the train and eval dataloaders. + + The resulting dataset does not include the collate function. + + Overriding this function can be useful if particular/runtime computed + parameters are needed. However, when overriding, it is recommended to + first call this implementation (super) to obtain a base dictionary of + parameters . + + :param kwargs: The dataloader arguments as passed to the `train` + or `eval` method. + :return: A dictionary of parameters to be passed to the DataLoader class + or to one of the Avalanche dataloaders. + """ + other_dataloader_args = {} + + if 'persistent_workers' in kwargs: + if parse_version(torch.__version__) >= parse_version("1.7.0"): + other_dataloader_args["persistent_workers"] = \ + kwargs['persistent_workers'] + + for k, v in kwargs.items(): + other_dataloader_args[k] = v + + if other_dataloader_args.get('pin_memory', None) is None: + other_dataloader_args['pin_memory'] = self.device.type == 'cuda' + + return other_dataloader_args + def make_train_dataloader( self, num_workers=0, shuffle=True, - pin_memory=True, + pin_memory=None, persistent_workers=False, **kwargs ): @@ -318,53 +375,79 @@ def make_train_dataloader( :param num_workers: number of thread workers for the data loading. :param shuffle: True if the data should be shuffled, False otherwise. :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. + pinned memory before returning them. Defaults to None, which means + that the value will be determined by looking at the strategy + `device` field. + :param persistent_workers: If True, the data loader will not shut down + the worker processes after a dataset has been consumed once. + Please refer to PyTorch `DataLoader` class for more details. + :param kwargs: Other dataloader parameters. """ - other_dataloader_args = {} - - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v + other_dataloader_args = self._obtain_common_dataloader_parameters( + batch_size=self.train_mb_size, + num_workers=num_workers, + shuffle=shuffle, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + **kwargs + ) self.dataloader = TaskBalancedDataLoader( self.adapted_dataset, oversample_small_groups=True, - num_workers=num_workers, - batch_size=self.train_mb_size, - shuffle=shuffle, - pin_memory=pin_memory, **other_dataloader_args ) def make_eval_dataloader( - self, num_workers=0, pin_memory=True, persistent_workers=False, **kwargs + self, + num_workers=0, + shuffle=False, + pin_memory=None, + persistent_workers=False, + drop_last=False, + **kwargs ): """ Initializes the eval data loader. :param num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0). + :param shuffle: True if the data should be shuffled, False otherwise. :param pin_memory: If True, the data loader will copy Tensors into CUDA - pinned memory before returning them. Defaults to True. - :param kwargs: - :return: + pinned memory before returning them. Defaults to None, which means + that the value will be determined by looking at the strategy + `device` field. + :param persistent_workers: If True, the data loader will not shut down + the worker processes after a dataset has been consumed once. + Please refer to PyTorch `DataLoader` class for more details. + :param drop_last: If True, the last batch will be skipped if not of size + equal to the eval minibatch size. + :param kwargs: Other dataloader parameters. """ - other_dataloader_args = {} - if parse_version(torch.__version__) >= parse_version("1.7.0"): - other_dataloader_args["persistent_workers"] = persistent_workers - for k, v in kwargs.items(): - other_dataloader_args[k] = v + other_dataloader_args = self._obtain_common_dataloader_parameters( + batch_size=self.eval_mb_size, + num_workers=num_workers, + shuffle=shuffle, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + drop_last=drop_last, + **kwargs + ) collate_from_data_or_kwargs(self.adapted_dataset, other_dataloader_args) + sampler = None + if DistributedHelper.is_distributed: + sampler = DistributedSampler( + self.adapted_dataset, + shuffle=other_dataloader_args.pop('shuffle'), + drop_last=other_dataloader_args.get('drop_last')) + self.dataloader = DataLoader( self.adapted_dataset, - num_workers=num_workers, - batch_size=self.eval_mb_size, - pin_memory=pin_memory, + sampler=sampler, **other_dataloader_args ) @@ -373,6 +456,17 @@ def eval_dataset_adaptation(self, **kwargs): self.adapted_dataset = self.experience.dataset self.adapted_dataset = self.adapted_dataset.eval() + @final + def unpack_minibatch(self): + """ + Move minibatch elements to device. + This method should not be overridden by child classes. + Consider overriding :meth:`_unpack_minibatch` instead. + """ + with self.use_local_input_batch(): + self._unpack_minibatch() + self.reset_distributed_mbatch() + def _unpack_minibatch(self): """Move to device""" # First verify the mini-batch @@ -448,6 +542,8 @@ class PeriodicEval(SupervisedPlugin): This plugin is automatically configured and added by the BaseTemplate. """ + supports_distributed = True + def __init__(self, eval_every=-1, peval_mode="epoch", do_initial=True): """Init. diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index 4ec073849..02cdb1889 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -1,25 +1,36 @@ -from typing import Iterable +from typing import final -from avalanche.benchmarks import CLExperience from avalanche.models.dynamic_optimizers import reset_optimizer from avalanche.models.utils import avalanche_model_adaptation class BatchObservation: + + @final def model_adaptation(self, model=None): """Adapts the model to the current data. + Calls the :class:`~avalanche.models.DynamicModule`s adaptation. + This method should not be overridden by child classes. + Consider overriding :meth:`_model_adaptation` instead. + """ + with self.use_local_model(): + return self._model_adaptation(model=model) + + def _model_adaptation(self, model=None): + """Adapts the model to the current data. Calls the :class:`~avalanche.models.DynamicModule`s adaptation. """ if model is None: model = self.model avalanche_model_adaptation(model, self.experience) + return model.to(self.device) def make_optimizer(self): """Optimizer initialization. - Called before each training experiene to configure the optimizer. + Called before each training experience to configure the optimizer. """ # we reset the optimizer's state after each experience. # This allows to add new parameters (new heads) and @@ -27,5 +38,12 @@ def make_optimizer(self): reset_optimizer(self.optimizer, self.model) def check_model_and_optimizer(self): - self.model = self.model_adaptation() + with self.use_local_model(): + self.model = self.model_adaptation() + self.model = self.wrap_distributed_model(self.model) self.make_optimizer() + + +__all__ = [ + 'BatchObservation' +] diff --git a/avalanche/training/templates/observation_type/online_observation.py b/avalanche/training/templates/observation_type/online_observation.py index d3dbfaac5..10590e8c2 100644 --- a/avalanche/training/templates/observation_type/online_observation.py +++ b/avalanche/training/templates/observation_type/online_observation.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import final from avalanche.benchmarks import OnlineCLExperience from avalanche.models.dynamic_optimizers import reset_optimizer @@ -7,6 +7,7 @@ class OnlineObservation: + def make_optimizer(self): """Optimizer initialization. @@ -26,8 +27,18 @@ def make_optimizer(self): self.model.parameters(), reset_state=False) + @final def model_adaptation(self, model=None): """Adapts the model to the current data. + Calls the :class:`~avalanche.models.DynamicModule`s adaptation. + This method should not be overridden by child classes. + Consider overriding :meth:`_model_adaptation` instead. + """ + with self.use_local_model(): + return self._model_adaptation(model=model) + + def _model_adaptation(self, model=None): + """Adapts the model to the current data. Calls the :class:`~avalanche.models.DynamicModule`s adaptation. """ @@ -53,14 +64,18 @@ def model_adaptation(self, model=None): return model.to(self.device) def check_model_and_optimizer(self): - # If strategy has access to the task boundaries, and the current - # sub-experience is the first sub-experience in the online (sub-)stream, - # then adapt the model with the full origin experience: - if self.experience.access_task_boundaries: - if self.experience.is_first_subexp: + with self.use_local_model(): + # If strategy has access to the task boundaries, and the current + # sub-experience is the first sub-experience in the online + # (sub-)stream, then adapt the model with the full origin + # experience: + if self.experience.access_task_boundaries: + if self.experience.is_first_subexp: + self.model = self.model_adaptation() + self.model = self.wrap_distributed_model(self.model) + self.make_optimizer() + # Otherwise, adapt to the current sub-experience: + else: self.model = self.model_adaptation() + self.model = self.wrap_distributed_model(self.model) self.make_optimizer() - # Otherwise, adapt to the current sub-experience: - else: - self.model = self.model_adaptation() - self.make_optimizer() diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index 9432e04ef..0bc94c19f 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -1,3 +1,5 @@ +from typing import final + from avalanche.models import avalanche_forward @@ -20,10 +22,23 @@ def mb_task_id(self): def criterion(self): """Loss function for supervised problems.""" - return self._criterion(self.mb_output, self.mb_y) + # Force self.mb_output and self.mb_y to be from local batch + with self.use_local_output_batch(): + with self.use_local_input_batch(): + return self._criterion(self.mb_output, self.mb_y) + @final def forward(self): - """Compute the model's output given the current mini-batch.""" + """ + Compute the model's output given the current mini-batch. + This method should not be overridden by child classes. + Consider overriding :meth:`_forward` instead. + """ + with self.use_local_input_batch(): + return self._forward() + + def _forward(self): + """Implementation of the forward pass.""" return avalanche_forward(self.model, self.mb_x, self.mb_task_id) def _check_minibatch(self): diff --git a/avalanche/training/templates/update_type/meta_update.py b/avalanche/training/templates/update_type/meta_update.py index d387db9c0..b0bba9727 100644 --- a/avalanche/training/templates/update_type/meta_update.py +++ b/avalanche/training/templates/update_type/meta_update.py @@ -12,7 +12,7 @@ def training_epoch(self, **kwargs): if self._stop_training: break - self._unpack_minibatch() + self.unpack_minibatch() self._before_training_iteration(**kwargs) self.optimizer.zero_grad() diff --git a/avalanche/training/templates/update_type/sgd_update.py b/avalanche/training/templates/update_type/sgd_update.py index d85365f49..e81d8e124 100644 --- a/avalanche/training/templates/update_type/sgd_update.py +++ b/avalanche/training/templates/update_type/sgd_update.py @@ -10,11 +10,10 @@ def training_epoch(self, **kwargs): if self._stop_training: break - self._unpack_minibatch() + self.unpack_minibatch() self._before_training_iteration(**kwargs) self.optimizer.zero_grad() - self.loss = 0 # Forward self._before_forward(**kwargs) @@ -22,7 +21,7 @@ def training_epoch(self, **kwargs): self._after_forward(**kwargs) # Loss & Backward - self.loss += self.criterion() + self.loss = self.criterion() self._before_backward(**kwargs) self.backward() diff --git a/avalanche/training/utils.py b/avalanche/training/utils.py index 4d0800c5d..1f7e76d7c 100644 --- a/avalanche/training/utils.py +++ b/avalanche/training/utils.py @@ -421,6 +421,7 @@ def __str__(self): __all__ = [ + "trigger_plugins", "load_all_dataset", "zerolike_params_dict", "copy_params_dict", diff --git a/examples/detection.py b/examples/detection.py index 30abb29d1..cec1329c1 100644 --- a/examples/detection.py +++ b/examples/detection.py @@ -15,41 +15,29 @@ stream of experiences is obtained by splitting the dataset in equal parts. """ +import argparse import logging from pathlib import Path from typing import Union +import torch +import torchvision from torch.utils.data import random_split, Subset +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor +from torchvision.transforms import ToTensor -from avalanche.benchmarks import StreamUserDef -from avalanche.benchmarks.datasets import LvisDataset, PennFudanDataset -from avalanche.benchmarks.scenarios.detection_scenario import ( - DetectionCLScenario, -) -from avalanche.benchmarks.utils import ( - make_classification_dataset, - classification_subset, -) -from avalanche.training.supervised.naive_object_detection import ( - ObjectDetectionTemplate, -) - +from avalanche.benchmarks.datasets import PennFudanDataset from avalanche.evaluation.metrics import ( - make_lvis_metrics, timing_metrics, loss_metrics, - DetectionMetrics, ) +from avalanche.evaluation.metrics.detection import DetectionMetrics from avalanche.logging import InteractiveLogger from avalanche.training.plugins import LRSchedulerPlugin, EvaluationPlugin -import argparse -import torch -from torchvision.transforms import ToTensor -import torchvision -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor - - +from avalanche.training.supervised.naive_object_detection import ( + ObjectDetectionTemplate, +) # This sets the root logger to write to stdout (your console). # Your script/app needs to call this somewhere at least once. from examples.detection_examples_utils import split_detection_benchmark diff --git a/examples/detection_examples_utils.py b/examples/detection_examples_utils.py index b46727066..b13b423fe 100644 --- a/examples/detection_examples_utils.py +++ b/examples/detection_examples_utils.py @@ -5,8 +5,7 @@ DetectionCLScenario, ) from avalanche.benchmarks.utils import ( - make_classification_dataset, - classification_subset, + make_detection_dataset, detection_subset, ) @@ -44,12 +43,12 @@ def split_detection_benchmark( exp_n_imgs = len(train_dataset) // n_experiences remaining = len(train_dataset) % n_experiences - train_dataset_avl = make_classification_dataset( + train_dataset_avl = make_detection_dataset( train_dataset, transform_groups=transform_groups, initial_transform_group="train", ) - test_dataset_avl = make_classification_dataset( + test_dataset_avl = make_detection_dataset( test_dataset, transform_groups=transform_groups, initial_transform_group="eval", @@ -73,9 +72,9 @@ def split_detection_benchmark( last_slice_idx = 0 for exp_id in range(n_experiences): n_imgs = exp_sz[exp_id] - idx_range = train_indices[last_slice_idx : last_slice_idx + n_imgs] + idx_range = train_indices[last_slice_idx:last_slice_idx + n_imgs] train_exps_datasets.append( - classification_subset(train_dataset_avl, indices=idx_range) + detection_subset(train_dataset_avl, indices=idx_range) ) last_slice_idx += n_imgs @@ -100,4 +99,6 @@ def split_detection_benchmark( ) -__all__ = ["split_detection_benchmark"] +__all__ = [ + "split_detection_benchmark" +] diff --git a/examples/distributed_training.py b/examples/distributed_training.py new file mode 100644 index 000000000..486a90bd8 --- /dev/null +++ b/examples/distributed_training.py @@ -0,0 +1,168 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 28-12-2021 # +# Author(s): Lorenzo Pellegrini # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This is a simple example on how to enable distributed training in Avalanche. +""" + + +import argparse +import os +import sys +import time + +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchvision import transforms +from torchvision.transforms import ToTensor, RandomCrop + +from avalanche.benchmarks import SplitMNIST +from avalanche.distributed import DistributedHelper +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.logging import TensorboardLogger +from avalanche.models import SimpleMLP +from avalanche.training import Naive, ClassBalancedBuffer +from avalanche.training.plugins import EvaluationPlugin, ReplayPlugin, \ + LRSchedulerPlugin + +OVERALL_MB_SIZE = 192 + + +def main(args): + # >> Notes on enabling distributed training support in Avalanche << + # + # There are only a few changes to be made when enabling distributed + # training in Avalanche. These are all shown in this example. To recap: + # + # 1. Wrap the main code in a function. Call that function from + # within a "if __name__ == '__main__':" section. + # 2. Add a call to `init_distributed` at the beginning of the main function. + # Obtain the device object using `make_device`. + # 3. (Optional, recommended) Suppress the output for non-main processes. + # 4. (If needed) Avalanche classic benchmarks already have proper ways + # to ensure that dataset files are not downloaded and written + # concurrently. If you need to dynamically download a custom dataset or + # create other working files, do it in the main process only (the one + # with rank 0). + # 5. Loggers cannot be created in non-main processes. Make sure you create + # them in the main process only. Metrics should be instantiated as usual. + # 6. IMPORTANT! Scale your minibatch size by the number of processes used. + # + # Notice that these changes do not impact your ability to run the same + # script in the classic single-process fashion. + # + # You can check how to run this script in a distributed way by looking at + # the `run_distributed_training_example.sh` script in the `examples` folder. + print('Starting experiment', args.exp_name) + + DistributedHelper.init_distributed(random_seed=4321, use_cuda=args.use_cuda) + rank = DistributedHelper.rank + world_size = DistributedHelper.world_size + device = DistributedHelper.make_device() + print(f'Current process rank: {rank}/{world_size}, ' + f'will use device: {device}') + + if not DistributedHelper.is_main_process: + # Suppress the output of non-main processes + # This prevents the output from being duplicated in the console + sys.stdout = open(os.devnull, 'w') + sys.stderr = open(os.devnull, 'w') + + # --- TRANSFORMATIONS + train_transform = transforms.Compose([ + RandomCrop(28, padding=4), + ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + test_transform = transforms.Compose([ + ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + # --------- + + # --- SCENARIO CREATION + scenario = SplitMNIST( + n_experiences=5, + train_transform=train_transform, + eval_transform=test_transform) + # --------- + + # MODEL CREATION + model = SimpleMLP(num_classes=scenario.n_classes) + + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) + + # CREATE THE STRATEGY INSTANCE (NAIVE) + loggers = [] + if DistributedHelper.is_main_process: + # Loggers should be created in the main process only + loggers.append(TensorboardLogger( + tb_log_dir=f'./distributed_training_logs/{args.exp_name}')) + + # Metrics should be created as usual, with no differences between main and + # non-main processes. + my_evaluator = EvaluationPlugin( + accuracy_metrics(epoch=True, experience=True, stream=True), + loss_metrics(epoch=True, experience=True, stream=True), + loggers=loggers + ) + + # Adapt the minibatch size + mb_size = OVERALL_MB_SIZE // DistributedHelper.world_size + + plugins = [] + if args.use_replay: + class_balanced_policy = ClassBalancedBuffer(1500) + plugins.append(ReplayPlugin( + 1500, + storage_policy=class_balanced_policy)) + + if args.use_scheduler: + plugins.append( + LRSchedulerPlugin( + ReduceLROnPlateau(optimizer), step_granularity='iteration', + metric='train_loss' + ) + ) + + cl_strategy = Naive( + model, optimizer, + CrossEntropyLoss(), train_mb_size=mb_size, train_epochs=4, + eval_mb_size=mb_size, plugins=plugins, + device=device, evaluator=my_evaluator) + + start_time = time.time() + + # TRAINING LOOP + print('Starting experiment...') + results = [] + for experience in scenario.train_stream: + print("Start of experience: ", experience.current_experience) + print("Current Classes: ", experience.classes_in_this_experience) + + cl_strategy.train(experience, num_workers=4) + + print('Training completed') + + print('Computing accuracy on the whole test set') + results.append(cl_strategy.eval(scenario.test_stream, num_workers=4)) + + print('Training+eval took', time.time() - start_time) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--use_cuda', action='store_true') + parser.add_argument('--use_replay', action='store_true') + parser.add_argument('--use_scheduler', action='store_true') + parser.add_argument('--exp_name', default='dist_exp') + main(parser.parse_args()) diff --git a/examples/run_distributed_training_example.sh b/examples/run_distributed_training_example.sh new file mode 100755 index 000000000..5d514b685 --- /dev/null +++ b/examples/run_distributed_training_example.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +eval "$(conda shell.bash hook)" +conda activate avalanche-dev-env +set -euo pipefail + +CPU_PARALLELISM=4 +GPU_PARALLELISM=0 + +usage() { + echo "This will run single-process and multi-process training for naive, replay, and replay+scheduler setups." + echo "Used to check for differences between local and distributed training." + echo "" + echo "Run me from the avalanche repo root as 'bash examples/run_distributed_training_example.sh'" + echo + echo "Syntax: examples/run_distributed_training_example [-h] [-c CPU_PARALLELISM] [-g GPU_PARALLELISM]" + echo "" + echo "Options:" + echo "-h Print this Help." + echo "-c Set the CPU parallelism for distributed experiments. Defaults to 4." + echo " Set this value to 0 to skip CPU experiments." + echo "-g Set the GPU parallelism for distributed experiments. Defaults to 0 (skip GPU experiments)." + echo " Set this value to -1 to auto-detect how many GPUs are in the system." +} + +exit_abnormal() { + usage + exit 1 +} + +while getopts ":c:g:" options; do + case "${options}" in + c) + CPU_PARALLELISM=${OPTARG} + ;; + g) + GPU_PARALLELISM=${OPTARG} + ;; + h) + usage + exit 0 + ;; + :) + echo "Error: -${OPTARG} requires an argument!" + echo "" + exit_abnormal + ;; + *) + exit_abnormal + ;; + esac +done + +if [[ "$GPU_PARALLELISM" == "-1" ]]; then + GPU_PARALLELISM=$(nvidia-smi -L | wc -l) + echo "Auto-detected $GPU_PARALLELISM GPUs." +fi + +export PYTHONPATH="${PYTHONPATH-}:${PWD}" + +if [[ "$CPU_PARALLELISM" == "0" ]]; then + echo "Skipping CPU experiments." +else + # Naive experiments + torchrun --standalone --nnodes=1 --nproc_per_node=$CPU_PARALLELISM examples/distributed_training.py \ + --exp_name "distributed_naive_unsched_cpu" + python examples/distributed_training.py \ + --exp_name "single_process_naive_unsched_cpu" + + # Replay experiments + torchrun --standalone --nnodes=1 --nproc_per_node=$CPU_PARALLELISM examples/distributed_training.py \ + --use_replay --exp_name "distributed_replay_unsched_cpu" + python examples/distributed_training.py \ + --use_replay --exp_name "single_process_replay_unsched_cpu" + + # Replay + LR scheduler experiments + torchrun --standalone --nnodes=1 --nproc_per_node=$CPU_PARALLELISM examples/distributed_training.py \ + --use_replay --use_scheduler --exp_name "distributed_replay_scheduler_cpu" + python examples/distributed_training.py \ + --use_replay --use_scheduler --exp_name "single_process_replay_scheduler_cpu" +fi + +if [[ "$GPU_PARALLELISM" == "0" ]]; then + echo "Skipping GPU experiments." + exit 0 +fi + +# Naive experiments (GPU) +torchrun --standalone --nnodes=1 --nproc_per_node=$GPU_PARALLELISM examples/distributed_training.py \ + --exp_name "distributed_naive_unsched_gpu" --use_cuda +python examples/distributed_training.py \ + --exp_name "single_process_naive_unsched_gpu" --use_cuda + +# Replay experiments (GPU) +torchrun --standalone --nnodes=1 --nproc_per_node=$GPU_PARALLELISM examples/distributed_training.py \ + --exp_name "distributed_replay_unsched_gpu" --use_cuda --use_replay +python examples/distributed_training.py \ + --exp_name "single_process_replay_unsched_gpu" --use_cuda --use_replay + +# Replay + LR scheduler experiments (GPU) +torchrun --standalone --nnodes=1 --nproc_per_node=$GPU_PARALLELISM examples/distributed_training.py \ + --exp_name "distributed_replay_scheduler_gpu" --use_cuda --use_replay --use_scheduler +python examples/distributed_training.py \ + --exp_name "single_process_replay_scheduler_gpu" --use_cuda --use_replay --use_scheduler \ No newline at end of file diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/distributed/check_metrics_aligned.py b/tests/distributed/check_metrics_aligned.py new file mode 100644 index 000000000..80c97369d --- /dev/null +++ b/tests/distributed/check_metrics_aligned.py @@ -0,0 +1,33 @@ +import os +import pickle +import sys + + +def load_pickles(directory): + # Load the pickle files into a list of dictionaries. + files = os.listdir(directory) + files.sort() + data = [] + for f in files: + with open(os.path.join(directory, f), 'rb') as fh: + data.append(pickle.load(fh)) + + return data + + +def check_metrics_aligned(directory1, directory2): + data1 = load_pickles(directory1) + data2 = load_pickles(directory2) + assert len(data1) == len(data2) + + # Check that the metrics are aligned. + for i in range(len(data1)): + if data1[i] != data2[i]: + print('Metrics are not aligned for experience {}'.format(i)) + sys.exit(1) + + print('Metrics are aligned') + + +if __name__ == '__main__': + check_metrics_aligned(sys.argv[1], sys.argv[2]) diff --git a/tests/distributed/distributed_test_utils.py b/tests/distributed/distributed_test_utils.py new file mode 100644 index 000000000..4e17e8f4b --- /dev/null +++ b/tests/distributed/distributed_test_utils.py @@ -0,0 +1,42 @@ +import contextlib +import os + +import torch + +from avalanche.distributed import DistributedHelper + + +def common_dst_tests_setup(): + use_gpu_in_tests = os.environ.get('USE_GPU', 'false').lower() in [ + '1', 'true'] + use_gpu_in_tests = use_gpu_in_tests and torch.cuda.is_available() + DistributedHelper.init_distributed(1234, use_cuda=use_gpu_in_tests) + return use_gpu_in_tests + + +def check_skip_distributed_test() -> bool: + return os.environ.get('DISTRIBUTED_TESTS', 'false').lower() \ + not in ['1', 'true'] + + +def check_skip_distributed_slow_test() -> bool: + return check_skip_distributed_test() or \ + os.environ.get('FAST_TEST', 'false').lower() in ['1', 'true'] + + +@contextlib.contextmanager +def suppress_dst_tests_output(): + if os.environ['LOCAL_RANK'] != 0: + with contextlib.redirect_stderr(None): + with contextlib.redirect_stdout(None): + yield + else: + yield + + +__all__ = [ + 'common_dst_tests_setup', + 'check_skip_distributed_test', + 'check_skip_distributed_slow_test', + 'suppress_dst_tests_output' +] diff --git a/tests/distributed/distributed_training_main.py b/tests/distributed/distributed_training_main.py new file mode 100644 index 000000000..d05e1e71b --- /dev/null +++ b/tests/distributed/distributed_training_main.py @@ -0,0 +1,299 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 06-12-2022 # +# Author(s): Lorenzo Pellegrini # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +This is a deterministic version of the script with the same name found in the +examples folder. + +Used in unit tests. + +Adapted from the one used for unit testing the checkpointing functionality. +""" + + +import argparse +import os +import sys +import time +import pickle +from pathlib import Path +from typing import Sequence + +import torch +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from avalanche.benchmarks import CLExperience, \ + SplitCIFAR100, SplitMNIST, SplitFMNIST, SplitCIFAR10 +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_consistency_verification import \ + hash_benchmark, hash_model +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, \ + class_accuracy_metrics +from avalanche.logging import InteractiveLogger, TensorboardLogger, \ + WandBLogger, TextLogger +from avalanche.models import SimpleMLP, as_multitask +from avalanche.training import Naive +from avalanche.training.plugins import EvaluationPlugin, CWRStarPlugin, \ + ReplayPlugin, GDumbPlugin, LwFPlugin, SynapticIntelligencePlugin, \ + EWCPlugin, LRSchedulerPlugin, SupervisedPlugin +from tests.unit_tests_utils import get_fast_benchmark + +OVERALL_MB_SIZE = 192 +BENCHMARK_HASH = \ + '8ac6f78597e6f7279c601f1f75113aec6c56abd1518e3386a6729c7be9262cdd' +MODEL_HASH = \ + 'cbb45bc281908892402fda9794e82d71c3593631f76229f1f396fa7a936affaa' + + +class CheckModelAlignedPlugin(SupervisedPlugin): + + supports_distributed = True + + def after_update(self, strategy, *args, **kwargs): + DistributedHelper.check_equal_objects( + hash_model(strategy.model, include_buffers=True)) + + +def main(args): + torch.use_deterministic_algorithms(True) + + is_dist = DistributedHelper.init_distributed( + random_seed=4321, use_cuda=args.cuda + ) + + rank = DistributedHelper.rank + world_size = DistributedHelper.world_size + device = DistributedHelper.make_device() + print(f'Current process rank: {rank}/{world_size}, ' + f'will use device: {device}') + + if not DistributedHelper.is_main_process: + # Suppress the output of non-main processes + # This prevents the output from being duplicated in the console + sys.stdout = open(os.devnull, 'w') + sys.stderr = open(os.devnull, 'w') + + # --- SCENARIO CREATION + use_tasks = 'si' not in args.plugins and 'cwr' not in args.plugins \ + and args.benchmark != 'Stream51' + input_size = 32*32*3 + + if args.benchmark == 'TestBenchmark': + input_size = 28 * 28 * 1 + scenario = get_fast_benchmark( + use_task_labels=use_tasks, + n_features=input_size, + n_samples_per_class=256, + seed=1337 + ) + + if use_tasks: + # print(hash_benchmark(scenario, num_workers=4)) + assert hash_benchmark(scenario, num_workers=4) == BENCHMARK_HASH + print('Benchmark hash is correct.') + elif args.benchmark == 'SplitMNIST': + scenario = SplitMNIST(n_experiences=5, return_task_id=True) + input_size = 28*28*1 + elif args.benchmark == 'SplitFMNIST': + scenario = SplitFMNIST(n_experiences=5, return_task_id=True) + input_size = 28*28*1 + elif args.benchmark == 'SplitCifar100': + scenario = SplitCIFAR100(n_experiences=5, return_task_id=use_tasks) + elif args.benchmark == 'SplitCifar10': + scenario = SplitCIFAR10(n_experiences=5, return_task_id=use_tasks) + else: + raise ValueError('Unrecognized benchmark name from CLI.') + train_stream: Sequence[CLExperience] = scenario.train_stream + test_stream: Sequence[CLExperience] = scenario.test_stream + + print('Testing using the', args.benchmark, 'benchmark') + # --------- + + # MODEL CREATION + if use_tasks: + model = SimpleMLP(input_size=input_size, + num_classes=scenario.n_classes // 5) + model = as_multitask(model, 'classifier') + if args.benchmark == 'TestBenchmark' and use_tasks: + # print(hash_model(model)) + assert hash_model(model) == MODEL_HASH + print('Model hash is correct.') + else: + model = SimpleMLP(input_size=input_size, num_classes=scenario.n_classes) + + DistributedHelper.check_equal_objects( + hash_model(model, include_buffers=True)) + DistributedHelper.check_equal_objects( + hash_benchmark(scenario, num_workers=4)) + + optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) + criterion = CrossEntropyLoss() + + # CREATE THE STRATEGY INSTANCE (NAIVE) + + # Adapt the minibatch size + mb_size = OVERALL_MB_SIZE // DistributedHelper.world_size + + plugins = [ + CheckModelAlignedPlugin() + ] + + cli_plugins = [] + cli_plugin_names = '_'.join(args.plugins) + for cli_plugin in args.plugins: + if cli_plugin == 'cwr': + plugin_instance = CWRStarPlugin( + model, freeze_remaining_model=True) + elif cli_plugin == 'replay': + plugin_instance = ReplayPlugin(mem_size=500) + elif cli_plugin == 'gdumb': + plugin_instance = GDumbPlugin(mem_size=500) + elif cli_plugin == 'lwf': + plugin_instance = LwFPlugin() + elif cli_plugin == 'si': + plugin_instance = SynapticIntelligencePlugin(0.001) + elif cli_plugin == 'ewc': + plugin_instance = EWCPlugin(0.001) + elif cli_plugin == 'reduce_on_plateau': + plugin_instance = LRSchedulerPlugin( + ReduceLROnPlateau(optimizer), step_granularity='iteration', + metric='train_loss' + ) + else: + raise ValueError('Unrecognized plugin name from CLI.') + print('Adding plugin', plugin_instance) + cli_plugins.append(plugin_instance) + plugins += cli_plugins + + loggers = [] + if DistributedHelper.is_main_process: + use_cuda_str = 'cuda' if args.cuda else 'cpu' + is_dist_str = 'distributed' if is_dist else 'single' + eval_every = f'peval{args.eval_every}' + + log_location: Path = Path('logs') / \ + (f'distributed_{args.benchmark}_' + + f'{use_cuda_str}_{is_dist_str}_{eval_every}_{cli_plugin_names}') + + # Loggers should be created in the main process only + os.makedirs(log_location, exist_ok=True) + loggers = [ + TextLogger(open(log_location / 'log.txt', 'w')), + InteractiveLogger(), + TensorboardLogger(log_location) + ] + + if args.wandb: + loggers.append(WandBLogger( + project_name='AvalancheDistributedTraining', + run_name=f'distributed_{args.benchmark}_' + f'{use_cuda_str}_{is_dist_str}_' + f'{eval_every}_{cli_plugin_names}' + )) + Path(args.log_metrics_to).mkdir(parents=True, exist_ok=True) + + # Metrics should be created as usual, with no differences between main and + # non-main processes. + evaluation_plugin = EvaluationPlugin( + accuracy_metrics(minibatch=False, epoch=True, + experience=True, stream=True), + loss_metrics(minibatch=False, epoch=True, + experience=True, stream=True), + class_accuracy_metrics( + stream=True + ), + loggers=loggers + ) + + cl_strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=mb_size, + train_epochs=2, + eval_mb_size=mb_size, + eval_every=args.eval_every, + peval_mode=args.eval_every_mode, + device=device, + plugins=plugins, + evaluator=evaluation_plugin + ) + + start_time = time.time() + + # TRAINING LOOP + + for experience in train_stream: + cl_strategy.train( + experience, + num_workers=8, + drop_last=True, + shuffle=False) + + metrics = cl_strategy.eval( + test_stream, + num_workers=8, + drop_last=True, + shuffle=False) + + if DistributedHelper.is_main_process: + with open(Path(args.log_metrics_to) / + f'metrics_exp' + f'{experience.current_experience}.pkl', 'wb') as f: + pickle.dump(metrics, f) + + print('Training+eval took', time.time() - start_time) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--cuda', + default=False, + action='store_true', + help="If set, use GPUs." + ) + parser.add_argument( + "--benchmark", + type=str, + default='SplitCifar100', + help="The benchmark to use." + ) + parser.add_argument( + "--eval_every", + type=int, + default=-1, + help="Evaluation frequency." + ) + parser.add_argument( + "--eval_every_mode", + type=str, + default="epoch", + help="Periodic evaluation mode (epoch, experience, iteration)." + ) + parser.add_argument( + "--log_metrics_to", + type=str, + default='./metrics' + ) + parser.add_argument( + "--wandb", + action='store_true' + ) + parser.add_argument( + "--plugins", + nargs='*', + required=False, + default=[] + ) + main(parser.parse_args()) diff --git a/tests/distributed/test_distributed.sh b/tests/distributed/test_distributed.sh new file mode 100755 index 000000000..2f61bcf4f --- /dev/null +++ b/tests/distributed/test_distributed.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Script used to automatically test various combinations of plugins when used with +# the distributed training functionality. +set -euo pipefail +cd tests/distributed +rm -rf logs +rm -rf metrics_no_distributed +rm -rf metrics_distributed + +export PYTHONUNBUFFERED=1 +export PYTHONPATH=../.. +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + +BENCHMARK="TestBenchmark" + +# Config from env +# https://blog.stigok.com/2022/02/08/parsing-boolean-string-statements-in-bash.html +function str_bool { + local str="${1:-false}" + local pat='^(true|1|yes)$' + if [[ ${str,,} =~ $pat ]] + then + echo 'true' + else + echo 'false' + fi +} + +RUN_FAST_TESTS=$(str_bool "${FAST_TEST:-False}") +RUN_GPU_TESTS=$(str_bool "${USE_GPU:-False}") + +TESTS_PARALLELISM=4 + +GPU_PARAM="" + +if [ "$RUN_GPU_TESTS" = "true" ] +then + GPU_PARAM="--cuda" + TESTS_PARALLELISM=$(nvidia-smi -L | wc -l) + echo "Auto-detected $TESTS_PARALLELISM GPUs." +fi + +EXP_RUN_LINE="torchrun --standalone --nnodes=1 --nproc_per_node=$TESTS_PARALLELISM" + +run_and_check() { + set -x + # Run distributed training + $EXP_RUN_LINE distributed_training_main.py $GPU_PARAM \ + --plugins "$@" --benchmark $BENCHMARK --log_metrics_to './metrics_distributed' + + # Without distributed training + python distributed_training_main.py $GPU_PARAM \ + --plugins "$@" --benchmark $BENCHMARK --log_metrics_to './metrics_no_distributed' + + #python -u check_metrics_aligned.py \ + # "./metrics_no_distributed" "./metrics_distributed" + + rm -r metrics_no_distributed + rm -r metrics_distributed + rm -r logs + set +x +} + +run_and_check "replay" + +if [ "$RUN_FAST_TESTS" = "false" ] +then + echo "Running slow tests..." + run_and_check "lwf" + run_and_check "gdumb" + run_and_check "cwr" "replay" +fi diff --git a/tests/distributed/test_distributed_batch.py b/tests/distributed/test_distributed_batch.py new file mode 100644 index 000000000..227d7de9c --- /dev/null +++ b/tests/distributed/test_distributed_batch.py @@ -0,0 +1,106 @@ +import unittest +from typing import Tuple + +import torch +from torch import Tensor +from torch.utils.data import default_collate + +from avalanche.distributed import DistributedHelper, \ + make_classification_distributed_batch, CollateDistributedBatch +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_test, suppress_dst_tests_output, \ + common_dst_tests_setup + + +class DistributedBatchesTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_classification_batch(self): + dt = make_classification_distributed_batch('mb') + + self.assertEqual(None, dt.local_value) + self.assertEqual(None, dt.value) + + batch = (torch.ones((8, 1, 28, 28)), + torch.full( + (8,), fill_value=DistributedHelper.rank, dtype=torch.long)) + + dt.value = batch + + distrib_val = dt.value + + self.assertEqual(2, len(distrib_val)) + self.assertIsInstance(distrib_val, tuple) + self.assertSequenceEqual((8*DistributedHelper.world_size, 1, 28, 28), + distrib_val[0].shape) + self.assertIsInstance(distrib_val[0], Tensor) + self.assertIsInstance(distrib_val[1], Tensor) + for rank in range(DistributedHelper.world_size): + expect = torch.full((8,), + rank, + dtype=torch.long) + self.assertTrue(torch.equal(expect, + distrib_val[1][8*rank:8*(rank+1)])) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_unsupervised_classification_batch(self): + dt = make_classification_distributed_batch('mb') + + self.assertEqual(None, dt.local_value) + self.assertEqual(None, dt.value) + + batch = torch.ones((8, 1, 28, 28)) + + dt.value = batch + + distrib_val = dt.value + + self.assertIsInstance(distrib_val, Tensor) + self.assertSequenceEqual((8*DistributedHelper.world_size, 1, 28, 28), + distrib_val.shape) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_tuple_merge_batch_vanilla_collate(self): + dt: CollateDistributedBatch[Tuple[Tensor, Tensor]] = \ + CollateDistributedBatch( + 'mb', + None, + default_collate, + None) + + self.assertEqual(None, dt.local_value) + self.assertEqual(None, dt.value) + + batch = (torch.ones((8, 1, 28, 28)), + torch.full( + (8,), fill_value=DistributedHelper.rank, dtype=torch.long)) + + dt.value = batch + + distrib_val = dt.value + + self.assertEqual(2, len(distrib_val)) + self.assertSequenceEqual((8 * DistributedHelper.world_size, 1, 28, 28), + distrib_val[0].shape) + for rank in range(DistributedHelper.world_size): + expect = torch.full((8,), + rank, + dtype=torch.long) + self.assertTrue( + torch.equal( + expect, + distrib_val[1][8 * rank:8 * (rank + 1)])) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/distributed/test_distributed_helper.py b/tests/distributed/test_distributed_helper.py new file mode 100644 index 000000000..6bafc9931 --- /dev/null +++ b/tests/distributed/test_distributed_helper.py @@ -0,0 +1,506 @@ +import os +import random +import shutil +import tempfile +import time +import unittest +import numpy as np + +import torch +import torch.distributed as dst +from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel +from avalanche.benchmarks.generators.benchmark_generators import \ + dataset_benchmark +from avalanche.benchmarks.utils.classification_dataset import \ + make_tensor_classification_dataset + +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_helper import \ + RollingSeedContext, BroadcastSeedContext +from avalanche.models import SimpleMLP, as_multitask +from avalanche.models.utils import avalanche_model_adaptation + +from avalanche.training.determinism.rng_manager import RNGManager +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_slow_test, check_skip_distributed_test, \ + suppress_dst_tests_output, common_dst_tests_setup + + +class DistributedHelperTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_device_id(self): + if self.use_gpu_in_tests: + self.assertEqual(dst.get_rank(), DistributedHelper.get_device_id()) + self.assertEqual(torch.device(f'cuda:{dst.get_rank()}'), + DistributedHelper.make_device()) + else: + self.assertEqual(-1, DistributedHelper.get_device_id()) + self.assertEqual(torch.device('cpu'), + DistributedHelper.make_device()) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_wrap_model(self): + mb_size = 1*2*2*3*5 + num_classes = 11 + torch.manual_seed(1234 + DistributedHelper.rank) + mb_x = torch.randn((mb_size, 32)) + mb_y = torch.randint(0, num_classes, (mb_size,)) + mb_t = torch.full((mb_size,), 1) + model = SimpleMLP(num_classes=num_classes, input_size=32) + model = as_multitask(model, 'classifier') + self.assertIsInstance(model, Module) + + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: must raise an error if the model + # is not already in the correct device + with self.assertRaises(Exception): + model_wrapped = DistributedHelper.wrap_model(model) + + model = model.to(device) + + model_wrapped = DistributedHelper.wrap_model(model) + self.assertIsInstance(model_wrapped, DistributedDataParallel) + self.assertNotIsInstance(model, DistributedDataParallel) + + device = DistributedHelper.make_device() + mb_x = mb_x.to(device) + mb_y = mb_y.to(device) + mb_t = mb_t.to(device) + model = model.to(device) + + model.eval() + model_wrapped.eval() + + benchmark = dataset_benchmark( + [make_tensor_classification_dataset( + mb_x, mb_y, mb_t, task_labels=mb_t.tolist() + )], + [make_tensor_classification_dataset( + mb_x, mb_y, mb_t, task_labels=mb_t.tolist() + )] + ) + + avalanche_model_adaptation(model, benchmark.train_stream[0]) + + with torch.no_grad(): + mb_out1 = model(mb_x, mb_t).detach() + self.assertEqual(mb_out1.device, device) + self.assertSequenceEqual([mb_size, num_classes], mb_out1.shape) + + mb_out2 = model_wrapped(mb_x, mb_t).detach() + self.assertEqual(mb_out2.device, device) + self.assertSequenceEqual([mb_size, num_classes], mb_out2.shape) + + self.assertTrue(torch.equal(mb_out1, mb_out2)) + + mb_out_all = DistributedHelper.cat_all(mb_out2) + + start_idx = mb_size * DistributedHelper.rank + end_idx = start_idx + mb_size + + self.assertTrue(torch.equal(mb_out1, + mb_out_all[start_idx: end_idx])) + + self.assertTrue(model is DistributedHelper.unwrap_model(model_wrapped)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_broadcast_tensor_or_objects(self): + ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long) + DistributedHelper.broadcast(ts) + self.assertTrue(torch.equal(ts, torch.zeros((10,), dtype=torch.long))) + + device = DistributedHelper.make_device() + ts = ts.to(device) + + my_object = {'a': DistributedHelper.rank, 'b': ts} + my_object_from_main = DistributedHelper.broadcast_object(my_object) + + expect = { + 'a': 0, + 'b': torch.full((10,), 0, dtype=torch.long).tolist()} + + self.assertEqual(device, my_object_from_main['b'].device) + my_object_from_main['b'] = my_object_from_main['b'].tolist() + self.assertEqual(expect, my_object_from_main) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_objects(self): + ts = torch.full((10,), DistributedHelper.rank, dtype=torch.long) + + device = DistributedHelper.make_device() + ts = ts.to(device) + + my_object = {'a': DistributedHelper.rank, 'b': ts} + all_objects = DistributedHelper.gather_all_objects(my_object) + self.assertIsInstance(all_objects, list) + self.assertEqual(DistributedHelper.world_size, len(all_objects)) + + for rank in range(DistributedHelper.world_size): + expect = { + 'a': rank, + 'b': torch.full((10,), rank, dtype=torch.long).tolist()} + + self.assertEqual(device, all_objects[rank]['b'].device) + all_objects[rank]['b'] = all_objects[rank]['b'].tolist() + self.assertEqual(expect, all_objects[rank]) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_cat_all(self): + if DistributedHelper.rank == 0: + ts = torch.full((10+1, 5), DistributedHelper.rank, dtype=torch.long) + else: + ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: tensors do not need to be on the default device + DistributedHelper.cat_all(ts) + + ts = ts.to(device) + + concatenated_tensor = DistributedHelper.cat_all(ts) + + self.assertEqual(device, concatenated_tensor.device) + + expect = torch.empty((DistributedHelper.world_size * 10 + 1, 5), + dtype=torch.long).to(device) + for rank in range(DistributedHelper.world_size): + if rank == 0: + expect[rank * 10: (rank + 1) * 10 + 1] = rank + else: + expect[1 + rank * 10: 1 + (rank + 1) * 10] = rank + + self.assertTrue(torch.equal(concatenated_tensor, expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_size(self): + ts = torch.full((10, 5), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + if device.type == 'cuda': + # Additional test: tensors do not need to be on the default device + DistributedHelper.gather_all(ts) + + # On the other hand, PyTorch all_gather requires tensors to be on + # the default device + with self.assertRaises(Exception): + + out_t = [torch.empty_like(ts) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(out_t, ts) + + # ... while this should work + out_t = [torch.empty_like(ts).to(device) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(out_t, ts.to(device)) + + ts = ts.to(device) + + for same_shape in [False, True]: + print(f'same_shape={same_shape}') + # with self.subTest(same_shape=same_shape): + tensor_list = DistributedHelper.gather_all( + ts, same_shape=same_shape) + + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10, 5), rank, dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_slow_test(), + 'Distributed tests ignored') + def test_gather_all_performance_known_same_shape(self): + ts = torch.full((128, 224, 224, 3), + DistributedHelper.rank, + dtype=torch.float32) + device = DistributedHelper.make_device() + ts = ts.to(device) + + resulting_tensors = [torch.empty_like(ts).to(device) + for _ in range(DistributedHelper.world_size)] + + from tqdm import tqdm + n_times = 30 + torch.distributed.all_gather(resulting_tensors, ts) + start_time = time.time() + for _ in tqdm(range(n_times)): + torch.distributed.all_gather(resulting_tensors, ts) + end_time = time.time() + print('Time taken by PyTorch all_gather', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + start_time = time.time() + out_list = [None for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather_object(out_list, ts) + + for _ in tqdm(range(n_times)): + torch.distributed.all_gather_object(out_list, ts) + end_time = time.time() + print('Time taken by PyTorch all_gather_object', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + @unittest.skipIf(check_skip_distributed_slow_test(), + 'Distributed tests ignored') + def test_gather_all_performance_sync_shape(self): + max_shape_size = 10 + shape = [128, 6, DistributedHelper.rank+1] + \ + ([3] * DistributedHelper.rank) + + device = DistributedHelper.make_device() + + def shape_all_gather(): + ts = torch.zeros((max_shape_size,), dtype=torch.int64) + for i in range(len(shape)): + ts[i] = shape[i] + + ts = ts.to(device) + all_tensors_shape = [torch.empty_like(ts) + for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather(all_tensors_shape, ts) + all_tensors_shape = [t.cpu() for t in all_tensors_shape] + + for i, t in enumerate(all_tensors_shape): + for x in range(len(t)): + if t[x] == 0: + if x == 0: + # Tensor with 0-length shape + all_tensors_shape[i] = t[:x+1] + else: + all_tensors_shape[i] = t[:x] + break + + def shape_all_gather_objects(): + out_list = [None for _ in range(DistributedHelper.world_size)] + torch.distributed.all_gather_object(out_list, shape) + + from tqdm import tqdm + n_times = 1000 + shape_all_gather() + start_time = time.time() + for _ in tqdm(range(n_times)): + shape_all_gather() + end_time = time.time() + print('Time taken by PyTorch all_gather', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + start_time = time.time() + shape_all_gather_objects() + + for _ in tqdm(range(n_times)): + shape_all_gather_objects() + end_time = time.time() + print('Time taken by PyTorch all_gather_object', end_time-start_time, + 'avg', (end_time-start_time) / n_times) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_dim0(self): + ts = torch.full((10, DistributedHelper.rank+1), + DistributedHelper.rank, + dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + tensor_list = DistributedHelper.gather_all(ts) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10, rank+1), + rank, + dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_same_dim1_n(self): + ts = torch.full((10+DistributedHelper.rank, 5), + DistributedHelper.rank, + dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + tensor_list = DistributedHelper.gather_all(ts) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full((10+rank, 5), + rank, + dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_gather_all_zero_shaped(self): + ts = torch.full(tuple(), DistributedHelper.rank, dtype=torch.long) + device = DistributedHelper.make_device() + + ts = ts.to(device) + + for same_shape in [False, True]: + print(f'same_shape={same_shape}') + # with self.subTest(same_shape=same_shape): + tensor_list = DistributedHelper.gather_all( + ts, + same_shape=same_shape) + self.assertEqual(DistributedHelper.world_size, len(tensor_list)) + + for t in tensor_list: + self.assertEqual(device, t.device) + + for rank in range(DistributedHelper.world_size): + expect = torch.full(tuple(), rank, dtype=torch.long).to(device) + self.assertTrue(torch.equal(tensor_list[rank], expect)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_check_equal_tensors(self): + torch.manual_seed(1234) + ts = torch.randn((100,)) + DistributedHelper.check_equal_tensors(ts) + + torch.manual_seed(1234 + DistributedHelper.rank) + ts = torch.randn((100,)) + with self.assertRaises(Exception): + DistributedHelper.check_equal_tensors(ts) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_fields(self): + self.assertEqual(dst.get_rank(), DistributedHelper.rank) + self.assertEqual(dst.get_world_size(), DistributedHelper.world_size) + self.assertEqual(True, DistributedHelper.is_distributed) + self.assertEqual(dst.get_rank() == 0, DistributedHelper.is_main_process) + + if self.use_gpu_in_tests: + self.assertEqual('nccl', DistributedHelper.backend) + self.assertTrue(DistributedHelper.forced_cuda_comm) + else: + self.assertEqual('gloo', DistributedHelper.backend) + self.assertFalse(DistributedHelper.forced_cuda_comm) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_set_random_seeds_and_align(self): + DistributedHelper.set_random_seeds(5678) + + self.assertEqual(297076, np.random.randint(0, 1000000)) + self.assertEqual(643380, torch.randint(0, 1000000, (1,)).item()) + self.assertEqual(683410, random.randint(0, 1000000)) + + if DistributedHelper.is_main_process: + np.random.randint(0, 1000000) + torch.randint(0, 1000000, (1,)) + random.randint(0, 1000000) + + DistributedHelper.align_seeds() + + ref_values = ( + int(np.random.randint(0, 1000000)), + int(torch.randint(0, 1000000, (1,))), + int(random.randint(0, 1000000)) + ) + + DistributedHelper.check_equal_objects(ref_values) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_rolling_seed_aligner(self): + RNGManager.set_random_seeds(4321) + + with RollingSeedContext(): + RNGManager.set_random_seeds(1234 + DistributedHelper.rank) + random.randint(0, 2 ** 64 - 1) + + final_value = random.randint(0, 2 ** 64 - 1) + self.assertEqual(14732185405572191734, final_value) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_broadcast_seed_aligner(self): + RNGManager.set_random_seeds(4321) + + with BroadcastSeedContext(): + RNGManager.set_random_seeds(1234 + DistributedHelper.rank) + random.randint(0, 2 ** 64 - 1) + + final_value = random.randint(0, 2 ** 64 - 1) + self.assertEqual(15306775005444441373, final_value) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_main_process_first(self): + tmpdirname = '' + try: + my_rank = DistributedHelper.rank + if DistributedHelper.is_main_process: + tmpdirname = tempfile.mkdtemp() + + tmpdirname = DistributedHelper.broadcast_object(tmpdirname) + + with DistributedHelper.main_process_first(): + + for _ in range(2): + time.sleep(0.1 + my_rank * 0.05) + files = list(os.listdir(tmpdirname)) + if DistributedHelper.is_main_process: + self.assertEqual(0, len(files)) + else: + self.assertIn(f'rank0', files) + self.assertNotIn(f'rank{my_rank}', files) + + with open(os.path.join(tmpdirname, f'rank{my_rank}'), 'w') \ + as f: + f.write('ok') + + for _ in range(2): + time.sleep(0.1 + my_rank * 0.05) + files = list(os.listdir(tmpdirname)) + if DistributedHelper.is_main_process: + self.assertEqual(1, len(files)) + self.assertIn(f'rank0', files) + else: + self.assertIn(f'rank0', files) + self.assertIn(f'rank{my_rank}', files) + + DistributedHelper.barrier() + files = set(os.listdir(tmpdirname)) + expect = set([f'rank{rnk}' + for rnk in range(DistributedHelper.world_size)]) + self.assertSetEqual(expect, files) + DistributedHelper.barrier() + finally: + if tmpdirname is not None and DistributedHelper.is_main_process: + shutil.rmtree(tmpdirname) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/distributed/test_distributed_model.py b/tests/distributed/test_distributed_model.py new file mode 100644 index 000000000..afd50f3fc --- /dev/null +++ b/tests/distributed/test_distributed_model.py @@ -0,0 +1,180 @@ +import unittest + +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader + +from avalanche.distributed import DistributedHelper, DistributedModel +from avalanche.models import SimpleMLP +from avalanche.models.helper_method import as_multitask +from avalanche.models.utils import avalanche_forward, avalanche_model_adaptation +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_test, suppress_dst_tests_output, \ + common_dst_tests_setup +from tests.unit_tests_utils import get_fast_benchmark + + +class DistributedModelTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_distributed_model(self): + dt: DistributedModel = DistributedModel() + model = SimpleMLP() + self.assertIsNone(dt.local_value) + self.assertIsNone(dt.value) + self.assertIsNone(dt.distributed_value) + + device = DistributedHelper.make_device() + + dt.model = model + + self.assertEqual(model, dt.local_value) + self.assertEqual(model, dt.value) + self.assertEqual(model, dt.distributed_value) + + if device.type == 'cuda': + # Additional test: must raise an error if the model + # is not already in the correct device + with self.assertRaises(Exception): + wrapped = DistributedDataParallel( + model, + device_ids=[device]) + + model = model.to(device) + wrapped = DistributedDataParallel( + model, + device_ids=[device]) + + dt.model = wrapped + + self.assertEqual(model, dt.local_value) + self.assertNotIsInstance(dt.local_value, DistributedDataParallel) + + self.assertIsInstance(dt.value, DistributedDataParallel) + self.assertEqual(wrapped, dt.value) + self.assertEqual(wrapped, dt.distributed_value) + + dt.reset_distributed_value() + + self.assertEqual(model, dt.local_value) + self.assertEqual(model, dt.value) + self.assertEqual(model, dt.distributed_value) + + self.assertNotIsInstance(dt.value, DistributedDataParallel) + + dt.reset_distributed_value() + self.assertIsNotNone(dt.local_value) + + dt.value = wrapped + dt.distributed_model = None + + self.assertIsNotNone(dt.local_value) + + dt.value = None + + self.assertIsNone(dt.local_value) + self.assertIsNone(dt.distributed_value) + self.assertIsNone(dt.value) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_distributed_model_multitask(self): + dt: DistributedModel = DistributedModel() + model = SimpleMLP() + model = as_multitask(model, 'classifier') + self.assertIsNone(dt.local_value) + self.assertIsNone(dt.value) + self.assertIsNone(dt.distributed_value) + + device = DistributedHelper.make_device() + + dt.model = model + + self.assertEqual(model, dt.local_value) + self.assertEqual(model, dt.value) + self.assertEqual(model, dt.distributed_value) + + if device.type == 'cuda': + # Additional test: must raise an error if the model + # is not already in the correct device + with self.assertRaises(Exception): + wrapped = DistributedDataParallel( + model, + device_ids=[device]) + + model = model.to(device) + wrapped = DistributedDataParallel( + model, + device_ids=[device]) + + dt.model = wrapped + + self.assertEqual(model, dt.local_value) + self.assertNotIsInstance(dt.local_value, DistributedDataParallel) + + self.assertIsInstance(dt.value, DistributedDataParallel) + self.assertEqual(wrapped, dt.value) + self.assertEqual(wrapped, dt.distributed_value) + + dt.reset_distributed_value() + + self.assertEqual(model, dt.local_value) + self.assertEqual(model, dt.value) + self.assertEqual(model, dt.distributed_value) + + self.assertNotIsInstance(dt.value, DistributedDataParallel) + + dt.reset_distributed_value() + self.assertIsNotNone(dt.local_value) + + dt.value = wrapped + dt.distributed_model = None + + self.assertIsNotNone(dt.local_value) + + dt.value = None + + self.assertIsNone(dt.local_value) + self.assertIsNone(dt.distributed_value) + self.assertIsNone(dt.value) + + # test model adaptation + input_size = 28 * 28 * 1 + scenario = get_fast_benchmark( + use_task_labels=True, + n_features=input_size, + n_samples_per_class=256, + seed=1337 + ) + avalanche_model_adaptation(model, scenario.train_stream[1]) + model.eval() + dt.value = model + + wrapped = DistributedDataParallel(model, device_ids=[device]) + dt.model = wrapped + + self.assertEqual(model, dt.local_value) + loader = DataLoader(scenario.train_stream[1].dataset, batch_size=32) + with torch.no_grad(): + for x, y, t in loader: + x = x.to(device) + y = y.to(device) + t = t.to(device) + self.assertEqual([1] * len(t), t.tolist()) + out_mb = avalanche_forward(dt.model, x, t) + DistributedHelper.check_equal_tensors(out_mb) + out_mb_local = avalanche_forward(dt.local_value, x, t) + DistributedHelper.check_equal_tensors(out_mb_local) + self.assertTrue(torch.equal(out_mb, out_mb_local)) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/distributed/test_distributed_strategy_support.py b/tests/distributed/test_distributed_strategy_support.py new file mode 100644 index 000000000..45d7e67f1 --- /dev/null +++ b/tests/distributed/test_distributed_strategy_support.py @@ -0,0 +1,309 @@ +import hashlib +import math +import unittest + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.utils.data import DistributedSampler, DataLoader + +from avalanche.core import SupervisedPlugin +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_consistency_verification import \ + hash_dataset +from avalanche.distributed.strategies import DistributedMiniBatchStrategySupport +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, \ + confusion_matrix_metrics, topk_acc_metrics, class_accuracy_metrics, \ + amca_metrics +from avalanche.models import SimpleMLP +from avalanche.training import Naive +from avalanche.training.plugins import EvaluationPlugin +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_test, suppress_dst_tests_output, \ + common_dst_tests_setup +from tests.unit_tests_utils import get_fast_benchmark + + +class DistributedStrategySupportTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_use_local_works(self): + uut = DistributedMiniBatchStrategySupport() + uut.mbatch = torch.full((5, 10), DistributedHelper.rank, + dtype=torch.float32) + uut.mb_output = torch.full((5, 10), DistributedHelper.rank, + dtype=torch.float32) + + # Test without use_local + got_mbatch = uut.mbatch + got_mb_output = uut.mb_output + + expected_shape = (DistributedHelper.world_size * 5, 10) + + self.assertSequenceEqual(expected_shape, got_mbatch.shape) + self.assertSequenceEqual(expected_shape, got_mb_output.shape) + + for row_idx in range(expected_shape[0]): + from_rank = row_idx // 5 + self.assertTrue(torch.equal( + torch.full((10,), from_rank, dtype=torch.float32), + got_mbatch[row_idx])) + self.assertTrue(torch.equal( + torch.full((10,), from_rank, dtype=torch.float32), + got_mb_output[row_idx])) + + # Test with use_local + uut.mbatch = torch.full((5, 10), DistributedHelper.rank, + dtype=torch.float32) + uut.mb_output = torch.full((5, 10), DistributedHelper.rank, + dtype=torch.float32) + + with uut.use_local(): + got_mbatch = uut.mbatch + got_mb_output = uut.mb_output + + expected_shape = (5, 10) + + self.assertSequenceEqual(expected_shape, got_mbatch.shape) + self.assertSequenceEqual(expected_shape, got_mb_output.shape) + + for row_idx in range(expected_shape[0]): + from_rank = DistributedHelper.rank + self.assertTrue(torch.equal( + torch.full((10,), from_rank, dtype=torch.float32), + got_mbatch[row_idx])) + self.assertTrue(torch.equal( + torch.full((10,), from_rank, dtype=torch.float32), + got_mb_output[row_idx])) + + def _check_loss_equal(self, uut): + local_loss = uut.local_loss + global_loss = uut.loss + + self.assertIsInstance(local_loss, Tensor) + self.assertIsInstance(global_loss, Tensor) + self.assertEqual(uut.device, local_loss.device) + self.assertEqual(uut.device, global_loss.device) + + all_losses = DistributedHelper.gather_all_objects(float(local_loss)) + # Note: the results of torch.mean are different from the ones + # of statistics.mean + self.assertAlmostEqual( + float(torch.mean(torch.as_tensor(all_losses))), + float(global_loss)) + + def _check_batches_equal(self, uut: Naive, rank: int, mb_size: int, + mb_dist_size: int, input_size: int): + local_input_mb = uut.local_mbatch + global_input_mb = uut.mbatch + + self.assertEqual(3, len(local_input_mb)) + self.assertEqual(3, len(global_input_mb)) + + for mb_i, mb_elem in enumerate(local_input_mb): + self.assertIsInstance(mb_elem, Tensor) + self.assertEqual(uut.device, mb_elem.device) + + for mb_i, mb_elem in enumerate(global_input_mb): + self.assertIsInstance(mb_elem, Tensor) + self.assertEqual(uut.device, mb_elem.device) + + self.assertTrue(torch.equal(global_input_mb[0], uut.mb_x)) + self.assertTrue(torch.equal(global_input_mb[1], uut.mb_y)) + self.assertTrue(torch.equal(global_input_mb[2], uut.mb_task_id)) + + self.assertSequenceEqual(local_input_mb[0].shape, + [mb_dist_size, input_size]) + self.assertSequenceEqual(local_input_mb[1].shape, [mb_dist_size]) + self.assertSequenceEqual(local_input_mb[2].shape, [mb_dist_size]) + + self.assertSequenceEqual(global_input_mb[0].shape, + [mb_size, input_size]) + self.assertSequenceEqual(global_input_mb[1].shape, [mb_size]) + self.assertSequenceEqual(global_input_mb[2].shape, [mb_size]) + + global_index_start = mb_dist_size * rank + global_index_end = global_index_start + mb_dist_size + + for i in range(3): + self.assertTrue( + torch.equal( + local_input_mb[i], + global_input_mb[i][global_index_start:global_index_end])) + + def _check_adapted_datasets_equal(self, uut: Naive): + local_adapted_dataset = uut.adapted_dataset + + DistributedHelper.check_equal_objects( + hash_dataset(local_adapted_dataset, + num_workers=4, + hash_engine=hashlib.sha1()) + ) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_naive_classification_dst(self): + self.assertTrue(DistributedHelper.is_distributed) + + input_size = 28 * 28 + # mb_size == 60 so that it can be tested using [1, 6] parallel processes + mb_size = 1*2*2*3*5 + model = SimpleMLP(input_size=input_size) + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) + criterion = CrossEntropyLoss() + device = DistributedHelper.make_device() + + # DST parameters adaptation + mb_size_dst = mb_size // DistributedHelper.world_size + + class IterationCheckerPlugin(SupervisedPlugin): + + supports_distributed = True + + def __init__(self, test_suite): + super().__init__() + self.test_suite = test_suite + + def after_training_iteration(self, strategy, *args, **kwargs): + self._check_aligned(strategy) + + def after_eval_iteration(self, strategy, *args, **kwargs): + self._check_aligned(strategy) + + def _check_aligned(self, strategy: Naive): + + is_last_iteration = strategy.clock.train_epoch_iterations == \ + (len(strategy.dataloader) - 1) + if is_last_iteration: + return + + self.test_suite._check_batches_equal( + strategy, + DistributedHelper.rank, + mb_size, + mb_size_dst, + input_size) + self.test_suite._check_loss_equal(strategy) + + metrics = EvaluationPlugin( + accuracy_metrics(minibatch=True, epoch=True, + experience=True, stream=True), + loss_metrics(minibatch=True, epoch=True, + experience=True, stream=True), + confusion_matrix_metrics(save_image=False, + stream=True), + topk_acc_metrics(minibatch=True, epoch=True, + experience=True, stream=True), + class_accuracy_metrics(minibatch=True, epoch=True, + experience=True, stream=True), + amca_metrics(), + loggers='default' + ) + + uut = Naive( + model, + optimizer, + criterion, + train_mb_size=mb_size_dst, + eval_mb_size=mb_size_dst, + train_epochs=2, + device=device, + plugins=[IterationCheckerPlugin(self)], + evaluator=metrics + ) + + self.assertEqual(device, uut.device) + + if not DistributedHelper.is_main_process: + self.assertEqual(0, len(uut.evaluator.loggers)) + + benchmark = get_fast_benchmark( + n_samples_per_class=250, + n_features=input_size) + + for exp_idx, train_experience in enumerate(benchmark.train_stream): + metrics = uut.train(train_experience, drop_last=False) + + # Check that drop_last=False works correctly + train_dataset_sz = len(uut.adapted_dataset) + world_size = DistributedHelper.world_size + last_mb_size_without_dropping = \ + math.ceil(train_dataset_sz / world_size) * world_size % mb_size + if last_mb_size_without_dropping == 0: + # Corner case: no drop needed + last_mb_size_without_dropping = mb_size + last_mb_size_without_dropping_dst = \ + last_mb_size_without_dropping // world_size + + self._check_batches_equal( + uut, + DistributedHelper.rank, + last_mb_size_without_dropping, + last_mb_size_without_dropping_dst, + input_size) + + # Other checks + self._check_loss_equal(uut) + if exp_idx < 2: + # Do it only for the first 2 experiences to speed up tests + self._check_adapted_datasets_equal(uut) + DistributedHelper.check_equal_objects(metrics) + + metrics = uut.eval(benchmark.test_stream, drop_last=True) + # Also checks that drop_last=True works correctly + self._check_batches_equal(uut, DistributedHelper.rank, mb_size, + mb_size_dst, input_size) + self._check_loss_equal(uut) + if exp_idx < 2: + # Do it only for the first 2 experiences to speed up tests + self._check_adapted_datasets_equal(uut) + DistributedHelper.check_equal_objects(metrics) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_pytorch_distributed_sampler(self): + """ + Only used to test the DistributedSampler class from PyTorch. + """ + self.assertTrue(DistributedHelper.is_distributed) + + input_size = 28 * 28 + mb_size = 210 # Can be tested using [1, 10] parallel processes + + # DST parameters adaptation + mb_size_dst = mb_size // DistributedHelper.world_size + + benchmark = get_fast_benchmark( + n_samples_per_class=175 * 4, + n_features=input_size) + + for train_experience in benchmark.train_stream: + dataset = train_experience.dataset + sampler = DistributedSampler( + dataset, + shuffle=True, + drop_last=True + ) + dataloader = DataLoader( + dataset, + batch_size=mb_size_dst, + sampler=sampler, + drop_last=True + ) + + for mb_x, mb_y, mb_t in dataloader: + self.assertSequenceEqual(mb_x.shape, [mb_size_dst, input_size]) + self.assertEqual(len(mb_y), mb_size_dst) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/distributed/test_distributed_tensor.py b/tests/distributed/test_distributed_tensor.py new file mode 100644 index 000000000..e4ca40cad --- /dev/null +++ b/tests/distributed/test_distributed_tensor.py @@ -0,0 +1,79 @@ +import unittest + +import torch + +from avalanche.distributed import DistributedHelper +from avalanche.distributed.distributed_tensor import \ + DistributedMeanTensor +from tests.distributed.distributed_test_utils import \ + check_skip_distributed_test, suppress_dst_tests_output, \ + common_dst_tests_setup + + +class DistributedTensorTests(unittest.TestCase): + + def setUp(self) -> None: + self.use_gpu_in_tests = common_dst_tests_setup() + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_one_element_tensor(self): + dt = DistributedMeanTensor('dt', torch.zeros((1,), dtype=torch.float32)) + + self.assertEqual(0.0, dt.local_value.float()) + self.assertEqual(0.0, dt.value.float()) + + i = DistributedHelper.rank + 1 + + dt.value = torch.full((1,), fill_value=i, + dtype=torch.float32) + + n = DistributedHelper.world_size + expected = n * (n + 1) / 2 + + self.assertEqual(i, float(dt.local_value)) + self.assertEqual(expected / n, float(dt.value)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_one_element_tensor_random(self): + dt = DistributedMeanTensor('dt', torch.zeros((1,), dtype=torch.float32)) + + rnd_value = torch.randint(0, 100000, (10,), dtype=torch.float32) + dt.value = rnd_value + + expected = torch.mean(rnd_value) + + self.assertTrue(torch.allclose(expected, torch.mean(dt.local_value))) + self.assertTrue(torch.allclose(expected, dt.value)) + + @unittest.skipIf(check_skip_distributed_test(), + 'Distributed tests ignored') + def test_unshaped_tensor(self): + dt = DistributedMeanTensor('dt', + torch.as_tensor(5, dtype=torch.float32)) + + self.assertEqual(5.0, dt.local_value.float()) + self.assertEqual(5.0, dt.value.float()) + self.assertEqual(0, len(dt.local_value.shape)) + self.assertEqual(0, len(dt.value.shape)) + + i = DistributedHelper.rank + 1 + + dt.value = torch.as_tensor(i, dtype=torch.float32) + + n = DistributedHelper.world_size + expected = n * (n + 1) / 2 + + self.assertEqual(i, float(dt.local_value)) + self.assertEqual(expected / n, float(dt.value)) + self.assertEqual(0, len(dt.local_value.shape)) + self.assertEqual(0, len(dt.value.shape)) + + +if __name__ == "__main__": + with suppress_dst_tests_output(): + verbosity = 1 + if DistributedHelper.rank > 0: + verbosity = 0 + unittest.main(verbosity=verbosity) diff --git a/tests/run_dist_tests.py b/tests/run_dist_tests.py new file mode 100644 index 000000000..207c0c371 --- /dev/null +++ b/tests/run_dist_tests.py @@ -0,0 +1,103 @@ +import os +import signal +import sys +import unittest +from subprocess import Popen +from typing import Union, Set +from unittest import TestSuite, TestCase + +import click + +os.environ['DISTRIBUTED_TESTS'] = '1' + + +def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]: + found_cases = set() + if isinstance(suite, TestSuite): + for x in suite: + found_cases.update(get_distributed_test_cases(x)) + + if isinstance(suite, TestCase): + case_id = suite.id() + + if case_id.startswith('distributed.') or \ + case_id.startswith('tests.distributed.'): + found_cases.add(case_id) + + if '_FailedTest' in case_id: + raise RuntimeError( + f'Errors encountered while listing test cases: {case_id}') + + return found_cases + + +@click.command() +@click.argument('test_cases', nargs=-1) +def run_distributed_suites(test_cases): + cases_names = get_distributed_test_cases( + unittest.defaultTestLoader.discover('.')) # Don't change the path! + cases_names = list(sorted(cases_names)) + print(cases_names) + if len(test_cases) > 0: + test_cases = set(test_cases) + cases_names = [x for x in cases_names if x in test_cases] + + if set(cases_names) != test_cases: + print('Some cases have not been found!', + test_cases - set(cases_names)) + sys.exit(1) + + print('Running', len(cases_names), 'tests') + p = None + success = True + exited = False + failed_test_cases = set() + + use_gpu_in_tests = os.environ.get('USE_GPU', 'false').lower() in [ + '1', 'true'] + if use_gpu_in_tests: + print('Running tests using GPUs') + import torch + nproc_per_node = torch.cuda.device_count() + else: + print('Running tests using CPU only') + nproc_per_node = 2 + + for case_name in cases_names: + if exited: + print('Exiting due to keyboard interrupt') + break + print('Running test:', case_name, flush=True) + try: + p = Popen( + ['python', '-m', 'torch.distributed.run', '--nnodes=1', + f'--nproc_per_node={nproc_per_node}', + '-m', 'unittest', case_name], + stdout=sys.stdout, stderr=sys.stderr) + p.communicate() + except KeyboardInterrupt: + success = False + exited = True + p.send_signal(signal.SIGINT) + finally: + exit_code = p.wait() + print('Test completed with code', exit_code) + success = success and exit_code == 0 + p = None + + if exit_code != 0: + failed_test_cases.add(case_name) + + if success: + print('Tests completed successfully') + sys.exit(0) + else: + print('The following tests terminated with errors:') + for failed_case in sorted(failed_test_cases): + print(failed_case) + + sys.exit(1) + + +if __name__ == '__main__': + run_distributed_suites() diff --git a/tests/test_avalanche_classification_dataset.py b/tests/test_avalanche_classification_dataset.py index 9dd5c972f..13bd81ec6 100644 --- a/tests/test_avalanche_classification_dataset.py +++ b/tests/test_avalanche_classification_dataset.py @@ -1713,7 +1713,7 @@ def test_replace_transforms(self): dataset_other = make_classification_dataset(dataset_reset) dataset_other = dataset_other.replace_current_transform_group( - (None, lambda l: l + 1) + (None, lambda val: val + 1) ) _, y6, _ = dataset_other[0] diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index fdced8935..f1820371e 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -53,7 +53,7 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, - evaluator=default_evaluator(), + evaluator=default_evaluator, ) ocl_benchmark = OnlineCLScenario(benchmark_streams, access_task_boundaries=True) @@ -68,7 +68,7 @@ def test_naive(self): train_mb_size=1, device=self.device, eval_mb_size=50, - evaluator=default_evaluator(), + evaluator=default_evaluator, ) ocl_benchmark = OnlineCLScenario(benchmark_streams, access_task_boundaries=False) diff --git a/tests/training/test_supervised_regression.py b/tests/training/test_supervised_regression.py index 6fc521c93..88b1d6020 100644 --- a/tests/training/test_supervised_regression.py +++ b/tests/training/test_supervised_regression.py @@ -317,7 +317,7 @@ def training_epoch(self, **kwargs): if self._stop_training: break - self._unpack_minibatch() + self.unpack_minibatch() trigger_plugins(self, "before_training_iteration") self.optimizer.zero_grad() @@ -354,7 +354,7 @@ def eval_dataset_adaptation(self, **kwargs): def eval_epoch(self, **kwargs): """Evaluation loop over the current `self.dataloader`.""" for self.mbatch in self.dataloader: - self._unpack_minibatch() + self.unpack_minibatch() trigger_plugins(self, "before_eval_iteration") trigger_plugins(self, "before_eval_forward") diff --git a/tests/unit_tests_utils.py b/tests/unit_tests_utils.py index bd6885d79..11032420a 100644 --- a/tests/unit_tests_utils.py +++ b/tests/unit_tests_utils.py @@ -29,7 +29,7 @@ if "UPDATE_METRICS" in os.environ: UPDATE_METRICS = os.environ["UPDATE_METRICS"].lower() == "true" -print(f"UPDATE_METRICS: {UPDATE_METRICS}") +# print(f"UPDATE_METRICS: {UPDATE_METRICS}") def is_github_action():