diff --git a/docs/source/include/blockarray.rst b/docs/source/include/blockarray.rst index 8ced54b70..5c93ae305 100644 --- a/docs/source/include/blockarray.rst +++ b/docs/source/include/blockarray.rst @@ -32,8 +32,8 @@ appropriate. For example, ((2, 3), (3,)) >>> x * 2 # returns BlockArray # doctest: +ELLIPSIS - BlockArray([DeviceArray([[ 2, 6, 14], - [ 4, 4, 2]], dtype=...), DeviceArray([ 4, 8, 16], dtype=...)]) + BlockArray([...Array([[ 2, 6, 14], + [ 4, 4, 2]], dtype=...), ...Array([ 4, 8, 16], dtype=...)]) >>> y = snp.blockarray(( ... [[.2], @@ -41,9 +41,9 @@ appropriate. For example, ... [.4] ... )) - >>> x + y # returns BlockArray # doctest: +ELLIPSIS - BlockArray([DeviceArray([[1.2, 3.2, 7.2], - [2.3, 2.3, 1.3]], dtype=...), DeviceArray([2.4, 4.4, 8.4], dtype=...)]) + >>> x + y # returns BlockArray # doctest: +ELLIPSIS + BlockArray([...Array([[1.2, 3.2, 7.2], + [2.3, 2.3, 1.3]], dtype=...), ...Array([2.4, 4.4, 8.4], dtype=...)]) .. _numpy_functions_blockarray: @@ -157,7 +157,7 @@ The recommended way to construct a :class:`.BlockArray` is by using the While :func:`.snp.blockarray` will accept either :class:`~numpy.ndarray`\ s or :obj:`~jax.numpy.DeviceArray`\ s as input, :class:`~numpy.ndarray`\ s -will be converted to :obj:`~jax.numpy.DeviceArray`\ s. +will be converted to :obj:`~jax.Array`\ s. Operating on a BlockArray @@ -177,7 +177,7 @@ Multiplication Between BlockArray and :class:`.LinearOperator` The :class:`.Operator` and :class:`.LinearOperator` classes are designed to work on instances of :class:`.BlockArray` in addition to instances of -:obj:`~jax.numpy.DeviceArray`. For example +:obj:`~jax.Array`. For example :: diff --git a/docs/source/notes.rst b/docs/source/notes.rst index 746071294..4cf02f76b 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -153,8 +153,8 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``: >>> import scico >>> import scico.numpy as snp >>> f = lambda x: snp.linalg.norm(x)**2 - >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # - DeviceArray([nan, nan], dtype=float32) + >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP + Array([nan, nan], dtype=float32) This can be fixed by defining the squared :math:`\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``. The gradient will work as expected: @@ -162,8 +162,8 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as :: >>> g = lambda x: snp.sum(x**2) - >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) - DeviceArray([0., 0.], dtype=float32) + >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP + Array([0., 0.], dtype=float32) An alternative is to define a `custom derivative rule `_ to enforce a particular derivative convention at a point. diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index e4d286b89..fb8782485 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -18,8 +18,8 @@ import numpy as np import jax +import jax.numpy as jnp from jax.dtypes import result_type -from jax.interpreters.xla import DeviceArray from jax.scipy.signal import convolve import scico.numpy as snp @@ -197,9 +197,9 @@ def __init__( if x.ndim != len(input_shape): raise ValueError(f"x.ndim = {x.ndim} must equal len(input_shape) = {len(input_shape)}.") - if isinstance(x, DeviceArray): + if isinstance(x, jnp.ndarray): self.x = x - elif isinstance(x, np.ndarray): + elif isinstance(x, np.ndarray): # TODO: this should not be handled at the LinOp level self.x = jax.device_put(x) else: raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(x)}.") diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index d60162cd8..d144ecbbf 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -18,8 +18,8 @@ import numpy as np import jax +import jax.numpy as jnp from jax.dtypes import result_type -from jax.interpreters.xla import DeviceArray import scico.numpy as snp from scico._autograd import linear_adjoint @@ -234,7 +234,7 @@ def __rmatmul__(self, other): if isinstance(other, LinearOperator): return other(self) - if isinstance(other, (np.ndarray, DeviceArray)): + if isinstance(other, (np.ndarray, jnp.ndarray)): # for real valued inputs: y @ self == (self.T @ y.T).T # for complex: y @ self == (self.conj().T @ y.conj().T).conj().T # self.conj().T == self.adj diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index d4eab1075..755257ba8 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -18,8 +18,8 @@ import numpy as np import jax +import jax.numpy as jnp from jax.dtypes import result_type -from jax.interpreters.xla import DeviceArray import scico.numpy as snp from scico.typing import JaxArray @@ -40,7 +40,7 @@ def wrapper(a, b): raise ValueError(f"MatrixOperator shapes {a.shape} and {b.shape} do not match.") - if isinstance(b, (DeviceArray, np.ndarray)): + if isinstance(b, (jnp.ndarray, np.ndarray)): if a.matrix_shape == b.shape: return MatrixOperator(op(a.A, b)) @@ -81,10 +81,10 @@ def __init__(self, A: JaxArray, input_cols: int = 0): self.A: JaxArray #: Dense array implementing this matrix # if A is an ndarray, make sure it gets converted to a DeviceArray - if isinstance(A, DeviceArray): + if isinstance(A, jnp.ndarray): self.A = A elif isinstance(A, np.ndarray): - self.A = jax.device_put(A) + self.A = jax.device_put(A) # TODO: ensure_on_device? else: raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(A)}.") @@ -163,7 +163,7 @@ def __mul__(self, other): raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") - if isinstance(other, (DeviceArray, np.ndarray)): + if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(self.A * other) @@ -185,7 +185,7 @@ def __truediv__(self, other): return MatrixOperator(self.A / other.A) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") - if isinstance(other, (DeviceArray, np.ndarray)): + if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(self.A / other) @@ -199,7 +199,7 @@ def __rtruediv__(self, other): if np.isscalar(other): return MatrixOperator(other / self.A) - if isinstance(other, (DeviceArray, np.ndarray)): + if isinstance(other, (jnp.ndarray, np.ndarray)): if self.matrix_shape == other.shape: return MatrixOperator(other / self.A) diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py index d2bfb478a..f9be49eae 100644 --- a/scico/numpy/_wrappers.py +++ b/scico/numpy/_wrappers.py @@ -88,35 +88,34 @@ def mapped(*args, **kwargs): return mapped -def map_func_over_blocks(func, is_reduction=False): +def map_func_over_blocks(func): """Wrap a function so that it maps over all of its `BlockArray` arguments. - - is_reduction: function is handled in a special way in order to allow - full reductions of `BlockArray`s. If the axis parameter exists but - is not specified, the function is called on a fully ravelled version - of all `BlockArray` inputs. """ - sig = signature(func) @wraps(func) def mapped(*args, **kwargs): - bound_args = sig.bind(*args, **kwargs) - - ba_args = {} - for k, v in list(bound_args.arguments.items()): - if isinstance(v, BlockArray): - ba_args[k] = bound_args.arguments.pop(k) - - if not ba_args: # no BlockArray arguments - return func(*args, **kwargs) # no mapping - - num_blocks = len(list(ba_args.values())[0]) - return BlockArray( - func(*bound_args.args, **bound_args.kwargs, **{k: v[i] for k, v in ba_args.items()}) - for i in range(num_blocks) - ) + first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None) + if first_ba_arg is None: + first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None) + if first_ba_kwarg is None: + return func(*args, **kwargs) # no BlockArray arguments, so no mapping + num_blocks = len(first_ba_kwarg) + else: + num_blocks = len(first_ba_arg) + + # build a list of new args and kwargs, one for each block + new_args_list = [] + new_kwargs_list = [] + for i in range(num_blocks): + new_args_list.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args]) + new_kwargs_list.append( + {k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()} + ) + + # run the function num_blocks times, return results in a BlockArray + return BlockArray(func(*new_args_list[i], **new_kwargs_list[i]) for i in range(num_blocks)) return mapped diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 06fd11d00..6e9beab85 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -9,8 +9,6 @@ import numpy as np import jax -from jax.interpreters.pxla import ShardedDeviceArray -from jax.interpreters.xla import DeviceArray import scico.numpy as snp from scico.typing import ArrayIndex, Axes, AxisIndex, BlockShape, DType, JaxArray, Shape @@ -52,15 +50,6 @@ def ensure_on_device( stacklevel=2, ) - elif not isinstance( - array, - (DeviceArray, BlockArray, ShardedDeviceArray), - ): - raise TypeError( - "Each element of parameter arrays must be ndarray, DeviceArray, BlockArray, or " - f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." - ) - arrays[i] = jax.device_put(arrays[i]) if len(arrays) == 1: diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index dde02724f..947b967a8 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -17,8 +17,8 @@ import numpy as np import jax +import jax.numpy as jnp from jax.dtypes import result_type -from jax.interpreters.xla import DeviceArray import scico.numpy as snp from scico.numpy import BlockArray @@ -207,15 +207,12 @@ def __call__( ) raise ValueError(f"Incompatible shapes {self.shape}, {x.shape}.") - if isinstance(x, (np.ndarray, DeviceArray, BlockArray)): - if self.input_shape == x.shape: - return self._eval(x) + if self.input_shape != x.shape: raise ValueError( f"Cannot evaluate {type(self)} with input_shape={self.input_shape} " f"on array with shape={x.shape}." ) - # What is the context under which this gets called? - # Currently: in jit and grad tracers + return self._eval(x) def __add__(self, other: Operator) -> Operator: @@ -386,7 +383,7 @@ def concat_args(args): # concat_args(args) = snp.blockarray([val, args]) if argnum = 0 # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 - if isinstance(args, (DeviceArray, np.ndarray)): + if isinstance(args, (jnp.ndarray, np.ndarray)): # In the case that the original operator takes a blockkarray with two # blocks, wrap in a list so we can use the same indexing as >2 block case args = [args] diff --git a/scico/test/linop/test_convolve.py b/scico/test/linop/test_convolve.py index bfd8ca202..1979a4626 100644 --- a/scico/test/linop/test_convolve.py +++ b/scico/test/linop/test_convolve.py @@ -4,6 +4,7 @@ import numpy as np import jax +import jax.numpy as jnp import jax.scipy.signal as signal import pytest @@ -174,7 +175,7 @@ def test_ndarray_h(): h = np.random.randn(3, 3).astype(np.float32) A = Convolve(input_shape=(16, 16), h=h) - assert isinstance(A.h, jax.interpreters.xla.DeviceArray) + assert isinstance(A.h, jnp.ndarray) class TestConvolveByX: @@ -339,4 +340,4 @@ def test_ndarray_x(): x = np.random.randn(3, 3).astype(np.float32) A = ConvolveByX(input_shape=(16, 16), x=x) - assert isinstance(A.x, jax.interpreters.xla.DeviceArray) + assert isinstance(A.x, jnp.ndarray) diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index e8a4a3e1b..3f00b2c7a 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -3,7 +3,7 @@ import numpy as np import jax -from jax.interpreters.xla import DeviceArray +import jax.numpy as jnp import pytest @@ -244,7 +244,7 @@ def test_matmul_identity(self): def test_init_devicearray(self): A = np.random.randn(4, 6) Ao = MatrixOperator(A) - assert isinstance(Ao.A, DeviceArray) + assert isinstance(Ao.A, jnp.ndarray) with pytest.raises(TypeError): MatrixOperator([1.0, 3.0]) diff --git a/scico/test/test_array.py b/scico/test/test_array.py index e37497f8c..884428043 100644 --- a/scico/test/test_array.py +++ b/scico/test/test_array.py @@ -2,7 +2,7 @@ import numpy as np -from jax.interpreters.xla import DeviceArray +import jax.numpy as jnp import pytest @@ -34,20 +34,18 @@ def test_ensure_on_device(): BA = snp.blockarray([NP, SNP]) NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) - assert isinstance(NP_, DeviceArray) + assert isinstance(NP_, jnp.ndarray) - assert isinstance(SNP_, DeviceArray) + assert isinstance(SNP_, jnp.ndarray) assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() assert isinstance(BA_, BlockArray) - assert isinstance(BA_[0], DeviceArray) - assert isinstance(BA_[1], DeviceArray) + assert isinstance(BA_[0], jnp.ndarray) + assert isinstance(BA_[1], jnp.ndarray) assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() - np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1]) - NP_ = ensure_on_device(NP) - assert isinstance(NP_, DeviceArray) + assert isinstance(NP_, jnp.ndarray) def test_no_nan_divide_array():