Skip to content

Commit

Permalink
Add more tests for jit decorator, and fix some issues
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescAlted committed Feb 5, 2025
1 parent 485c462 commit 77cc75e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 9 deletions.
9 changes: 5 additions & 4 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def extract_numpy_scalars(expr: str):
return transformed_expr, transformer.replacements


def validate_inputs(inputs: dict, out=None) -> tuple: # noqa: C901
def validate_inputs(inputs: dict, out=None, reduce=False) -> tuple: # noqa: C901
"""Validate the inputs for the expression."""
if len(inputs) == 0:
if out is None:
Expand Down Expand Up @@ -624,7 +624,7 @@ def validate_inputs(inputs: dict, out=None) -> tuple: # noqa: C901
fast_path = True
first_input = NDinputs[0]
# Check the out NDArray (if present) first
if isinstance(out, blosc2.NDArray):
if isinstance(out, blosc2.NDArray) and not reduce:
if first_input.shape != out.shape:
raise ValueError("Output shape does not match the first input shape")
if first_input.chunks != out.chunks:
Expand Down Expand Up @@ -1590,14 +1590,15 @@ def chunked_eval( # noqa: C901
if where:
# Make the where arguments part of the operands
operands = {**operands, **where}
_, _, _, fast_path = validate_inputs(operands, out)

reduce_args = kwargs.pop("_reduce_args", {})
_, _, _, fast_path = validate_inputs(operands, out, reduce=reduce_args != {})

# Activate last read cache for NDField instances
for op in operands:
if isinstance(operands[op], blosc2.NDField):
operands[op].ndarr.keep_last_read = True

reduce_args = kwargs.pop("_reduce_args", {})
if reduce_args:
# Eval and reduce the expression in a single step
return reduce_slices(expression, operands, reduce_args=reduce_args, _slice=item, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions src/blosc2/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from abc import ABC, abstractmethod

import numpy as np
from traitlets import Callable

import blosc2

Expand Down Expand Up @@ -581,7 +580,7 @@ def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
return self.src[item]


def jit(func : Callable, out=None, **kwargs): # noqa: C901
def jit(func=None, *, out=None, **kwargs): # noqa: C901
"""
Prepare a function so that it can be used with the Blosc2 compute engine.
Expand Down Expand Up @@ -610,9 +609,10 @@ def jit(func : Callable, out=None, **kwargs): # noqa: C901
-----
* Although many NumPy functions are supported, some may not be implemented yet.
If you find a function that is not supported, please open an issue.
* `kwargs` parameters are not supported for all expressions (e.g. when using a
reduction as the last function). In this case, you can still use the `out`
parameter of the reduction function for some custom control over the output.
* `out` and `kwargs` parameters are not supported for all expressions
(e.g. when using a reduction as the last function). In this case, you can
still use the `out` parameter of the reduction function for some custom
control over the output.
Examples
--------
Expand Down
139 changes: 139 additions & 0 deletions tests/ndarray/test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#######################################################################
# Copyright (c) 2019-present, Blosc Development Team <[email protected]>
# All rights reserved.
#
# This source code is licensed under a BSD-style license (found in the
# LICENSE file in the root directory of this source tree)
#######################################################################

import pytest

import blosc2

import numpy as np

###### General expressions

def expr_nojit(a, b, c):
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)

@blosc2.jit
def expr_jit(a, b, c):
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)

# Define the parameters
test_params = [
((10, 100), (10, 100,), "float32"),
((10, 100), (100,), "float64"), # using broadcasting
]

@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_expr(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)

d_jit = expr_jit(a, b, c)
d_nojit = expr_nojit(a, b, c)

np.testing.assert_equal(d_jit[...], d_nojit[...])


@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_expr_out(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
d_nojit = expr_nojit(a, b, c)

# Testing jit decorator with an out param
out = blosc2.zeros(shape, dtype=np.bool_)

@blosc2.jit(out=out)
def expr_jit_out(a, b, c):
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)

d_jit = expr_jit_out(a, b, c)
np.testing.assert_equal(d_jit[...], d_nojit[...])
np.testing.assert_equal(out[...], d_nojit[...])

@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_expr_kwargs(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
d_nojit = expr_nojit(a, b, c)

# Testing jit decorator with kwargs
cparams = blosc2.CParams(clevel=1, codec=blosc2.Codec.LZ4, filters=[blosc2.Filter.BITSHUFFLE])

@blosc2.jit(**{"cparams": cparams})
def expr_jit_cparams(a, b, c):
return ((a ** 3 + np.sin(a * 2)) < c) & (b > 0)

d_jit = expr_jit_cparams(a, b, c)
np.testing.assert_equal(d_jit[...], d_nojit[...])
assert d_jit.schunk.cparams.clevel == 1
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5


###### Reductions

def reduc_nojit(a, b, c):
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1)

@blosc2.jit
def reduc_jit(a, b, c):
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1)

@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_reduc(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)

d_jit = reduc_jit(a, b, c)
d_nojit = reduc_nojit(a, b, c)

np.testing.assert_equal(d_jit[...], d_nojit[...])

@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_reduc_out(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
d_nojit = reduc_nojit(a, b, c)

# Testing jit decorator with an out param via the reduction function
out = np.zeros((shape[0],), dtype=np.int64)

# Note that out does not work with reductions as the last function call
@blosc2.jit
def reduc_jit_out(a, b, c):
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)

d_jit = reduc_jit_out(a, b, c)
np.testing.assert_equal(d_jit[...], d_nojit[...])
np.testing.assert_equal(out[...], d_nojit[...])

@pytest.mark.parametrize("shape, cshape, dtype", test_params)
def test_reduc_kwargs(shape, cshape, dtype):
a = blosc2.linspace(0, 1, shape[0] * shape[1], dtype=dtype, shape=shape)
b = blosc2.linspace(1, 2, shape[0] * shape[1], dtype=dtype, shape=shape)
c = blosc2.linspace(-10, 10, cshape[0], dtype=dtype, shape=cshape)
d_nojit = reduc_nojit(a, b, c)

# Testing jit decorator with kwargs via an out param in the reduction function
cparams = blosc2.CParams(clevel=1, codec=blosc2.Codec.LZ4, filters=[blosc2.Filter.BITSHUFFLE])
out = blosc2.zeros((shape[0],), dtype=np.int64, cparams=cparams)

@blosc2.jit
def reduc_jit_cparams(a, b, c):
return np.sum(((a ** 3 + np.sin(a * 2)) < c) & (b > 0), axis=1, out=out)

d_jit = reduc_jit_cparams(a, b, c)
np.testing.assert_equal(d_jit[...], d_nojit[...])
assert d_jit.schunk.cparams.clevel == 1
assert d_jit.schunk.cparams.codec == blosc2.Codec.LZ4
assert d_jit.schunk.cparams.filters == [blosc2.Filter.BITSHUFFLE] + [blosc2.Filter.NOFILTER] * 5

0 comments on commit 77cc75e

Please sign in to comment.