Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed support (rework) #996

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6bd297b
Merge remote-tracking branch 'upstream/master'
lrzpellegrini Apr 8, 2022
55a5480
Reworking distributed support (WIP).
lrzpellegrini Apr 12, 2022
b0ce2e3
Working strategy composition and example (naive, replay, scheduler).
lrzpellegrini Apr 21, 2022
976e5c5
Fixed pep8 issues.
lrzpellegrini Apr 22, 2022
efb7f86
Fixed typing error. Removed debug code.
lrzpellegrini Apr 22, 2022
e13f067
Merge remote-tracking branch 'upstream/master' into distributed_suppo…
lrzpellegrini Apr 22, 2022
3017aeb
Removed debug prints.
lrzpellegrini Apr 22, 2022
f8882d7
Implemented lazy creation of the default logger.
lrzpellegrini Apr 29, 2022
8571b91
[Distributed] Simplified internal API and example. Added in-code guide.
lrzpellegrini Apr 29, 2022
b752568
Added support for general use_local in strategies.
lrzpellegrini Apr 29, 2022
f5eaf96
Merge remote-tracking branch 'upstream/master' into distributed_suppo…
lrzpellegrini Apr 29, 2022
b13cc9b
Merge remote-tracking branch 'upstream/master' into distributed_suppo…
lrzpellegrini Jul 19, 2022
d1b9d28
Add type hints to _make_data_loader. Fix distributed training example.
lrzpellegrini Jul 19, 2022
f104a0e
Partial merge remote-tracking branch 'upstream/master' into distribut…
lrzpellegrini Nov 10, 2022
88f75a9
Integrated distributed training with RNGManager, new collate system. …
lrzpellegrini Nov 22, 2022
1717b8d
Improved management of dataloader arguments in strategies. Improved d…
lrzpellegrini Nov 23, 2022
da5c58c
Improved distributed strategy unit tests. Fixed PEP8 issues.
lrzpellegrini Nov 23, 2022
cdcd8c4
Aligned environment update action content.
lrzpellegrini Nov 23, 2022
2a93ad8
Fix multitask issues. Improve distributed training support and tests.
lrzpellegrini Dec 11, 2022
1174f33
Added additional unit tests. Issue with all_gather to be fixed.
lrzpellegrini Jan 10, 2023
6a3dd1f
Tests for DistributedHelper. Distributed support field in plugins.
lrzpellegrini Jan 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions avalanche/benchmarks/classic/cmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.distributed import DistributedHelper

_default_mnist_train_transform = Compose(
[ToTensor(), Normalize((0.1307,), (0.3081,))]
Expand Down Expand Up @@ -394,9 +395,12 @@ def _get_mnist_dataset(dataset_root):
if dataset_root is None:
dataset_root = default_dataset_location("mnist")

train_set = MNIST(root=dataset_root, train=True, download=True)
with DistributedHelper.main_process_first():
train_set = MNIST(root=dataset_root,
train=True, download=True)

test_set = MNIST(root=dataset_root, train=False, download=True)
test_set = MNIST(root=dataset_root,
train=False, download=True)

return train_set, test_set

Expand Down
69 changes: 69 additions & 0 deletions avalanche/benchmarks/utils/collate_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 21-04-2022 #
# Author(s): Antonio Carta, Lorenzo Pellegrini #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################

import itertools
from collections import defaultdict

import torch


def classification_collate_mbatches_fn(mbatches):
"""Combines multiple mini-batches together.

Concatenates each tensor in the mini-batches along dimension 0 (usually
this is the batch size).

:param mbatches: sequence of mini-batches.
:return: a single mini-batch
"""
batch = []
for i in range(len(mbatches[0])):
t = classification_single_values_collate_fn(
[el[i] for el in mbatches], i)
batch.append(t)
return batch


def classification_single_values_collate_fn(values_list, index):
return torch.cat(values_list, dim=0)


def detection_collate_fn(batch):
"""
Collate function used when loading detection datasets using a DataLoader.
"""
return tuple(zip(*batch))


def detection_collate_mbatches_fn(mbatches):
"""
Collate function used when loading detection datasets using a DataLoader.
"""
lists_dict = defaultdict(list)
for mb in mbatches:
for mb_elem_idx, mb_elem in enumerate(mb):
lists_dict[mb_elem_idx].append(mb_elem)

lists = []
for mb_elem_idx in range(max(lists_dict.keys()) + 1):
lists.append(list(itertools.chain.from_iterable(
lists_dict[mb_elem_idx]
)))

return lists


__all__ = [
'classification_collate_mbatches_fn',
'classification_single_values_collate_fn',
'detection_collate_fn',
'detection_collate_mbatches_fn'
]
Loading