Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify BlockArray implementation #259

Merged
merged 37 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
31c9489
Rough in a new BlockArray based on tuple
Michael-T-McCann Mar 22, 2022
39c6864
Add old blockarray which is needed to even import scico
Michael-T-McCann Mar 22, 2022
420ba24
Start on snp
Michael-T-McCann Mar 25, 2022
c22f67f
Wrap all of jnp, add total reduction mechanism
Michael-T-McCann Mar 29, 2022
97ecbdd
Finish reductions
Michael-T-McCann Mar 29, 2022
22dbce0
Finish first pass over the tests
Michael-T-McCann Mar 29, 2022
f78a11f
Start on entire test suite
Michael-T-McCann Mar 29, 2022
8070984
Remove automatic reduction wrapping
Michael-T-McCann Apr 7, 2022
2d54727
Add files
Michael-T-McCann Apr 8, 2022
6fbe32f
Move BlockArray to scico.numpy, like DeviceArray in jax.numpy
Michael-T-McCann Apr 8, 2022
ccbfd60
Change reductions back to ravel
Michael-T-McCann Apr 9, 2022
9af42bd
Hack away at the tests
Michael-T-McCann Apr 12, 2022
99b6127
Fix tests, add dtype temporary fix
Michael-T-McCann Apr 12, 2022
ee5e39a
Get tests passing
Michael-T-McCann Apr 12, 2022
8c6780d
Add missing module
Michael-T-McCann Apr 13, 2022
b05d70d
Work on docs
Michael-T-McCann Apr 13, 2022
4d8b60a
Refactor out function lists
Michael-T-McCann Apr 14, 2022
ae76e8d
Start on new BlockArray tests
Michael-T-McCann Apr 14, 2022
e0b1b03
Wrap additional functions
Michael-T-McCann Apr 14, 2022
72f4d20
Stop tests ending on first failure
Michael-T-McCann Apr 14, 2022
0600cfa
Make ray test less (not?) stochastic
Michael-T-McCann Apr 14, 2022
8fe6930
Add matmul test
Michael-T-McCann Apr 14, 2022
fb76d01
Handle CodeFactor
Michael-T-McCann Apr 14, 2022
ffe96ce
Make ray test less (not?) stochastic
Michael-T-McCann Apr 14, 2022
e017bdf
Fix doc tests
Michael-T-McCann Apr 14, 2022
c61cbe6
Update example scripts
Michael-T-McCann Apr 19, 2022
8d97d1f
Work on docs
Michael-T-McCann Apr 20, 2022
23cb53a
Move array, improve blockarray docs
Michael-T-McCann Apr 20, 2022
8326630
Combine BlockArray tests
Michael-T-McCann Apr 20, 2022
c5e595c
Continue removing scico.array, allow snp.blockarray syntax
Michael-T-McCann Apr 20, 2022
8cb7b31
Remove imports
Michael-T-McCann Apr 20, 2022
f9bac79
Trigger lint
Michael-T-McCann Apr 20, 2022
87baffa
Clean up docs
Michael-T-McCann Apr 20, 2022
81c5d32
Add jits
Michael-T-McCann Apr 21, 2022
5a11b21
consistent spelling of #-dimensional
tbalke Apr 21, 2022
1c0f1e4
consistent spelling of #-dimensional
tbalke Apr 22, 2022
715711c
Thilo review
Michael-T-McCann Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:
>>> import scico
>>> import scico.numpy as snp
>>> f = lambda x: snp.linalg.norm(x)**2
>>> scico.grad(f)(snp.zeros(2)) #
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) #
DeviceArray([nan, nan], dtype=float32)

This can be fixed by defining the squared :math:`\ell_2` norm directly as
Expand All @@ -194,7 +194,7 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as
::

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2))
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))
DeviceArray([0., 0.], dtype=float32)

An alternative is to define a `custom derivative rule <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_ to enforce a particular derivative convention at a point.
12 changes: 7 additions & 5 deletions examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/Usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down Expand Up @@ -39,8 +39,8 @@
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, operator, plot
from scico.array import ensure_on_device
from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.numpy.util import ensure_on_device
from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize
from scico.typing import JaxArray
from scico.util import device_info
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(
super().__init__(y=y, A=A, scale=1.0)
self.lmbda = lmbda

@jax.jit
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:

xint = self.y - self.lmbda * self.A(x)
Expand All @@ -117,14 +118,15 @@ class IsoProjector(functional.Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0

@jax.jit
def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0))

x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp)
out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1]))
x_out_1 = jax.ops.index_update(x_out, jax.ops.index[0, :, -1], out1)
x_out = x_out.at[0, :, -1].set(out1)
out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :]))
x_out = jax.ops.index_update(x_out_1, jax.ops.index[1, -1, :], out2)
x_out = x_out.at[1, -1, :].set(out2)

return x_out

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/sparsecode_poisson_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
$I(\mathbf{x}^{(0)} \geq 0)$ is the non-negative indicator.
This example also demonstrates the application of
[blockarray.BlockArray](../_autosummary/scico.blockarray.rst#scico.blockarray.BlockArray),
[blockarray.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray),
[functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional),
and
[functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional)
Expand All @@ -40,7 +40,7 @@
import scico.numpy as snp
import scico.random
from scico import functional, loss, plot
from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.operator import Operator
from scico.optimize.pgm import (
AcceleratedPGM,
Expand Down
2 changes: 1 addition & 1 deletion scico/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flax.core import Scope # noqa
from flax.linen.module import _Sentinel # noqa

from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.typing import JaxArray

# The imports of Scope and _Sentinel (above) and the definition of Module
Expand Down
14 changes: 7 additions & 7 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import scico.numpy as snp
from scico._autograd import linear_adjoint
from scico.array import is_complex_dtype, is_nested
from scico.blockarray import BlockArray, block_sizes
from scico.numpy import BlockArray
from scico.numpy.util import is_complex_dtype, is_nested, shape_to_size
from scico.typing import BlockShape, DType, JaxArray, Shape


Expand Down Expand Up @@ -152,8 +152,8 @@ def __init__(
# Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m}
# If the function returns a BlockArray we need to compute the size of each block,
# then sum.
self.input_size = int(np.sum(block_sizes(self.input_shape)))
self.output_size = int(np.sum(block_sizes(self.output_shape)))
self.input_size = shape_to_size(self.input_shape)
self.output_size = shape_to_size(self.output_shape)

self.shape = (self.output_shape, self.input_shape)
self.matrix_shape = (self.output_size, self.input_size)
Expand Down Expand Up @@ -320,8 +320,8 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator:
def concat_args(args):
# Creates a blockarray with args and the frozen value in the correct place
# Eg if this operator takes a blockarray with two blocks, then
# concat_args(args) = BlockArray.array([val, args]) if argnum = 0
# concat_args(args) = BlockArray.array([args, val]) if argnum = 1
# concat_args(args) = snp.blockarray([val, args]) if argnum = 0
# concat_args(args) = snp.blockarray([args, val]) if argnum = 1

if isinstance(args, (DeviceArray, np.ndarray)):
# In the case that the original operator takes a blcokarray with two
Expand All @@ -336,7 +336,7 @@ def concat_args(args):
arg_list.append(args[i - 1])
else:
arg_list.append(val)
return BlockArray.array(arg_list)
return snp.blockarray(arg_list)

return Operator(
input_shape=input_shape,
Expand Down
Loading