Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of merge-kit into PEFT #2179

Open
ParagEkbote opened this issue Oct 25, 2024 · 17 comments
Open

Integration of merge-kit into PEFT #2179

ParagEkbote opened this issue Oct 25, 2024 · 17 comments
Assignees

Comments

@ParagEkbote
Copy link

Feature request

Integrate merge-kit functionalities within the PEFT library to enable users to leverage the techniques provided in the library.

This could include additional merging techniques beyond TIES and DARE which are currently natively supported by PEFT.

References:
1)https://github.com/arcee-ai/mergekit

2)https://huggingface.co/docs/peft/main/en/developer_guides/model_merging#model-merging

Motivation

For beginners, especially those new to fine-tuning large language models, integrating merge-kit requires familiarity with multiple merging methods and careful handling of model weights.

PEFT could bridge this gap by providing an easy-to-use, fully integrated solution for merging model weights.

Your contribution

With ample support and guidance, I could help in the integration.

@BenjaminBossan
Copy link
Member

Thanks for opening this feature request and offering to help with the integration.

For my understanding, you are talking about the methods mentioned here? Do you mean that you would like to port over those methods to PEFT, similar to what we have for DARE and TIES, or is your suggestion to integrate the mergekit package itself?

@ParagEkbote
Copy link
Author

@BenjaminBossan You are correct, I am indeed referring to the methods you have mentioned. Firstly, I think that porting over the entire mergekit package will be better in the long-term?

Secondly, if we can port individual methods like DARE and TIES, would that be simpler to implement and provide a better experience for merging models than merge-kit provides?

I think it depends on how swiftly we could add the merging methods in PEFT. Could you please let me know which of the two methods you have suggest is simpler to implement?

@BenjaminBossan
Copy link
Member

I think it won't be easy to answer the question which of the two approaches is better. This would require a very good understanding of how mergekit is implemented, which I don't have.

To me, it looks like mergekit is more of a framework when looking at classes like this one than a utility package, even though there appear to be some standalone functions. A utility package would be easier to integrate than a framework. As such, I tend towards copying the methods into PEFT instead of relying on mergekit directly.

Before doing that, however, it would be important to figure out how beneficial the integration would be. Your main argument, as I understand it, would be to simplify the usage of mergekit functionality with PEFT. However, this could theoretically also be achieved by documentation and examples, except if mergekit is built in a way that makes it incompatible with PEFT. Do you have some experience with using mergekit with PEFT?

@ParagEkbote
Copy link
Author

ParagEkbote commented Oct 29, 2024

I have a couple of points I'd like to add. After completing a model merge process between 2 models, there is no native support within mergekit to perform finetuning with LoRA or by using the TRL Library to help improve performance after merging. By adding these methods, it could make the process effective and simpler.

Secondly, model merging itself is a practice that is being increasingly adopted by companies to help improve model performance, so I think it would be beneficial to add the methods to the PEFT library, if integration of mergekit seems to be difficult.

Let me know if this makes sense.

References:

  1. https://arxiv.org/pdf/2405.15032
  2. https://cameronrwolfe.substack.com/p/model-merging
  3. https://arxiv.org/pdf/2407.06089

@BenjaminBossan
Copy link
Member

Thanks for the additional context. In that light, I think the more promising way forward is to integrate the methods directly into PEFT without reliance on mergekit. Of course, when it makes sense, we can copy, or at least stick very closely to, the mergekit functions that implement the actual merging logic. Of course, we should attribute mergekit in that case.

If you want to give this a go, that would be fantastic. Feel free to ask any question that comes up and to open draft PR to get early feedback. You can check how DARE and TIES are implemented and see if you could add a new method, maybe you have an idea which ones are most popular.

The current implementations live here:

