Skip to content

Commit

Permalink
add to docs, reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 4, 2024
1 parent ad550cd commit 4d6369f
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 17 deletions.
6 changes: 6 additions & 0 deletions docs/source/distributed/tensors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``distributed.tensors``
=======================

.. automodule:: olmo_core.distributed.tensors
:members:
:member-order: bysource
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
:caption: API Reference

exceptions.rst
utils.rst
io.rst
utils.rst
distributed/checkpoint.rst
distributed/fsdp.rst
distributed/tensors.rst

.. toctree::
:hidden:
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
)
from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES

from .sharded_flat_tensor import ShardedFlatTensor, ShardingSpec
from .tensors import ShardedFlatTensor, ShardingSpec
from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object

log = logging.getLogger(__name__)
Expand Down
66 changes: 66 additions & 0 deletions src/olmo_core/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,72 @@
- Well-defined handling of frozen params. You can mix and match within an FSDP instance as long
as you're consistent across the process group with which parameters are frozen.
- Full support for CPU-only training and inference via the GLOO backend.
- Low-overhead checkpointing with :mod:`olmo_core.distributed.checkpoint`.
Usage Tips
----------
- Always initialize your optimizer *after* wrapping your model with FSDP.
- When you use initialize model (prior to wrapping with FSDP), use ``device=torch.device("meta")``
when initializing *parameters* to save memory. :class:`FSDP` will automatically materialize and
move parameters to the right device when wrapping.
Then you can use :meth:`FSDP.apply()` to initialize parameters how you want.
- Analogous to with PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`, you should
use :func:`FSDP.clip_grad_norm_()` for clipping gradient norms instead of :func:`torch.nn.utils.clip_grad_norm_()`.
- Use activation checkpointing via :func:`torch.utils.checkpoint.checkpoint()` to save more memory
during the forward and backward pass at the expense of more computation.
- To save and load checkpoints for your FSDP model and its optimizer, use
:func:`~olmo_core.distributed.checkpoint.save_model_and_optim_state()` and
:func:`~olmo_core.distributed.checkpoint.load_model_and_optim_state()`, respectively.
Implementation Details
----------------------
When you wrap a :class:`~torch.nn.Module` with :class:`FSDP`, the wrapping FSDP instance will replace
each original parameter in the module with a :class:`~olmo_core.distributed.tensors.ShardedFlatParameter` instance,
and each rank will only keep a shard of the original data. Buffers are left as-is.
.. note::
Further, the sharded data for all of the :class:`~olmo_core.distributed.tensors.ShardedFlatParameter`
instances will be collected into a single :class:`FlatParamHandle`, and each flat parameter will
just hold a view into a slice of the data managed by the handle. This makes gathering the full
params more efficient as it only requires a single all-gather per FSDP node.
Forward Pass
~~~~~~~~~~~~
When the :meth:`~torch.nn.Module.forward()` method is called on the wrapping FSDP instance, it will gather
the full unsharded data for each parameter in the desired :class:`~torch.dtype`
(as defined by the :class:`FSDPPrecision` settings) while caching the sharded data behind the scenes.
Then it runs the forward method of the wrapped module, which is completely unsharded at that point.
After the forward method of the wrapped module returns, the wrapping FSDP instance will reshard
the parameters and, if gradients are enabled, register backward hooks to manage the state of parameters
and gradients during the backward pass.
During the first forward pass the root FSDP instance will also record the order of execution of all
FSDP children, and use that order to prefetch the full parameters for its FSDP children during
subsequent forward passes. The number of children that are prefetched at once is controlled by the
``max_prefetch_count`` setting.
.. note::
When CUDA is available :class:`FSDP` instances utilize multiple CUDA streams in order to overlap
communication (e.g. unsharding params or reducing gradients) with computation
(e.g. the forward pass or computing gradients during the backward pass).
Backward Pass
~~~~~~~~~~~~~
At the end of the forward method, the wrapping FSDP instance registers ephemeral "pre-backward" and "post-backward" hooks
to unshard the parameters and reduce-scatter the gradients, respectively, during the backward pass.
At the end of the backward pass the :attr:`~torch.Tensor.grad` attribute of each (non-frozen) parameter will
be the shard of the full gradient corresponding to the shard of the full parameter, i.e. it will
have the same shape/size as the sharded parameter.
Just how the root FSDP instance records the execution order of its FSDP children during the first
forward pass, the root will also record the order during the first backward pass and use that
to prefetch the full parameters of its children during subsequent backward passes.
API Reference
-------------
Expand Down
7 changes: 5 additions & 2 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import torch
import torch.distributed as dist

from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec
from olmo_core.distributed.tensors import (
ShardedFlatParameter,
ShardedFlatTensor,
ShardingSpec,
)
from olmo_core.distributed.utils import get_rank, get_world_size
from olmo_core.utils import get_default_device

Expand Down
7 changes: 4 additions & 3 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import torch.distributed as dist
import torch.nn as nn

from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.utils import apply_to_tensors, get_default_device, get_grad_norm
from olmo_core.distributed.tensors import ShardedFlatParameter
from olmo_core.utils import apply_to_tensors, gc_cuda, get_default_device, get_grad_norm

from .flat_param_handle import FlatParamHandle
from .state import FSDPState
Expand Down Expand Up @@ -498,10 +498,11 @@ def _shard(self):

# Collate the data from all flat params into the flat param handle. The data in each flat param
# will then just be a view into a slice of the data managed by the flat param handle.
# This makes unsharded more efficient.
# This makes unsharding more efficient as we'll only need a single `all_gather` call.
self.state.flat_param_handle = FlatParamHandle.collate_flat_params(
params, param_fqns, process_group=self.process_group, device=self.device
)
gc_cuda()

@torch.no_grad()
def _unshard(
Expand Down
8 changes: 8 additions & 0 deletions src/olmo_core/distributed/tensors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Distributed tensor and parameter classes.
"""

from .sharded_flat_parameter import ShardedFlatParameter
from .sharded_flat_tensor import ShardedFlatTensor, ShardingSpec

__all__ = ["ShardedFlatTensor", "ShardedFlatParameter", "ShardingSpec"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@


class ShardedFlatParameter(ShardedFlatTensor, nn.Parameter):
"""
A :class:`~torch.nn.parameter.Parameter` version of :class:`ShardedFlatTensor`.
"""

def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> ShardedFlatParameter:
if data is not None and data.ndim != 1:
raise ValueError(f"{cls.__name__} requires flat data! Got {data.shape}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .utils import get_rank, get_world_size
from ..utils import get_rank, get_world_size

__all__ = ["ShardedFlatTensor", "ShardingSpec"]

Expand Down
10 changes: 10 additions & 0 deletions src/olmo_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import gc
import os
import time
from enum import Enum
Expand Down Expand Up @@ -172,3 +173,12 @@ def same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
x_ptrs = set(e.data_ptr() for e in x.view(-1))
y_ptrs = set(e.data_ptr() for e in y.view(-1))
return (x_ptrs <= y_ptrs) or (y_ptrs <= x_ptrs)


def gc_cuda():
"""
Run CUDA garbage collection.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
7 changes: 5 additions & 2 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
unshard_optim_state,
)
from olmo_core.distributed.fsdp import FSDP
from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec
from olmo_core.distributed.tensors import (
ShardedFlatParameter,
ShardedFlatTensor,
ShardingSpec,
)

from .utils import (
BACKENDS,
Expand Down
2 changes: 1 addition & 1 deletion src/test/distributed/fsdp/flat_param_handle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

from olmo_core.distributed.fsdp.flat_param_handle import FlatParamHandle
from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.tensors import ShardedFlatParameter
from olmo_core.utils import same_storage

from ..utils import BACKENDS, get_default_device, run_distributed_test
Expand Down
2 changes: 1 addition & 1 deletion src/test/distributed/fsdp/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP

from olmo_core.distributed.fsdp import FSDP, FSDPDebugConfig
from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.tensors import ShardedFlatParameter

from ..utils import (
BACKENDS,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import torch
import torch.distributed as dist

from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec
from olmo_core.distributed.tensors.sharded_flat_parameter import ShardedFlatParameter
from olmo_core.distributed.tensors.sharded_flat_tensor import (
ShardedFlatTensor,
ShardingSpec,
)

from .utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test
from ..utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test


def test_init_empty_sharded_parameter():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import torch
import torch.distributed as dist

from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec
from olmo_core.distributed.tensors.sharded_flat_tensor import (
ShardedFlatTensor,
ShardingSpec,
)

from .utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test
from ..utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test


def test_init_sharded():
Expand Down

0 comments on commit 4d6369f

Please sign in to comment.