-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX for ConvNd layers using the groups argument. #2403
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. The change seems to be much smaller than I initially thought, which is great. Before we proceed, could we do the following:
Let's add a test case for this. First, let's create an entry like this one:
peft/tests/test_custom_models.py
Line 114 in f51203f
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), |
Then we need to define a model with a conv layer that uses groups. Something similar to this with groups=5
should work:
peft/tests/test_custom_models.py
Lines 864 to 880 in f51203f
class ModelConv2D(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv2d = nn.Conv2d(5, 10, 3) | |
self.relu = nn.ReLU() | |
self.flat = nn.Flatten() | |
self.lin0 = nn.Linear(10, 2) | |
self.sm = nn.LogSoftmax(dim=-1) | |
def forward(self, X): | |
X = X.float().reshape(-1, 5, 3, 3) | |
X = self.conv2d(X) | |
X = self.relu(X) | |
X = self.flat(X) | |
X = self.lin0(X) | |
X = self.sm(X) | |
return X |
Then make sure that the model is being used when it's model ID is passed by adding an entry similar to this one:
peft/tests/test_custom_models.py
Lines 967 to 968 in f51203f
if model_id == "Conv2d": | |
return ModelConv2D().to(torch_dtype) |
LMK if anything is unclear.
Moreover, don't forget to run make style
for the linter.
Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
thanks for the fast reply. I implemented the test as you described. I added one for lora and one for dora. LMK if there is something missing. |
Thanks for adding the tests. Unfortunately, a lot of them are failing for me locally. Do they pass for you? E.g. this one:
|
I had a bug in the test model, i fixed it now and the TC you stated should work. Let's see if anything else fails. |
As discussed in #2153.
Code example