Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: ModulesToSaveWrapper has no attribute dense #2326

Closed
2 of 4 tasks
KQDtianxiaK opened this issue Jan 13, 2025 · 5 comments
Closed
2 of 4 tasks

AttributeError: ModulesToSaveWrapper has no attribute dense #2326

KQDtianxiaK opened this issue Jan 13, 2025 · 5 comments

Comments

@KQDtianxiaK
Copy link

KQDtianxiaK commented Jan 13, 2025

System Info

Original model architecture:

EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 640, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 640, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-29): 30 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
           ...
            **(output): EsmSelfOutput(
              (dense): Linear(in_features=640, out_features=640, bias=True)**
              (dropout): Dropout(p=0.0, inplace=False)
            )
          ...
          **(intermediate): EsmIntermediate(
            (dense): Linear(in_features=640, out_features=2560, bias=True)
          )**
          **(output): EsmOutput(
            (dense): Linear(in_features=2560, out_features=640, bias=True)**
     ...
  **(classifier): EsmClassificationHead(
    (dense): Linear(in_features=640, out_features=640, bias=True)**
   ...

my code:

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=7)
config = OFTConfig(task_type=TaskType.SEQ_CLS, target_modules=['dense'])
model_OFT = get_peft_model(model, config)

Peft model architecture:

PeftModelForSequenceClassification(
  (base_model): OFTModel(
    (model): EsmForSequenceClassification(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 640, padding_idx=1)
          (dropout): Dropout(p=0.0, inplace=False)
          (position_embeddings): Embedding(1026, 640, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-29): 30 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): Linear(in_features=640, out_features=640, bias=True)
                  (key): Linear(in_features=640, out_features=640, bias=True)
                  (value): Linear(in_features=640, out_features=640, bias=True)
               ...
                  **(dense): oft.Linear(
                    (base_layer): Linear(in_features=640, out_features=640, bias=True)
                    (oft_r): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8x80x80])
                  )**
                 ...
              **(intermediate): EsmIntermediate(
                (dense): oft.Linear(
                  (base_layer): Linear(in_features=640, out_features=2560, bias=True)
                  (oft_r): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8x320x320])
                )**
              )
              **(output): EsmOutput(
                (dense): oft.Linear(
                  (base_layer): Linear(in_features=2560, out_features=640, bias=True)
                  (oft_r): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8x80x80])
                )**
              ...
      **(classifier): ModulesToSaveWrapper(
        (original_module): EsmClassificationHead(
          (dense): oft.Linear(
            (base_layer): Linear(in_features=640, out_features=640, bias=True)
            (oft_r): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8x80x80])
          )**
       ...
        (modules_to_save): ModuleDict(
          (default): EsmClassificationHead(
            **(dense): oft.Linear(
              (base_layer): Linear(in_features=640, out_features=640, bias=True)
              (oft_r): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 8x80x80])
            )**
           ...

adapter_config.json:

{
  "alpha_pattern": {},
  "auto_mapping": null,
  "base_model_name_or_path": "model/esm2_35M",
  "block_share": false,
  "coft": false,
  "eps": 6e-05,
  "inference_mode": true,
  "init_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "module_dropout": 0.0,
  "modules_to_save": [
    "classifier",
    "score"
  ],
  "peft_type": "OFT",
  "r": 8,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "dense"
  ],
  "task_type": "SEQ_CLS"
}

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

After training, I load the model from the saved checkpoint, using the following codes:

best_model_path = best_model_dir.path
model_peft = AutoPeftModelForSequenceClassification.from_pretrained(best_model_path, num_labels=7)

Got this error:

