Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh: Accept "duck arrays" for tensordot #833

Open
brendan-m-murphy opened this issue Jan 9, 2025 · 8 comments
Open

Enh: Accept "duck arrays" for tensordot #833

brendan-m-murphy opened this issue Jan 9, 2025 · 8 comments
Labels
enhancement Indicates new feature requests needs triage Issue has not been confirmed nor labeled

Comments

@brendan-m-murphy
Copy link

Please describe the purpose of the new feature or describe the problem to solve.

Sparse's tensordot only allows multiplication between sparse arrays and either scipy sparse arrays or numpy ndarrays.

It would be useful if other array-like objects were allowed.

For instance, in xarray, the dot function can only multiply a sparse DataArray and a dask DataArray if the einsum/tensordot function from dask is used: pydata/xarray#9934

Suggest a solution if possible.

The code for multiplying a COO matrix and a np.ndarray in _dot seems like it mostly relies on being able to infer the dtype, index, and create empty ndarrays, so it seems plausible that other array-like objects could be used here.

I haven't tried to implement this though.

If you have tried alternatives, please describe them below.

No response

Additional information that may help us understand your needs.

Please see this issue for further discussion in the context of xarray: pydata/xarray#9934

@brendan-m-murphy brendan-m-murphy added enhancement Indicates new feature requests needs triage Issue has not been confirmed nor labeled labels Jan 9, 2025
@hameerabbasi
Copy link
Collaborator

Unfortunately; it actually uses Numba under the hood (which only accepts NumPy arrays) to do the actual computation, doing this with Dask (or even NumPy without Numba) would be excruciatingly slow.

However, we do support NumPy arrays inside the @ operator -- on either side. It seems to me that XArray is having trouble detecting this.

>>> import numpy as np; import sparse
>>> sp_arr = sparse.zeros((5, 5), dtype=sparse.float32)
>>> np_arr = np.zeros((5, 5), dtype=np.float32)
>>> sp_arr @ np_arr
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
>>> np_arr @ sp_arr
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

I'd be happy to hear an alternative solution that would help XArray/this use-case but supporting duck arrays in @ is beyond the scope of this library.

@brendan-m-murphy
Copy link
Author

Thanks, I suspected it might not be so simple.

I think Dask is using sp_arr @ np_arr on each chunk, and I believe @ works properly if the underlying array is a np.ndarray rather than a Dask array.

The linked xarray issue has some workarounds to get Dask to initiate the multiplication, rather than sparse, so I will go back and ask if any of those could be incorporated into xarray.

@brendan-m-murphy brendan-m-murphy closed this as not planned Won't fix, can't repro, duplicate, stale Jan 15, 2025
@keewis
Copy link

keewis commented Jan 15, 2025

if the variant that works is implemented by dask, wouldn't it be fine to have sparse return NotImplemented and have it use dask.array.Array.__rmatmul__ instead?

@hameerabbasi
Copy link
Collaborator

That we can do -- can you open an issue for that?

@keewis
Copy link

keewis commented Jan 15, 2025

sorry, looks like that already happens (I think)? See

def __matmul__(self, other):
from .._common import matmul
try:
return matmul(self, other)
except NotImplementedError:
return NotImplemented
def __rmatmul__(self, other):
from .._common import matmul
try:
return matmul(other, self)
except NotImplementedError:
return NotImplemented

I'll have to look into why that is not triggered.

Edit: maybe because the error is TypeError and not NotImplementedError?
Edit: or maybe because we're calling tensordot and not matmul?

@brendan-m-murphy
Copy link
Author

brendan-m-murphy commented Jan 15, 2025

sparse @ dask (in xarray) throws a TypeError from sparse's _dot, I believe. If you're using xarray with opt-einsum, then it selects sparse.tensordot as the method to use for sparse @ dask (i.e. xr.dot(sparse, dask)).

@hameerabbasi
Copy link
Collaborator

hameerabbasi commented Jan 15, 2025

Right -- I just pushed a fix/release for __matmul__ and __rmatmul__, but functions should indeed return TypeError, which is correct.

@hameerabbasi hameerabbasi reopened this Jan 15, 2025
@keewis
Copy link

keewis commented Jan 15, 2025

yeah, forwarding to __rmatmul__ only works if we use sparse @ dask, whereas xr.dot uses something else. So we'll have to figure out how sparse can reject unknown array types while still allowing other array types to implement the operation.

Tensordot can be dispatched through various numpy protocols, where I believe for __array_function__ this is possible by returning NotImplemented, but __array_namespace__ explicitly does not want to allow the interaction of multiple array types. This may mean that the calling code (in this context probably xarray) has to cast the data to the appropriate type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Indicates new feature requests needs triage Issue has not been confirmed nor labeled
Projects
None yet
Development

No branches or pull requests

3 participants