Skip to content

Commit

Permalink
Continue removing scico.array, allow snp.blockarray syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Apr 20, 2022
1 parent 8326630 commit c5e595c
Show file tree
Hide file tree
Showing 22 changed files with 101 additions and 102 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/Usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down
6 changes: 3 additions & 3 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator:
def concat_args(args):
# Creates a blockarray with args and the frozen value in the correct place
# Eg if this operator takes a blockarray with two blocks, then
# concat_args(args) = BlockArray.array([val, args]) if argnum = 0
# concat_args(args) = BlockArray.array([args, val]) if argnum = 1
# concat_args(args) = snp.blockarray([val, args]) if argnum = 0
# concat_args(args) = snp.blockarray([args, val]) if argnum = 1

if isinstance(args, (DeviceArray, np.ndarray)):
# In the case that the original operator takes a blcokarray with two
Expand All @@ -336,7 +336,7 @@ def concat_args(args):
arg_list.append(args[i - 1])
else:
arg_list.append(val)
return BlockArray.array(arg_list)
return snp.blockarray(arg_list)

return Operator(
input_shape=input_shape,
Expand Down
2 changes: 1 addition & 1 deletion scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
"""
if len(v.shape) == len(self.functional_list):
return BlockArray.array([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)])
return snp.blockarray([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)])
raise ValueError(
f"Number of blocks in v, {len(v.shape)}, and length of functional_list, "
f"{len(self.functional_list)}, do not match"
Expand Down
3 changes: 2 additions & 1 deletion scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from jax import jit, lax

from scico import numpy as snp
from scico.numpy import BlockArray, count_nonzero, no_nan_divide
from scico.numpy import BlockArray, count_nonzero
from scico.numpy.linalg import norm
from scico.numpy.util import no_nan_divide
from scico.typing import JaxArray

from ._functional import Functional
Expand Down
4 changes: 2 additions & 2 deletions scico/linop/_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax.scipy.signal import convolve

import scico.numpy as snp
from scico import array
from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar
from scico.numpy.util import ensure_on_device
from scico.typing import DType, JaxArray, Shape


Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(

if h.ndim != len(input_shape):
raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}")
self.h = array.ensure_on_device(h)
self.h = ensure_on_device(h)

if mode not in ["full", "valid", "same"]:
raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'")
Expand Down
7 changes: 3 additions & 4 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
from typing import Any, Callable, Optional, Union

import scico.numpy as snp
from scico import array
from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar
from scico.numpy import BlockArray
from scico.numpy.util import is_nested
from scico.numpy.util import ensure_on_device, indexed_shape, is_nested
from scico.random import randn
from scico.typing import ArrayIndex, BlockShape, DType, JaxArray, PRNGKey, Shape

Expand Down Expand Up @@ -182,7 +181,7 @@ def __init__(
"""

self.diagonal = array.ensure_on_device(diagonal)
self.diagonal = ensure_on_device(diagonal)

if input_shape is None:
input_shape = self.diagonal.shape
Expand Down Expand Up @@ -286,7 +285,7 @@ def __init__(
if is_nested(input_shape):
output_shape = input_shape[idx]
else:
output_shape = array.indexed_shape(input_shape, idx)
output_shape = indexed_shape(input_shape, idx)

self.idx: ArrayIndex = idx
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion scico/linop/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

import scico.numpy as snp
from scico.linop import Diagonal, Identity, LinearOperator
from scico.numpy import no_nan_divide
from scico.numpy.util import no_nan_divide
from scico.typing import Shape

from ._dft import DFT
Expand Down
4 changes: 2 additions & 2 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import scico.numpy as snp
from scico import functional, linop, operator
from scico.numpy import BlockArray, no_nan_divide
from scico.numpy.util import ensure_on_device
from scico.numpy import BlockArray
from scico.numpy.util import ensure_on_device, no_nan_divide
from scico.scipy.special import gammaln
from scico.solver import cg
from scico.typing import JaxArray
Expand Down
8 changes: 4 additions & 4 deletions scico/numpy/blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
>>> x_v, _ = scico.random.randn((n-1, m), key=key)
# Form the blockarray
>>> x_B = BlockArray.array([x_h, x_v])
>>> x_B = snp.blockarray([x_h, x_v])
# The blockarray shape is a tuple of tuples
>>> x_B.shape
Expand All @@ -146,15 +146,15 @@
>>> import numpy as np
>>> x0, key = scico.random.randn((32, 32))
>>> x1, _ = scico.random.randn((16,), key=key)
>>> X = BlockArray.array((x0, x1))
>>> X = snp.blockarray((x0, x1))
>>> X.shape
((32, 32), (16,))
>>> X.size
(1024, 16)
>>> len(X)
2
While :func:`.BlockArray.array` will accept either `ndarray` or
While :func:`.snp.blockarray` will accept either `ndarray` or
`DeviceArray` as input, the resulting :class:`.BlockArray` will be backed
by a `DeviceArray` memory buffer.
Expand Down Expand Up @@ -190,7 +190,7 @@
>>> A_2.shape # array -> BlockArray
(((2, 4), (3, 3)), (3, 4))
>>> diag = BlockArray.array([np.array(1.0), np.array(2.0)])
>>> diag = snp.blockarray([np.array(1.0), np.array(2.0)])
>>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape))
>>> A_3.shape # BlockArray -> BlockArray
(((2, 4), (3, 3)), ((2, 4), (3, 3)))
Expand Down
2 changes: 1 addition & 1 deletion scico/operator/biconvolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BiConvolve(Operator):
blocks of equal ndims, and convolves the first block with the second.
If `A` is a BiConvolve operator, then
`A(BlockArray.array([x, h]))` equals `jax.scipy.signal.convolve(x, h)`.
`A(snp.blockarray([x, h]))` equals `jax.scipy.signal.convolve(x, h)`.
"""

Expand Down
4 changes: 2 additions & 2 deletions scico/optimize/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from scico.functional import Functional
from scico.linop import CircularConvolve, Identity, LinearOperator
from scico.loss import SquaredL2Loss
from scico.numpy import BlockArray, is_real_dtype
from scico.numpy import BlockArray
from scico.numpy.linalg import norm
from scico.numpy.util import ensure_on_device
from scico.numpy.util import ensure_on_device, is_real_dtype
from scico.solver import cg as scico_cg
from scico.solver import minimize
from scico.typing import JaxArray
Expand Down
4 changes: 2 additions & 2 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArr
BlockArray.
"""
if isinstance(x, BlockArray):
return BlockArray.array([_split_real_imag(_) for _ in x])
return snp.blockarray([_split_real_imag(_) for _ in x])
return snp.stack((snp.real(x), snp.imag(x)))


