Skip to content

Commit

Permalink
[Mosaic GPU] Handle TMEM allocation in the compiler
Browse files Browse the repository at this point in the history
The code for allocation is uninteresting and it's the only set of primitives
that is executed by a single warp (other TMA APIs have single-thread or
warpgroup issue granularity).

PiperOrigin-RevId: 725583720
  • Loading branch information
apaszke authored and Google-ML-Automation committed Feb 11, 2025
1 parent 5a22351 commit 0209eee
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 46 deletions.
6 changes: 4 additions & 2 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
from jax import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401

# The imports below shadow the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401

from .core import (
Barrier as Barrier,
ClusterBarrier as ClusterBarrier,
TMABarrier as TMABarrier,
ThreadSemantics as ThreadSemantics,
TMEM as TMEM,
Union as Union,
as_gpu_kernel as as_gpu_kernel,
)
Expand Down Expand Up @@ -85,8 +89,6 @@
warpgroup_idx as warpgroup_idx,
when as when,
)
# The import below shadows the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401
from .wgmma import (
WGMMAAccumulator as WGMMAAccumulator,
wgmma as wgmma,
Expand Down
75 changes: 65 additions & 10 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@
import ctypes
import dataclasses
import enum
import functools
import hashlib
import math
import os
import pathlib
import time
from typing import Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import weakref

import jax
from jax._src.interpreters import mlir
from jax._src.lib import mosaic_gpu_dialect as dialect
from jaxlib.mlir import ir
from jaxlib.mlir import passmanager
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import gpu
Expand All @@ -49,6 +51,7 @@
from . import profiler
from . import utils
from . import launch_context
from . import tcgen05

# mypy: ignore-errors

Expand Down Expand Up @@ -163,6 +166,19 @@ class ClusterBarrier:
collective_dims: Sequence[gpu.Dimension]
num_barriers: int = 1

@dataclasses.dataclass(frozen=True)
class TMEM:
shape: tuple[int, int]
dtype: Any
layout: tcgen05.TMEMLayout

def __post_init__(self):
if self.shape[0] != self.layout.num_rows:
raise ValueError(
f"Row must match layout={self.layout} ({self.layout.num_rows}), but"
f" got {self.shape[0]}"
)


def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
Expand All @@ -179,10 +195,12 @@ def _construct_smem_reftree(
cluster_shape: tuple[int, int, int],
dynamic_smem: ir.Value,
smem_buffers: ShapeTree,
delayed_warp_init: list[Callable[[], None]], # Mutated by this function!
dynamic_smem_offset: int = 0,
) -> RefTree:
) -> Callable[[], RefTree]:
index = ir.IndexType.get()
i8 = ir.IntegerType.get_signless(8)
i32 = ir.IntegerType.get_signless(32)
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
Expand All @@ -205,13 +223,17 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value:
return barrier_base_ptr
match ref_ty:
case Union(members):
member_trees = [
_construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset)
member_thunks = [
_construct_smem_reftree(
cluster_shape, dynamic_smem, m,
delayed_warp_init, dynamic_smem_offset,
)
for m in members
]
# TODO(apaszke): This is quadratic, but it shouldn't matter for now...
dynamic_smem_offset += _smem_tree_size(ref_ty)
ref = Union(member_trees)
def ref(member_thunks=member_thunks):
return Union([t() for t in member_thunks])
case TMABarrier(num_barriers):
ref = utils.BarrierRef.initialize(
get_barrier_ptr(num_barriers), num_barriers, arrival_count=1
Expand All @@ -229,6 +251,20 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value:
collective_dims,
cluster_shape,
)
case TMEM(shape, dtype, layout):
addr_ref = memref.view(
ir.MemRefType.get([], i32, memory_space=smem),
dynamic_smem, c(dynamic_smem_offset, index), [],
)
delayed_warp_init.append(
functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], exact=False)
)
def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout):
addr = memref.load(addr_ref, [])
return tcgen05.TMEMRef(
addr, layout, shape[1], utils.dtype_to_ir_type(dtype)
)
dynamic_smem_offset += 4 # i32 takes up 4 bytes
case _:
mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype)
tile_smem = memref.view(
Expand All @@ -238,7 +274,14 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value:
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
ref = tile_smem
smem_refs.append(ref)
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
def ref_tree_thunk():
refs = []
for ref in smem_refs:
if callable(ref):
ref = ref()
refs.append(ref)
return jax.tree.unflatten(smem_buffer_tree, refs)
return ref_tree_thunk


def _smem_tree_size(smem_buffers: ShapeTree) -> int:
Expand All @@ -258,6 +301,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
if size % utils.MBARRIER_BYTES:
raise NotImplementedError("Misaligned barrier allocation")
size += num_barriers * utils.MBARRIER_BYTES
case TMEM(_):
size += 4 # i32 takes up 4 bytes
case _:
size += _count_buffer_bytes(l)
return size
Expand Down Expand Up @@ -336,15 +381,25 @@ def _launch(
scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr])
ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof)
with ctx.named_region("Init"):
smem_ref_tree = _construct_smem_reftree(
cluster, dynamic_smem, smem_buffers
delayed_warp_init = []
smem_ref_tree_thunk = _construct_smem_reftree(
cluster, dynamic_smem, smem_buffers, delayed_warp_init
)
# TODO(apaszke): Skip the following if no barriers were initialized.
# TODO(apaszke): Skip fences if no barriers or TMEM is initialized.
# TODO(apaszke): Only initialize cluster barriers before the cluster wait.
nvvm.fence_mbarrier_init()
if math.prod(cluster) != 1:
nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get())
nvvm.cluster_wait(aligned=ir.UnitAttr.get())
gpu.barrier()
if delayed_warp_init:
eq = arith.CmpIPredicate.eq
is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32))
with utils.when(is_init_warp):
for init in delayed_warp_init:
init()
tcgen05.tmem_relinquish_alloc_permit()
gpu.barrier() # Make sure the init is visible to all threads.
smem_ref_tree = smem_ref_tree_thunk()

