Skip to content

Commit

Permalink
[better_errors] Continue adding debug info to Jaxprs (step 4)
Browse files Browse the repository at this point in the history
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
  • Loading branch information
gnecula committed Feb 7, 2025
1 parent c6e8390 commit 3150109
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 77 deletions.
3 changes: 1 addition & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,10 +2060,9 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
shape/dtypes/structure as ``primals``.
>>> import jax
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,8 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
"to handle custom_jvp primitives")
raise NotImplementedError(msg)

def process_custom_transpose(self, prim, call, tracers, **params):
def process_custom_transpose(self, prim: Primitive,
call: lu.WrappedFun, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
Expand Down
38 changes: 26 additions & 12 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax._src import config
from jax._src import core
from jax._src import custom_api_util
from jax._src.custom_transpose import custom_transpose
from jax._src.custom_transpose import CustomTranspose
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
Expand Down Expand Up @@ -57,11 +57,13 @@

### util

def _initial_style_jaxpr(fun, in_avals):
def _initial_style_jaxpr(fun: lu.WrappedFun,
in_avals: Sequence[core.AbstractValue]
) -> tuple[core.Jaxpr, Sequence[Any]]:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts

def _close_jaxpr(jaxpr):
def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr:
return pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))

def _sum_tangents(_, x, *xs):
Expand Down Expand Up @@ -1298,7 +1300,8 @@ def merge(l1, l2):

### Custom transposition

