Skip to content

Commit

Permalink
[Tests] fix: lumina2 lora fuse_nan test (huggingface#10911)
Browse files Browse the repository at this point in the history
fix: lumina2 lora fuse_nan test
  • Loading branch information
sayakpaul authored Feb 26, 2025
1 parent 3fab662 commit 764d7ed
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions tests/lora/test_lora_layers_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import sys
import unittest

import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GemmaForCausalLM

Expand All @@ -24,12 +26,12 @@
Lumina2Text2ImgPipeline,
Lumina2Transformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device


sys.path.append(".")

from utils import PeftLoraLoaderMixinTests # noqa: E402
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402


@require_peft_backend
Expand Down Expand Up @@ -130,3 +132,41 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass

@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]

self.assertTrue(np.isnan(out).all())

0 comments on commit 764d7ed

Please sign in to comment.