diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 38a50157c6..ea04ad0046 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -512,7 +512,7 @@ def __init__( self.swa_model = AveragedModel( module_for_swa, - device=device, + device=self.device, use_buffers=swa_params.use_buffers, averaging_method=swa_params.averaging_method, ema_decay=swa_params.ema_decay,