Skip to content

Commit

Permalink
Add DnCNN with noise level input (#349)
Browse files Browse the repository at this point in the history
* fixed bug

* add DnCNN with nosie level input

* add DnCNN with nosie level input

* revise based on suggestion

* Docs improvements and clean up

* Minor improvement

* Minor improvement

* Correct submodule reference

* Add tests

* Update submodule

* Update contributor list

* Apply isort

* fix conflicts

* fix conflicts

* fix a minor issue

* fix issues

* add Optional in denoiser.py

* add more descriptions

* Update examples index

* Docstring changes

* Docstring and comment changes

* Minor docstring edit

* Update submodule

Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
wjgancn and bwohlberg authored Dec 12, 2022
1 parent fd26f49 commit 2e64c7d
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 36 deletions.
2 changes: 1 addition & 1 deletion data
3 changes: 2 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Usage Examples
.. toctree::
:maxdepth: 1

.. include:: include/exampledepend.rst
.. include:: exampledepend.rst


Organized by Application
Expand Down Expand Up @@ -73,6 +73,7 @@ Miscellaneous
examples/denoise_tv_pgm
examples/denoise_tv_multi
examples/denoise_cplx_tv_pdhg
examples/denoise_dncnn_universal
examples/video_rpca_admm


Expand Down
54 changes: 34 additions & 20 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ @Article {almeida-2013-deconvolving
doi = {10.1109/TIP.2013.2258354}
}

@Article {balke-2022-scico,
author = {Thilo Balke and Fernando Davis and Cristina
Garcia-Cardona and Soumendu Majee and Michael McCann
and Luke Pfister and Brendt Wohlberg},
title = {Scientific Computational Imaging Code ({SCICO})},
journal = {Journal of Open Source Software},
year = 2022,
volume = 7,
number = 78,
pages = 4722,
doi = {10.21105/joss.04722}
}

@Article {barzilai-1988-stepsize,
author = {Jonathan Barzilai and Jonathan M. Borwein},
title = {Two-point step size gradient methods},
Expand Down Expand Up @@ -70,7 +83,8 @@ @InCollection{beck-2010-gradient
publisher = {Cambridge University Press},
year = 2010,
doi = {10.1017/CBO9780511804458.003},
url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf}
url =
{http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf}
}

@Book {beck-2017-first,
Expand Down Expand Up @@ -331,10 +345,10 @@ @Article {kamilov-2022-plug
T. Buzzard and Brendt Wohlberg},
title = {Plug-and-Play Methods for Integrating Physical and
Learned Models in Computational Imaging},
journal = {IEEE Signal Processing Magazine},
journal = {IEEE Signal Processing Magazine},
year = 2022,
eprint = {arXiv:2203.17061},
note = {To appear.}
note = {To appear.}
}

@Article {liu-2018-first,
Expand All @@ -361,7 +375,7 @@ @Article {maggioni-2012-nonlocal
number = 1,
pages = {119--133},
year = 2012,
doi = {10.1109/TIP.2012.2210725}
doi = {10.1109/TIP.2012.2210725}
}

@InProceedings {makinen-2019-exact,
Expand Down Expand Up @@ -414,7 +428,7 @@ @Book {nocedal-2006-numerical

@Book {paganin-2006-coherent,
doi = {10.1093/acprof:oso/9780198567288.001.0001},
isbn = {9780198567288},
isbn = 9780198567288,
year = 2006,
month = Jan,
publisher = {Oxford University Press},
Expand Down Expand Up @@ -481,19 +495,6 @@ @Article {sauer-1993-local
doi = {10.1109/78.193196}
}

@Article {balke-2022-scico,
author = {Thilo Balke and Fernando Davis and Cristina
Garcia-Cardona and Soumendu Majee and Michael McCann
and Luke Pfister and Brendt Wohlberg},
title = {Scientific Computational Imaging Code ({SCICO})},
journal = {Journal of Open Source Software},
year = {2022},
volume = {7},
number = {78},
pages = {4722},
doi = {10.21105/joss.04722}
}

@Article {soulez-2016-proximity,
author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut
and Antony Schutz and Andr{\'{e}} Ferrari and
Expand All @@ -506,7 +507,6 @@ @Article {soulez-2016-proximity
volume = 55,
number = 26,
pages = {7412--7421}

}

@Article {sreehari-2016-plug,
Expand Down Expand Up @@ -541,7 +541,7 @@ @Article {valkonen-2014-primal
journal = {Inverse Problems},
volume = 30,
number = 5,
pages = {055012},
pages = 055012,
year = 2014,
doi = {10.1088/0266-5611/30/5/055012}
}
Expand Down Expand Up @@ -609,6 +609,20 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}

@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
title = {Plug-and-Play Image Restoration With Deep Denoiser
Prior},
journal = {IEEE Transactions on Pattern Analysis and Machine
Intelligence},
year = 2022,
volume = 44,
number = 10,
doi = {10.1109/TPAMI.2021.3088914},
pages = {6360--6376}
}

@Article {zhou-2006-adaptive,
author = {Bin Zhou and Li Gao and Yu-Hong Dai},
title = {Gradient Methods with Adaptive Step-Sizes},
Expand Down
1 change: 1 addition & 0 deletions docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ Contributors
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Andrew Leong <https://scholar.google.com/citations?user=-2wRWbcAAAAJ&hl=en>`_ (Improvements to optics module documentation)
- `Weijie Gan <https://github.com/wjgancn>`_ (Non-blind variant of DnCNN)
2 changes: 2 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ Miscellaneous
Comparison of Optimization Algorithms for Total Variation Denoising
`denoise_cplx_tv_pdhg.py <denoise_cplx_tv_pdhg.py>`_
Complex Total Variation Denoising
`denoise_dncnn_universal.py <denoise_dncnn_universal.py>`_
Comparison of DnCNN Variants for Image Denoising
`video_rpca_admm.py <video_rpca_admm.py>`_
Video Decomposition via Robust PCA

Expand Down
82 changes: 82 additions & 0 deletions examples/scripts/denoise_dncnn_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""
Comparison of DnCNN Variants for Image Denoising
================================================
This example demonstrates the solution of an image denoising problem
using DnCNN :cite:`zhang-2017-dncnn` networks trained for different noise
levels, as well as custom variants with fewer network layers, and with a
noise level input.
The networks trained for specific noise levels are labeled 6L, 6M, 6H,
17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M,
H} represent noise standard deviation of the training images (0.06, 0.10,
and 0.20 respectively). The networks with a noise standard deviation
input are labeled 6N and 17N, where {6, 17} again denote the number of
layers.
"""

import numpy as np

import jax

from xdesign import Foam, discrete_phantom

import scico.random
from scico import metric, plot
from scico.denoiser import DnCNN

"""
Create a ground truth image.
"""
np.random.seed(1234)
N = 512 # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jax.device_put(x_gt) # convert to jax array, push to GPU

"""
Test different DnCNN variants on images with different noise levels.
"""
print(" σ | variant | noisy image PSNR (dB) | denoised image PSNR (dB)")
for σ in [0.06, 0.10, 0.20]:
print("------+---------+-------------------------+-------------------------")
for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]:

# Instantiate a DnCNN.
denoiser = DnCNN(variant=variant)

# Generate a noisy image.
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

if variant in ["6N", "17N"]:
x_hat = denoiser(y, sigma=σ)
else:
x_hat = denoiser(y)

x_hat = np.clip(x_hat, a_min=0, a_max=1.0)

if variant[0] == "6":
variant += " " # add spaces to maintain alignment

print(
" %.2f | %s | %.2f | %.2f "
% (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat))
)


"""
Show reference and denoised images for σ=0.2 and variant=6N.
"""
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
plot.imview(x_gt, title="Reference", fig=fig, ax=ax[0])
plot.imview(y, title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1])
plot.imview(x_hat, title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2])
fig.show()

input("\nWaiting for input to close figures and exit")
1 change: 1 addition & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Miscellaneous
- denoise_tv_pgm.py
- denoise_tv_multi.py
- denoise_cplx_tv_pdhg.py
- denoise_dncnn_universal.py
- video_rpca_admm.py


Expand Down
76 changes: 62 additions & 14 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""Interfaces to standard denoisers."""


from typing import Any, Union
from typing import Any, Optional, Union

import numpy as np

Expand Down Expand Up @@ -53,8 +53,8 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False, profile: Union[BM3DPro
x: Input image. Expected to be a 2D array (gray-scale denoising)
or 3D array (color denoising). Higher-dimensional arrays are
tolerated only if the additional dimensions are singletons.
For color denoising, the color channel is assumed to be in the
last non-singleton dimension.
For color denoising, the color channel is assumed to be in
the last non-singleton dimension.
sigma: Noise parameter.
is_rgb: Flag indicating use of BM3D with a color transform.
Default: ``False``.
Expand Down Expand Up @@ -182,42 +182,75 @@ class DnCNN(FlaxMap):
Note that :class:`.DnCNNNet` represents an untrained form of the
generic DnCNN CNN structure, while this class represents a trained
form with six or seventeen layers.
The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not
have a noise level input. This implementation of DnCNN also supports
a custom variant that includes a noise standard deviation input,
`sigma`, which is included in the network as an additional channel
consisting of a constant array with value `sigma`. This network was
trained with image data on the range [0, 1], and with noise standard
deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet
:cite:`zhang-2021-plug`, another recent approach to including a noise
level input in a CNN denoiser, is based on a substantially different
network architecture.
"""

def __init__(self, variant: str = "6M"):
"""
Note that all DnCNN models are trained for single-channel image
input. Multi-channel input is supported via independent denoising
of each channel.
of each channel. Input images are expected to have pixel values
in the range [0, 1].
Args:
variant: Identify the DnCNN model to be used. Options are
'6L', '6M' (default), '6H', '17L', '17M', and '17H',
where the integer indicates the number of layers in the
network, and the postfix indicates the training noise
standard deviation: L (low) = 0.06, M (mid) = 0.1,
H (high) = 0.2, where the standard deviations are
with respect to data in the range [0, 1].
'6L', '6M' (default), '6H', '6N', '17L', '17M', '17H',
and '17N', where the integer indicates the number of
layers in the network, and the postfix indicates the
training noise standard deviation (with respect to data
in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10,
H (high) = 0.20, or N indicating that a noise standard
deviation input, `sigma`, is available.
"""
if variant not in ["6L", "6M", "6H", "17L", "17M", "17H"]:

self.variant = variant

if variant not in ["6L", "6M", "6H", "17L", "17M", "17H", "6N", "17N"]:
raise ValueError(f"Invalid value {variant} of parameter variant.")
if variant[0] == "6":
nlayer = 6
else:
nlayer = 17
model = DnCNNNet(depth=nlayer, channels=1, num_filters=64, dtype=np.float32)

channels = 2 if variant in ["6N", "17N"] else 1

model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32)
variables = load_weights(_flax_data_path("dncnn%s.npz" % variant))
super().__init__(model, variables)

def __call__(self, x: JaxArray) -> JaxArray:
def __call__(self, x: JaxArray, sigma: Optional[float] = None) -> JaxArray:
r"""Apply DnCNN denoiser.
Args:
x: Input array.
sigma: Noise standard deviation (for variants `6N` and `17N`).
Returns:
Denoised output.
"""
if sigma is not None and self.variant not in ["6N", "17N"]:
raise ValueError(
"A non-default value for the sigma parameter may "
"only be specified when the variant is 6N or 17N"
f"; got variant = {self.variant}."
)

if sigma is None and self.variant in ["6N", "17N"]:
raise ValueError(
"A float value must be specified for the sigma "
"parameter when the variant is 6N or 17N."
)

if snp.util.is_complex_dtype(x.dtype):
raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}.")

Expand All @@ -238,13 +271,28 @@ def __call__(self, x: JaxArray) -> JaxArray:
)

if x.ndim == 3:
y = snp.swapaxes(x, 0, -1)

if sigma is not None:
y = snp.stack([y, snp.ones_like(y) * sigma], -1)
else:
y = y[..., np.newaxis]

# swap channel axis to batch axis and add singleton axis at end
y = super().__call__(snp.swapaxes(x, 0, -1)[..., np.newaxis])
y = super().__call__(y)
# drop singleton axis and swap axes back to original positions
y = snp.swapaxes(y[..., 0], 0, -1)

else:
if sigma is not None:
x = snp.stack([x, snp.ones_like(x) * sigma], -1)
x = x[np.newaxis, ...]

y = super().__call__(x)

if sigma is not None:
y = y[0, ..., 0]

y = y.reshape(x_in_shape)

return y
Loading

0 comments on commit 2e64c7d

Please sign in to comment.