Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX for ConvNd layers using the groups argument. #2403

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

gslama12
Copy link
Contributor

As discussed in #2153.

Code example

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. The change seems to be much smaller than I initially thought, which is great. Before we proceed, could we do the following:

Let's add a test case for this. First, let's create an entry like this one:

("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),

Then we need to define a model with a conv layer that uses groups. Something similar to this with groups=5 should work:

class ModelConv2D(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(5, 10, 3)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float().reshape(-1, 5, 3, 3)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X

Then make sure that the model is being used when it's model ID is passed by adding an entry similar to this one:

if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)

LMK if anything is unclear.

Moreover, don't forget to run make style for the linter.

@gslama12
Copy link
Contributor Author

thanks for the fast reply. I implemented the test as you described. I added one for lora and one for dora. LMK if there is something missing.

@BenjaminBossan
Copy link
Member

Thanks for adding the tests. Unfortunately, a lot of them are failing for me locally. Do they pass for you? E.g. this one:

pytest tests/test_custom_models.py -k test_forward_output_finite_021_Conv2d_Groups_LoRA

@gslama12
Copy link
Contributor Author

gslama12 commented Mar 1, 2025

I had a bug in the test model, i fixed it now and the TC you stated should work. Let's see if anything else fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants