Skip to content

Commit

Permalink
Marigold Update: v1-1 models, Intrinsic Image Decomposition pipeline,…
Browse files Browse the repository at this point in the history
… documentation (huggingface#10884)

* minor documentation fixes of the depth and normals pipelines

* update license headers

* update model checkpoints in examples
fix missing prediction_type in register_to_config in the normals pipeline

* add initial marigold intrinsics pipeline
update comments about num_inference_steps and ensemble_size
minor fixes in comments of marigold normals and depth pipelines

* update uncertainty visualization to work with intrinsics

* integrate iid


---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
3 people authored Feb 26, 2025
1 parent f0ac7aa commit 3fab662
Show file tree
Hide file tree
Showing 14 changed files with 1,886 additions and 258 deletions.
123 changes: 89 additions & 34 deletions docs/source/en/api/pipelines/marigold.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [PAG](pag) | text2image |
Expand Down
485 changes: 312 additions & 173 deletions docs/source/en/using-diffusers/marigold_usage.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
"Lumina2Text2ImgPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
Expand Down Expand Up @@ -845,6 +846,7 @@
Lumina2Text2ImgPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
]
)
Expand Down Expand Up @@ -603,6 +604,7 @@
from .lumina2 import Lumina2Text2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/marigold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
else:
_import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"]
_import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"]
_import_structure["pipeline_marigold_intrinsics"] = ["MarigoldIntrinsicsOutput", "MarigoldIntrinsicsPipeline"]
_import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -35,6 +36,7 @@
else:
from .marigold_image_processing import MarigoldImageProcessor
from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline
from .pipeline_marigold_intrinsics import MarigoldIntrinsicsOutput, MarigoldIntrinsicsPipeline
from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline

else:
Expand Down
141 changes: 127 additions & 14 deletions src/diffusers/pipelines/marigold/marigold_image_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
from typing import List, Optional, Tuple, Union
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL
Expand Down Expand Up @@ -379,7 +397,7 @@ def visualize_depth(
val_min: float = 0.0,
val_max: float = 1.0,
color_map: str = "Spectral",
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.
Expand All @@ -391,7 +409,7 @@ def visualize_depth(
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
depth prediction into colored representation.
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization.
Returns: `List[PIL.Image.Image]` with depth maps visualization.
"""
if val_max <= val_min:
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].")
Expand Down Expand Up @@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
val_min: float = 0.0,
val_max: float = 1.0,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
def export_depth_to_16bit_png_one(img, idx=None):
prefix = "Depth" + (f"[{idx}]" if idx else "")
if not isinstance(img, np.ndarray) and not torch.is_tensor(img):
Expand Down Expand Up @@ -478,7 +496,7 @@ def visualize_normals(
flip_x: bool = False,
flip_y: bool = False,
flip_z: bool = False,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.
Expand All @@ -492,7 +510,7 @@ def visualize_normals(
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
Default direction is facing the observer.
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization.
Returns: `List[PIL.Image.Image]` with surface normals visualization.
"""
flip_vec = None
if any((flip_x, flip_y, flip_z)):
Expand Down Expand Up @@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
else:
raise ValueError(f"Unexpected input type: {type(normals)}")

@staticmethod
def visualize_intrinsics(
prediction: Union[
np.ndarray,
torch.Tensor,
List[np.ndarray],
List[torch.Tensor],
],
target_properties: Dict[str, Any],
color_map: Union[str, Dict[str, str]] = "binary",
) -> List[Dict[str, PIL.Image.Image]]:
"""
Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.
Args:
prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
Intrinsic image decomposition.
target_properties (`Dict[str, Any]`):
Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
`prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
missing modalities), `up_to_scale: bool`, one for each target and sub-target.
color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
Color map used to convert a single-channel predictions into colored representations. When a dictionary
is passed, each modality can be colored with its own color map.
Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
"""
if "target_names" not in target_properties:
raise ValueError("Missing `target_names` in target_properties")
if not isinstance(color_map, str) and not (
isinstance(color_map, dict)
and all(isinstance(k, str) and isinstance(v, str) for k, v in color_map.items())
):
raise ValueError("`color_map` must be a string or a dictionary of strings")
n_targets = len(target_properties["target_names"])

def visualize_targets_one(images, idx=None):
# img: [T, 3, H, W]
out = {}
for target_name, img in zip(target_properties["target_names"], images):
img = img.permute(1, 2, 0) # [H, W, 3]
prediction_space = target_properties[target_name].get("prediction_space", "srgb")
if prediction_space == "stack":
sub_target_names = target_properties[target_name]["sub_target_names"]
if len(sub_target_names) != 3 or any(
not (isinstance(s, str) or s is None) for s in sub_target_names
):
raise ValueError(f"Unexpected target sub-names {sub_target_names} in {target_name}")
for i, sub_target_name in enumerate(sub_target_names):
if sub_target_name is None:
continue
sub_img = img[:, :, i]
sub_prediction_space = target_properties[sub_target_name].get("prediction_space", "srgb")
if sub_prediction_space == "linear":
sub_up_to_scale = target_properties[sub_target_name].get("up_to_scale", False)
if sub_up_to_scale:
sub_img = sub_img / max(sub_img.max().item(), 1e-6)
sub_img = sub_img ** (1 / 2.2)
cmap_name = (
color_map if isinstance(color_map, str) else color_map.get(sub_target_name, "binary")
)
sub_img = MarigoldImageProcessor.colormap(sub_img, cmap=cmap_name, bytes=True)
sub_img = PIL.Image.fromarray(sub_img.cpu().numpy())
out[sub_target_name] = sub_img
elif prediction_space == "linear":
up_to_scale = target_properties[target_name].get("up_to_scale", False)
if up_to_scale:
img = img / max(img.max().item(), 1e-6)
img = img ** (1 / 2.2)
elif prediction_space == "srgb":
pass
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy()
img = PIL.Image.fromarray(img)
out[target_name] = img
return out

if prediction is None or isinstance(prediction, list) and any(o is None for o in prediction):
raise ValueError("Input prediction is `None`")
if isinstance(prediction, (np.ndarray, torch.Tensor)):
prediction = MarigoldImageProcessor.expand_tensor_or_array(prediction)
if isinstance(prediction, np.ndarray):
prediction = MarigoldImageProcessor.numpy_to_pt(prediction) # [N*T,3,H,W]
if not (prediction.ndim == 4 and prediction.shape[1] == 3 and prediction.shape[0] % n_targets == 0):
raise ValueError(f"Unexpected input shape={prediction.shape}, expecting [N*T,3,H,W].")
N_T, _, H, W = prediction.shape
N = N_T // n_targets
prediction = prediction.reshape(N, n_targets, 3, H, W)
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
elif isinstance(prediction, list):
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
else:
raise ValueError(f"Unexpected input type: {type(prediction)}")

@staticmethod
def visualize_uncertainty(
uncertainty: Union[
Expand All @@ -537,24 +648,26 @@ def visualize_uncertainty(
List[torch.Tensor],
],
saturation_percentile=95,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`.
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
`MarigoldIntrinsicsPipeline`.
Args:
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
Uncertainty maps.
saturation_percentile (`int`, *optional*, defaults to `95`):
Specifies the percentile uncertainty value visualized with maximum intensity.
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization.
Returns: `List[PIL.Image.Image]` with uncertainty visualization.
"""

def visualize_uncertainty_one(img, idx=None):
prefix = "Uncertainty" + (f"[{idx}]" if idx else "")
if img.min() < 0:
raise ValueError(f"{prefix}: unexected data range, min={img.min()}.")
img = img.squeeze(0).cpu().numpy()
raise ValueError(f"{prefix}: unexpected data range, min={img.min()}.")
img = img.permute(1, 2, 0) # [H,W,C]
img = img.squeeze(2).cpu().numpy() # [H,W] or [H,W,3]
saturation_value = np.percentile(img, saturation_percentile)
img = np.clip(img * 255 / saturation_value, 0, 255)
img = img.astype(np.uint8)
Expand All @@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
if isinstance(uncertainty, (np.ndarray, torch.Tensor)):
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty)
if isinstance(uncertainty, np.ndarray):
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W]
if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1):
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].")
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,C,H,W]
if not (uncertainty.ndim == 4 and uncertainty.shape[1] in (1, 3)):
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,C,H,W] with C in (1,3).")
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
elif isinstance(uncertainty, list):
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
Expand Down
34 changes: 19 additions & 15 deletions src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
# Marigold project website: https://marigoldmonodepth.github.io
# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -64,7 +64,7 @@
>>> import torch
>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
... "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")
>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
Expand All @@ -86,11 +86,12 @@ class MarigoldDepthOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height
\times width$, regardless of whether the images were passed as a 4D array or a list.
Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
\times 1 \times height \times width$.
\times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
Expand Down Expand Up @@ -208,6 +209,11 @@ def check_inputs(
output_type: str,
output_uncertainty: bool,
) -> int:
actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if actual_vae_scale_factor != self.vae_scale_factor:
raise ValueError(
f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
)
if num_inference_steps is None:
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
if num_inference_steps < 1:
Expand Down Expand Up @@ -320,6 +326,7 @@ def check_inputs(

return num_images

@torch.compiler.disable
def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
Expand Down Expand Up @@ -370,11 +377,9 @@ def __call__(
same width and height.
num_inference_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
for Marigold-LCM models.
selection.
ensemble_size (`int`, defaults to `1`):
Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
faster inference.
Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
processing_resolution (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, matches the larger input image dimension. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
Expand Down Expand Up @@ -486,9 +491,7 @@ def __call__(
# `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
# into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
# reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
# code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
# as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
# noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
# Model invocation: self.vae.encoder.
image_latent, pred_latent = self.prepare_latents(
Expand Down Expand Up @@ -733,6 +736,7 @@ def init_param(depth: torch.Tensor):
param = init_s.cpu().numpy()
else:
raise ValueError("Unrecognized alignment.")
param = param.astype(np.float64)

return param

Expand Down Expand Up @@ -775,7 +779,7 @@ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:

if regularizer_strength > 0:
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
err_near = (0.0 - prediction.min()).abs().item()
err_near = prediction.min().abs().item()
err_far = (1.0 - prediction.max()).abs().item()
cost += (err_near + err_far) * regularizer_strength

Expand Down
Loading

0 comments on commit 3fab662

Please sign in to comment.