Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 15, 2024
1 parent ec0fa74 commit b193f1a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 13 deletions.
6 changes: 2 additions & 4 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger

from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, get_optimizer_state_dict_numl, has_index_file
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file

__all__ = ["CheckpointIO"]

Expand Down Expand Up @@ -230,9 +230,7 @@ def save_optimizer(
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if not shard and use_async:
size_per_shard = get_optimizer_state_dict_numl(optimizer)
if shard or use_async:
if shard:
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
Expand Down
9 changes: 0 additions & 9 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,3 @@ def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
for name, tensor in state_dict.items():
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
return pin_mem


def get_optimizer_state_dict_numl(optimizer):
total_size = 0
state_dict = optimizer.state_dict()
for param_group in state_dict["state"].values():
for param_name, param_tensor in param_group.items():
total_size += torch.tensor(param_tensor).numel() if param_name == "step" else param_tensor.numel()
return total_size

0 comments on commit b193f1a

Please sign in to comment.