def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
def linear_call(fun: Callable,
fun_transpose: Callable, residual_args,
linear_args):
"""Call a linear function, with a custom implementation for its transpose.
Expand Down Expand Up @@ -1388,19 +1391,30 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
operands_lin, lin_tree = tree_flatten(linear_args)

f_in_tree = treedef_tuple((res_tree, lin_tree))
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
f, out_tree = flatten_fun_nokwargs(
lu.wrap_init(
fun,
debug_info=debug_info("linear_call fun", fun,
(residual_args, linear_args), {})),
f_in_tree)

res_avals = map(core.get_aval, operands_res)
lin_avals = map(core.get_aval, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = f_jaxpr.out_avals
f_jaxpr_closed = _close_jaxpr(f_jaxpr)
out_avals = f_jaxpr_closed.out_avals

t_in_tree = treedef_tuple((res_tree, out_tree()))
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
t, t_out_tree = flatten_fun_nokwargs(
lu.wrap_init(
fun_transpose,
# TODO(necula): the fun_transpose takes residual and output of fun!
debug_info=debug_info("linear_call fun_transpose", fun_transpose,
(residual_args, linear_args), {})),
t_in_tree)

t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
t_jaxpr = _close_jaxpr(t_jaxpr)
t_jaxpr_closed = _close_jaxpr(t_jaxpr)

if t_out_tree() != lin_tree:
raise TypeError(
Expand All @@ -1409,8 +1423,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
f'and input structure {lin_tree}.')

out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin,
callee=f_jaxpr,
transpose=t_jaxpr,
callee=f_jaxpr_closed,
transpose=t_jaxpr_closed,
num_callee_consts=len(f_consts),
num_transpose_consts=len(t_consts),
num_res=len(operands_res))
Expand Down Expand Up @@ -1523,7 +1537,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd):
def jvp(primals, tangents):
outs, residuals = fwd(*primals)
tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs)
tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types))
tan_fn = CustomTranspose(partial(disallow_jvp, out_avals=tan_out_types))
tan_fn.def_transpose(bwd)
return outs, tan_fn(tan_out_types, residuals, tangents)

Expand Down
56 changes: 35 additions & 21 deletions jax/_src/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
tree_structure, treedef_tuple, tree_unflatten,
PyTreeDef)


source_info_util.register_exclusion(__file__)
Expand Down Expand Up @@ -66,7 +67,7 @@ def transformation_with_aux(
### api

@custom_api_util.register_custom_decorator_type
class custom_transpose:
class CustomTranspose:
fun: Callable
transpose: Callable | None = None

Expand All @@ -90,10 +91,20 @@ def __call__(self, out_types, res_arg, lin_arg):
# TODO(frostig,mattjj): could, and should, we avoid flattening
# self.fun at this point?

flat_fun, out_tree2 = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
flat_fun, out_tree2 = flatten_fun_nokwargs(
lu.wrap_init(
self.fun,
debug_info=api_util.debug_info("custom_transpose fun", self.fun,
(res_arg, lin_arg), {})),
in_tree)
out_types_flat, out_tree = tree_flatten(out_types)
transpose_wrapped = lu.wrap_init(
self.transpose,
debug_info=api_util.debug_info("custom_transpose transpose_fun",
self.transpose,
(res_arg, out_types), {}) )
out_flat = custom_transpose_p.bind(flat_fun, *args_flat,
transpose=self.transpose,
transpose=transpose_wrapped,
out_types=out_types_flat,
lin_tree=lin_tree,
res_tree=res_tree,
Expand Down Expand Up @@ -125,27 +136,28 @@ def is_treedef_prefix(entire, prefix):
def rule_name(rule):
return getattr(rule, '__name__', '<unnamed transpose rule>')

def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
def check_transpose_rule_trees(rule: lu.WrappedFun,
lin_tree: PyTreeDef,
rule_out_tree: PyTreeDef):
if not is_treedef_prefix(lin_tree, rule_out_tree):
if hasattr(rule, '_transpose_type_error'):
raise rule._transpose_type_error(lin_tree, rule_out_tree)
else:
raise TypeError(
'structure of custom transpose rule\'s output does not prefix-match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name(rule)}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')

def make_transpose_from_thunk(thunk, lin_tree):
rule_name = rule.debug_info.func_src_info if rule.debug_info else "<unknown>"
raise TypeError(
'structure of custom transpose rule\'s output does not prefix-match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')

def make_transpose_from_thunk(thunk: Callable,
lin_tree: PyTreeDef) -> lu.WrappedFun:
transpose_jaxpr, transpose_consts = thunk()
transpose_jaxpr = core.ClosedJaxpr(
pe.convert_constvars_jaxpr(transpose_jaxpr), ())
def transpose(res_arg, ct_out):
args_flat = tree_leaves((res_arg, ct_out))
ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
return tree_unflatten(lin_tree, ct_ins)
return transpose
return lu.wrap_init(transpose, debug_info=transpose_jaxpr.jaxpr.debug_info)


### custom_transpose primitive and rules
Expand All @@ -167,11 +179,13 @@ def bind_with_trace(self, trace, call_args, params):
def get_bind_params(self, params):
assert 'call_jaxpr' in params
assert 'transpose_jaxpr_thunk' in params
new_params = dict(params)
new_params: dict[str, Any] = dict(params)
new_params['transpose'] = make_transpose_from_thunk(
new_params.pop('transpose_jaxpr_thunk'),
new_params['lin_tree'])
call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr')))
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
debug_info=call_jaxpr.jaxpr.debug_info)
return [call], new_params


Expand All @@ -183,7 +197,7 @@ def custom_transpose_typecheck(_, *in_atoms, out_types, **params):

def custom_transpose_transpose_rule(
cts, *args, out_types, res_tree, lin_tree, out_tree, **params):

transpose: lu.WrappedFun
if 'transpose_jaxpr_thunk' in params:
assert 'call_jaxpr' in params
transpose = make_transpose_from_thunk(
Expand All @@ -205,7 +219,7 @@ def custom_transpose_transpose_rule(
cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
for ct in cts]
ct_out = tree_unflatten(out_tree, cts)
ct_lin = transpose(res_arg, ct_out)
ct_lin = transpose.call_wrapped(res_arg, ct_out)
check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
ct_lin_flat, _ = tree_flatten(
tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None),
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,9 +2056,12 @@ def fwd_jaxpr_from_zeros(*zeros):
self.frame.add_eqn(eqn)
return out_tracers

def process_custom_transpose(self, prim, call, tracers, *,
transpose, out_types,
lin_tree, res_tree, out_tree):
def process_custom_transpose(self, prim: core.Primitive, # type: ignore[override]
call: lu.WrappedFun, tracers, *,
transpose: lu.WrappedFun,
out_types,
lin_tree: PyTreeDef,
res_tree: PyTreeDef, out_tree: PyTreeDef):
tracers = map(self.to_jaxpr_tracer, tracers)
tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves])

Expand All @@ -2070,7 +2073,7 @@ def process_custom_transpose(self, prim, call, tracers, *,
convert_constvars_jaxpr(call_jaxpr), ())

transpose_flat, in_tree2 = api_util.flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
transpose, treedef_tuple((res_tree, out_tree)))

# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
@_memoize
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def wrap_init(f: Callable, params=None, *,
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
if debug_info:
if debug_info.result_paths is None:
fun, result_paths_thunk = _get_result_paths_thunk(fun)
Expand Down
1 change: 1 addition & 0 deletions jax/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from jax._src.api_util import (
argnums_partial as argnums_partial,
debug_info as debug_info,
donation_vector as donation_vector,
flatten_axes as flatten_axes,
flatten_fun as flatten_fun,
Expand Down
2 changes: 1 addition & 1 deletion jax/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.

from jax._src.custom_transpose import (
custom_transpose as custom_transpose,
CustomTranspose as CustomTranspose,
)
10 changes: 5 additions & 5 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
import jax._src.custom_transpose
import jax.experimental.custom_dce
from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError,
ConcretizationTypeError, TracerBoolConversionError)
Expand Down Expand Up @@ -9721,7 +9721,7 @@ def transposed(y):
class _custom_transpose:
def __init__(self, out_types, fun):
self.out_types = out_types
self.fun = jax.custom_transpose.custom_transpose(fun)
self.fun = jax._src.custom_transpose.CustomTranspose(fun)

def __getattr__(self, name):
return getattr(self.fun, name)
Expand Down Expand Up @@ -11067,7 +11067,7 @@ class CustomApiTest(jtu.JaxTestCase):
def test_method_forwarding(self):
@jax.custom_batching.custom_vmap
@jax.custom_jvp
@jax.custom_transpose.custom_transpose
@jax.custom_transpose.CustomTranspose
def f(x): return 2. * x

# none of these err:
Expand All @@ -11080,7 +11080,7 @@ def f_transpose(x): return 2. * x

def test_def_method_forwarding_all_permutations(self):
for wraps in it.permutations([
jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]):
jax.custom_jvp, jax.custom_transpose.CustomTranspose, jax.custom_batching.custom_vmap]):
f = lambda x: x + 1.
for wrap in wraps:
f = wrap(f)
Expand All @@ -11089,7 +11089,7 @@ def test_def_method_forwarding_all_permutations(self):
self.assertIsInstance(getattr(f, method), Callable)

for decorators in it.permutations([
jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]):
jax.custom_vjp, jax.custom_transpose.CustomTranspose, jax.custom_batching.custom_vmap]):
f = lambda x: x + 1.
for decorator in decorators:
f = decorator(f)
Expand Down
Loading

0 comments on commit 3150109

Please sign in to comment.