yield ctx, smem_ref_tree
if prof is not None:
Expand Down
18 changes: 7 additions & 11 deletions jax/experimental/mosaic/gpu/examples/matmul_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import nvvm
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, ds, utils
from jax.experimental.mosaic.gpu import c, ds
from jax.experimental.mosaic.gpu import tcgen05
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -65,7 +65,7 @@ def build_kernel(
tma_tile_kn = 64

def kernel(ctx, a, b, d, smem):
a_smem, b_smem, d_smem, barriers, mma_done_barrier, tmem_addr = smem
a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem
(ab_full_barriers, ab_empty_barriers) = barriers

warp_idx = mgpu.warp_idx(sync=True)
Expand Down Expand Up @@ -109,18 +109,14 @@ def _tma_body(ki, _):
**common_args,
)

with mgpu.when(is_warp(MMA_WARP)):
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
tcgen05.tmem_alloc(tmem_addr_addr, tile_n)
tcgen05.tmem_relinquish_alloc_permit()
tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
with mgpu.when(arith.andi(is_warp(MMA_WARP), warp_leader)):
with mgpu.when(warp_leader):
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
slot = arith.remui(ki, c(max_concurrent_steps, index))
ab_full_barriers[slot].wait()
tcgen05.mma(
tmem_ref,
acc,
mgpu.memref_slice(a_smem, slot),
mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)),
a_swizzle=swizzle,
Expand All @@ -142,9 +138,9 @@ def _mma_body(ki, accumulate):
gpu.barrier()
mma_done_barrier.wait()

tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32)
tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128)
mgpu.commit_shared()
# TODO(apaszke): Free up TMEM?
ctx.async_copy(
src_ref=d_smem,
dst_ref=d,
Expand All @@ -161,7 +157,7 @@ def _mma_body(ki, accumulate):
jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16),
[mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2,
mgpu.Barrier(arrival_count=1),
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D),
)
return mgpu.as_gpu_kernel(
kernel,
Expand Down
12 changes: 7 additions & 5 deletions jax/experimental/mosaic/gpu/tcgen05.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ class TMEMLayout(enum.Enum):
"""
D = "D"

@property
def num_rows(self) -> int:
match self:
case TMEMLayout.D:
return 128


@dataclasses.dataclass(frozen=True)
class TMEMRef:
Expand Down Expand Up @@ -327,11 +333,7 @@ def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int,

@property
def num_rows(self):
match self.layout:
case TMEMLayout.D:
return 128
case _:
raise NotImplementedError(self.layout)
return self.layout.num_rows

@property
def shape(self):
Expand Down
29 changes: 11 additions & 18 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,19 +968,17 @@ def test_mma_basic(
in_jax_dtype,
out_jax_dtype,
):
i32 = ir.IntegerType.get_signless(32)
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
raise self.skipTest("Only f16 input is supported for f16 output.")

in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
out_mlir_dtype = utils.dtype_to_ir_type(out_jax_dtype)
m_tile = 128
nk_tile = swizzle // bytewidth(in_mlir_dtype)
k = nk_tile * k_steps
assert m % m_tile == 0 and n % nk_tile == 0

def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, tmem_addr_ref = scratch
lhs_smem, rhs_smem, barriers, acc = scratch
lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),)
if lhs_transpose:
assert nk_tile == m_tile # Make sure we didn't have to transpose tiling
Expand All @@ -1004,21 +1002,16 @@ def kernel(ctx, lhs, rhs, out, scratch):
)
barriers[0].wait()
barriers[1].wait()
with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, mgpu.warp_idx(), c(0, i32))):
tcgen05.tmem_alloc(tmem_addr_ref, n)
tcgen05.tmem_relinquish_alloc_permit()
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
with mgpu.single_thread():
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
)
tcgen05.commit_arrive(barriers[2])
with mgpu.single_thread():
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
)
tcgen05.commit_arrive(barriers[2])
barriers[2].wait()
acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype)
acc[:].store_untiled(out)

in_finfo = jnp.finfo(in_jax_dtype)
Expand All @@ -1036,7 +1029,7 @@ def quantize(x):
jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype),
mgpu.TMABarrier(3),
jax.ShapeDtypeStruct((), jnp.int32),
mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D),
]
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
Expand Down

0 comments on commit 0209eee

Please sign in to comment.