Expand All @@ -163,7 +163,7 @@ def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra
and `x[1]` respectively.
"""
if isinstance(x, BlockArray):
return BlockArray.array([_join_real_imag(_) for _ in x])
return snp.blockarray([_join_real_imag(_) for _ in x])
return x[0] + 1j * x[1]


Expand Down
7 changes: 3 additions & 4 deletions scico/test/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import scico.numpy as snp
from scico import functional
from scico.numpy import BlockArray
from scico.random import randn

NO_BLOCK_ARRAY = [functional.L21Norm, functional.NuclearNorm]
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(self, dtype):

self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval
self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval
self.vb = BlockArray.array([self.v1, self.v2])
self.vb = snp.blockarray([self.v1, self.v2])


@pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128])
Expand All @@ -68,7 +67,7 @@ def test_separable_prox(test_separable_obj):
fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha)
gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha)
fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha)
out = BlockArray.array((fv1, gv2))
out = snp.blockarray((fv1, gv2))
snp.testing.assert_allclose(out, fgv, rtol=5e-2)


Expand All @@ -86,7 +85,7 @@ def test_separable_grad(test_separable_obj):
fv1 = test_separable_obj.f.grad(test_separable_obj.v1)
gv2 = test_separable_obj.g.grad(test_separable_obj.v2)
fgv = test_separable_obj.fg.grad(test_separable_obj.vb)
out = BlockArray.array((fv1, gv2))
out = snp.blockarray((fv1, gv2))
snp.testing.assert_allclose(out, fgv, rtol=5e-2)


Expand Down
2 changes: 1 addition & 1 deletion scico/test/functional/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prox import prox_test
import scico.numpy as snp
from scico import functional, linop, loss
from scico.numpy import complex_dtype
from scico.numpy.util import complex_dtype
from scico.random import randn, uniform


Expand Down
8 changes: 4 additions & 4 deletions scico/test/functional/test_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from scico import functional
from scico.numpy import BlockArray
from scico.numpy import blockarray
from scico.numpy.testing import assert_allclose
from scico.random import randn

Expand All @@ -27,7 +27,7 @@ def __init__(self, dtype):

self.v1, key = randn((n,), key=key, dtype=dtype) # point for prox eval
self.v2, key = randn((m,), key=key, dtype=dtype) # point for prox eval
self.vb = BlockArray.array([self.v1, self.v2])
self.vb = blockarray([self.v1, self.v2])


@pytest.fixture(params=[np.float32, np.complex64, np.float64, np.complex128])
Expand All @@ -47,7 +47,7 @@ def test_separable_prox(test_separable_obj):
fv1 = test_separable_obj.f.prox(test_separable_obj.v1, alpha)
gv2 = test_separable_obj.g.prox(test_separable_obj.v2, alpha)
fgv = test_separable_obj.fg.prox(test_separable_obj.vb, alpha)
out = BlockArray.array((fv1, gv2)).ravel()
out = blockarray((fv1, gv2)).ravel()
assert_allclose(out, fgv.ravel(), rtol=5e-2)


Expand All @@ -65,5 +65,5 @@ def test_separable_grad(test_separable_obj):
fv1 = test_separable_obj.f.grad(test_separable_obj.v1)
gv2 = test_separable_obj.g.grad(test_separable_obj.v2)
fgv = test_separable_obj.fg.grad(test_separable_obj.vb)
out = BlockArray.array((fv1, gv2)).ravel()
out = blockarray((fv1, gv2)).ravel()
assert_allclose(out, fgv.ravel(), rtol=5e-2)
2 changes: 1 addition & 1 deletion scico/test/optimize/test_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def callback(obj):
class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
self.y = snp.blockarray(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
Expand Down
2 changes: 1 addition & 1 deletion scico/test/optimize/test_pdhg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def callback(obj):
class TestBlockArray:
def setup_method(self, method):
np.random.seed(12345)
self.y = BlockArray.array(
self.y = snp.blockarray(
(
np.random.randn(32, 33).astype(np.float32),
np.random.randn(
Expand Down
11 changes: 7 additions & 4 deletions scico/test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
import pytest

import scico.numpy as snp
from scico.numpy import (
BlockArray,
from scico.numpy import BlockArray
from scico.numpy.util import (
complex_dtype,
ensure_on_device,
indexed_shape,
is_complex_dtype,
is_nested,
is_real_dtype,
no_nan_divide,
parse_axes,
real_dtype,
slice_length,
)
from scico.numpy.util import ensure_on_device, indexed_shape, parse_axes, slice_length
from scico.random import randn


Expand All @@ -28,7 +31,7 @@ def test_ensure_on_device():

NP = np.ones(2)
SNP = snp.ones(2)
BA = BlockArray.array([NP, SNP])
BA = snp.blockarray([NP, SNP])
NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA)

assert isinstance(NP_, DeviceArray)
Expand Down
4 changes: 2 additions & 2 deletions scico/test/test_biconvolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from scico.linop import Convolve, ConvolveByX
from scico.numpy import BlockArray
from scico.numpy import blockarray
from scico.operator.biconvolve import BiConvolve
from scico.random import randn

Expand All @@ -22,7 +22,7 @@ def test_eval(self, input_dtype, mode, jit):
x, key = randn((32, 32), dtype=input_dtype, key=self.key)
h, key = randn((4, 4), dtype=input_dtype, key=self.key)

x_h = BlockArray.array([x, h])
x_h = blockarray([x, h])

A = BiConvolve(input_shape=x_h.shape, mode=mode, jit=jit)
signal_out = signal.convolve(x, h, mode=mode)
Expand Down
Loading

0 comments on commit c5e595c

Please sign in to comment.