Skip to content

Commit

Permalink
FSDP fixes (#8)
Browse files Browse the repository at this point in the history
* check for nans when unsharding

* don't cast root forward inputs

* use all_gather_into_tensor when possible

* fix

* fix

* fix dtype

* clean up

* clean up

* assert dtypes in `summon_full_params` context

* no write back

* more tests

* explicit cast when checking

* updates

* ensure `cast` and `writeback` not both set

* cast in other direction

* revert

* Add mp option to train script

* update

* adjust LR

* update prefetching logic in forward pass

* update backward prefetch logic

* define stream in top-level of package

* debugging

* add to stream test

* more test

* clean up

* Try recording stream

* Add comment

* clean up

* updates

* fix how many mods are prefetched

* don't check for nan loss

* clean up
  • Loading branch information
epwalsh authored Apr 12, 2024
1 parent 15be9f2 commit 963f7dd
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 102 deletions.
27 changes: 24 additions & 3 deletions src/benchmarks/fsdp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TransformerConfig:
mlp_ratio: int = 4
max_sequence_length: int = 2048
init_device: torch.device = torch.device("cpu")
debug: bool = False

@classmethod
def tiniest(cls) -> TransformerConfig:
Expand All @@ -50,6 +51,7 @@ def medium(cls) -> TransformerConfig:
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
self.wpe = nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)
self.blocks = nn.ModuleList(
Expand Down Expand Up @@ -82,11 +84,24 @@ def __init__(self, config: TransformerConfig):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.config.debug:
for param in self.parameters(recurse=False):
assert not param.isnan().any()
assert not x.isnan().any()
x = self.wte(x)
if self.config.debug:
assert not x.isnan().any()
x = x + self.wpe(self.positions)
if self.config.debug:
assert not x.isnan().any()
for block in self.blocks:
x = block(x, src_mask=self.causal_mask, is_causal=True)
return self.decoder(x)
if self.config.debug:
assert not x.isnan().any()
x = self.decoder(x)
if self.config.debug:
assert not x.isnan().any()
return x


class Dataloader:
Expand Down Expand Up @@ -135,6 +150,8 @@ def build_components(
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
wrap_blocks: bool = True,
mixed_precision: bool = True,
max_prefetch_count: int = 1,
learning_rate: float = 1e-4,
) -> Tuple[nn.Module, torch.optim.Optimizer, Dataloader]:
model = Transformer(config)

Expand All @@ -148,6 +165,7 @@ def build_components(
precision=FSDPPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
if mixed_precision
else None,
max_prefetch_count=max_prefetch_count,
)

model.apply(init_function)
Expand All @@ -164,7 +182,10 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
model = FullyShardedDataParallel(
model,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
cast_root_forward_inputs=False,
)
if mixed_precision
else None,
Expand All @@ -180,7 +201,7 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
print_rank0(model)

print_rank0("Initializing optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=1e-5)
optim = torch.optim.AdamW(model.parameters(), lr=learning_rate)
return model, optim, Dataloader(batch_size, config, num_batches=num_batches)


Expand Down
37 changes: 31 additions & 6 deletions src/benchmarks/fsdp/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,35 @@ def main(
load_model_and_optim_state(checkpoint_dir, olmo_model, olmo_optim)

print_rank0("Checking state dict...")
with TorchFSDP.summon_full_params(torch_model), olmo_model.summon_full_params():
torch_state_dict = {k.replace("_fsdp_wrapped_module.", ""): v for k, v in torch_model.state_dict().items()}
olmo_state_dict = olmo_model.state_dict()
assert torch_state_dict.keys() == olmo_state_dict.keys()
for key in torch_state_dict:
with TorchFSDP.summon_full_params(torch_model, writeback=False), olmo_model.summon_full_params(
writeback=False
):
torch_fp32_state_dict = {
k.replace("_fsdp_wrapped_module.", ""): v for k, v in torch_model.state_dict().items()
}
olmo_fp32_state_dict = olmo_model.state_dict()
assert torch_fp32_state_dict.keys() == olmo_fp32_state_dict.keys()
for key in torch_fp32_state_dict:
assert torch_fp32_state_dict[key].dtype == torch.float32
assert olmo_fp32_state_dict[key].dtype == torch.float32
torch.testing.assert_close(
torch_state_dict[key], olmo_state_dict[key], msg=lambda msg: f"Failure for {key}: {msg}"
torch_fp32_state_dict[key], olmo_fp32_state_dict[key], msg=lambda msg: f"Failure for {key}: {msg}"
)

if mixed_precision:
print_rank0("Checking gathering full params in low precision...")
with olmo_model.summon_full_params(cast=True, writeback=False):
olmo_bf16_state_dict = olmo_model.state_dict()
assert olmo_bf16_state_dict.keys() == olmo_fp32_state_dict.keys()
for key in olmo_bf16_state_dict.keys():
torch.testing.assert_close(
olmo_bf16_state_dict[key],
olmo_fp32_state_dict[key].to(torch.bfloat16),
msg=lambda msg: f"Failure for {key}: {msg}",
rtol=1.3e-6,
atol=1e-5,
)

if dry_run:
print_rank0("Dry run complete")
return
Expand All @@ -84,6 +104,11 @@ def main(
olmo_logits = olmo_model(batch1)
torch_loss = compute_loss(torch_model, batch1, logits=torch_logits)
olmo_loss = compute_loss(olmo_model, batch1, logits=olmo_logits)

if mixed_precision:
assert torch_logits.dtype == torch.bfloat16
assert olmo_logits.dtype == torch.bfloat16

torch.testing.assert_close(olmo_logits, torch_logits)
torch.testing.assert_close(olmo_loss, torch_loss)

Expand Down
35 changes: 30 additions & 5 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ def main(
dry_run: bool = False,
save_path: Optional[str] = None,
load_path: Optional[str] = None,
mixed_precision: bool = True,
**kwargs,
):
model, optim, dataloader = build_components(
config, batch_size, num_batches=num_batches, fsdp_wrapper=fsdp_wrapper
config,
batch_size,
num_batches=num_batches,
fsdp_wrapper=fsdp_wrapper,
mixed_precision=mixed_precision,
**kwargs,
)

if load_path is not None:
Expand All @@ -58,13 +65,11 @@ def main(
optim.zero_grad()

# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16):
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
loss = compute_loss(model, batch)

# Trigger backward pass.
loss.backward()
if not torch.isfinite(loss):
raise ValueError("NaN loss encountered.")

# Clip gradient norms.
model.clip_grad_norm_(1.0)
Expand All @@ -76,7 +81,7 @@ def main(
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.1f}",
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
)

if save_path is not None:
Expand Down Expand Up @@ -129,8 +134,24 @@ def main(
"--load-path",
type=str,
)
parser.add_argument(
"--no-mixed-precision",
action="store_true",
)
parser.add_argument(
"--max-prefetch-count",
type=int,
default=1,
)
parser.add_argument(
"--lr",
type=float,
default=1e-4,
)
args = parser.parse_args()

mixed_precision = not args.no_mixed_precision

config: TransformerConfig
if args.model_size == "tiny":
config = TransformerConfig.tiny()
Expand All @@ -143,6 +164,7 @@ def main(

if args.debug:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
config.debug = True

dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
Expand All @@ -158,4 +180,7 @@ def main(
dry_run=args.dry_run,
save_path=args.save_path,
load_path=args.load_path,
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
)
15 changes: 13 additions & 2 deletions src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def unshard(
# Load the state dict in place.
self.load(dir, state_dict, metadata=metadata, no_dist=no_dist or rank0_only)

# Check for NaNs which would indicate we didn't fill the state dict correctly.
for key, tensor in state_dict.items():
if tensor.isnan().any().item():
raise RuntimeError("error loading {key} from checkpoint, nans encountered")

return state_dict

def get_metadata(self, dir: str, no_dist: bool = False) -> StorageMetadata:
Expand Down Expand Up @@ -665,7 +670,10 @@ def torch_dtype(self) -> torch.dtype:
def materialize_empty(
self, *, device: Optional[torch.device] = None, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
return torch.empty(shape if shape is not None else self.shape, dtype=self.torch_dtype, device=device)
tensor = torch.empty(shape if shape is not None else self.shape, dtype=self.torch_dtype, device=device)
if tensor.dtype.is_floating_point:
tensor.fill_(torch.nan)
return tensor

def materialize_from_sharded(
self, tensor: torch.Tensor, device: Optional[torch.device] = None
Expand All @@ -675,7 +683,10 @@ def materialize_from_sharded(
raise ValueError(
f"unexpected shape for sharded tensor, expected {self.shape}, got {tensor.unsharded_shape}"
)
return torch.empty(tensor.shape, device=device, dtype=self.torch_dtype)
tensor = torch.empty(tensor.shape, device=device, dtype=self.torch_dtype)
if tensor.dtype.is_floating_point:
tensor.fill_(torch.nan)
return tensor
else:
raise NotImplementedError(f"`materialize_from_sharded()` not implemented for {tensor}")

Expand Down
35 changes: 27 additions & 8 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class FlatParamHandle:
The FQNs of the managed params.
"""

grads: List[Optional[torch.Tensor]] = field(default_factory=list)
grads_cache: List[Optional[torch.Tensor]] = field(default_factory=list)
"""
Used for caching gradients during gradient accumulation.
"""
Expand Down Expand Up @@ -121,7 +121,7 @@ def collate_flat_params(
return cls(
params=params,
param_fqns=list(param_fqns),
grads=[None] * len(params),
grads_cache=[None] * len(params),
params_data=params_data,
params_offsets_per_rank=params_offsets_per_rank,
process_group=process_group,
Expand All @@ -134,9 +134,27 @@ def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False
"""
if not self.params:
return

local_rank = get_rank(self.process_group)
world_size = get_world_size(self.process_group)
all_params_unsharded_data = self.params_data.gather(dtype=dtype, rank0_only=rank0_only)

# Gather full, padded, unsharded data for all params.
all_params_unsharded_data: torch.Tensor
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
all_params_unsharded_data = self.params_data.gather(dtype=dtype, rank0_only=rank0_only)
else:
# We prefer to use `all_gather_into_tensor()` directly when possible as it involves
# fewer allocations.
all_params_unsharded_data = torch.empty(
self.params_data.unsharded_shape, dtype=dtype or self.params_data.dtype, device=self.device
)
dist.all_gather_into_tensor(
all_params_unsharded_data,
self.params_data.data.to(dtype or self.params_data.dtype),
group=self.process_group,
)

# Set the data for each param as a view into `all_params_unsharded_data`.
for i, (param, param_offsets) in enumerate(zip(self.params, self.params_offsets_per_rank)):
if rank0_only and local_rank != 0:
param.unshard_(
Expand All @@ -163,8 +181,8 @@ def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False
# We should only be caching these between the pre-backward and post-backward
# hooks. The post-backward hook will remove the cached grad as it accumulates
# it into the persistent sharded grad.
assert self.grads[i] is None
self.grads[i] = param.grad.data
assert self.grads_cache[i] is None
self.grads_cache[i] = param.grad.data
param.grad = None

del all_params_unsharded_data
Expand Down Expand Up @@ -194,7 +212,8 @@ def reduce_scatter_grads(
if grad_dtype is None:
grad_dtype = param.dtype

# TODO: batch reductions together
# TODO: batch reductions together? This is complicated, especially if we want to allow
# a mixture of trainable and frozen params.

# Only NCCL supports 'reduce_scatter'. So with other backends we use 'all_reduce'.
if dist.get_backend() == dist.Backend.NCCL:
Expand All @@ -209,7 +228,7 @@ def reduce_scatter_grads(

del unsharded_grad

if (cached_grad := self.grads[i]) is not None:
if (cached_grad := self.grads_cache[i]) is not None:
param.grad.add_(cached_grad)
self.grads[i] = None
self.grads_cache[i] = None
del cached_grad
Loading

0 comments on commit 963f7dd

Please sign in to comment.