From 7c51938e00b797bdd33ffb2a717e527d2a825a4d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 16 Nov 2024 00:57:24 +0000 Subject: [PATCH 1/2] move obcq to pruning.sparsegpt Signed-off-by: Kyle Sayers --- .../example_alternating_recipe.yaml | 4 +- src/llmcompressor/modifiers/README.md | 2 +- src/llmcompressor/modifiers/obcq/base.py | 337 +----------------- .../modifiers/pruning/__init__.py | 1 + .../modifiers/pruning/sparsegpt/__init__.py | 3 + .../modifiers/pruning/sparsegpt/base.py | 337 ++++++++++++++++++ .../sparsegpt}/utils/__init__.py | 0 .../sparsegpt}/utils/helpers.py | 0 .../sparsegpt}/utils/sgpt_wrapper.py | 0 .../modifiers/quantization/gptq/base.py | 2 +- .../transformers/finetune/README.md | 2 +- .../modifiers/pruning/sparsegpt/test_base.py | 2 +- .../pruning/sparsegpt/test_pytorch.py | 2 +- tests/llmcompressor/recipe/test_recipe.py | 2 +- .../finetune/test_alternate_recipe.yaml | 2 +- .../test_finetune_oneshot_with_modifier.py | 2 +- .../obcq/recipes/additional_sparsity.yaml | 2 +- .../additional_sparsity_with_quant.yaml | 2 +- .../transformers/obcq/recipes/quant.yaml | 2 +- .../obcq/recipes/quant_and_sparse.yaml | 2 +- .../transformers/obcq/recipes/sparse.yaml | 2 +- .../recipes/sparse_with_mask_structure.yaml | 2 +- .../transformers/obcq/recipes/test_tiny2.yaml | 2 +- .../obcq/test_obcq_infer_targets.py | 2 +- .../transformers/obcq/test_obcq_lm_head.py | 2 +- .../transformers/obcq/test_sgpt_defaults.py | 2 +- .../oneshot_configs/recipes/recipe.yaml | 2 +- .../oneshot_configs/tiny_stories_conf1.yaml | 2 +- .../oneshot_configs/tiny_stories_conf4.yaml | 2 +- 29 files changed, 371 insertions(+), 353 deletions(-) create mode 100644 src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py create mode 100644 src/llmcompressor/modifiers/pruning/sparsegpt/base.py rename src/llmcompressor/modifiers/{obcq => pruning/sparsegpt}/utils/__init__.py (100%) rename src/llmcompressor/modifiers/{obcq => pruning/sparsegpt}/utils/helpers.py (100%) rename src/llmcompressor/modifiers/{obcq => pruning/sparsegpt}/utils/sgpt_wrapper.py (100%) diff --git a/examples/finetuning/example_alternating_recipe.yaml b/examples/finetuning/example_alternating_recipe.yaml index ca186150c..a3be682a4 100644 --- a/examples/finetuning/example_alternating_recipe.yaml +++ b/examples/finetuning/example_alternating_recipe.yaml @@ -1,6 +1,6 @@ initial_sparsity_stage: run_type: oneshot - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 @@ -18,7 +18,7 @@ initial_training_stage: start: 0 next_sparsity_stage: run_type: oneshot - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 diff --git a/src/llmcompressor/modifiers/README.md b/src/llmcompressor/modifiers/README.md index 77a4cd425..009b31f5d 100644 --- a/src/llmcompressor/modifiers/README.md +++ b/src/llmcompressor/modifiers/README.md @@ -8,7 +8,7 @@ are relevant only during training. Below is a summary of the key modifiers avail Modifiers that introduce sparsity into a model -### [SparseGPT](./obcq/base.py) +### [SparseGPT](./pruning/gptq/base.py) One-shot algorithm that uses calibration data to introduce unstructured or structured sparsity into weights. Implementation based on [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774). A small amount of calibration data is used to calculate a Hessian for each layers input activations, this Hessian is then used to diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 3da0e3d0c..cbd4c0e09 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -1,335 +1,12 @@ -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import warnings -import numpy as np -import torch -from loguru import logger -from torch.nn import Module -from tqdm import tqdm +from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier -from llmcompressor.core import State -from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - get_prunable_layers, +warnings.warn( + "llmcompressor.modifiers.obcq has been moved to " + "llmcompressor.modifiers.pruning.sparsegpt Please update your paths", + DeprecationWarning, ) -__all__ = ["SparseGPTModifier"] - - -class SparseGPTModifier(Modifier): - """ - Modifier for applying the one-shot SparseGPT algorithm to a model - - Lifecycle: - - on_initialize - - initialize_compression() - - compressible_layers() - - LayerCompressor.pre_compress() - - apply_compression() - - run_calibration_forward() - - LayerCompressor.compress() - - LayerCompressor.post_compress() - - LayerCompressor.revert_layer_wrappers() - - | Sample yaml: - | test_stage: - | obcq_modifiers: - | SparseGPTModifier: - | sparsity: 0.5 - | mask_structure: "2:4" - | sequential_update: True - | dampening_frac: 0.001 - | block_size: 128 - - :param sparsity: Sparsity to compress model to - :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed - Layerwise Sparsity (OWL), more information can be found - in the paper https://arxiv.org/pdf/2310.05175 - :param owl_m: Number of outliers to use for OWL - :param owl_lmbda: Lambda value to use for OWL - :param mask_structure: String to define the structure of the mask to apply. - Must be of the form N:M where N, M are integers that define a custom block - shape. Defaults to 0:0 which represents an unstructured mask. - :param sequential_update: Whether or not to update weights sequentially by layer, - True saves on GPU memory - :param targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model - :param block_size: Used to determine number of columns to compress in one pass - :param dampening_frac: Amount of dampening to apply to H, as a fraction of the - diagonal norm - :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask - during when applying sparsegpt, this becomes useful when starting from a - previously pruned model, defaults to False. - """ - - sparsity: Union[float, List[float]] = 0.0 - sparsity_profile: Optional[str] = None - owl_m: Optional[int] = None - owl_lmbda: Optional[float] = None - mask_structure: str = "0:0" - sequential_update: Optional[bool] = False - targets: Union[str, List[str], None] = None - block_size: int = 128 - dampening_frac: Optional[float] = 0.01 - preserve_sparsity_mask: bool = False - - model: Optional[Any] = None - layer_compressors_: Optional[List[Any]] = None - prunen_: Optional[int] = None - prunem_: Optional[int] = None - compressible_layers_: Optional[List] = None - - def on_initialize(self, state: "State", **kwargs) -> bool: - """ - Initialize and run the OBCQ algorithm on the current state - - :param state: session state storing input model and calibration data - """ - if self.sparsity == 0.0: - raise ValueError( - "To use the SparseGPTModifier, target sparsity must be > 0.0" - ) - - modifiable_model = state.model - calibration_dataloader = state.data.calib - - if self.targets is None: - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - self.targets = get_no_split_params(modifiable_model) - - self.initialize_compression(modifiable_model, calibration_dataloader) - self.apply_compression(calibration_dataloader) - - return True - - def initialize_compression( - self, - model: Module, - dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, - ): - """ - Setup for SparseGPT, initializes the model, device, - and other parameters, also initilializes the - compressible layers of model, and sets the device - - :param model: model to initialize for compression - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_compressors_ = [] - self._infer_mask_block_size() - - if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl": - logger.info( - "Inferring layer-wise sparsities from " - f"{len(dataloader)} calibration samples..." - ) - self.sparsity = self._infer_layer_sparsity(dataloader) - self._validate_layerwise_sparsity() - - for idx, (name, layer) in enumerate(self.compressible_layers_.items()): - logger.info(f"Preparing {name} for compression") - if isinstance(self.sparsity, Dict): - layer_sparsity = self.sparsity[name] - elif isinstance(self.sparsity, List): - layer_sparsity = self.sparsity[idx] - else: # float - layer_sparsity = self.sparsity - args = self._pruning_arguments(layer_sparsity) - comp_cls = self._compression_class() - compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - if not self.sequential_update: - # add all batch processing hooks before the forward pass - compressor.pre_compress() - self.layer_compressors_.append(compressor) - - def compressible_layers(self) -> Dict: - """ - Retrieves the modules corresponding to a list of - compressible layer names - - :precondition: self.model is set and is a torch.nn.Module - :return: dictionary of modules to compress - """ - if not isinstance(self.model, Module): - raise ValueError( - "`self.model` must be a PyTorch Module to use " - f"the {self.__class__.__qualname__} modifier but got " - f"{type(self.model)} instead" - ) - - return get_layers(self.targets, self.model) - - @torch.no_grad() - def apply_compression( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run Wanda on the loaded model, using dataloader as calibration data - :param dataloader: calibration data for WANDA - """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " - f"{len(dataloader) if dataloader else 0} samples..." - ) - if not self.sequential_update: - # in non-sequential mode we run one forward batch for all modules - run_calibration_forward(self.model, dataloader, mask_padding=True) - - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - layer_sparsity = layer_compressor.args["sparsity"] - logger.info( - f"\n===== Compressing layer {idx+1}/{num_layers} " - f"to sparsity {layer_sparsity} =====" - ) - - # Prune/quantize using SparseGPT - if self.sequential_update: - # in sequential mode we run one forward pass for each module we - # want to compress, this will be really slow but allows compression in - # earlier layers to affect later layers - layer_compressor.pre_compress() - logger.info(f"Calibrating {layer_compressor.name}...") - run_calibration_forward(self.model, dataloader, mask_padding=True) - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() - torch.cuda.empty_cache() - - def _validate_layerwise_sparsity(self): - if isinstance(self.sparsity, float): - # single sparsity will be applied to all layers - return - - target_layers = list(self.compressible_layers_.keys()) - - if len(target_layers) != len(self.sparsity): - raise ValueError( - "Number of layer targets must match the number of sparsities. " - "Received {len(target_layers)} layers and " - f"{len(self.sparsity)} sparsities" - ) - - def _pruning_arguments(self, sparsity): - """ - Gather the parameters needed for root module compression in a dict - - :param sparsity: target sparsity - :return: dict of params for pruning - """ - return { - "sparsity": sparsity, - "prunen": self.prunen_, - "prunem": self.prunem_, - "blocksize": self.block_size, - "percdamp": self.dampening_frac, - "preserve_sparsity_mask": self.preserve_sparsity_mask, - } - - def _compression_class(self): - """ - :return: wrapper class used for root modules of this compression class - """ - return SparseGptWrapper - - def _infer_mask_block_size(self): - """ - Infer the mask block size from the mask structure. - Parses mask_structure of the form N:M where N, M are integers that - define a custom block shape; and sets prunen_ and prunem_ accordingly. - - :post-condition: prunen_ and prunem_ are set - """ - if self.mask_structure is None: - raise ValueError("mask_structure must be defined") - - self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) - - def _infer_layer_sparsity(self, calibration_dataloader): - acts = _get_activations(self.model, calibration_dataloader) - sparsegpt_groups = {} - for name, layer in self.compressible_layers_.items(): - prunable_layers = get_prunable_layers(layer) - z = [ - m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) - for n, m in prunable_layers.items() - ] - sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) - - acts = None - del acts - torch.cuda.empty_cache() - - outlier_ratios = {} - for group in sparsegpt_groups: - threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m - outlier_ratios[group] = ( - 100 - * (sparsegpt_groups[group] > threshold).sum().item() - / sparsegpt_groups[group].numel() - ) - outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) - for k in outlier_ratios: - outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( - 1 - / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) - * self.owl_lmbda - * 2 - ) - outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) - sparsities = { - k: 1 - - ( - outlier_ratios[k] - - np.mean(outlier_ratios_arr) - + (1 - float(self.sparsity)) - ) - for k in outlier_ratios - } - logger.info(f"OWL sparsities for sp={self.sparsity} are:") - for k in sparsities: - logger.info(f"Sparsity for {k}: {sparsities[k]}") - return sparsities - - -@torch.no_grad() -def _get_activations(model, data_loader, nsamples=128): - import functools - - model.eval() - acts = {} - - def save_acts(module, input, name): - if isinstance(input, tuple): - input = input[0] - if name not in acts: - acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - else: - acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() - - hooks = [] - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: - hooks.append( - mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) - ) - device = next(model.parameters()).device - for batch in tqdm(data_loader): - batch = {k: v.to(device) for k, v in batch.items()} - model(**batch) - batch = None - torch.cuda.empty_cache() - - for h in hooks: - h.remove() - - return acts +__all__ = ["SparseGPTModifier"] diff --git a/src/llmcompressor/modifiers/pruning/__init__.py b/src/llmcompressor/modifiers/pruning/__init__.py index d54a770f1..664215219 100644 --- a/src/llmcompressor/modifiers/pruning/__init__.py +++ b/src/llmcompressor/modifiers/pruning/__init__.py @@ -2,4 +2,5 @@ from .constant import * from .magnitude import * +from .sparsegpt import * from .wanda import * diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py b/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/base.py b/src/llmcompressor/modifiers/pruning/sparsegpt/base.py new file mode 100644 index 000000000..a816bc152 --- /dev/null +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/base.py @@ -0,0 +1,337 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from loguru import logger +from torch.nn import Module +from tqdm import tqdm + +from llmcompressor.core import State +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.pruning.sparsegpt.utils.sgpt_wrapper import ( + SparseGptWrapper, +) +from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.utils.pytorch.module import ( + get_layers, + get_no_split_params, + get_prunable_layers, +) + +__all__ = ["SparseGPTModifier"] + + +class SparseGPTModifier(Modifier): + """ + Modifier for applying the one-shot SparseGPT algorithm to a model + + Lifecycle: + - on_initialize + - initialize_compression() + - compressible_layers() + - LayerCompressor.pre_compress() + - apply_compression() + - run_calibration_forward() + - LayerCompressor.compress() + - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() + + | Sample yaml: + | test_stage: + | modifiers: + | SparseGPTModifier: + | sparsity: 0.5 + | mask_structure: "2:4" + | sequential_update: True + | dampening_frac: 0.001 + | block_size: 128 + + :param sparsity: Sparsity to compress model to + :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed + Layerwise Sparsity (OWL), more information can be found + in the paper https://arxiv.org/pdf/2310.05175 + :param owl_m: Number of outliers to use for OWL + :param owl_lmbda: Lambda value to use for OWL + :param mask_structure: String to define the structure of the mask to apply. + Must be of the form N:M where N, M are integers that define a custom block + shape. Defaults to 0:0 which represents an unstructured mask. + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory + :param targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model + :param block_size: Used to determine number of columns to compress in one pass + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. + """ + + sparsity: Union[float, List[float]] = 0.0 + sparsity_profile: Optional[str] = None + owl_m: Optional[int] = None + owl_lmbda: Optional[float] = None + mask_structure: str = "0:0" + sequential_update: Optional[bool] = False + targets: Union[str, List[str], None] = None + block_size: int = 128 + dampening_frac: Optional[float] = 0.01 + preserve_sparsity_mask: bool = False + + model: Optional[Any] = None + layer_compressors_: Optional[List[Any]] = None + prunen_: Optional[int] = None + prunem_: Optional[int] = None + compressible_layers_: Optional[List] = None + + def on_initialize(self, state: "State", **kwargs) -> bool: + """ + Initialize and run the OBCQ algorithm on the current state + + :param state: session state storing input model and calibration data + """ + if self.sparsity == 0.0: + raise ValueError( + "To use the SparseGPTModifier, target sparsity must be > 0.0" + ) + + modifiable_model = state.model + calibration_dataloader = state.data.calib + + if self.targets is None: + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + self.targets = get_no_split_params(modifiable_model) + + self.initialize_compression(modifiable_model, calibration_dataloader) + self.apply_compression(calibration_dataloader) + + return True + + def initialize_compression( + self, + model: Module, + dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, + ): + """ + Setup for SparseGPT, initializes the model, device, + and other parameters, also initilializes the + compressible layers of model, and sets the device + + :param model: model to initialize for compression + """ + self.model = model + self.compressible_layers_ = self.compressible_layers() + self.layer_compressors_ = [] + self._infer_mask_block_size() + + if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl": + logger.info( + "Inferring layer-wise sparsities from " + f"{len(dataloader)} calibration samples..." + ) + self.sparsity = self._infer_layer_sparsity(dataloader) + self._validate_layerwise_sparsity() + + for idx, (name, layer) in enumerate(self.compressible_layers_.items()): + logger.info(f"Preparing {name} for compression") + if isinstance(self.sparsity, Dict): + layer_sparsity = self.sparsity[name] + elif isinstance(self.sparsity, List): + layer_sparsity = self.sparsity[idx] + else: # float + layer_sparsity = self.sparsity + args = self._pruning_arguments(layer_sparsity) + comp_cls = self._compression_class() + compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) + if not self.sequential_update: + # add all batch processing hooks before the forward pass + compressor.pre_compress() + self.layer_compressors_.append(compressor) + + def compressible_layers(self) -> Dict: + """ + Retrieves the modules corresponding to a list of + compressible layer names + + :precondition: self.model is set and is a torch.nn.Module + :return: dictionary of modules to compress + """ + if not isinstance(self.model, Module): + raise ValueError( + "`self.model` must be a PyTorch Module to use " + f"the {self.__class__.__qualname__} modifier but got " + f"{type(self.model)} instead" + ) + + return get_layers(self.targets, self.model) + + @torch.no_grad() + def apply_compression( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run Wanda on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for WANDA + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + logger.info( + f"Running {class_name} calibration with " + f"{len(dataloader) if dataloader else 0} samples..." + ) + if not self.sequential_update: + # in non-sequential mode we run one forward batch for all modules + run_calibration_forward(self.model, dataloader, mask_padding=True) + + num_layers = len(self.compressible_layers_) + for idx, layer_compressor in enumerate(self.layer_compressors_): + layer_sparsity = layer_compressor.args["sparsity"] + logger.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) + + # Prune/quantize using SparseGPT + if self.sequential_update: + # in sequential mode we run one forward pass for each module we + # want to compress, this will be really slow but allows compression in + # earlier layers to affect later layers + layer_compressor.pre_compress() + logger.info(f"Calibrating {layer_compressor.name}...") + run_calibration_forward(self.model, dataloader, mask_padding=True) + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + torch.cuda.empty_cache() + + def _validate_layerwise_sparsity(self): + if isinstance(self.sparsity, float): + # single sparsity will be applied to all layers + return + + target_layers = list(self.compressible_layers_.keys()) + + if len(target_layers) != len(self.sparsity): + raise ValueError( + "Number of layer targets must match the number of sparsities. " + "Received {len(target_layers)} layers and " + f"{len(self.sparsity)} sparsities" + ) + + def _pruning_arguments(self, sparsity): + """ + Gather the parameters needed for root module compression in a dict + + :param sparsity: target sparsity + :return: dict of params for pruning + """ + return { + "sparsity": sparsity, + "prunen": self.prunen_, + "prunem": self.prunem_, + "blocksize": self.block_size, + "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, + } + + def _compression_class(self): + """ + :return: wrapper class used for root modules of this compression class + """ + return SparseGptWrapper + + def _infer_mask_block_size(self): + """ + Infer the mask block size from the mask structure. + Parses mask_structure of the form N:M where N, M are integers that + define a custom block shape; and sets prunen_ and prunem_ accordingly. + + :post-condition: prunen_ and prunem_ are set + """ + if self.mask_structure is None: + raise ValueError("mask_structure must be defined") + + self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) + + def _infer_layer_sparsity(self, calibration_dataloader): + acts = _get_activations(self.model, calibration_dataloader) + sparsegpt_groups = {} + for name, layer in self.compressible_layers_.items(): + prunable_layers = get_prunable_layers(layer) + z = [ + m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + for n, m in prunable_layers.items() + ] + sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) + + acts = None + del acts + torch.cuda.empty_cache() + + outlier_ratios = {} + for group in sparsegpt_groups: + threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m + outlier_ratios[group] = ( + 100 + * (sparsegpt_groups[group] > threshold).sum().item() + / sparsegpt_groups[group].numel() + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + for k in outlier_ratios: + outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( + 1 + / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) + * self.owl_lmbda + * 2 + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + sparsities = { + k: 1 + - ( + outlier_ratios[k] + - np.mean(outlier_ratios_arr) + + (1 - float(self.sparsity)) + ) + for k in outlier_ratios + } + logger.info(f"OWL sparsities for sp={self.sparsity} are:") + for k in sparsities: + logger.info(f"Sparsity for {k}: {sparsities[k]}") + return sparsities + + +@torch.no_grad() +def _get_activations(model, data_loader, nsamples=128): + import functools + + model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + else: + acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + + hooks = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + hooks.append( + mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) + ) + device = next(model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + batch = None + torch.cuda.empty_cache() + + for h in hooks: + h.remove() + + return acts diff --git a/src/llmcompressor/modifiers/obcq/utils/__init__.py b/src/llmcompressor/modifiers/pruning/sparsegpt/utils/__init__.py similarity index 100% rename from src/llmcompressor/modifiers/obcq/utils/__init__.py rename to src/llmcompressor/modifiers/pruning/sparsegpt/utils/__init__.py diff --git a/src/llmcompressor/modifiers/obcq/utils/helpers.py b/src/llmcompressor/modifiers/pruning/sparsegpt/utils/helpers.py similarity index 100% rename from src/llmcompressor/modifiers/obcq/utils/helpers.py rename to src/llmcompressor/modifiers/pruning/sparsegpt/utils/helpers.py diff --git a/src/llmcompressor/modifiers/obcq/utils/sgpt_wrapper.py b/src/llmcompressor/modifiers/pruning/sparsegpt/utils/sgpt_wrapper.py similarity index 100% rename from src/llmcompressor/modifiers/obcq/utils/sgpt_wrapper.py rename to src/llmcompressor/modifiers/pruning/sparsegpt/utils/sgpt_wrapper.py diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b6dbda485..14dd98f0d 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -46,7 +46,7 @@ class GPTQModifier(Modifier): - LayerCompressor.revert_layer_wrappers() | Sample yaml: | test_stage: - | obcq_modifiers: + | modifiers: | GPTQModifier: | dampening_frac: 0.001 | block_size: 128 diff --git a/src/llmcompressor/transformers/finetune/README.md b/src/llmcompressor/transformers/finetune/README.md index 7384b077b..f811dcfdc 100644 --- a/src/llmcompressor/transformers/finetune/README.md +++ b/src/llmcompressor/transformers/finetune/README.md @@ -45,7 +45,7 @@ See [configure_fsdp.md](../../../../examples/finetuning/configure_fsdp.md) for a ```python from llmcompressor.transformers import train -model = "./obcq_deployment" +model = "./model_path" teacher_model = "Xenova/llama2.c-stories15M" dataset_name = "open_platypus" concatenate_data = False diff --git a/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py b/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py index 2126baa99..e4baccc13 100644 --- a/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py +++ b/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py @@ -3,7 +3,7 @@ import pytest from llmcompressor.modifiers.factory import ModifierFactory -from llmcompressor.modifiers.obcq.base import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from tests.llmcompressor.modifiers.conf import setup_modifier_factory diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 5421af4cf..ee49b20fd 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -4,7 +4,7 @@ from compressed_tensors.quantization import QuantizationScheme from parameterized import parameterized -from llmcompressor.modifiers.obcq import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier from llmcompressor.modifiers.quantization.gptq import GPTQModifier from llmcompressor.modifiers.quantization.quantization import QuantizationModifier from llmcompressor.utils.pytorch.module import qat_active diff --git a/tests/llmcompressor/recipe/test_recipe.py b/tests/llmcompressor/recipe/test_recipe.py index 7a3674052..729543eee 100644 --- a/tests/llmcompressor/recipe/test_recipe.py +++ b/tests/llmcompressor/recipe/test_recipe.py @@ -4,7 +4,7 @@ import yaml from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.obcq.base import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from llmcompressor.recipe import Recipe from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers from tests.llmcompressor.helpers import valid_recipe_strings diff --git a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml index 009eb2c6c..4f9d4293d 100644 --- a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml +++ b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml @@ -1,5 +1,5 @@ test_oneshot_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py index 47ef85244..0b63bcabd 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py @@ -22,7 +22,7 @@ def setUp(self): self.output = Path("./finetune_output") def test_oneshot_with_modifier_object(self): - from llmcompressor.modifiers.obcq.base import SparseGPTModifier + from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from llmcompressor.transformers import oneshot recipe_str = [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml index 6a50deae3..64ce30250 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml index 906d0c8da..027c56363 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: QuantizationModifier: config_groups: group_0: diff --git a/tests/llmcompressor/transformers/obcq/recipes/quant.yaml b/tests/llmcompressor/transformers/obcq/recipes/quant.yaml index 435503e50..1df51c804 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/quant.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/quant.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SmoothQuantModifier: smoothing_strength: 0.6 GPTQModifier: diff --git a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml index 05022fd80..eb02ea81d 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SmoothQuantModifier: smoothing_strength: 0.5 mappings: [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml index e47ac2bdc..d485064fa 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.3 block_size: 128 diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml index 5f283b609..20c4c9397 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 diff --git a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml index 3633cfef6..8a97ff733 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py index 5d5a06fbc..a43911a19 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py @@ -17,7 +17,7 @@ def setUp(self): self.targets = get_no_split_params(self.modifiable_model) def test_infer_targets(self): - from llmcompressor.modifiers.obcq import SparseGPTModifier + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier self.assertEqual(len(self.targets), 1) self.assertEqual(self.targets[0], "LlamaDecoderLayer") diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py index 6b2729f6a..69b84570f 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py @@ -34,7 +34,7 @@ def setUp(self): def test_lm_head_target(self): from llmcompressor.core.state import State - from llmcompressor.modifiers.obcq import SparseGPTModifier + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier sparsegpt_modifier_no_head = SparseGPTModifier(**self.kwargs) diff --git a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py index 8cdca786a..527807873 100644 --- a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py +++ b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py @@ -10,7 +10,7 @@ class TestSGPTDefaults(unittest.TestCase): def test_sgpt_defaults(self): from llmcompressor.core.state import State - from llmcompressor.modifiers.obcq import SparseGPTModifier + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier kwargs = {"sparsity": 0.5} sparsegpt_modifier_only_sparsity = SparseGPTModifier(**kwargs) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index c5bf782d4..b9aa59e06 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index 39f9d6576..b4f61ff9f 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -5,7 +5,7 @@ model: "Xenova/llama2.c-stories15M" dataset: open_platypus recipe: | test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml index c6cc1376c..6443c09c7 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml @@ -6,7 +6,7 @@ dataset: "gsm8k" dataset_config_name: "main" recipe: | test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 From a24180415702b745d111ff6dfd9011b1914f8fd8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 15 Nov 2024 20:21:46 -0500 Subject: [PATCH 2/2] update readme path Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/README.md b/src/llmcompressor/modifiers/README.md index 009b31f5d..72ff0b058 100644 --- a/src/llmcompressor/modifiers/README.md +++ b/src/llmcompressor/modifiers/README.md @@ -8,7 +8,7 @@ are relevant only during training. Below is a summary of the key modifiers avail Modifiers that introduce sparsity into a model -### [SparseGPT](./pruning/gptq/base.py) +### [SparseGPT](./pruning/sparsegpt/base.py) One-shot algorithm that uses calibration data to introduce unstructured or structured sparsity into weights. Implementation based on [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774). A small amount of calibration data is used to calculate a Hessian for each layers input activations, this Hessian is then used to