Skip to content

Commit

Permalink
integrate with magnitude and constant
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Nov 15, 2024
1 parent d0dc807 commit 55f69d6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 24 deletions.
21 changes: 5 additions & 16 deletions src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Dict

import torch
from pydantic import BaseModel
from torch.nn import Parameter
from torch.utils.hooks import RemovableHandle

from llmcompressor.core import ModelParameterizedLayer
from llmcompressor.modifiers.utils.hooks import HooksMixin

__all__ = ["LayerParamMasking", "param_mask_name"]

Expand Down Expand Up @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings:
use_hooks: bool = False


class LayerParamMasking(BaseModel):
class LayerParamMasking(HooksMixin):
_mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {}
_masked_layer_params: Dict[str, ModelParameterizedLayer] = {}
_forward_hooks: Dict[str, RemovableHandle] = {}
_backward_hooks: Dict[str, RemovableHandle] = {}
enabled_: bool = False

def add_mask(
Expand Down Expand Up @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients):

return gradients

self._forward_hooks[layer_param_name] = (
parameterized_layer.layer.register_forward_hook(_forward_hook_fn)
)
self._backward_hooks[layer_param_name] = (
parameterized_layer.param.register_hook(_backward_hook_fn)
)
self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward")
self.register_hook(parameterized_layer.param, _backward_hook_fn, "")

def update_mask(
self,
Expand All @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str):
del self._mask_settings[layer_param_name]

if mask_settings.use_hooks:
self._forward_hooks[layer_param_name].remove()
self._backward_hooks[layer_param_name].remove()

del self._forward_hooks[layer_param_name]
del self._backward_hooks[layer_param_name]
self.remove_hooks()

def apply_mask_weight(self, layer_param_name: str):
if not self.enabled_:
Expand Down
20 changes: 12 additions & 8 deletions src/llmcompressor/modifiers/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from functools import wraps
from typing import Any, Callable, ClassVar, List
from typing import Any, Callable, ClassVar, List, Union

import torch
from loguru import logger
Expand All @@ -19,7 +19,9 @@ class HooksMixin(BaseModel):
Modifiers which implement hooks should register them using
`self.register_..._hook(module, hook)` rather than the usual
`module.register_..._hook(hook)`. Modifiers should remove hooks with
`self.remove_hooks()`
`self.remove_hooks()`.
Hooks can be applied to modules or parameters
Lifecycle:
- modifier.register_forward_hook(module, hook)
Expand All @@ -42,20 +44,20 @@ def disable_hooks(cls):

def register_hook(
self,
module: torch.nn.Module,
target: Union[torch.nn.Module, torch.nn.Parameter],
hook: Callable[[Any], Any],
hook_type: str,
**kwargs,
) -> RemovableHandle:
"""
Registers a hook on a specified module with the option to disable it with
HooksMixin.disable_hooks
Registers a hook on a specified module/parameter with the option to disable it
with HooksMixin.disable_hooks()
:param module: the module on which the hook should be registered
:param target: the module or parameter on which the hook should be registered
:param hook: the hook to register
:param hook_type: the type of hook to register corresponding to the
`register_{hook_type}_hook` attribute on torch.nn.Module.
Ex. "forward", "forward_pre", "full_backward", "state_dict_post"
Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""
:param kwargs: keyword arguments to pass to register hook method
"""

Expand All @@ -66,7 +68,7 @@ def wrapped_hook(*args, **kwargs):

return hook(*args, **kwargs)

handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs)
handle = getattr(target, f"register_{hook_type}_hook")(wrapped_hook, **kwargs)
self._hooks.append(handle)
logger.debug(f"{self} added {handle}")

Expand All @@ -76,3 +78,5 @@ def remove_hooks(self):
"""Remove all hooks belonging to a modifier"""
for hook in self._hooks:
hook.remove()

self._hooks = []

0 comments on commit 55f69d6

Please sign in to comment.