Skip to content

Commit

Permalink
[FEAT] Model loading refactor (huggingface#10604)
Browse files Browse the repository at this point in the history
* first draft model loading refactor

* revert name change

* fix bnb

* revert name

* fix dduf

* fix huanyan

* style

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* suggestions from reviews

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: YiYi Xu <[email protected]>

* remove safetensors check

* fix default value

* more fix from suggestions

* revert logic for single file

* style

* typing + fix couple of issues

* improve speed

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Aryan <[email protected]>

* fp8 dtype

* add tests

* rename resolved_archive_file to resolved_model_file

* format

* map_location default cpu

* add utility function

* switch to smaller model + test inference

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* rm comment

* add log

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* add decorator

* cosine sim instead

* fix use_keep_in_fp32_modules

* comm

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Aryan <[email protected]>
  • Loading branch information
4 people authored Feb 19, 2025
1 parent 6fe05b9 commit f5929e0
Show file tree
Hide file tree
Showing 12 changed files with 844 additions and 515 deletions.
20 changes: 14 additions & 6 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@


if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate import dispatch_model, init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta

Expand Down Expand Up @@ -366,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
keep_in_fp32_modules=keep_in_fp32_modules,
)

device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
unexpected_keys=unexpected_keys,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand All @@ -400,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

model.eval()

if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)

return model
24 changes: 3 additions & 21 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)

if torch_dtype is not None:
model.to(torch_dtype)
Expand Down Expand Up @@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
Loading

0 comments on commit f5929e0

Please sign in to comment.