Skip to content

Commit

Permalink
Add Optimizer FSDP and AC on 3xUnet/5xUnet
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #614

Reviewed By: wat3rBro, YanjunChen329

Differential Revision: D48544742

fbshipit-source-id: 9e49f13aa50e065c30e5551a636a83afd2d11acd
  • Loading branch information
Jessica Zhong authored and facebook-github-bot committed Aug 24, 2023
1 parent c3169c1 commit 7ad54f5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
save_binary_outputs,
)
from detectron2.engine.defaults import create_ddp_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

logger = logging.getLogger("d2go.tools.train_net")
# Make sure logging is set up centrally even for e.g. dataloading workers which
Expand Down Expand Up @@ -73,7 +74,7 @@ def main(

# Use DDP if FSDP is not enabled
# TODO (T142223289): rewrite ddp wrapping as modeling hook
if not is_fsdp_enabled(cfg):
if not isinstance(model, FSDP):
model = create_ddp_model(
model,
fp16_compression=cfg.MODEL.DDP_FP16_GRAD_COMPRESS,
Expand Down

0 comments on commit 7ad54f5

Please sign in to comment.