Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTQ] Vision Model Support #850

Draft
wants to merge 51 commits into
base: main
Choose a base branch
from
Draft

[GPTQ] Vision Model Support #850

wants to merge 51 commits into from

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Oct 16, 2024

Purpose

  • Support applying GPTQ algorithm to models with non-traditional architectures such as image-text-to-text models
  • Lay the groundwork for allowing multiple modifiers to be active at the same time and for batched one-shot steps

Prerequisites

Changes

Notes

  • This PR is reviewable but will stay in draft mode until all prerequisites are merged
  • Batched updates require more memory
  • The hessian calculation is not invariant to batch size
  • Evaluation regression tests

@mgoin
Copy link
Collaborator

mgoin commented Oct 17, 2024

I tried this on microsoft/Phi-3.5-vision-instruct and once finishing the GPTQ initialization, I saw this failure

2024-10-17T00:32:32.353305+0000 | quantize_module | INFO - Compressing model.layers.31.mlp.down_proj...
2024-10-17T00:32:34.083082+0000 | compress | METRIC - time 1.73
2024-10-17T00:32:34.084876+0000 | compress | METRIC - error 0.13
2024-10-17T00:32:34.126879+0000 | compress | METRIC - GPU 0 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127060+0000 | compress | METRIC - GPU 1 | usage: 56.51% | total memory: 79 GB
2024-10-17T00:32:34.127088+0000 | compress | METRIC - GPU 2 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127108+0000 | compress | METRIC - GPU 3 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127125+0000 | compress | METRIC - GPU 4 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127139+0000 | compress | METRIC - GPU 5 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127155+0000 | compress | METRIC - GPU 6 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127170+0000 | compress | METRIC - GPU 7 | usage: 0.55% | total memory: 79 GB
2024-10-17T00:32:34.127219+0000 | compress | METRIC - Compressed layer size: 48.0087890625 MB
manager stage: Modifiers initialized
2024-10-17T00:32:34.338759+0000 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
Traceback (most recent call last):
  File "/home/mgoin/code/llm-compressor/examples/quantization_w8a8_int8/llava1.5_example.py", line 72, in <module>
    oneshot(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 76, in oneshot
    main(model_args, data_args, training_args)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 364, in main
    stage_runner.one_shot()
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 171, in one_shot
    self.trainer.one_shot(calibration_data=calib_data, stage=stage)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py", line 416, in one_shot
    apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session_functions.py", line 184, in apply
    return active_session().apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 212, in apply
    return self.finalize(**kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 192, in finalize
    mod_data = self._lifecycle.finalize(**kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/lifecycle.py", line 158, in finalize
    data = mod.finalize(state=self.state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/stage.py", line 143, in finalize
    modifier.finalize(state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/modifier.py", line 147, in finalize
    finalized = self.on_finalize(state=state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py", line 241, in on_finalize
    self.remove_gptq_hooks(state.model)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/quantization/gptq/base.py", line 365, in remove_gptq_hooks
    self.remove_hooks(child_module)
  File "/home/mgoin/venvs/vllm/lib/python3.10/site-packages/pydantic/main.py", line 856, in __getattr__
    raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
AttributeError: 'GPTQModifier' object has no attribute 'remove_hooks'. Did you mean: 'remove_gptq_hooks'?

I think the error message suggestion is probably right AttributeError: 'GPTQModifier' object has no attribute 'remove_hooks'. Did you mean: 'remove_gptq_hooks'?

@mgoin
Copy link
Collaborator

mgoin commented Oct 17, 2024

Replacing that function, I was able to produce an INT8 W8A8 model that loads in vLLM and seems to give reasonable output! https://huggingface.co/nm-testing/Phi-3.5-vision-instruct-W8A8-Dynamic-Per-Token

However it seems like the calibration was not actually performed based on the speed of compression and the logs (happy to be incorrect on this):

2024-10-17T00:47:42.871375+0000 | _check_create_state | INFO - State created for compression lifecycle
2024-10-17T00:47:42.872188+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
2024-10-17T00:47:42.872333+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py:91: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  super().__init__(**kwargs)
2024-10-17T00:47:42.887208+0000 | one_shot | INFO - *** One Shot ***
2024-10-17T00:47:42.892085+0000 | from_modifiers | INFO - Creating recipe from modifiers
2024-10-17T00:47:42.922928+0000 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created
2024-10-17T00:47:42.925187+0000 | on_initialize_structure | WARNING - GPTQ quantization is set to True without an active quantization modifier.
2024-10-17T00:47:42.925257+0000 | _build_quant_modifier | INFO - Building quantization modifier with args: {'targets': 'Linear', 'scheme': 'W8A8', 'ignore': ['re:.*lm_head', 're:model.vision_embed_tokens.*']}
2024-10-17T00:47:42.961327+0000 | _check_calibration_data | INFO - Skipping QuantizationModifier calibration, it is not required for the provided quantization config.
2024-10-17T00:47:43.169421+0000 | layer_pre_forward | INFO - 
===== Compressing layer 0/32 =====
You are not running the flash-attention implementation, expect numerical differences.
2024-10-17T00:47:43.289524+0000 | quantize_module | INFO - Compressing model.layers.0.self_attn.qkv_proj...

Here is the full script I used to test this:

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from llmcompressor.modifiers.quantization import GPTQModifier
# from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

# Select model and load it.
MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
model_class = wrap_hf_model_class(AutoModelForCausalLM)
model = model_class.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True, 
    _attn_implementation="eager",
)
processor = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": processor.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return processor(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)
print(ds)

# Configure algorithms. In this case, we:
#   * apply SmoothQuant to make the activations easier to quantize
#   * quantize the weights to int8 with GPTQ (static per channel)
#   * quantize the activations to int8 (dynamic per token)
# Note: set sequential_update: true in the recipe to reduce memory
ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"]
recipe = [
    # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore),
]

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(processor.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

@kylesayrs
Copy link
Collaborator Author

@mgoin Nice, indeed the calibration only uses the first sample right now

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We should add an updated lifecycle docstring since we're no longer using the GPTQ wrapper. Something like this:
https://github.com/neuralmagic/compressed-tensors/blob/232e4944b84798bd05fddc18a7752ae2b5d460da/src/compressed_tensors/compressors/base.py#L29 or

Run calibration if running input/output activation quantization or kv_cache

# decoder layers (ie LlamaDecoderLayer)
self.sequential_targets = get_no_split_params(modifiable_model)
self.sequential_targets = get_no_split_params(state.model)
layers = get_layers(self.sequential_targets, state.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to keep compressible_layers() to return the output of get_layers

  • Slightly better naming/clarity on the layers being returned
  • Consistent with the other modifiers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to depreciate this to be handled by SequentialLayerCompressor, but I'd be fine to keep in on GPTQModifier

src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
src/llmcompressor/modifiers/quantization/gptq/base.py Outdated Show resolved Hide resolved
@kylesayrs kylesayrs requested a review from dsikka October 23, 2024 00:41
@kylesayrs kylesayrs changed the title [WIP] GPTQ Vision Model Support GPTQ Vision Model Support Oct 23, 2024
@kylesayrs kylesayrs changed the title GPTQ Vision Model Support [GPTQ] Vision Model Support Oct 23, 2024
@kylesayrs
Copy link
Collaborator Author

kylesayrs commented Oct 25, 2024

The hook-based design was initially proposed because of

  1. Closer alignment with how other modifiers use hooks
  2. Easier handling of batched oneshot updates (future)
  3. Avoid manually moving data through the model and instead let that be handled by the forward calls, which better supports models with unconventional data flows such as vision models

However, using hooks to has proven to be more difficult than expected. The argument goes something like this:

Suppose that a dataset contains N samples with max length M

Note that attempting to use batch size of N as input to the model causes the model to allocate an NxM attention mask, which accounts for most of the large memory requirements (afaict)

Consider the case where the batch size is 1

  1. Each module must be called N times in order to accumulate all the samples into the hessian
  2. However, a module can only emit an output after the module has been quantized, which can only happen after the hessian is fully accumulated
    a. Note: a module can avoid emitting an output by raising an EarlyStopException
  3. This means that the first N-1 calls must be skipped
  4. Since hooks cannot generate more forward calls, only skip, this means that only 1 forward call makes it past the first module and to subsequent modules
  5. This means that all the N outputs must be concatenated into one batch, allocating a NxM matrix
  6. In addition, the model was initialized with batch size of 1, not batch size of N. This means that certain matrix multiplications outside of the module may have mismatched shapes

The same argument applies for both modules and layers in the case of true_sequential=False

In summary, a hook cannot return any samples until all samples are accumulated, and a hook cannot return more than one sample without batching, and forced batching can lead to mismatched shapes and memory requirements

@kylesayrs
Copy link
Collaborator Author

kylesayrs commented Oct 25, 2024

This isn't a formal proof that hook-based compression and batch_size=1 are incompatible, but the solution is not straight-forward and more thought will be needed as to how the two concepts might work together

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants