Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a frequency parameter to TriggerWandbSyncLightningCallback #101

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
17 changes: 15 additions & 2 deletions src/wandb_osh/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@


class TriggerWandbSyncHook:
def __init__(self, communication_dir: PathLike = _comm_default_dir):
def __init__(
self,
communication_dir: PathLike = _comm_default_dir,
sync_every_n_epochs: int = 1,
):
"""Hook to trigger synchronization of wandb with wandb-osh

Args:
Expand All @@ -26,16 +30,25 @@ def __init__(self, communication_dir: PathLike = _comm_default_dir):
__version__,
self.communication_dir,
)
self._sync_every_n_epochs = sync_every_n_epochs

def __call__(self, logdir: str | PathLike | None = None):
def __call__(
self,
logdir: str | PathLike | None = None,
current_epoch: int = 0,
):
"""Trigger synchronization on the head nodes

Args:
logdir: The directory in which wandb puts its run files.
current_epoch: The epoch the hook is called from.

Returns:
None
"""
# Only trigger the hook every n epochs
if current_epoch % self._sync_every_n_epochs != 0:
return
if logdir is None:
# run.dir actually points to the `/files` subdirectory of the run,
# but we need the directory above that.
Expand Down
22 changes: 19 additions & 3 deletions src/wandb_osh/lightning_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,29 @@ class TriggerWandbSyncLightningCallback(pl.Callback):
def __init__(
self,
communication_dir: PathLike = _comm_default_dir,
sync_every_n_epochs: int = 1,
):
"""Hook to be used when interfacing wandb with Lightning.

Args:
communication_dir: Directory used for communication with wandb-osh.
sync_every_n_epochs: Number of epochs between each trigger (default = 1, every epoch).


Usage

.. code-block:: python

from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback

trainer = Trainer(callbacks=[TriggerWandbSyncLightningCallback()])
trainer = Trainer(callbacks=[TriggerWandbSyncLightningCallback(sync_every_n_epochs = 5)])

"""
super().__init__()
self._hook = TriggerWandbSyncHook(communication_dir=communication_dir)
self.sync_every_n_epochs = sync_every_n_epochs
self._hook = TriggerWandbSyncHook(
communication_dir=communication_dir, sync_every_n_epochs=sync_every_n_epochs
)

def on_validation_epoch_end(
self,
Expand All @@ -49,4 +55,14 @@ def on_validation_epoch_end(
) -> None:
if trainer.sanity_checking:
return
self._hook()
self._hook(current_epoch=trainer.current_epoch)

def on_test_epoch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
) -> None:
if trainer.sanity_checking:
return
# Force the hook to trigger on last epoch end
self._hook(current_epoch=self.sync_every_n_epochs)
13 changes: 10 additions & 3 deletions src/wandb_osh/ray_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@


class TriggerWandbSyncRayHook(LoggerCallback):
def __init__(self, communication_dir: PathLike = _comm_default_dir):
def __init__(
self,
communication_dir: PathLike = _comm_default_dir,
sync_every_n_epochs: int = 1,
):
"""Hook to be used when interfacing wandb with ray tune.

Args:
communication_dir: Directory used for communication with wandb-osh.
sync_every_n_epochs: Number of epochs between each trigger (default = 1, every epoch).
"""
super().__init__()
self._hook = TriggerWandbSyncHook(communication_dir=communication_dir)
self._hook = TriggerWandbSyncHook(
communication_dir=communication_dir, sync_every_n_epochs=sync_every_n_epochs
)

def log_trial_result(self, iteration: int, trial: Trial, result: dict):
trial_dir = Path(trial.logdir)
self._hook(trial_dir)
self._hook(trial_dir, current_epoch=iteration)