Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HooksMixin #917

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Implement HooksMixin #917

wants to merge 12 commits into from

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Nov 14, 2024

Purpose

  • Precursor to Kylesayrs/gptq partition #914
  • Create a shared API for adding hooks to modules
  • Allow code which handles data pipelines to selectively disable hooks for certain passes. This will be needed in cases with custom datapipelines (GPTQ/Wanda/SparseGPTQ) and when multiple modifiers are active at the same time.
    • This is needed for GPTQ-style sequential algorithms which require one pass with hooks in order to accumulate the hessians and compress, and then a second pass without hooks in order to compute compressed (weight-quantized) outputs
    • This is also a tool for research users to be able to control when hooks are enabled from within the data pipelines
for layer in model_layers:
    # accumulate hessians
    unquantized_outputs = layer(*args, **kwargs)

    # get sequential outputs
    with HooksMixin.disable_hooks():
        quantized_outputs = layer(*args, **kwargs)
    
    print(f"Mean error from quantization: {get_loss(unquantized_outputs, quantized_outputs)}")

Changes

  • Implement HooksMixin
    • The _HOOKS_DISABLED attribute is a global variable attached to the class which is used to disable hooks globally
    • The _hooks attribute is a local variable attached to each modifier which lists all of the hooks created by that modifier
  • Integrate with QuantizationModifier, refactor calibration functions to reference the same function rather than generating hook functions
  • Integrate with SmoothQuantModifier
  • Integrate with WandaPruningModifier and SparseGPTModifier
  • Integrate with MagnitudePruningModifier and ConstantPruningModifier via LayerParamMasking
  • Purposefully did not integrate with LayerCompressor since this will be handled by future data pipelines and doing so would all the BaseModel inheritance to the LayerCompressor class, which add unnecessary complexity to this PR

Testing

  • Added tests in tests/llmcompressor/modifiers/utils/test_hooks.py

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Collaborator Author

kylesayrs commented Nov 17, 2024

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We briefly looked at the implications of using hooks with FSDP - are we taking care of that already or through this PR?

@kylesayrs
Copy link
Collaborator Author

@dsikka I consider that to be out of scope for this PR. I consider FSDP to be unsupported as of now, although this PR makes it easier to support FSDP in the future.

Modifying a module's parameter requires being in special FSDP contexts.

@torch.no_grad()
def pre_hook(module, _args):
  # modifying both training and handle training states is required
  with model._use_training_state(TrainingState.IDLE, HandleTrainingState.IDLE):
    with FullyShardedDataParallel.summon_full_params(model):
      # modify module weight. Doing so outside of the contexts will raise a non-contiguous tensor error
      module.weight *= 0

We can bake these contexts into the HooksMixin.register_hook function, although there's implementation details associated with that I'd like to leave that for a separate task/PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants