-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify BlockArray
implementation
#259
Conversation
ef3988b
to
2efb2ae
Compare
Codecov Report
@@ Coverage Diff @@
## main #259 +/- ##
==========================================
+ Coverage 93.86% 94.07% +0.20%
==========================================
Files 51 49 -2
Lines 3701 3241 -460
==========================================
- Hits 3474 3049 -425
+ Misses 227 192 -35
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@@ -85,7 +85,7 @@ def __init__( | |||
def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: | |||
if self.collapsable and self.collapse: | |||
return snp.stack([op @ x for op in self.ops]) | |||
return BlockArray.array([op @ x for op in self.ops]) | |||
return BlockArray([op @ x for op in self.ops]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any opinions on the syntax snp.BlockArray(stuff)
vs snp.blockarray(stuff)
? The second one is little more like NumPy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd lean towards copying NumPy style, but this is worth a discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made both do the same thing, but tended to write snp.blockarray(...)
wherever I could. By comparison, np.ndarray
and np.array
do different things, with np.ndarray
documented as a low-level method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's OK, but perhaps add a note in the docs indicating that the snp.blockarray
form is preferred.
33dd939
to
8cb7b31
Compare
scico/numpy/util.py
Outdated
def is_nested(x: Any) -> bool: | ||
"""Check if input is a list/tuple containing at least one list/tuple. | ||
|
||
Args: | ||
x: Object to be tested. | ||
|
||
Returns: | ||
``True`` if `x` is a list/tuple of list/tuples, otherwise | ||
``False``. | ||
``True`` if `x` is a list/tuple of list/tuples, ``False`` otherwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean list/tuple of lists/tuples
?
scico/test/linop/test_diff.py
Outdated
A = FiniteDifference( | ||
input_shape=input_shape, input_dtype=input_dtype, axes=axes, append=append | ||
) | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checkin, was this pass-cause left here with intention? And if so, why not use if not
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. No intention, probably just quickly hacking the old test. Should be more readable now.
This is a work in progress PR to change the underlying implementation of
BlockArray
from a flattenedDeviceArray
to a list ofDeviceArray
s. The goal is to simplify theBlockArray
implementation, reduce its coupling to jax internals, and provide additional functionality (mixed datatypes, blocks on different GPUs).Closes #179. Also touches #237 #238 #159 #239.
Timing examples (best of 3 runs, total time, scripts edited to remove input, MacBook Pro, CPU):
time python examples/scripts/denoise_tv_iso_pgm.py > /dev/null 2>&1
old: 2.936, new 7.185,
w/ additional
@jit.jit
, old: 2.936, new: 2.923time python examples/scripts/sparsecode_poisson_pgm.py > /dev/null 2>&1
old: 12.964, new 12.663
Timing simple ops
%timeit -n 3 -r 3 x = snp.ones(10000*((2, 2),))
2.03 s ± 26.8 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
62.7 ms ± 178 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
f = jax.jit(lambda: snp.ones(10000*((2, 2),)))
f() # trigger jit
%timeit -n 3 -r 3 z = f()
36.6 ms ± 345 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
33.3 µs ± 30.2 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
%timeit -n 3 -r 3 x = snp.ones(5*((512, 512),))
2.1 ms ± 176 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
1.38 ms ± 125 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
x = snp.ones(512*((512, 512),))
%timeit -n 3 -r 3 y = x @ x
406 ms ± 60.5 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
2.21 s ± 55.7 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
%timeit -n 3 -r 3 z = snp.linalg.norm(x)
567 ms ± 50.5 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
152 ms ± 9.02 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
f = jax.jit(lambda x: snp.linalg.norm(x))
f(x) # trigger jit
%timeit -n 3 -r 3 z = f(x)
269 ms ± 1.6 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
147 ms ± 1.35 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
x = snp.ones(5*((512, 512),))
%timeit -n 3 -r 3 z = snp.linalg.norm(x)
3.86 ms ± 266 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
1.63 ms ± 135 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
1.49 ms ± 203 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
1.65 ms ± 228 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)