From d0dc8076dee8912975e4a20e79a36a924440013a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 21:28:50 +0000 Subject: [PATCH] integrate with wanda Signed-off-by: Kyle Sayers --- .../modifiers/pruning/wanda/base.py | 76 +++++++++---------- src/llmcompressor/modifiers/utils/hooks.py | 6 +- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index f056ee1ae..4e6784bea 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -1,3 +1,4 @@ +import functools from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -9,6 +10,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper +from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.pytorch.module import ( @@ -20,7 +22,7 @@ __all__ = ["WandaPruningModifier"] -class WandaPruningModifier(Modifier): +class WandaPruningModifier(Modifier, HooksMixin): """ Modifier for applying the one-shot WANDA algorithm to a model from the paper: https://arxiv.org/abs/2306.11695 @@ -121,7 +123,8 @@ def initialize_compression( "Inferring layer-wise sparsities from " f"{len(dataloader) if dataloader else 0} calibration samples..." ) - self.sparsity = self._infer_layer_sparsity(dataloader) + activations = self._get_activations(dataloader) + self.sparsity = self._infer_layer_sparsity(activations) self._validate_layerwise_sparsity() for idx, (name, layer) in enumerate(self.compressible_layers_.items()): @@ -224,19 +227,17 @@ def _infer_mask_block_size(self): self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) + def _infer_layer_sparsity(self, activations): wanda = {} for name, layer in self.compressible_layers_.items(): prunable_layers = get_prunable_layers(layer) z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() ] wanda[name] = torch.cat([item.flatten().cpu() for item in z]) - acts = None - del acts + del activations torch.cuda.empty_cache() outlier_ratios = {} @@ -268,36 +269,35 @@ def _infer_layer_sparsity(self, calibration_dataloader): logger.info(f"Sparsity for {k}: {sparsities[k]}") return sparsities + @torch.no_grad() + def _get_activations(self, data_loader, nsamples=128): + self.model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + else: + acts[name] += ( + 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + ) + + for name, mod in self.model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + self.register_hook( + mod, functools.partial(save_acts, name=name), "forward_pre" + ) + device = next(self.model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + self.model(**batch) + batch = None + torch.cuda.empty_cache() -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() + self.remove_hooks() - return acts + return acts diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py index de4e42898..39134e273 100644 --- a/src/llmcompressor/modifiers/utils/hooks.py +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -46,7 +46,7 @@ def register_hook( hook: Callable[[Any], Any], hook_type: str, **kwargs, - ): + ) -> RemovableHandle: """ Registers a hook on a specified module with the option to disable it with HooksMixin.disable_hooks @@ -68,7 +68,9 @@ def wrapped_hook(*args, **kwargs): handle = getattr(module, f"register_{hook_type}_hook")(wrapped_hook, **kwargs) self._hooks.append(handle) - logger.debug(f"Added {handle} for {self}") + logger.debug(f"{self} added {handle}") + + return handle def remove_hooks(self): """Remove all hooks belonging to a modifier"""