Traceback (most recent call last):
  File "/root/autodl-tmp/PEFT-PLM/ESM2_scop_OFT.py", line 213, in <module>
    best_model = load_best_model_for_test(training_args.output_dir, i+1)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/autodl-tmp/PEFT-PLM/ESM2_scop_OFT.py", line 189, in load_best_model_for_test
    model_peft = AutoPeftModelForSequenceClassification.from_pretrained(best_model_path, num_labels=7)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/peft/auto.py", line 130, in from_pretrained
    return cls._target_peft_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py", line 541, in from_pretrained
    model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py", line 1311, in __init__
    super().__init__(model, peft_config, adapter_name, **kwargs)
  File "/root/miniconda3/lib/python3.12/site-packages/peft/peft_model.py", line 155, in __init__
    self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/peft/tuners/lycoris_utils.py", line 196, in __init__
    super().__init__(model, config, adapter_name)
  File "/root/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 175, in __init__
    self.inject_adapter(self.model, adapter_name)
  File "/root/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 430, in inject_adapter
    parent, target, target_name = _get_submodules(model, key)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/peft/utils/other.py", line 313, in _get_submodules
    target = model.get_submodule(key)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 717, in get_submodule
    raise AttributeError(
AttributeError: ModulesToSaveWrapper has no attribute `dense`

Expected behavior

Find out the cause and solve the problem

@BenjaminBossan
Copy link
Member

I could not reproduce the error. This is what I tried:

from transformers import EsmForSequenceClassification
from peft import OFTConfig, TaskType, get_peft_model, AutoPeftModelForSequenceClassification

model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=7)
config = OFTConfig(task_type=TaskType.SEQ_CLS, target_modules=['dense'])
model_OFT = get_peft_model(model, config)

model_OFT.save_pretrained("/tmp/peft/2326")
# no error:
model_peft = AutoPeftModelForSequenceClassification.from_pretrained("/tmp/peft/2326", num_labels=7)

Could you please check how this differs from what you're doing? Ideally, you can post a complete reproducer for me to check. Training the model should not be necessary for this type of error.

@KQDtianxiaK
Copy link
Author

KQDtianxiaK commented Jan 15, 2025

I could not reproduce the error. This is what I tried:

from transformers import EsmForSequenceClassification
from peft import OFTConfig, TaskType, get_peft_model, AutoPeftModelForSequenceClassification

model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=7)
config = OFTConfig(task_type=TaskType.SEQ_CLS, target_modules=['dense'])
model_OFT = get_peft_model(model, config)

model_OFT.save_pretrained("/tmp/peft/2326")

no error:

model_peft = AutoPeftModelForSequenceClassification.from_pretrained("/tmp/peft/2326", num_labels=7)
Could you please check how this differs from what you're doing? Ideally, you can post a complete reproducer for me to check. Training the model should not be necessary for this type of error.

The code I use:

model_path = 'model/esm2_35M'
model = EsmForSequenceClassification.from_pretrained(model_path, num_labels=7)
config = OFTConfig(task_type=TaskType.SEQ_CLS, target_modules=['dense'])
model_OFT = get_peft_model(model, config)
model_OFT.save_pretrained("model/peft/2326")
model_peft = AutoPeftModelForSequenceClassification.from_pretrained("model/peft/2326", num_labels=7)

First of all, thank you very much for your reply. I'm sorry if it took up your time.I used the same code as you in jupyter, and the same error occurred. The complete error message is as follows:

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at model/esm2_35M and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at model/esm2_35M and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[11], line 8
      6 model_OFT.save_pretrained("model/peft/2326")
      7 # no error:
----> 8 model_peft = AutoPeftModelForSequenceClassification.from_pretrained("model/peft/2326", num_labels=7)

File [~/miniconda3/lib/python3.12/site-packages/peft/auto.py:130](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/auto.py#line=129), in _BaseAutoPeftModel.from_pretrained(cls, pretrained_model_name_or_path, adapter_name, is_trainable, config, revision, **kwargs)
    125     tokenizer = AutoTokenizer.from_pretrained(
    126         pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
    127     )
    128     base_model.resize_token_embeddings(len(tokenizer))
--> 130 return cls._target_peft_class.from_pretrained(
    131     base_model,
    132     pretrained_model_name_or_path,
    133     adapter_name=adapter_name,
    134     is_trainable=is_trainable,
    135     config=config,
    136     **kwargs,
    137 )

File [~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py:541](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py#line=540), in PeftModel.from_pretrained(cls, model, model_id, adapter_name, is_trainable, config, autocast_adapter_dtype, ephemeral_gpu_offload, **kwargs)
    539     model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)
    540 else:
--> 541     model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](
    542         model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype
    543     )
    545 model.load_adapter(
    546     model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs
    547 )
    549 return model

