-
Notifications
You must be signed in to change notification settings - Fork 912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing numerical vjp's in the gradient computation #529
base: master
Are you sure you want to change the base?
Conversation
I can't speak from the library's maintainer's POV, but from a user of autograd, this is of interest to me. Importantly, thats it's a native part of autograd. We recently did something similar specifically for an important and common set of functions: https://github.com/CamDavidsonPilon/autograd-gamma - however, this library currently satisfies all the missing pieces we need. |
Autograd is not actively developed anymore in favor of https://github.com/google/jax, so can this awesome feature be submitted there as a PR? If not, I think this is worth a package of its own. |
Thanks for the interest! If I find some time, I could look into inserting this in the JAX code, but I'm really not sure if the developers will be interested in adding this as native functionality. It's a bit hacky, and I guess not so hard to write yourself as an external addition. To that point, note that I have a fairly general external function that can do the same thing in one of my packages: https://github.com/fancompute/legume/blob/d663dd5d4b0fd864537f350db521f464ed7cfc32/legume/utils.py#L103-L137 You can get the vjp-s with respect to a list of arguments at once, and you can specify a different step size for every argument. The only hiccup is that this again only works for real-valued functions of real arguments. |
Hi there, just weighing in as the (unofficial) main maintainer. I'm interested to know which operations you find this useful for, i.e. which functions would you like to differentiate which aren't currently supported by Autograd/JAX? Our default response as authors of an autodiff tool would be to try to implement proper autodiff-style support for your functions, rather than implementing a numerical workaround. That said, I guess there might be situations where that's impractical for some reason. There is a very simple |
Reading through the comments above a bit more closely I can see that np.linalg.lstsq is one example that isn't implemented, and also the incomplete gamma functions, w.r.t. some arguments. Re: lstsq, I guess you can work around that using pinv, whose derivative is implemented, although that's going to have sub-optimal performance. My understanding is that various people are working to get that into JAX (see e.g. jax-ml/jax#2200), and once it's there it should be fairly straightforward to back-port the derivatives to Autograd. Alternatively someone could implement them from scratch, or see if there's an existing implementation in TF or another autodiff package that could be copied. I don't know much about incomplete gamma functions but I'm guessing there isn't a well known analytical formula for them. I think we only implemented the 'easy' ones. @CamDavidsonPilon if you get round to finishing off accurate derivative functions I'm sure they would be a welcome addition to JAX as well as Autograd. |
Thanks for looking into this! I agree that with many functions there usually is a way to include them properly in the autodiff (and if there isn't, then probably the derivative is not well-defined for some reason). The idea of this PR is rather to be able to incorporate an unsupported function without any extra effort. This could be useful especially if that function is not the bottleneck of the computation, such that the added complexity of doing it numerically is negligible. But yeah, in my research I always ended up writing the proper backward pass for unsupported functions. It's clearly the better way, the question is whether it's good to have this workaround easily available or not. |
86820fd
to
2f6cc22
Compare
This is a preliminary PR as I would like to hear your feedback on the following idea and its implementation. Sometimes, it might be useful to incorporate a numerically computed derivative in the automatic differentiation process. For example, assume we have the following flow
where
f1(x)
andf2(x)
can be handled byautograd
butg(y)
has not been implemented. Ifg(y)
is fast to compute and/or if the size ofy
is small, we could compute the associated vjp by numerically computing the jacobian ofg(y)
. Specifically, this requires2 * y.size
evaluations ofg(y)
if the centered finite-difference is computed, or justy.size
evaluations for forward difference. This is obviously going to be prohibitively expensive in some cases, but in others it might come in handy.Particular use cases that come to mind include:
g(y)
is one of the numpy/scipy functions which has not yet been implemented in autograd.g(y)
is some user-defined function for which an analyticvjp
is unknown or too tedious to implement.g(y)
is not even différentiable (e.g. discontinuous), but a numerical derivative could still be of some use (this is a bit more speculative).I have implemented a function
vjp_numeric
inautograd.core
that allows the user to do what I describe above. Here is a simple example along the lines of the autograd tutorial:Output:
And here is a more useful example: the gradient of
numpy.linalg.lstsq
is currently not implemented. This can be circumvented using the new functionality.Output:
There is some more work that can be put into this (most notably, currently it only works for real numbers, not complex). Before that, however, I have two questions: