Skip to content

Commit

Permalink
fix issues with loading, add test for pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jun 7, 2022
1 parent fe99460 commit d8287fc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def get_config_dict(
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
import ipdb; ipdb.set_trace()
init_dict = {}
for key in expected_keys:
if key in kwargs:
Expand Down
20 changes: 14 additions & 6 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,23 @@ def register_modules(self, **kwargs):
class_name = module.__class__.__name__

register_dict = {name: (library, class_name)}
register_dict["_module"] = self.__module__


# save model index config
self.register(**register_dict)

# set models
setattr(self, name, module)

register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict)

def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)

model_index_dict = self._dict_to_save
model_index_dict.pop("_class_name")
model_index_dict.pop("_module")

for name, (library_name, class_name) in self._dict_to_save.items():
importable_classes = LOADABLE_CLASSES[library_name]
Expand Down Expand Up @@ -98,12 +102,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
cached_folder = pretrained_model_name_or_path

config_dict = cls.get_config_dict(cached_folder)

module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)

if class_name_ == cls.__name__:
pipeline_class = cls
else:
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)


init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
import ipdb; ipdb.set_trace()
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

init_kwargs = {}

Expand Down Expand Up @@ -132,6 +141,5 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)


model = class_obj(**init_kwargs)
model = pipeline_class(**init_kwargs)
return model
Empty file added tests/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch

from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddpm.modeling_ddpm import DDPM


global_rng = random.Random()
Expand Down Expand Up @@ -199,3 +201,46 @@ def test_sample_fast(self):
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3


class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
schedular = GaussianDDPMScheduler(timesteps=10)

ddpm = DDPM(model, schedular)

with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPM.from_pretrained(tmpdirname)

generator = torch.Generator()
generator = generator.manual_seed(669472945848556)

image = ddpm(generator)
generator = generator.manual_seed(669472945848556)
new_image = new_ddpm(generator)

assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"


@slow
def test_from_pretrained_hub(self):
model_path = "fusing/ddpm-cifar10"

ddpm = DDPM.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)

ddpm.noise_scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10


generator = torch.Generator(device=torch_device)
generator = generator.manual_seed(669472945848556)

image = ddpm(generator)
generator = generator.manual_seed(669472945848556)
new_image = ddpm_from_hub(generator)

assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

0 comments on commit d8287fc

Please sign in to comment.