File [~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py:1311](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py#line=1310), in PeftModelForSequenceClassification.__init__(self, model, peft_config, adapter_name, **kwargs)
   1308 def __init__(
   1309     self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs
   1310 ) -> None:
-> 1311     super().__init__(model, peft_config, adapter_name, **kwargs)
   1313     classifier_module_names = ["classifier", "score"]
   1314     if self.modules_to_save is None:

File [~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py:155](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/peft_model.py#line=154), in PeftModel.__init__(self, model, peft_config, adapter_name, autocast_adapter_dtype)
    153     self._peft_config = None
    154     cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]
--> 155     self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
    156     self.set_additional_trainable_modules(peft_config, adapter_name)
    158 if hasattr(self.base_model, "_cast_adapter_dtype"):

File [~/miniconda3/lib/python3.12/site-packages/peft/tuners/lycoris_utils.py:196](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/tuners/lycoris_utils.py#line=195), in LycorisTuner.__init__(self, model, config, adapter_name)
    195 def __init__(self, model, config, adapter_name):
--> 196     super().__init__(model, config, adapter_name)

File [~/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:175](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py#line=174), in BaseTuner.__init__(self, model, peft_config, adapter_name)
    173 self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
    174 if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA:
--> 175     self.inject_adapter(self.model, adapter_name)
    177 # Copy the peft_config in the injected model.
    178 self.model.peft_config = self.peft_config

File [~/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:430](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/tuners/tuners_utils.py#line=429), in BaseTuner.inject_adapter(self, model, adapter_name, autocast_adapter_dtype)
    428     self.targeted_module_names.append(key)
    429     is_target_modules_in_base_model = True
--> 430     parent, target, target_name = _get_submodules(model, key)
    431     self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
    433 # Handle X-LoRA case.

File [~/miniconda3/lib/python3.12/site-packages/peft/utils/other.py:313](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/peft/utils/other.py#line=312), in _get_submodules(model, key)
    311 parent = model.get_submodule(".".join(key.split(".")[:-1]))
    312 target_name = key.split(".")[-1]
--> 313 target = model.get_submodule(key)
    314 return parent, target, target_name

File [~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:717](https://a418260-8ce5-ee6095e2.cqa1.seetacloud.com:8443/jupyter/lab/tree/autodl-tmp/PEFT-PLM/~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py#line=716), in Module.get_submodule(self, target)
    715 for item in atoms:
    716     if not hasattr(mod, item):
--> 717         raise AttributeError(
    718             mod._get_name() + " has no " "attribute `" + item + "`"
    719         )
    721     mod = getattr(mod, item)
    723     if not isinstance(mod, torch.nn.Module):

AttributeError: ModulesToSaveWrapper has no attribute `dense`

Is there something wrong with the peft/transformers package version I'm using? Here is my version:

transformers              4.44.0
peft                      0.12.0

And I have always wanted to know whether the error "Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at model/esm2_35M and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." at the beginning will cause any problems. It often occurs when training and using the trained model. Peft should automatically freeze the original weights of the model during training, right?

@BenjaminBossan
Copy link
Member

The code I use:

Since you're using a local model, I can't reproduce. Is it different from the one from HF that I used?

I used the same code as you in jupyter, and the same error occurred.
Is there something wrong with the peft/transformers package version I'm using? Here is my version:

If the same code errors for you, it could very well be the versions. Could you try updating both PEFT and transformers to their latest versions?

And I have always wanted to know whether the error "Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at model/esm2_35M and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." at the beginning will cause any problems. It often occurs when training and using the trained model. Peft should automatically freeze the original weights of the model during training, right?

Yes, PEFT should take care of that when you specify task_type=TaskType.SEQ_CLS. To be really sure, you can print the model and you should see that the classification head is wrapped in a ModulesToSaveWrapper from PEFT.

@KQDtianxiaK
Copy link
Author

@BenjaminBossan
I updated peft to version 0.14.0, and the model can be loaded normally after retraining.
Today I encountered the following problem when using the ESMC model:

RuntimeError: The weights trying to be saved contained shared tensors [{'sequence_head.3.weight', 'embed.weight'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.

I modified the Trainer using the following code:

class CustomTrainer(Trainer):
    def save_model(self, output_dir, _internal_call=False):
        self.model.save_pretrained(output_dir, safe_serialization=False)
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
trainer = CustomTrainer(...)

But of course the following error occurred again:

AttributeError: ModulesToSaveWrapper has no attribute `2`

But in the end, my problems were solved using the 0.14.0 version of peft.Thank you very much for your help.

@BenjaminBossan
Copy link
Member

Great to hear that the initial issue was solved by updating the PEFT version. I'll close this issue then.

Regarding your new issue, I don't have an idea ad hoc what the reason could be. If this continues to bother you, feel free to open a new issue, ideally with a full reproducer and using an open model. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants