Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing numerical vjp's in the gradient computation #529

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

momchilmm
Copy link
Contributor

@momchilmm momchilmm commented Jul 31, 2019

This is a preliminary PR as I would like to hear your feedback on the following idea and its implementation. Sometimes, it might be useful to incorporate a numerically computed derivative in the automatic differentiation process. For example, assume we have the following flow

def fun(x):
   y = f1(x)
   z = g(y)
   return f2(z)

where f1(x) and f2(x) can be handled by autograd but g(y) has not been implemented. If g(y) is fast to compute and/or if the size of y is small, we could compute the associated vjp by numerically computing the jacobian of g(y). Specifically, this requires 2 * y.size evaluations of g(y) if the centered finite-difference is computed, or just y.size evaluations for forward difference. This is obviously going to be prohibitively expensive in some cases, but in others it might come in handy.

Particular use cases that come to mind include:

  • When g(y) is one of the numpy/scipy functions which has not yet been implemented in autograd.
  • When g(y) is some user-defined function for which an analytic vjp is unknown or too tedious to implement.
  • Potentially, when g(y) is not even différentiable (e.g. discontinuous), but a numerical derivative could still be of some use (this is a bit more speculative).

I have implemented a function vjp_numeric in autograd.core that allows the user to do what I describe above. Here is a simple example along the lines of the autograd tutorial:

import autograd.numpy as np
from autograd import grad
from autograd.extend import primitive, defvjp, vjp_numeric

@primitive
def logsumexp(x):
    """Numerically stable log(sum(exp(x)))"""
    max_x = np.max(x)
    return max_x + np.log(np.sum(np.exp(x - max_x)))

def logsumexp_vjp(ans, x):
    x_shape = x.shape
    return lambda g: np.full(x_shape, g) * np.exp(x - np.full(x_shape, ans))

# Random input
x = np.random.rand(4)

# First compute analytic gradient
defvjp(logsumexp, logsumexp_vjp)
grad_analytic = grad(logsumexp)(x)

# Then compute numeric gradient
logsumexp_vjp_num = vjp_numeric(logsumexp)
defvjp(logsumexp, logsumexp_vjp_num)
grad_numeric = grad(logsumexp)(x)

print('Analytic gradient :', grad_analytic)
print('Numeric gradient  :', grad_numeric)

Output:

Analytic gradient : [0.2806353  0.3378267  0.14405427 0.23748373]
Numeric gradient  : [0.2806353  0.3378267  0.14405427 0.23748373]

And here is a more useful example: the gradient of numpy.linalg.lstsq is currently not implemented. This can be circumvented using the new functionality.

import numpy as np
import autograd.numpy as npa
from autograd import grad
from autograd.extend import primitive, defvjp, vjp_numeric

@primitive
def lstsq(*args, **kwargs):
    return np.linalg.lstsq(*args, **kwargs)[0]

vjp_lstsq_0 = vjp_numeric(lstsq, 0)
vjp_lstsq_1 = vjp_numeric(lstsq, 1)

defvjp(lstsq, vjp_lstsq_0, vjp_lstsq_1)

def linreg_slope(x, y):
    A = npa.vstack([x, npa.ones(len(x))]).T
    m, c = lstsq(A, y, rcond=None)
    return m

x = np.array([0, 1, 2, 3])
y = np.array([-1, 0.2, 0.7, 2.5])

print('Slope of linear fit of y vs. x :', linreg_slope(x, y))
print('Gradient of slope w.r.t. x     :', grad(linreg_slope, 0)(x, y))
print('Gradient of slope w.r.t. y     :', grad(linreg_slope, 1)(x, y))

Output:

Slope of linear fit of y vs. x : 1.0999999999999994
Gradient of slope w.r.t. x     : [ 0.34  0.14 -0.2  -0.28]
Gradient of slope w.r.t. y     : [-0.3 -0.1  0.1  0.3]

There is some more work that can be put into this (most notably, currently it only works for real numbers, not complex). Before that, however, I have two questions:

  • Is this of interest?
  • Do you have remarks on the approach I'm taking in implementing it?

@CamDavidsonPilon
Copy link
Contributor

CamDavidsonPilon commented Jul 31, 2019

I can't speak from the library's maintainer's POV, but from a user of autograd, this is of interest to me. Importantly, thats it's a native part of autograd.

We recently did something similar specifically for an important and common set of functions: https://github.com/CamDavidsonPilon/autograd-gamma - however, this library currently satisfies all the missing pieces we need.

@schneiderfelipe
Copy link

Autograd is not actively developed anymore in favor of https://github.com/google/jax, so can this awesome feature be submitted there as a PR? If not, I think this is worth a package of its own.

@momchilmm
Copy link
Contributor Author

Thanks for the interest! If I find some time, I could look into inserting this in the JAX code, but I'm really not sure if the developers will be interested in adding this as native functionality. It's a bit hacky, and I guess not so hard to write yourself as an external addition.

To that point, note that I have a fairly general external function that can do the same thing in one of my packages: https://github.com/fancompute/legume/blob/d663dd5d4b0fd864537f350db521f464ed7cfc32/legume/utils.py#L103-L137

You can get the vjp-s with respect to a list of arguments at once, and you can specify a different step size for every argument. The only hiccup is that this again only works for real-valued functions of real arguments.

@j-towns
Copy link
Collaborator

j-towns commented Apr 15, 2020

Hi there, just weighing in as the (unofficial) main maintainer. I'm interested to know which operations you find this useful for, i.e. which functions would you like to differentiate which aren't currently supported by Autograd/JAX? Our default response as authors of an autodiff tool would be to try to implement proper autodiff-style support for your functions, rather than implementing a numerical workaround.

That said, I guess there might be situations where that's impractical for some reason. There is a very simple make_numerical_jvp function in autograd.test_util which you might find useful for doing this (obviously to generate the whole Jacobian numerically you'd need to compute the jacobian vector product on each of a sequence of basis vectors). It's unlikely that we'd ever merge a more heavy-weight/general solution into Autograd (or JAX, for that matter), but you're right that it should be possible to extend Autograd or JAX with your own tool.

@j-towns
Copy link
Collaborator

j-towns commented Apr 15, 2020

Reading through the comments above a bit more closely I can see that np.linalg.lstsq is one example that isn't implemented, and also the incomplete gamma functions, w.r.t. some arguments.

Re: lstsq, I guess you can work around that using pinv, whose derivative is implemented, although that's going to have sub-optimal performance. My understanding is that various people are working to get that into JAX (see e.g. jax-ml/jax#2200), and once it's there it should be fairly straightforward to back-port the derivatives to Autograd. Alternatively someone could implement them from scratch, or see if there's an existing implementation in TF or another autodiff package that could be copied.

I don't know much about incomplete gamma functions but I'm guessing there isn't a well known analytical formula for them. I think we only implemented the 'easy' ones. @CamDavidsonPilon if you get round to finishing off accurate derivative functions I'm sure they would be a welcome addition to JAX as well as Autograd.

@momchilmm
Copy link
Contributor Author

Thanks for looking into this! I agree that with many functions there usually is a way to include them properly in the autodiff (and if there isn't, then probably the derivative is not well-defined for some reason). The idea of this PR is rather to be able to incorporate an unsupported function without any extra effort. This could be useful especially if that function is not the bottleneck of the computation, such that the added complexity of doing it numerically is negligible.

But yeah, in my research I always ended up writing the proper backward pass for unsupported functions. It's clearly the better way, the question is whether it's good to have this workaround easily available or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants