Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP ENH: setdiff1d for Dask and jax.jit #124

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Jan 24, 2025

Closes #116
Needs more thorough unit tests + performance benchmarks.

This function's output is of unknown shape, so with the previous API it will never work in jax.jit.

There are a few options:

  1. Stipulate that the function will never work inside jax.jit and you need to hack your way around it with WIP lazy_apply #86 (comment).
    I'm not a fan of this because UX is very painful as it forces the user to think in graphs.
  2. Add optional parameter fill_value. iff running inside the jax.jit, quietly return a longer array padded with it.
    I'm not happy about this because it causes jax.jit to quietly diverge from other backends and users will spend a lot of time debugging.
  3. Add optional parameters size and fill_value. size becomes mandatory when running inside jax.jit. This is the same design as jax.numpy.unique_values.
    This also allows having a known-shape output in Dask. However, implementing it for Dask is fairly complicated.
  4. Add optional parameters size and fill_value. size is mandatory when running inside jax.jit and disregarded otherwise. Again, this will cause bugs in the user code that only appear in jax.jit, but at least it demands an initial explicit user intervention. This is the simplest to implement; unsure on the UX. It also has the advantage of not sacrificing performance on other backends. If in the future jax.jit will support arrays of unknown size, it becomes easy to deprecate it as we said that the output size requested by the user may be disregarded anyway.

My current favourite is (4).
@rgommers you previously said, talking about functions with the same problem in scipy, that you prefer (1) to (3) because of not being able to retract the API in the future. What's your opinion on (4)?

CC @lucascolley

@crusaderky crusaderky force-pushed the in1d branch 2 times, most recently from 2900169 to 0bc3adf Compare January 24, 2025 20:38
@crusaderky crusaderky marked this pull request as draft January 24, 2025 22:02
@lucascolley lucascolley added enhancement New feature or request lazy arrays labels Jan 26, 2025
@crusaderky crusaderky force-pushed the in1d branch 2 times, most recently from a952ede to 028441c Compare January 26, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request lazy arrays
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: setdiff1d for jax.jit and dask NaN-shaped arrays
2 participants