Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Tuning dimension error with Qwen2 and missing vocab_size for PaliGemma2 #2315

Open
2 of 4 tasks
Florian-Dreyer opened this issue Jan 8, 2025 · 13 comments
Open
2 of 4 tasks

Comments

@Florian-Dreyer
Copy link

System Info

PEFT: 0.14.0
Transformers: 4.48.0.dev0

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

For Qwen we get the following error:

IndexError: Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/{user_name}/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 84, in _worker
output = module(*input, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/peft/peft_model.py", line 1755, in forward
return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/{user_name}/venv/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1682, in forward
position_ids, rope_deltas = self.get_rope_index(
File "/home/{user_name}/venv/lib/python3.10/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 1486, in get_rope_index
input_ids = input_ids[attention_mask[i] == 1]
IndexError: The shape of the mask [172] at index 0 does not match the shape of the indexed tensor [122] at index 0

And for PaliGemma2 this one:
AttributeError Traceback (most recent call last)
Cell In[68], line 8
6 tokenizer = processor.tokenizer
7 # Apply PEFT model adaptation
----> 8 peft_model = get_peft_model(model, peft_config)
10 # Print trainable parameters
11 peft_model.print_trainable_parameters()

File ~/venv/lib/python3.10/site-packages/peft/mapping.py:222, in get_peft_model(model, peft_config, adapter_name, mixed, autocast_adapter_dtype, revision, low_cpu_mem_usage)
220 if peft_config.is_prompt_learning:
221 peft_config = _prepare_prompt_learning_config(peft_config, model_config)
--> 222 return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
223 model,
224 peft_config,
225 adapter_name=adapter_name,
226 autocast_adapter_dtype=autocast_adapter_dtype,
227 low_cpu_mem_usage=low_cpu_mem_usage,
228 )

File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:1684, in PeftModelForCausalLM.init(self, model, peft_config, adapter_name, **kwargs)
1681 def init(
1682 self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
1683 ) -> None:
-> 1684 super().init(model, peft_config, adapter_name, **kwargs)
1685 self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:170, in PeftModel.init(self, model, peft_config, adapter_name, autocast_adapter_dtype, low_cpu_mem_usage)
168 self._peft_config = {adapter_name: peft_config}
169 self.base_model = model
--> 170 self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)
171 else:
172 self._peft_config = None

File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:958, in PeftModel.add_adapter(self, adapter_name, peft_config, low_cpu_mem_usage)
955 dict_config = self.config
957 peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
--> 958 self._setup_prompt_encoder(adapter_name)
959 elif peft_config.is_adaption_prompt:
960 self.base_model.add_adapter(adapter_name, peft_config)

File ~/venv/lib/python3.10/site-packages/peft/peft_model.py:642, in PeftModel._setup_prompt_encoder(self, adapter_name)
635 for named_param, value in list(transformer_backbone.named_parameters()):
636 # for ZeRO-3, the tensor is sharded across accelerators and deepspeed modifies it to a tensor with shape
637 # [0] the actual unsharded shape is stored in "ds_shape" attribute special handling is needed in case
638 # the model is initialized in deepspeed.zero.Init() context or HfDeepSpeedConfig has been called before
639 # For reference refer to issue: #996
640 deepspeed_distributed_tensor_shape = getattr(value, "ds_shape", None)
--> 642 if value.shape[0] == self.base_model.config.vocab_size or (
643 deepspeed_distributed_tensor_shape is not None
644 and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size
645 ):
646 word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
647 break

File ~/venv/lib/python3.10/site-packages/transformers/configuration_utils.py:211, in PretrainedConfig.getattribute(self, key)
209 if key != "attribute_map" and key in super().getattribute("attribute_map"):
210 key = super().getattribute("attribute_map")[key]
--> 211 return super().getattribute(key)

AttributeError: 'PaliGemmaConfig' object has no attribute 'vocab_size'

You can find the notebook here to replicate the errors here:
https://github.com/Florian-Dreyer/PEFT_BUG/blob/main/prefix_tuning_peft.ipynb
Just execute the cells to get the errors.

Expected behavior

We would expect the models to be able to process the input. We tried just calling model(**inputs) but ran into the same error with Qwen. Note: The dimension difference is exactly the prefix length.
So the question is, how can we get the models to run? Is PaliGemma even supported?

@marzi9696
Copy link

@BenjaminBossan Is it okay if I work on this?

@BenjaminBossan
Copy link
Member