def add_weighted_adapter(
self,
adapters: list[str],
weights: list[float],
adapter_name: str,
combination_type: str = "svd",
svd_rank: int | None = None,
svd_clamp: int | None = None,
svd_full_matrices: bool = True,
svd_driver: str | None = None,
density: float | None = None,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> None:
"""
This method adds a new adapter by merging the given adapters with the given weights.
When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to
the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM
errors.
Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
combination_type (`str`):
The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
`dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
mixed adapter may be too big and result in OOM errors).
svd_rank (`int`, *optional*):
Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
svd_clamp (`float`, *optional*):
A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform
clamping. Defaults to None.
svd_full_matrices (`bool`, *optional*):
Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned
tensors U and Vh. Defaults to True.
svd_driver (`str`, *optional*):
Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be
one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd`
documentation. Defaults to None.
density (`float`, *optional*):
Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
`magnintude_prune`, `magnitude_prune_svd`]
majority_sign_method (`str`):
The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
"""
if adapter_name in list(self.peft_config.keys()):
return
combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
combination_type=combination_type,
svd_rank=svd_rank,
)
self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
r=new_rank,
lora_alpha=new_rank,
target_modules=new_target_modules,
)
self.inject_adapter(self.model, adapter_name)
# Do we really need that?
_freeze_adapter(self.model, adapter_name)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
if adapter_name in target.lora_A:
target_lora_A = target.lora_A[adapter_name].weight
target_lora_B = target.lora_B[adapter_name].weight
elif adapter_name in target.lora_embedding_A:
target_lora_A = target.lora_embedding_A[adapter_name]
target_lora_B = target.lora_embedding_B[adapter_name]
else:
continue
target_lora_A.data = target_lora_A.data * 0.0
target_lora_B.data = target_lora_B.data * 0.0
if combination_type == "cat":
loras_A, loras_B = [], []
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A:
current_adapter_lora_A = target.lora_A[adapter].weight
current_adapter_lora_B = target.lora_B[adapter].weight
elif adapter in target.lora_embedding_A:
current_adapter_lora_A = target.lora_embedding_A[adapter]
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter])
loras_B.append(current_adapter_lora_B.data)
if len(loras_A) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
loras_A = torch.cat(loras_A, dim=0)
loras_B = torch.cat(loras_B, dim=1)
target_lora_A.data[: loras_A.shape[0], :] = loras_A
target_lora_B.data[:, : loras_B.shape[1]] = loras_B
elif combination_type in [
"svd",
"ties_svd",
"dare_linear_svd",
"dare_ties_svd",
"magnitude_prune_svd",
]:
target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
svd_clamp,
full_matrices=svd_full_matrices,
driver=svd_driver,
)
elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
combination_type, adapters, weights, target, density, majority_sign_method
)
def _svd_generalized_task_arithmetic_weighted_adapter(
self,
combination_type,
adapters,
weights,
new_rank,
target,
target_lora_A,
target_lora_B,
density,
majority_sign_method,
clamp=None,
full_matrices=True,
driver=None,
):
valid_adapters = []
valid_weights = []
is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters)
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A or adapter in target.lora_embedding_A:
valid_adapters.append(adapter)
valid_weights.append(weight * target.scaling[adapter])
# if no valid adapter, nothing to do
if len(valid_adapters) == 0:
raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
if combination_type == "svd":
delta_weight = task_arithmetic(delta_weight, valid_weights)
elif combination_type == "ties_svd":
delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method)
elif combination_type == "dare_linear_svd":
delta_weight = dare_linear(delta_weight, valid_weights, density)
elif combination_type == "dare_ties_svd":
delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune_svd":
delta_weight = magnitude_prune(delta_weight, valid_weights, density)
else:
raise ValueError(f"Invalid value passed to combination type: {combination_type}")
conv2d = isinstance(target, Conv2d)
if conv2d:
conv2d_1x1 = target.weight.size()[2:4] == (1, 1)
if not conv2d_1x1:
delta_weight = delta_weight.flatten(start_dim=1)
else:
delta_weight = delta_weight.squeeze()
if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
delta_weight = delta_weight.T
# based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
if clamp is not None:
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(target_lora_B.data.shape)
Vh = Vh.reshape(target_lora_A.data.shape)
return Vh, U
def _generalized_task_arithmetic_weighted_adapter(
self,
combination_type,
adapters,
weights,
target,
density,
majority_sign_method,
):
# account weights for LoRA A and B layers.
valid_weights = []
lora_A_deltas = []
lora_B_deltas = []
for adapter, weight in zip(adapters, weights):
if adapter in target.lora_A:
current_adapter_lora_A = target.lora_A[adapter].weight
current_adapter_lora_B = target.lora_B[adapter].weight
elif adapter in target.lora_embedding_A:
current_adapter_lora_A = target.lora_embedding_A[adapter]
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
valid_weights.append(math.sqrt(weight * target.scaling[adapter]))
lora_A_deltas.append(current_adapter_lora_A.data)
lora_B_deltas.append(current_adapter_lora_B.data)
valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device)
lora_deltas = [lora_A_deltas, lora_B_deltas]
dtype = lora_A_deltas[0].dtype
for i, task_tensors in enumerate(lora_deltas):
if combination_type == "linear":
lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
elif combination_type == "ties":
lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method)
elif combination_type == "dare_linear":
lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
elif combination_type == "dare_ties":
lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
elif combination_type == "magnitude_prune":
lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
else:
raise ValueError("Invalid combination type")
lora_deltas = [delta.to(dtype) for delta in lora_deltas]
return lora_deltas

The tests can be found here:

peft/tests/testing_common.py

Lines 1318 to 1500 in 214345e

