From da17c2c68f3c72ea41267e59879616829a2f617b Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Sun, 12 May 2024 22:48:33 -0700 Subject: [PATCH 1/7] Pass load_trainer_state --- olmo/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/olmo/train.py b/olmo/train.py index 0653b6bf3..20765dcc6 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -495,6 +495,7 @@ def restore_sharded_checkpoint( self.optim, local_cache=local_cache, load_optimizer_state=load_optimizer_state, + load_trainer_state=load_trainer_state, ) if load_trainer_state: self.load_trainer_state_dict(trainer_state) From d9aa0440d87b22b6b805399256b8cd6fd331b535 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Sun, 12 May 2024 22:53:14 -0700 Subject: [PATCH 2/7] Optionally load trainer state --- olmo/checkpoint.py | 58 ++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 343bdfa31..565cec978 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -516,6 +516,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: """ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state. @@ -672,6 +673,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: with FSDP.state_dict_type( fsdp_model, @@ -745,11 +747,13 @@ def restore_checkpoint( del optim_state_dict_to_load # Load other state. - try: - trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache) - except FileNotFoundError: - # for backwards compatibility - trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache) + except FileNotFoundError: + # for backwards compatibility + trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache) barrier() return trainer_state @@ -866,6 +870,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: # Load model and optimizer state in place. log.info("Loading model and optimizer state...") @@ -879,14 +884,16 @@ def restore_checkpoint( # Load trainer state dict. log.info("Loading trainer state...") - try: - trainer_state = load_state_dict( - load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache - ) - except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict( + load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache + ) + except FileNotFoundError: + # Fall back to rank 0 train state. + # This can happen when we're restoring a checkpoint with a different world size. + trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) barrier() return trainer_state @@ -943,6 +950,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: with FSDP.state_dict_type( fsdp_model, @@ -1556,6 +1564,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: # Load metadata and make sure checkpoint is compatible. metadata = self._load_metadata(load_path, local_cache=local_cache) @@ -1593,7 +1602,9 @@ def restore_checkpoint( # Load local trainer state. log.info("Loading local trainer state...") - trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) barrier() return trainer_state @@ -1862,6 +1873,7 @@ def restore_checkpoint( *, local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, + load_trainer_state: bool = True, ) -> Dict[str, Any]: from olmo_core.distributed.checkpoint import ( # type: ignore load_model_and_optim_state, @@ -1871,14 +1883,16 @@ def restore_checkpoint( load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None) log.info("Loading trainer state...") - try: - trainer_state = load_state_dict( - load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache - ) - except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) + trainer_state = None + if load_trainer_state: + try: + trainer_state = load_state_dict( + load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache + ) + except FileNotFoundError: + # Fall back to rank 0 train state. + # This can happen when we're restoring a checkpoint with a different world size. + trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) barrier() return trainer_state From 5f8f27e863181c6453b89dd56fb5aef7423f544a Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 13 May 2024 09:22:39 -0700 Subject: [PATCH 3/7] Fix typing Co-authored-by: Pete --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 565cec978..4fd2dbc14 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -517,7 +517,7 @@ def restore_checkpoint( local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, load_trainer_state: bool = True, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state. """ From 64eb5a25959b4064fa4a4f33aa1ea5d68a11eecb Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 13 May 2024 09:24:21 -0700 Subject: [PATCH 4/7] Fix typing --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 4fd2dbc14..c877df1f6 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -674,7 +674,7 @@ def restore_checkpoint( local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, load_trainer_state: bool = True, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: with FSDP.state_dict_type( fsdp_model, state_dict_type=StateDictType.FULL_STATE_DICT, From 9fe6e39041cb981a4b41aeab651c6ce0f0266506 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 13 May 2024 09:24:26 -0700 Subject: [PATCH 5/7] Fix typing --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index c877df1f6..cfd1cbf1d 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -871,7 +871,7 @@ def restore_checkpoint( local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, load_trainer_state: bool = True, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: # Load model and optimizer state in place. log.info("Loading model and optimizer state...") load_fsdp_model_and_optim_state( From ae6fcc58e205baf9d6d70afc0dc85886ba6f9994 Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 13 May 2024 09:24:31 -0700 Subject: [PATCH 6/7] Fix typing --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index cfd1cbf1d..83366e6b2 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1565,7 +1565,7 @@ def restore_checkpoint( local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, load_trainer_state: bool = True, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: # Load metadata and make sure checkpoint is compatible. metadata = self._load_metadata(load_path, local_cache=local_cache) assert metadata.world_size == get_world_size() From 6346ea123658ddfe2f4eb3b2d199fb2bd01ee44d Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Mon, 13 May 2024 09:24:36 -0700 Subject: [PATCH 7/7] Fix typing --- olmo/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 83366e6b2..68afe31aa 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1874,7 +1874,7 @@ def restore_checkpoint( local_cache: Optional[PathOrStr] = None, load_optimizer_state: bool = True, load_trainer_state: bool = True, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: from olmo_core.distributed.checkpoint import ( # type: ignore load_model_and_optim_state, )