diff --git a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py index 3ada8c7fb..d59b4563b 100644 --- a/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py @@ -2,11 +2,10 @@ from typing import Dict import torch -from pydantic import BaseModel from torch.nn import Parameter -from torch.utils.hooks import RemovableHandle from llmcompressor.core import ModelParameterizedLayer +from llmcompressor.modifiers.utils.hooks import HooksMixin __all__ = ["LayerParamMasking", "param_mask_name"] @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings: use_hooks: bool = False -class LayerParamMasking(BaseModel): +class LayerParamMasking(HooksMixin): _mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} _masked_layer_params: Dict[str, ModelParameterizedLayer] = {} - _forward_hooks: Dict[str, RemovableHandle] = {} - _backward_hooks: Dict[str, RemovableHandle] = {} enabled_: bool = False def add_mask( @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients): return gradients - self._forward_hooks[layer_param_name] = ( - parameterized_layer.layer.register_forward_hook(_forward_hook_fn) - ) - self._backward_hooks[layer_param_name] = ( - parameterized_layer.param.register_hook(_backward_hook_fn) - ) + self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward") + self.register_hook(parameterized_layer.param, _backward_hook_fn, "") def update_mask( self, @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str): del self._mask_settings[layer_param_name] if mask_settings.use_hooks: - self._forward_hooks[layer_param_name].remove() - self._backward_hooks[layer_param_name].remove() - - del self._forward_hooks[layer_param_name] - del self._backward_hooks[layer_param_name] + self.remove_hooks() def apply_mask_weight(self, layer_param_name: str): if not self.enabled_: diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index 39134e273..44e9c13f9 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -1,6 +1,6 @@ import contextlib from functools import wraps -from typing import Any, Callable, ClassVar, List +from typing import Any, Callable, ClassVar, List, Union import torch from loguru import logger @@ -19,7 +19,9 @@ class HooksMixin(BaseModel): Modifiers which implement hooks should register them using `self.register_..._hook(module, hook)` rather than the usual `module.register_..._hook(hook)`. Modifiers should remove hooks with - `self.remove_hooks()` + `self.remove_hooks()`. + + Hooks can be applied to modules or parameters Lifecycle: - modifier.register_forward_hook(module, hook) @@ -42,20 +44,20 @@ def disable_hooks(cls): def register_hook( self, - module: torch.nn.Module, + target: Union[torch.nn.Module, torch.nn.Parameter], hook: Callable[[Any], Any], hook_type: str, **kwargs, ) -> RemovableHandle: """ - Registers a hook on a specified module with the option to disable it with - HooksMixin.disable_hooks + Registers a hook on a specified module/parameter with the option to disable it + with HooksMixin.disable_hooks() - :param module: the module on which the hook should be registered + :param target: the module or parameter on which the hook should be registered :param hook: the hook to register :param hook_type: the type of hook to register corresponding to the `register_{hook_type}_hook` attribute on torch.nn.Module. - Ex. "forward", "forward_pre", "full_backward", "state_dict_post" + Ex. "forward", "forward_pre", "full_backward", "state_dict_post", "" :param kwargs: keyword arguments to pass to register hook method """ @@ -66,7 +68,7 @@ def wrapped_hook(*args, **kwargs): return hook(*args, **kwargs) - handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) + handle = getattr(target, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) logger.debug(f"{self} added {handle}") @@ -76,3 +78,5 @@ def remove_hooks(self): """Remove all hooks belonging to a modifier""" for hook in self._hooks: hook.remove() + + self._hooks = []