diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index ae4d8245..c73cf975 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -42,23 +42,33 @@ def disable_hooks(cls): def register_hook( self, module: torch.nn.Module, - func: Callable[[Any], Any], + hook: Callable[[Any], Any], hook_type: str, **kwargs, ): - @wraps(func) + """ + Registers a hook on a specified module with the option to disable it with + HooksMixin.disable_hooks + + :param module: the module on which the hook should be registered + :param hook: the hook to register + :param hook_type: the type of hook to register corresponding to the + `register_{hook_type}_hook` attribute on torch.nn.Module. + Ex. "forward", "forward_pre", "full_backward", "state_dict_post" + :param kwargs: keyword arguments to pass to register hook method + """ + + @wraps(hook) def wrapped_hook(*args, **kwargs): if HooksMixin._HOOKS_DISABLED: return - return func(*args, **kwargs) + return hook(*args, **kwargs) handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) def remove_hooks(self): - """ - Remove all hooks belonging to a modifier - """ + """Remove all hooks belonging to a modifier""" for hook in self._hooks: hook.remove()