Skip to content

Commit

Permalink
Make compatible with jax_0.4.1 (#372)
Browse files Browse the repository at this point in the history
* Replace checks for DeviceArray with checks for ndarray, other fixes

* Skip some doctests to make compatible with both old and new scico
  • Loading branch information
Michael-T-McCann authored Dec 22, 2022
1 parent d05ff9c commit 029a595
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 75 deletions.
14 changes: 7 additions & 7 deletions docs/source/include/blockarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ appropriate. For example,
((2, 3), (3,))

>>> x * 2 # returns BlockArray # doctest: +ELLIPSIS
BlockArray([DeviceArray([[ 2, 6, 14],
[ 4, 4, 2]], dtype=...), DeviceArray([ 4, 8, 16], dtype=...)])
BlockArray([...Array([[ 2, 6, 14],
[ 4, 4, 2]], dtype=...), ...Array([ 4, 8, 16], dtype=...)])

>>> y = snp.blockarray((
... [[.2],
... [.3]],
... [.4]
... ))

>>> x + y # returns BlockArray # doctest: +ELLIPSIS
BlockArray([DeviceArray([[1.2, 3.2, 7.2],
[2.3, 2.3, 1.3]], dtype=...), DeviceArray([2.4, 4.4, 8.4], dtype=...)])
>>> x + y # returns BlockArray # doctest: +ELLIPSIS
BlockArray([...Array([[1.2, 3.2, 7.2],
[2.3, 2.3, 1.3]], dtype=...), ...Array([2.4, 4.4, 8.4], dtype=...)])


.. _numpy_functions_blockarray:
Expand Down Expand Up @@ -157,7 +157,7 @@ The recommended way to construct a :class:`.BlockArray` is by using the

While :func:`.snp.blockarray` will accept either :class:`~numpy.ndarray`\ s or
:obj:`~jax.numpy.DeviceArray`\ s as input, :class:`~numpy.ndarray`\ s
will be converted to :obj:`~jax.numpy.DeviceArray`\ s.
will be converted to :obj:`~jax.Array`\ s.


Operating on a BlockArray
Expand All @@ -177,7 +177,7 @@ Multiplication Between BlockArray and :class:`.LinearOperator`

The :class:`.Operator` and :class:`.LinearOperator` classes are designed
to work on instances of :class:`.BlockArray` in addition to instances of
:obj:`~jax.numpy.DeviceArray`. For example
:obj:`~jax.Array`. For example


::
Expand Down
8 changes: 4 additions & 4 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,17 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:
>>> import scico
>>> import scico.numpy as snp
>>> f = lambda x: snp.linalg.norm(x)**2
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) #
DeviceArray([nan, nan], dtype=float32)
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP
Array([nan, nan], dtype=float32)

This can be fixed by defining the squared :math:`\ell_2` norm directly as
``g = lambda x: snp.sum(x**2)``. The gradient will work as expected:

::

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))
DeviceArray([0., 0.], dtype=float32)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP
Array([0., 0.], dtype=float32)

