Skip to content

Commit

Permalink
Improving multi lora config
Browse files Browse the repository at this point in the history
  • Loading branch information
leopedroso45 committed Jun 30, 2024
1 parent 77f01b3 commit 64e5345
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
18 changes: 15 additions & 3 deletions sevsd/setup_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,22 @@ def setup_pipeline(pretrained_model_link_or_path, loras, **kwargs):

if loras:
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline.unfuse_lora()

set_loras = []
set_weights = []
for lora in loras:
if lora.endswith(".safetensors"):
pipeline.load_lora_weights(lora)
pipeline.fuse_lora()
adapter_name = lora.replace(".", "")
pipeline.load_lora_weights(
lora,
weight_name=lora,
adapter_name=adapter_name
)
set_loras.append(adapter_name)
set_weights.append(1.0)

pipeline.set_adapters(set_loras, set_weights)
pipeline.fuse_lora()

pipeline.to(device)
pipeline.enable_attention_slicing()
Expand Down
10 changes: 7 additions & 3 deletions tests/test_setup_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ def test_setup_pipeline_with_loras(self, mock_scheduler_from_config, mock_setup_
)
mock_feature_extractor.assert_called_once_with(config)
mock_scheduler_from_config.assert_called_once()
mock_pipeline.load_lora_weights.assert_any_call('lora1.safetensors')
mock_pipeline.load_lora_weights.assert_any_call('lora2.safetensors')
mock_pipeline.fuse_lora.assert_called()

# Check if LoRA weights were loaded and fused correctly
self.assertEqual(mock_pipeline.load_lora_weights.call_count, 2)
mock_pipeline.load_lora_weights.assert_any_call('lora1.safetensors', weight_name='lora1.safetensors', adapter_name='lora1safetensors')
mock_pipeline.load_lora_weights.assert_any_call('lora2.safetensors', weight_name='lora2.safetensors', adapter_name='lora2safetensors')
mock_pipeline.set_adapters.assert_called_once_with(['lora1safetensors', 'lora2safetensors'], [1.0, 1.0])
mock_pipeline.fuse_lora.assert_called_once()

if __name__ == '__main__':
unittest.main()

0 comments on commit 64e5345

Please sign in to comment.