Skip to content

Commit

Permalink
[Single File] Add Single File support for Lumina Image 2.0 Transformer (
Browse files Browse the repository at this point in the history
huggingface#10781)

* update

* update
  • Loading branch information
DN6 authored Feb 12, 2025
1 parent 067eab1 commit 28f48f4
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 1 deletion.
50 changes: 50 additions & 0 deletions docs/source/en/api/pipelines/lumina2.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,56 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)

</Tip>

## Using Single File loading with Lumina Image 2.0

Single file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel`

```python
import torch
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline

ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path, torch_dtype=torch.bfloat16
)

pipe = Lumina2Text2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = pipe(
"a cat holding a sign that says hello",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save("lumina-single-file.png")

```

## Using GGUF Quantized Checkpoints with Lumina Image 2.0

GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`

```python
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig

ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)

pipe = Lumina2Text2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = pipe(
"a cat holding a sign that says hello",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save("lumina-gguf.png")
```

## Lumina2Text2ImgPipeline

[[autodoc]] Lumina2Text2ImgPipeline
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
Expand Down Expand Up @@ -111,6 +112,10 @@
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"Lumina2Transformer2DModel": {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
}


Expand Down
77 changes: 77 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -174,6 +175,7 @@
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -657,6 +659,9 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"

else:
model_type = "v1"

Expand Down Expand Up @@ -2798,3 +2803,75 @@ def calculate_layers(keys, key_prefix):
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")

return converted_state_dict


def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}

# Original Lumina-Image-2 has an extra norm paramter that is unused
# We just remove it here
checkpoint.pop("norm_final.weight", None)

# Comfy checkpoints add this prefix
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

LUMINA_KEY_MAP = {
"cap_embedder": "time_caption_embed.caption_embedder",
"t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
"t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
"attention": "attn",
".out.": ".to_out.0.",
"k_norm": "norm_k",
"q_norm": "norm_q",
"w1": "linear_1",
"w2": "linear_2",
"w3": "linear_3",
"adaLN_modulation.1": "norm1.linear",
}
ATTENTION_NORM_MAP = {
"attention_norm1": "norm1.norm",
"attention_norm2": "norm2",
}
CONTEXT_REFINER_MAP = {
"context_refiner.0.attention_norm1": "context_refiner.0.norm1",
"context_refiner.0.attention_norm2": "context_refiner.0.norm2",
"context_refiner.1.attention_norm1": "context_refiner.1.norm1",
"context_refiner.1.attention_norm2": "context_refiner.1.norm2",
}
FINAL_LAYER_MAP = {
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.linear": "norm_out.linear_2",
}

def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
q_dim = 2304
k_dim = v_dim = 768

to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)

return {
diffusers_key.replace("qkv", "to_q"): to_q,
diffusers_key.replace("qkv", "to_k"): to_k,
diffusers_key.replace("qkv", "to_v"): to_v,
}

for key in keys:
diffusers_key = key
for k, v in CONTEXT_REFINER_MAP.items():
diffusers_key = diffusers_key.replace(k, v)
for k, v in FINAL_LAYER_MAP.items():
diffusers_key = diffusers_key.replace(k, v)
for k, v in ATTENTION_NORM_MAP.items():
diffusers_key = diffusers_key.replace(k, v)
for k, v in LUMINA_KEY_MAP.items():
diffusers_key = diffusers_key.replace(k, v)

if "qkv" in diffusers_key:
converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
else:
converted_state_dict[diffusers_key] = checkpoint.pop(key)

return converted_state_dict
3 changes: 2 additions & 1 deletion src/diffusers/models/transformers/transformer_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
Expand Down Expand Up @@ -333,7 +334,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
)


class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
Lumina2NextDiT: Diffusion model with a Transformer backbone.
Expand Down
74 changes: 74 additions & 0 deletions tests/single_file/test_lumina2_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

from diffusers import (
Lumina2Transformer2DModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
torch_device,
)


enable_full_determinism()


@require_torch_accelerator
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [
"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
]

repo_id = "Alpha-VLLM/Lumina-Image-2.0"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
model = self.model_class.from_single_file(ckpt_path)

del model
gc.collect()
torch.cuda.empty_cache()

0 comments on commit 28f48f4

Please sign in to comment.