Skip to content

Commit

Permalink
Fixing adapter name for paths containing directories
Browse files Browse the repository at this point in the history
  • Loading branch information
leopedroso45 committed Jun 30, 2024
1 parent 64e5345 commit 1def683
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions sevsd/setup_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sevsd.setup_device import setup_device
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
from transformers import AutoFeatureExtractor
import os

def setup_pipeline(pretrained_model_link_or_path, loras, **kwargs):
r"""
Expand All @@ -17,7 +18,7 @@ def setup_pipeline(pretrained_model_link_or_path, loras, **kwargs):
StableDiffusionPipeline: The initialized Stable Diffusion pipeline ready for image generation.
Example:
pipeline = setup_pipeline("CompVis/stable-diffusion-v1-4", ["lora1.safetensors", "lora2.safetensors"])
pipeline = setup_pipeline("CompVis/stable-diffusion-v1-4", ["./loras/lora1.safetensors", "./loras/lora2.safetensors"])
Note:
- The function supports both remote model links and local `.safetensors` files.
Expand Down Expand Up @@ -55,7 +56,7 @@ def setup_pipeline(pretrained_model_link_or_path, loras, **kwargs):
set_loras = []
set_weights = []
for lora in loras:
adapter_name = lora.replace(".", "")
adapter_name = os.path.basename(lora).replace(".", "")
pipeline.load_lora_weights(
lora,
weight_name=lora,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_setup_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_setup_pipeline_with_loras(self, mock_scheduler_from_config, mock_setup_
mock_feature_extractor_instance = MagicMock()
mock_feature_extractor.return_value = mock_feature_extractor_instance
config = 'model_path'
loras = ['lora1.safetensors', 'lora2.safetensors']
loras = ['./loras/lora1.safetensors', './loras/lora2.safetensors']

pipeline = setup_pipeline(config, loras)

Expand All @@ -74,8 +74,8 @@ def test_setup_pipeline_with_loras(self, mock_scheduler_from_config, mock_setup_

# Check if LoRA weights were loaded and fused correctly
self.assertEqual(mock_pipeline.load_lora_weights.call_count, 2)
mock_pipeline.load_lora_weights.assert_any_call('lora1.safetensors', weight_name='lora1.safetensors', adapter_name='lora1safetensors')
mock_pipeline.load_lora_weights.assert_any_call('lora2.safetensors', weight_name='lora2.safetensors', adapter_name='lora2safetensors')
mock_pipeline.load_lora_weights.assert_any_call('./loras/lora1.safetensors', weight_name='./loras/lora1.safetensors', adapter_name='lora1safetensors')
mock_pipeline.load_lora_weights.assert_any_call('./loras/lora2.safetensors', weight_name='./loras/lora2.safetensors', adapter_name='lora2safetensors')
mock_pipeline.set_adapters.assert_called_once_with(['lora1safetensors', 'lora2safetensors'], [1.0, 1.0])
mock_pipeline.fuse_lora.assert_called_once()

Expand Down

0 comments on commit 1def683

Please sign in to comment.