Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose checkpoint name in cuDNN SDPA #26374

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax.interpreters.mlir import ir
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax.ad_checkpoint import checkpoint_name

Array = jnp.ndarray

Expand Down Expand Up @@ -388,7 +389,7 @@ def check_compute_capability(capability):
def _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
sliding_window_length, cudnn_version, outputs_checkpoint_name):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, False,
Expand All @@ -404,7 +405,7 @@ def _dot_product_attention_fwd(
def _dot_product_attention_fwd_rule(
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
sliding_window_length, cudnn_version, outputs_checkpoint_name):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, True,
Expand All @@ -414,13 +415,19 @@ def _dot_product_attention_fwd_rule(
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=True)
if outputs_checkpoint_name:
output = checkpoint_name(outputs[0], outputs_checkpoint_name)
softmax_stat = checkpoint_name(outputs[1], outputs_checkpoint_name)
else:
output, softmax_stat = outputs
res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets,
kv_offsets, outputs[1], outputs[0])
return outputs[0], res
kv_offsets, softmax_stat, output)
return output, res

def _dot_product_attention_bwd_rule(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training, res, grad_output):
sliding_window_length, cudnn_version, outputs_checkpoint_name,
res, grad_output):
(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
activation, fwd_output) = res
grads = _dot_product_attention_bwd_p_wrapper.bind(
Expand Down Expand Up @@ -1090,7 +1097,7 @@ def sharded_impl(*args):
_dot_product_attention_bwd_p_wrapper
)

@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
Expand All @@ -1106,13 +1113,15 @@ def _dot_product_attention(query: Array,
mask_type: bool,
layout: int,
sliding_window_length: int | None,
cudnn_version: int):
cudnn_version: int,
outputs_checkpoint_name: str):
output = _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
cudnn_version=cudnn_version)
cudnn_version=cudnn_version,
outputs_checkpoint_name=outputs_checkpoint_name)
return output

_dot_product_attention.defvjp(
Expand Down Expand Up @@ -1712,7 +1721,8 @@ def dot_product_attention(
dropout_rate: float = 0.,
qkv_layout: str = "BTNH",
sliding_window_length: int | None = None,
use_fp8: bool = False
use_fp8: bool = False,
outputs_checkpoint_name: str = ""
):
"""Computes dot-product attention given query (Q), key (K), and value (V).

Expand Down Expand Up @@ -1768,6 +1778,8 @@ def dot_product_attention(
is the index of each token. E.g., if sliding_window_length == 3 and the
sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c].
use_fp8: Whether to use FP8 attention mechanism.
outputs_checkpoint_name: checkpoint name for output tensor and softmax stat
tensor.
Returns:
Output of the same shape as the query.
amax_s: amax of state. (fp8 only)
Expand Down Expand Up @@ -1843,5 +1855,5 @@ def dot_product_attention(
output = _dot_product_attention(
query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets,
scale, seed, dropout_rate, variadic_args, mask_type, layout.value,
sliding_window_length, cudnn_version)
sliding_window_length, cudnn_version, outputs_checkpoint_name)
return output
59 changes: 54 additions & 5 deletions tests/fused_attention_stablehlo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
MaskType,
AttentionLayout,
)
from jax._src.ad_checkpoint import saved_residuals, checkpoint as new_checkpoint
from jax._src import api

config.parse_flags_with_absl()
Array = jnp.ndarray
Expand Down Expand Up @@ -104,19 +106,27 @@ def sdpa_train(query: Array,
mask_type: MaskType = MaskType.NO_MASK,
is_bnth: bool = False,
dropout_rate: float = 0.1,
sliding_window_length: int | None = None) -> Array:
sliding_window_length: int | None = None,
outputs_checkpoint_name: str = "",
policy = None) -> Array:
if mask_type == MaskType.PADDING:
if is_bnth:
B, _, S, _ = query.shape
else:
B, S, _, _ = query.shape
q_seqlen = kv_seqlen = jnp.full((B,), S // 2, jnp.int32)
out, sdpa_vjp = jax.vjp(
partial(dot_product_attention, scale=scale, mask_type=mask_type,

f = partial(dot_product_attention, scale=scale, mask_type=mask_type,
dropout_rate=dropout_rate,
qkv_layout="BNTH" if is_bnth else "BTNH",
sliding_window_length=sliding_window_length),
query, key, value, bias, mask, q_seqlen, kv_seqlen, q_offsets, kv_offsets)
sliding_window_length=sliding_window_length,
outputs_checkpoint_name=outputs_checkpoint_name)

if policy is not None:
f = new_checkpoint(f, policy=policy)

out, sdpa_vjp = jax.vjp(f, query, key, value, bias, mask, q_seqlen,
kv_seqlen, q_offsets, kv_offsets)
query_grad, key_grad, value_grad, bias_grad = sdpa_vjp(grad)[:4]
if bias is not None and len(bias.shape) == 3:
# has dbias
Expand Down Expand Up @@ -744,6 +754,45 @@ def generate_segment_mask(segment_ids, dtype):
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2)

@jtu.run_on_devices("cuda")
def test_sdpa_checkpoint(self):
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if cudnn_version < 90500:
self.skipTest("Requires >= cuDNN 9.5.0")

B, T, N, H = 2, 64, 2, 256
bf16 = jnp.bfloat16
keys = jax.random.split(jax.random.key(0), 4)
query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16)
key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16)
value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16)
grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16)
f = jax.jit(partial(
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0,
outputs_checkpoint_name="context",
policy=jax.checkpoint_policies.save_only_these_names("context")))
g = jax.jit(partial(
sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0,
policy=jax.checkpoint_policies.nothing_saveable))
out, (query_grad, key_grad, value_grad) = f(query, key, value, grad)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = g(query, key, value, grad)

f_jaxpr = api.make_jaxpr(f)(query, key, value, grad)
g_jaxpr = api.make_jaxpr(g)(query, key, value, grad)
f_jaxpr_text = str(f_jaxpr)
g_jaxpr_text = str(g_jaxpr)
self.assertEqual(f_jaxpr_text.count('dot_product_attention_fwd_wrapper'), 1)
self.assertEqual(g_jaxpr_text.count('dot_product_attention_fwd_wrapper'), 2)

self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)

@jtu.run_on_devices("cuda")
def test_layouts(self):
if jax.device_count() < 4:
Expand Down