Skip to content

Commit

Permalink
[Utils] add utilities for checking if certain utilities are properly …
Browse files Browse the repository at this point in the history
…documented (huggingface#7763)

* add; utility to check if attn_procs,norms,acts are properly documented.

* add support listing to the workflows.

* change to 2024.

* small fixes.

* does adding detailed docstrings help?

* uncomment image processor check

* quality

* fix, thanks to @Mishig.

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* style

* JointAttnProcessor2_0

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* Update docs/source/en/api/normalization.md

Co-authored-by: hlky <[email protected]>

---------

Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: hlky <[email protected]>
  • Loading branch information
3 people authored Feb 20, 2025
1 parent f10d3c6 commit f550745
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
Expand Down
13 changes: 13 additions & 0 deletions docs/source/en/api/activations.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers.
## ApproximateGELU

[[autodoc]] models.activations.ApproximateGELU


## SwiGLU

[[autodoc]] models.activations.SwiGLU

## FP32SiLU

[[autodoc]] models.activations.FP32SiLU

## LinearActivation

[[autodoc]] models.activations.LinearActivation
17 changes: 17 additions & 0 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,20 @@ An attention processor is a class for applying different types of attention mech
## XLAFlashAttnProcessor2_0

[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0

## XFormersJointAttnProcessor

[[autodoc]] models.attention_processor.XFormersJointAttnProcessor

## IPAdapterXFormersAttnProcessor

[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor

## FluxIPAdapterJointAttnProcessor2_0

[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0


## XLAFluxFlashAttnProcessor2_0

[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0
40 changes: 40 additions & 0 deletions docs/source/en/api/normalization.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,43 @@ Customized normalization layers for supporting various models in 🤗 Diffusers.
## AdaGroupNorm

[[autodoc]] models.normalization.AdaGroupNorm

## AdaLayerNormContinuous

[[autodoc]] models.normalization.AdaLayerNormContinuous

## RMSNorm

[[autodoc]] models.normalization.RMSNorm

## GlobalResponseNorm

[[autodoc]] models.normalization.GlobalResponseNorm


## LuminaLayerNormContinuous
[[autodoc]] models.normalization.LuminaLayerNormContinuous

## SD35AdaLayerNormZeroX
[[autodoc]] models.normalization.SD35AdaLayerNormZeroX

## AdaLayerNormZeroSingle
[[autodoc]] models.normalization.AdaLayerNormZeroSingle

## LuminaRMSNormZero
[[autodoc]] models.normalization.LuminaRMSNormZero

## LpNorm
[[autodoc]] models.normalization.LpNorm

## CogView3PlusAdaLayerNormZeroTextImage
[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage

## CogVideoXLayerNormZero
[[autodoc]] models.normalization.CogVideoXLayerNormZero

## MochiRMSNormZero
[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero

## MochiRMSNorm
[[autodoc]] models.normalization.MochiRMSNorm
43 changes: 43 additions & 0 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,20 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:


class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""

def __init__(
self,
embedding_dim: int,
Expand Down Expand Up @@ -462,6 +476,17 @@ def forward(
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
r"""
LayerNorm with the bias parameter.
Args:
dim (`int`): Dimensionality to use for the parameters.
eps (`float`, defaults to 1e-5): Epsilon factor.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
"""

def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()

Expand All @@ -484,6 +509,17 @@ def forward(self, input):


class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
Args:
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
bias (`bool`, defaults to False): If also training the `bias` param.
"""

def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__()

Expand Down Expand Up @@ -573,6 +609,13 @@ def forward(self, hidden_states):


class GlobalResponseNorm(nn.Module):
r"""
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
Args:
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
"""

# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
Expand Down
68 changes: 68 additions & 0 deletions tests/others/test_check_support_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import sys
import unittest
from unittest.mock import mock_open, patch


git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))

from check_support_list import check_documentation # noqa: E402


class TestCheckSupportList(unittest.TestCase):
def setUp(self):
# Mock doc and source contents that we can reuse
self.doc_content = """# Documentation
## FooProcessor
[[autodoc]] module.FooProcessor
## BarProcessor
[[autodoc]] module.BarProcessor
"""
self.source_content = """
class FooProcessor(nn.Module):
pass
class BarProcessor(nn.Module):
pass
"""

def test_check_documentation_all_documented(self):
# In this test, both FooProcessor and BarProcessor are documented
with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
doc_file.side_effect = [
mock_open(read_data=self.doc_content).return_value,
mock_open(read_data=self.source_content).return_value,
]

undocumented = check_documentation(
doc_path="fake_doc.md",
src_path="fake_source.py",
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
)
self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")

def test_check_documentation_missing_class(self):
# In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
doc_content_missing = """# Documentation
## FooProcessor
[[autodoc]] module.FooProcessor
"""
with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
doc_file.side_effect = [
mock_open(read_data=doc_content_missing).return_value,
mock_open(read_data=self.source_content).return_value,
]

undocumented = check_documentation(
doc_path="fake_doc.md",
src_path="fake_source.py",
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
)
self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")
124 changes: 124 additions & 0 deletions utils/check_support_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""
Utility that checks that modules like attention processors are listed in the documentation file.
```bash
python utils/check_support_list.py
```
It has no auto-fix mode.
"""

import os
import re


# All paths are set with the intent that you run this script from the root of the repo
REPO_PATH = "."


def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
"""
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
Returns a list of documented class names (just the class name portion).
"""
with open(os.path.join(REPO_PATH, doc_path), "r") as f:
doctext = f.read()
matches = re.findall(autodoc_regex, doctext)
return [match.split(".")[-1] for match in matches]


def read_source_classes(src_path, class_regex, exclude_conditions=None):
"""
Reads class names from a source file using a regex that captures class definitions.
Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
"""
if exclude_conditions is None:
exclude_conditions = []
with open(os.path.join(REPO_PATH, src_path), "r") as f:
doctext = f.read()
classes = re.findall(class_regex, doctext)
# Filter out classes that meet any of the exclude conditions
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
return filtered_classes


def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
"""
Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
Returns a set of undocumented class names.
"""
documented = set(read_documented_classes(doc_path, doc_regex))
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))

# Find which classes in source are not documented in a deterministic way.
undocumented = sorted(source_classes - documented)
return undocumented


if __name__ == "__main__":
# Define the checks we need to perform
checks = {
"Attention Processors": {
"doc_path": "docs/source/en/api/attnprocessor.md",
"src_path": "src/diffusers/models/attention_processor.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
},
"Image Processors": {
"doc_path": "docs/source/en/api/image_processor.md",
"src_path": "src/diffusers/image_processor.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
},
"Activations": {
"doc_path": "docs/source/en/api/activations.md",
"src_path": "src/diffusers/models/activations.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
},
"Normalizations": {
"doc_path": "docs/source/en/api/normalization.md",
"src_path": "src/diffusers/models/normalization.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
"exclude_conditions": [
# Exclude LayerNorm as it's an intentional exception
lambda c: c == "LayerNorm"
],
},
"LoRA Mixins": {
"doc_path": "docs/source/en/api/loaders/lora.md",
"src_path": "src/diffusers/loaders/lora_pipeline.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
},
}

missing_items = {}
for category, params in checks.items():
undocumented = check_documentation(
doc_path=params["doc_path"],
src_path=params["src_path"],
doc_regex=params["doc_regex"],
src_regex=params["src_regex"],
exclude_conditions=params.get("exclude_conditions"),
)
if undocumented:
missing_items[category] = undocumented

# If we have any missing items, raise a single combined error
if missing_items:
error_msg = ["Some classes are not documented properly:\n"]
for category, classes in missing_items.items():
error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
raise ValueError("\n".join(error_msg))

0 comments on commit f550745

Please sign in to comment.