Skip to content

Commit

Permalink
FIX: Generating with mixed adapter batches and with beam search enabl…
Browse files Browse the repository at this point in the history
…ed (#2287)

See #2283

Right now, using mixed adapter batches with beam search generations does
not work. This is because users need to pass the adapter names
associated with each sample, i.e. the number of adapter names should be
identical to the number of samples in the input.

When applying beam search, transformers internally repeats the samples
once per beam (or so it looks like). Therefore, we have more samples
during generation than samples in the input. Consequently, the adapter
names have to be extended accordingly. This is now taken care of.
  • Loading branch information
BenjaminBossan authored Jan 17, 2025
1 parent f973b28 commit aa3f41f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,36 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")

# deal with beam search
num_beams = kwargs.get("num_beams", None)
uses_beam_search = isinstance(num_beams, int) and (num_beams > 1)
original_adapter_names = adapter_names[:]
if uses_beam_search:
if not isinstance(adapter_names, (list, tuple)):
raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
# When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
# then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
# encoder part. Further below, the original argument is thus restored for the encoder.
adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

if uses_beam_search and hasattr(self.model, "get_encoder"):
# For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
# the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
for module in self.model.get_encoder().modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
# Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
# trying to exclude the encoder.
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

yield

for handle in hook_handles:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs):
def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"task_type": "CAUSAL_LM",
},
)
)
def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_oft_or_hra_and_gpt2)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
)
)
def test_generate_with_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs):
self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs)

# skip non lora models - generate does not work for prefix tuning, prompt tuning
@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
Expand Down
81 changes: 81 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,87 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
assert torch.allclose(logits_adapter0[1::3], logits_mixed[1::3], atol=atol, rtol=rtol)
assert torch.allclose(logits_adapter1[2::3], logits_mixed[2::3], atol=atol, rtol=rtol)

def _test_generate_with_mixed_adapter_batches_and_beam_search(self, model_id, config_cls, config_kwargs):
# Test generating with beam search and with mixing different adapters in a single batch by passing the
# adapter_names argument. See #2283.
if config_cls not in (LoraConfig,):
return pytest.skip(f"Mixed adapter batches not supported for {config_cls}")

config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

torch.manual_seed(0)
model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_name="adapter0").eval()
model.add_adapter("adapter1", config)

# In contrast to forward, for generate, it can sometimes happen that we get the same results as the base model
# even with LoRA applied because the impact of LoRA is not big enough. Therefore, use this "trick" to make LoRA
# stronger.
for name, param in model.named_parameters():
if model.base_model.prefix in name:
param.data.mul_(10.0)

model = model.to(self.torch_device).eval()

dummy_input = self.prepare_inputs_for_testing()
# ensure that we have at least 3 samples for this test
dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()}

gen_kwargs = {**dummy_input, "max_length": 20, "num_beams": 10, "early_stopping": True}
with torch.inference_mode():
with model.disable_adapter():
gen_base = model.generate(**gen_kwargs)

model.set_adapter("adapter0")
with torch.inference_mode():
gen_adapter0 = model.generate(**gen_kwargs)

model.set_adapter("adapter1")
with torch.inference_mode():
gen_adapter1 = model.generate(**gen_kwargs)

def remove_padding(seq, pad_value):
lst = list(seq)
while lst and (lst[-1] == pad_value):
lst.pop()
return lst

def gens_are_same(gen0, gen1):
# Special function to compare generations. We cannot use torch.allclose it will raise an error when sequence
# lengths differ. Morevoer, we need to remove the padding from the sequences. This is because, even though
# normally identical sequences should have the same length, when we do mixed adapter batches, each sample
# will be padded to the longest sequence in that mixed batch, which can be different from the longest
# sequence without mixed adapter batches.
pad_value = model.config.eos_token_id
for sample0, sample1 in zip(gen0, gen1):
sample0 = remove_padding(sample0, pad_value)
sample1 = remove_padding(sample1, pad_value)
if (len(sample0) != len(sample1)) or (sample0 != sample1):
# at least one sample differs, the generations are not identical
return False
return True

# sanity check that there are enough outputs and that they are different
assert len(gen_base) == len(gen_adapter0) == len(gen_adapter1)
assert len(gen_adapter1) >= 3
assert not gens_are_same(gen_base, gen_adapter0)
assert not gens_are_same(gen_base, gen_adapter1)
assert not gens_are_same(gen_adapter0, gen_adapter1)

# alternate between base model, adapter0, and adapter1
adapters = ["__base__", "adapter0", "adapter1"]
gen_kwargs["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))]

with torch.inference_mode():
gen_mixed = model.generate(**gen_kwargs)

assert gens_are_same(gen_base[::3], gen_mixed[::3])
assert gens_are_same(gen_adapter0[1::3], gen_mixed[1::3])
assert gens_are_same(gen_adapter1[2::3], gen_mixed[2::3])

def _test_generate(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
Expand Down

0 comments on commit aa3f41f

Please sign in to comment.