An alternative is to define a `custom derivative rule <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_ to enforce a particular derivative convention at a point.

Expand Down
6 changes: 3 additions & 3 deletions scico/linop/_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.interpreters.xla import DeviceArray
from jax.scipy.signal import convolve

import scico.numpy as snp
Expand Down Expand Up @@ -197,9 +197,9 @@ def __init__(
if x.ndim != len(input_shape):
raise ValueError(f"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.")

if isinstance(x, DeviceArray):
if isinstance(x, jnp.ndarray):
self.x = x
elif isinstance(x, np.ndarray):
elif isinstance(x, np.ndarray): # TODO: this should not be handled at the LinOp level
self.x = jax.device_put(x)
else:
raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(x)}.")
Expand Down
4 changes: 2 additions & 2 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.interpreters.xla import DeviceArray

import scico.numpy as snp
from scico._autograd import linear_adjoint
Expand Down Expand Up @@ -234,7 +234,7 @@ def __rmatmul__(self, other):
if isinstance(other, LinearOperator):
return other(self)

if isinstance(other, (np.ndarray, DeviceArray)):
if isinstance(other, (np.ndarray, jnp.ndarray)):
# for real valued inputs: y @ self == (self.T @ y.T).T
# for complex: y @ self == (self.conj().T @ y.conj().T).conj().T
# self.conj().T == self.adj
Expand Down
14 changes: 7 additions & 7 deletions scico/linop/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.interpreters.xla import DeviceArray

import scico.numpy as snp
from scico.typing import JaxArray
Expand All @@ -40,7 +40,7 @@ def wrapper(a, b):

raise ValueError(f"MatrixOperator shapes {a.shape} and {b.shape} do not match.")

if isinstance(b, (DeviceArray, np.ndarray)):
if isinstance(b, (jnp.ndarray, np.ndarray)):
if a.matrix_shape == b.shape:
return MatrixOperator(op(a.A, b))

Expand Down Expand Up @@ -81,10 +81,10 @@ def __init__(self, A: JaxArray, input_cols: int = 0):
self.A: JaxArray #: Dense array implementing this matrix

# if A is an ndarray, make sure it gets converted to a DeviceArray
if isinstance(A, DeviceArray):
if isinstance(A, jnp.ndarray):
self.A = A
elif isinstance(A, np.ndarray):
self.A = jax.device_put(A)
self.A = jax.device_put(A) # TODO: ensure_on_device?
else:
raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(A)}.")

Expand Down Expand Up @@ -163,7 +163,7 @@ def __mul__(self, other):

raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.")

if isinstance(other, (DeviceArray, np.ndarray)):
if isinstance(other, (jnp.ndarray, np.ndarray)):
if self.matrix_shape == other.shape:
return MatrixOperator(self.A * other)

Expand All @@ -185,7 +185,7 @@ def __truediv__(self, other):
return MatrixOperator(self.A / other.A)
raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.")

if isinstance(other, (DeviceArray, np.ndarray)):
if isinstance(other, (jnp.ndarray, np.ndarray)):
if self.matrix_shape == other.shape:
return MatrixOperator(self.A / other)

Expand All @@ -199,7 +199,7 @@ def __rtruediv__(self, other):
if np.isscalar(other):
return MatrixOperator(other / self.A)

if isinstance(other, (DeviceArray, np.ndarray)):
if isinstance(other, (jnp.ndarray, np.ndarray)):
if self.matrix_shape == other.shape:
return MatrixOperator(other / self.A)

Expand Down
43 changes: 21 additions & 22 deletions scico/numpy/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,35 +88,34 @@ def mapped(*args, **kwargs):
return mapped


def map_func_over_blocks(func, is_reduction=False):
def map_func_over_blocks(func):
"""Wrap a function so that it maps over all of its `BlockArray`
arguments.
is_reduction: function is handled in a special way in order to allow
full reductions of `BlockArray`s. If the axis parameter exists but
is not specified, the function is called on a fully ravelled version
of all `BlockArray` inputs.
"""
sig = signature(func)

@wraps(func)
def mapped(*args, **kwargs):
bound_args = sig.bind(*args, **kwargs)

ba_args = {}
for k, v in list(bound_args.arguments.items()):
if isinstance(v, BlockArray):
ba_args[k] = bound_args.arguments.pop(k)

if not ba_args: # no BlockArray arguments
return func(*args, **kwargs) # no mapping

num_blocks = len(list(ba_args.values())[0])

return BlockArray(
func(*bound_args.args, **bound_args.kwargs, **{k: v[i] for k, v in ba_args.items()})
for i in range(num_blocks)
)
first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None)
if first_ba_arg is None:
first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None)
if first_ba_kwarg is None:
return func(*args, **kwargs) # no BlockArray arguments, so no mapping
num_blocks = len(first_ba_kwarg)
else:
num_blocks = len(first_ba_arg)

# build a list of new args and kwargs, one for each block
new_args_list = []
new_kwargs_list = []
for i in range(num_blocks):
new_args_list.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args])
new_kwargs_list.append(
{k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()}
)

# run the function num_blocks times, return results in a BlockArray
return BlockArray(func(*new_args_list[i], **new_kwargs_list[i]) for i in range(num_blocks))

return mapped

Expand Down
11 changes: 0 additions & 11 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import numpy as np

import jax
from jax.interpreters.pxla import ShardedDeviceArray
from jax.interpreters.xla import DeviceArray

import scico.numpy as snp
from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, JaxArray, Shape
Expand Down Expand Up @@ -52,15 +50,6 @@ def ensure_on_device(
stacklevel=2,
)

elif not isinstance(
array,
(DeviceArray, BlockArray, ShardedDeviceArray),
):
raise TypeError(
"Each element of parameter arrays must be ndarray, DeviceArray, BlockArray, or "
f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}."
)

arrays[i] = jax.device_put(arrays[i])

if len(arrays) == 1:
Expand Down
11 changes: 4 additions & 7 deletions scico/operator/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.interpreters.xla import DeviceArray

import scico.numpy as snp
from scico.numpy import BlockArray
Expand Down Expand Up @@ -207,15 +207,12 @@ def __call__(
)
raise ValueError(f"Incompatible shapes {self.shape}, {x.shape}.")

if isinstance(x, (np.ndarray, DeviceArray, BlockArray)):
if self.input_shape == x.shape:
return self._eval(x)
if self.input_shape != x.shape:
raise ValueError(
f"Cannot evaluate {type(self)} with input_shape={self.input_shape} "
f"on array with shape={x.shape}."
)
# What is the context under which this gets called?
# Currently: in jit and grad tracers

return self._eval(x)

def __add__(self, other: Operator) -> Operator:
Expand Down Expand Up @@ -386,7 +383,7 @@ def concat_args(args):
# concat_args(args) = snp.blockarray([val, args]) if argnum = 0
# concat_args(args) = snp.blockarray([args, val]) if argnum = 1

if isinstance(args, (DeviceArray, np.ndarray)):
if isinstance(args, (jnp.ndarray, np.ndarray)):
# In the case that the original operator takes a blockkarray with two
# blocks, wrap in a list so we can use the same indexing as >2 block case
args = [args]
Expand Down
5 changes: 3 additions & 2 deletions scico/test/linop/test_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.signal as signal

import pytest
Expand Down Expand Up @@ -174,7 +175,7 @@ def test_ndarray_h():

h = np.random.randn(3, 3).astype(np.float32)
A = Convolve(input_shape=(16, 16), h=h)
assert isinstance(A.h, jax.interpreters.xla.DeviceArray)
assert isinstance(A.h, jnp.ndarray)


class TestConvolveByX:
Expand Down Expand Up @@ -339,4 +340,4 @@ def test_ndarray_x():

x = np.random.randn(3, 3).astype(np.float32)
A = ConvolveByX(input_shape=(16, 16), x=x)
assert isinstance(A.x, jax.interpreters.xla.DeviceArray)
assert isinstance(A.x, jnp.ndarray)
4 changes: 2 additions & 2 deletions scico/test/linop/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import jax
from jax.interpreters.xla import DeviceArray
import jax.numpy as jnp

import pytest

Expand Down Expand Up @@ -244,7 +244,7 @@ def test_matmul_identity(self):
def test_init_devicearray(self):
A = np.random.randn(4, 6)
Ao = MatrixOperator(A)
assert isinstance(Ao.A, DeviceArray)
assert isinstance(Ao.A, jnp.ndarray)

with pytest.raises(TypeError):
MatrixOperator([1.0, 3.0])
Expand Down
14 changes: 6 additions & 8 deletions scico/test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from jax.interpreters.xla import DeviceArray
import jax.numpy as jnp

import pytest

Expand Down Expand Up @@ -34,20 +34,18 @@ def test_ensure_on_device():
BA = snp.blockarray([NP, SNP])
NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA)

assert isinstance(NP_, DeviceArray)
assert isinstance(NP_, jnp.ndarray)

assert isinstance(SNP_, DeviceArray)
assert isinstance(SNP_, jnp.ndarray)
assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer()

assert isinstance(BA_, BlockArray)
assert isinstance(BA_[0], DeviceArray)
assert isinstance(BA_[1], DeviceArray)
assert isinstance(BA_[0], jnp.ndarray)
assert isinstance(BA_[1], jnp.ndarray)
assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer()

np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1])

NP_ = ensure_on_device(NP)
assert isinstance(NP_, DeviceArray)
assert isinstance(NP_, jnp.ndarray)


def test_no_nan_divide_array():
Expand Down

0 comments on commit 029a595

Please sign in to comment.