@marzi9696 Sure, you can give it a shot. I'll probably look into this tomorrow if there is no solution by then.

@Florian-Dreyer
Copy link
Author

Update from my side:
We got Qwen working with transformers version 4.46.3 without any further changes.
For Paligemma we found a workaround:
modify the config with the following two lines:
model.config.vocab_size = model.config._vocab_size
model.config.hidden_size = model.config.hidden_size // 2
and change the source code of the trainer for one if statement where it checks if the vocab_size has changed before saving the model.

@marzi9696
Copy link

@Florian-Dreyer I was trying to run your jupyter notebook but I could not actually run the data processing part.
I believe you load the specific dataset. Can you tell me where you get the dataset from so I can replicate your code?

Also you mentioned a "if_statement in train.py". I was reading the source code yesterday and I notice the if-statement as well. was it this block of the code:

Screenshot from 2025-01-10 16-01-46

@Florian-Dreyer
Copy link
Author

Do you mean the csv file for the _gemini dfs? The file is the other file on the GitHub repo I shared.

No, it was a different one, this one should work after modifying the config of the model. The if statement I meant compared the current vocab_size of the base model with the vocab_size of the base model, but the original one, newly loaded, that’s why it doesnt have the vocab_size attribute in its config.

@marzi9696
Copy link

@Florian-Dreyer hey Florian.
I could not find the dataset in your git repo you shared but I created a fake one , with few examples just to replicate your code and I got the same error.
I also tested other type of peft with "Paligemma" and it has issues when loading the config of the model.

I think peft just dose not support "Paligemma" models.Here is the mappings for loading the configs of models before training:
https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py

I think they should add support for "Paligemma" models, as you are going to encounter issues when using any peft setting with "Paligemma" models.

@Florian-Dreyer
Copy link
Author

Do you mean the train_dataset, val_datatest and test_dataset? They are loaded in the second cell:
dataset_id = "derek-thomas/ScienceQA"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train", "validation", "test"])
I am not aware of any extra steps for loading it.

Thank you, that was also my guess, that peft prefix tuning just doesnt support paligemma without workarounds...

@marzi9696
Copy link

@Florian-Dreyer No here is the dataset I was talking about:

Image

You are welcome.Yes for sure the only way is to find workarounds like you did. They do support Qwen2 however, that's why you probably got rid of the error by upgrading your transformers library.

@marzi9696
Copy link

@BenjaminBossan Do you want to add support for Paligemma models?
I can work on it if it's okay with you

@Florian-Dreyer
Copy link
Author

Oh sorry, my bad I forgot to push the file...Its not pushed, but as you said, it shouldnt make a difference.

@BenjaminBossan
Copy link
Member

Thanks all for investigating this issue further.

I think peft just dose not support "Paligemma" models.Here is the mappings for loading the configs of models before training: https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py

This is not quite correct. What you see there are the model architectures that are supported out of the box for certain PEFT models. It does not imply that if a model is not listed there, it is not supported. Quite in contrast, PEFT is very flexible, but as a user you might have to configure PEFT a bit to fit the model.

When it comes to PaliGemma2 specifically, I don't know exactly what its architecture looks like and why these workarounds with the vocab size and hidden size are needed. However, I found that it is possible to train it with PEFT when using LoRA instead of prefix tuning, see e.g. this notebook:

https://github.com/merveenoyan/smol-vision/blob/614d538a035e429b77cabd244b58820567c4ef74/Fine_tune_PaliGemma.ipynb

@Florian-Dreyer Would this be an alternative for you?

@marzi9696
Copy link

@BenjaminBossan Thanks for clarification.I can investigate more into why this possible model needs workarounds while prefix tuning if it can help in future use-cases.

@Florian-Dreyer
Copy link
Author

@BenjaminBossan for our project the workaround I wrote about is enough. We noticed too, that LoRA works for PaliGemma without any problems, which was why we were even more confused...
@marzi9696 from my understanding the problems are that PEFT doesnt recognize _vocab_size as vocab_size because of the _, and prefix tuning computes a value for the past_key_value parameter of the model, for which it needs the dimension of hidden_size divided by attention_heads. The weird thing is, that for these past_key_value parameters you need to divide the hidden_size by the number of attention_heads but it uses the number of key_value_heads as the number of attention heads for a different dimension. Thats also why we need to divide the hidden_size by two, because for PaliGemma the ratio between key_value_heads and attention_heads is 2, but the number of attention_heads is only listed in the text_config part of the config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants