Skip to content

Commit

Permalink
remove .abs() and avoid division by zero
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Sep 6, 2024
1 parent e56024d commit a14de85
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/olmo_core/optim/skip_step_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,13 @@ def get_step_factor(self) -> torch.Tensor:
return torch.tensor(1.0).to(device=self.device, non_blocking=True)

loss_std, loss_mean = torch.std_mean(torch.stack(self._losses[:-1]))
loss_z_score = (self.latest_loss - loss_mean).abs().div(loss_std)

if self._grad_norms:
grad_norm_std, grad_norm_mean = torch.std_mean(torch.stack(self._grad_norms[:-1]))
grad_norm_z_score = (self.latest_grad_norm - grad_norm_mean).abs().div(grad_norm_std)
return loss_z_score <= self.sigma_factor and grad_norm_z_score <= self.sigma_factor
return ((self.latest_loss - loss_mean) <= self.sigma_factor * loss_std) and (
(self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std
)
else:
return loss_z_score <= self.sigma_factor
return (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std

@property
def step_skipped(self) -> torch.Tensor:
Expand Down

0 comments on commit a14de85

Please sign in to comment.