Skip to content

Commit

Permalink
use some threads when unsharding by default
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 9, 2024
1 parent a581b06 commit de2ec70
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
serialize_to_tensor,
upload,
)
from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES
from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES, default_thread_count

from .tensors import ShardedFlatTensor, ShardingSpec
from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object
Expand Down Expand Up @@ -510,8 +510,12 @@ def unshard(
:param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed
context. Other ranks will receive an empty dictionary.
:param no_dist: Set to true to avoid any distributed communication whatsoever.
:param num_threads: The maximum number of threads to use to unshard the checkpoint.
Increasing ``num_threads`` can lead to a substantial speed up, especially when loading
from a remote checkpoint. Set to ``0`` to disable threading.
"""
dir = self._normalize_dir(dir)
num_threads = num_threads if num_threads is not None else default_thread_count()

if rank0_only and no_dist and get_rank() != 0:
raise ValueError(
Expand Down

0 comments on commit de2ec70

Please sign in to comment.