def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)
# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")
# test svd re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting")
# test ties_svd re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_ties_svd_reweighting",
combination_type="ties_svd",
density=0.5,
)
# test dare_linear_svd re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_dare_linear_svd_reweighting",
combination_type="dare_linear_svd",
density=0.5,
)
# test dare_ties_svd re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_dare_ties_svd_reweighting",
combination_type="dare_ties_svd",
density=0.5,
)
# test magnitude_prune_svd re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_magnitude_prune_svd_reweighting",
combination_type="magnitude_prune_svd",
density=0.5,
)
# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
)
# test linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)
# test ties re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2], weight_list[:2], "multi_adapter_ties_reweighting", combination_type="ties", density=0.5
)
# test dare_linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2],
weight_list[:2],
"multi_adapter_dare_linear_reweighting",
combination_type="dare_linear",
density=0.5,
)
# test dare_ties re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2],
weight_list[:2],
"multi_adapter_dare_ties_reweighting",
combination_type="dare_ties",
density=0.5,
)
# test magnitude_prune re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2],
weight_list[:2],
"multi_adapter_magnitude_prune_reweighting",
combination_type="magnitude_prune",
density=0.5,
)
# test linear re-weighting with multiple adapters with only first adapter having non zero weight
model.add_weighted_adapter(
adapter_list[:2],
[weight_list[0], 0],
"multi_adapter_linear_reweighting_single_enabled",
combination_type="linear",
)
with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_linear_reweighting_uneven_r",
combination_type="linear",
)
with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_ties_reweighting_uneven_r",
combination_type="ties",
density=0.5,
)
with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_dare_linear_reweighting_uneven_r",
combination_type="dare_linear",
density=0.5,
)
with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_dare_ties_reweighting_uneven_r",
combination_type="dare_ties",
density=0.5,
)
with pytest.raises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_magnitude_prune_reweighting_uneven_r",
combination_type="magnitude_prune",
density=0.5,
)
new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_ties_svd_reweighting",
"multi_adapter_dare_linear_svd_reweighting",
"multi_adapter_dare_ties_svd_reweighting",
"multi_adapter_magnitude_prune_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
"multi_adapter_linear_reweighting_single_enabled",
"multi_adapter_ties_reweighting",
"multi_adapter_dare_linear_reweighting",
"multi_adapter_dare_ties_reweighting",
"multi_adapter_magnitude_prune_reweighting",
]
for new_adapter in new_adapters:
assert new_adapter in model.peft_config
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
assert torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
elif "svd" in adapter_name:
assert target.r[adapter_name] == 20
elif "linear" in adapter_name:
assert target.r[adapter_name] == 8
elif "cat" in adapter_name:
assert target.r[adapter_name] == 28
dummy_input = self.prepare_inputs_for_testing()
model.eval()
for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
assert model.active_adapter == adapter_name
assert model.active_adapters == [adapter_name]
model(**dummy_input)[0]

This code is a bit messy and I wouldn't mind refactoring it, but that can be left as an exercise for the future.

there is no native support within mergekit to perform finetuning with LoRA or by using the TRL Library to help improve performance after merging.

Just for my understanding, the practice would be to merge two trained LoRA adapters, then fine-tune this merged adapter further? Interesting, I didn't know this was commonly done.

@ParagEkbote
Copy link
Author

ParagEkbote commented Oct 29, 2024

Thank you for your immediate reply, could you please assign this issue to me, so that this will be easier to ask for help within the Hugging Face userbase and OSS community. I will start working on it in some weeks, if that works.

If you want to give this a go, that would be fantastic. Feel free to ask any question that comes up and to open draft PR to get early feedback. You can check how DARE and TIES are implemented and see if you could add a new method, maybe you have an idea which ones are most popular.

This process is done to otherwise improve convergence for a model, model merging can be done for two LLMs as well, not just LoRA adapters.

Just for my understanding, the practice would be to merge two trained LoRA adapters, then fine-tune this merged adapter further? Interesting, I didn't know this was commonly done.

Additionally, I had a query. Are we trying to implement something similar to this arcee article or adding support for merging language models like this blog?

References:

  1. https://blog.arcee.ai/use-mergekit-to-extract-lora-adapters-from-any-fine-tuned-model/

  2. https://mlabonne.github.io/blog/posts/2024-01-08_Merge_LLMs_with_mergekit.html

@ParagEkbote
Copy link
Author

Could you please clarify?

cc: @BenjaminBossan

@BenjaminBossan
Copy link
Member

Sorry, what would you like me to clarify?

@ParagEkbote
Copy link
Author

ParagEkbote commented Oct 29, 2024

Should I implement the merging methods for the LoRA adapters only, right?

Not for entire language models?

https://blog.arcee.ai/use-mergekit-to-extract-lora-adapters-from-any-fine-tuned-model/

@BenjaminBossan
Copy link
Member

My understanding is that new merging methods for LoRA can be implemented, similar to what we have for DARE and TIES. The LoRA extraction feature seems to be a completely separate issue, as it deals with fully fine-tuned models, whereas in PEFT, we can assume that the LoRA adapter already exists separately (whether via extraction or not wouldn't be relevant). LMK if you had other ideas.

@ParagEkbote
Copy link
Author

ParagEkbote commented Oct 29, 2024

I'll do some research on the merging methods and reach out if I need input on specific approaches.

Thanks for the feedback and I'll definitely keep in touch.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@ParagEkbote
Copy link
Author

not stale

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@ParagEkbote
Copy link
Author

not stale.

@ParagEkbote
Copy link
Author

ParagEkbote commented Jan 11, 2025

Can I first contribute examples using mergekit for PEFT before exploring a larger integration?

cc: @BenjaminBossan

@BenjaminBossan
Copy link
Member

Can I first contribute examples using mergekit for PEFT before exploring a larger integration?

Yes, it sounds like a good plan to start small and go from there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants