Skip to content

Commit

Permalink
Comprehensive type checking for from_pretrained kwargs (huggingface…
Browse files Browse the repository at this point in the history
…#10758)

* More robust from_pretrained init_kwargs type checking

* Corrected for Python 3.10

* Type checks subclasses and fixed type warnings

* More type corrections and skip tokenizer type checking

* make style && make quality

* Updated docs and types for Lumina pipelines

* Fixed check for empty signature

* changed location of helper functions

* make style

---------

Co-authored-by: hlky <[email protected]>
  • Loading branch information
guiyrt and hlky authored Feb 22, 2025
1 parent 64dec70 commit 9c7e205
Show file tree
Hide file tree
Showing 26 changed files with 208 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def __init__(
Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel,
],
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
Expand Down Expand Up @@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
image_encoder (`PreTrainedModel`, *optional*):
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
"""

Expand All @@ -202,8 +202,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipModel,
T5EncoderModel,
T5TokenizerFast,
)
Expand Down Expand Up @@ -223,8 +223,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: SiglipModel = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

from ...models import UNet1DModel
from ...schedulers import SchedulerMixin
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
Expand Down Expand Up @@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from ...models import UNet2DModel
from ...schedulers import DDIMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()

# make sure scheduler can always be converted to DDIM
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import torch

from ...models import UNet2DModel
from ...schedulers import DDPMScheduler
from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Expand Down Expand Up @@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):

model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
scheduler: RePaintScheduler
model_cpu_offload_seq = "unet"

def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def __init__(
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=MT5Tokenizer,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
):
super().__init__()

Expand Down
17 changes: 7 additions & 10 deletions src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoTokenizer
from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
Expand Down Expand Up @@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
text_encoder ([`GemmaPreTrainedModel`]):
Frozen Gemma text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
Expand Down Expand Up @@ -185,8 +182,8 @@ def __init__(
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
text_encoder: GemmaPreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
super().__init__()

Expand Down
17 changes: 7 additions & 10 deletions src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...image_processor import VaeImageProcessor
from ...loaders import Lumina2LoraLoaderMixin
Expand Down Expand Up @@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]):
Frozen text-encoder. Lumina-T2I uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
text_encoder ([`Gemma2PreTrainedModel`]):
Frozen Gemma2 text-encoder.
tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
Gemma tokenizer.
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
Expand All @@ -165,8 +162,8 @@ def __init__(
transformer: Lumina2Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
text_encoder: Gemma2PreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
super().__init__()

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
Expand Down Expand Up @@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: AutoModelForCausalLM,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
Expand Down
75 changes: 74 additions & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin

import requests
import torch
Expand Down Expand Up @@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")


def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
"""
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
the correct type as well.
"""
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)

# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)

if Any in class_or_tuple:
return True

obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}

# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)

else:
return False


def _get_detailed_type(obj: Any) -> Type:
"""
Gets a detailed type for an object, including nested types for collections.
"""
obj_type = type(obj)

if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type
Loading

0 comments on commit 9c7e205

Please sign in to comment.