WIP ENH: setdiff1d
for Dask and jax.jit
#124
Draft
+151
−158
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #116
Needs more thorough unit tests + performance benchmarks.
This function's output is of unknown shape, so with the previous API it will never work in jax.jit.
There are a few options:
jax.jit
and you need to hack your way around it with WIP lazy_apply #86 (comment).I'm not a fan of this because UX is very painful as it forces the user to think in graphs.
fill_value
. iff running inside the jax.jit, quietly return a longer array padded with it.I'm not happy about this because it causes jax.jit to quietly diverge from other backends and users will spend a lot of time debugging.
size
andfill_value
.size
becomes mandatory when running inside jax.jit. This is the same design asjax.numpy.unique_values
.This also allows having a known-shape output in Dask. However, implementing it for Dask is fairly complicated.
size
andfill_value
.size
is mandatory when running inside jax.jit and disregarded otherwise. Again, this will cause bugs in the user code that only appear in jax.jit, but at least it demands an initial explicit user intervention. This is the simplest to implement; unsure on the UX. It also has the advantage of not sacrificing performance on other backends. If in the future jax.jit will support arrays of unknown size, it becomes easy to deprecate it as we said that the output size requested by the user may be disregarded anyway.My current favourite is (4).
@rgommers you previously said, talking about functions with the same problem in scipy, that you prefer (1) to (3) because of not being able to retract the API in the future. What's your opinion on (4)?
CC @lucascolley