Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally load trainer state #573

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,8 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
load_trainer_state: bool = True,
) -> Optional[Dict[str, Any]]:
"""
Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
"""
Expand Down Expand Up @@ -678,7 +679,8 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
load_trainer_state: bool = True,
) -> Optional[Dict[str, Any]]:
with FSDP.state_dict_type(
fsdp_model,
state_dict_type=StateDictType.FULL_STATE_DICT,
Expand Down Expand Up @@ -751,11 +753,13 @@ def restore_checkpoint(
del optim_state_dict_to_load

# Load other state.
try:
trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
except FileNotFoundError:
# for backwards compatibility
trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
trainer_state = None
if load_trainer_state:
try:
trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
except FileNotFoundError:
# for backwards compatibility
trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
barrier()
return trainer_state

Expand Down Expand Up @@ -872,7 +876,8 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
load_trainer_state: bool = True,
) -> Optional[Dict[str, Any]]:
# Load model and optimizer state in place.
log.info("Loading model and optimizer state...")
load_fsdp_model_and_optim_state(
Expand All @@ -885,14 +890,16 @@ def restore_checkpoint(

# Load trainer state dict.
log.info("Loading trainer state...")
try:
trainer_state = load_state_dict(
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
trainer_state = None
if load_trainer_state:
try:
trainer_state = load_state_dict(
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
barrier()
return trainer_state

Expand Down Expand Up @@ -949,6 +956,7 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
) -> Dict[str, Any]:
with FSDP.state_dict_type(
fsdp_model,
Expand Down Expand Up @@ -1562,7 +1570,8 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
load_trainer_state: bool = True,
) -> Optional[Dict[str, Any]]:
# Load metadata and make sure checkpoint is compatible.
metadata = self._load_metadata(load_path, local_cache=local_cache)
assert metadata.world_size == get_world_size()
Expand Down Expand Up @@ -1599,7 +1608,9 @@ def restore_checkpoint(

# Load local trainer state.
log.info("Loading local trainer state...")
trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
trainer_state = None
if load_trainer_state:
trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
barrier()
return trainer_state

Expand Down Expand Up @@ -1868,7 +1879,8 @@ def restore_checkpoint(
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
load_trainer_state: bool = True,
) -> Optional[Dict[str, Any]]:
from olmo_core.distributed.checkpoint import ( # type: ignore
load_model_and_optim_state,
)
Expand All @@ -1877,14 +1889,16 @@ def restore_checkpoint(
load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None)

log.info("Loading trainer state...")
try:
trainer_state = load_state_dict(
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
trainer_state = None
if load_trainer_state:
try:
trainer_state = load_state_dict(
load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)

barrier()
return trainer_state
Expand Down
1 change: 1 addition & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def restore_sharded_checkpoint(
self.optim,
local_cache=local_cache,
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
)
if load_trainer_state:
self.load_trainer_state_dict(trainer_state)
Expand Down
Loading