Skip to content

Commit

Permalink
Integrate into trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Sep 6, 2024
1 parent 276a4fb commit f7b6e76
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/olmo_core/optim/skip_step_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
from torch.optim.optimizer import Optimizer
Expand All @@ -18,6 +18,11 @@ class SkipStepOptimizer(Optimizer):
:data:`latest_grad_norm` to the current loss and grad norm, respectively, *before* calling
:meth:`step()`.
The :class:`~olmo_core.train.Trainer` will automatically set the :data:`latest_loss` whenever
its optimizer is a subclass of :class:`SkipStepOptimizer`, and the
:class:`~olmo_core.train.callbacks.GradClipperCallback` will automatically set the
:data:`latest_grad_norm`.
.. tip::
When implementing a :class:`SkipStepOptimizer` you should be careful to avoid host-device
syncs. You can use :meth:`get_step_factor()` within your :meth:`step()` method to do this.
Expand Down Expand Up @@ -50,10 +55,11 @@ def latest_loss(self, loss: torch.Tensor):
self._losses.pop(0)

@property
def latest_grad_norm(self) -> torch.Tensor:
def latest_grad_norm(self) -> Optional[torch.Tensor]:
if not self._grad_norms:
raise RuntimeError("'latest_grad_norm' has not been set yet")
return self._grad_norms[-1]
return None
else:
return self._grad_norms[-1]

@latest_grad_norm.setter
def latest_grad_norm(self, grad_norm: torch.Tensor):
Expand All @@ -73,8 +79,12 @@ def get_step_factor(self) -> torch.Tensor:
return torch.tensor(1.0).to(device=self.latest_loss.device, non_blocking=True)

loss_std = torch.std(torch.stack(self._losses[:-1]))
grad_norm_std = torch.std(torch.stack(self._grad_norms[:-1]))
return (
self.latest_loss <= self.sigma_factor * loss_std
and self.latest_grad_norm <= self.sigma_factor * grad_norm_std
)

if self._grad_norms:
grad_norm_std = torch.std(torch.stack(self._grad_norms[:-1]))
return (
self.latest_loss <= self.sigma_factor * loss_std
and self.latest_grad_norm <= self.sigma_factor * grad_norm_std
)
else:
return self.latest_loss <= self.sigma_factor * loss_std
4 changes: 4 additions & 0 deletions src/olmo_core/train/callbacks/grad_clipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from olmo_core.optim import SkipStepOptimizer

from .callback import Callback


Expand All @@ -26,3 +28,5 @@ def pre_optim_step(self):

# NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`.
self.trainer.record_metric("optim/total grad norm", grad_norm, reduce_type=None)
if isinstance(self.trainer.optim, SkipStepOptimizer):
self.trainer.optim.latest_grad_norm = grad_norm
4 changes: 4 additions & 0 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
cross_entropy_loss,
fused_cross_entropy_loss,
)
from ..optim import SkipStepOptimizer
from ..utils import move_to_device
from .callbacks import (
Callback,
Expand Down Expand Up @@ -1148,6 +1149,9 @@ def _train_batch(self, batch: Dict[str, Any]):
if z_batch_loss is not None:
self.record_metric(TRAIN_Z_LOSS_METRIC, z_batch_loss, ReduceType.mean)

if isinstance(self.optim, SkipStepOptimizer):
self.optim.latest_loss = ce_batch_loss

# Run through callbacks.
for callback in self.callbacks.values():
callback.pre_optim_step()
Expand Down

0 comments on commit f7b6e76

Please sign in to comment.