diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 1286e4b77c8a..4a0316192d28 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -16,11 +16,15 @@ from jax import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 +# The imports below shadow the module, so we need to rename it. +from . import wgmma as _wgmma # noqa: F401 + from .core import ( Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, ThreadSemantics as ThreadSemantics, + TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, ) @@ -85,8 +89,6 @@ warpgroup_idx as warpgroup_idx, when as when, ) -# The import below shadows the module, so we need to rename it. -from . import wgmma as _wgmma # noqa: F401 from .wgmma import ( WGMMAAccumulator as WGMMAAccumulator, wgmma as wgmma, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 40b6e425d278..3c273f1778c9 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -18,12 +18,13 @@ import ctypes import dataclasses import enum +import functools import hashlib import math import os import pathlib import time -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import weakref import jax @@ -31,6 +32,7 @@ from jax._src.lib import mosaic_gpu_dialect as dialect from jaxlib.mlir import ir from jaxlib.mlir import passmanager +from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func from jaxlib.mlir.dialects import gpu @@ -49,6 +51,7 @@ from . import profiler from . import utils from . import launch_context +from . import tcgen05 # mypy: ignore-errors @@ -163,6 +166,19 @@ class ClusterBarrier: collective_dims: Sequence[gpu.Dimension] num_barriers: int = 1 +@dataclasses.dataclass(frozen=True) +class TMEM: + shape: tuple[int, int] + dtype: Any + layout: tcgen05.TMEMLayout + + def __post_init__(self): + if self.shape[0] != self.layout.num_rows: + raise ValueError( + f"Row must match layout={self.layout} ({self.layout.num_rows}), but" + f" got {self.shape[0]}" + ) + def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize @@ -179,10 +195,12 @@ def _construct_smem_reftree( cluster_shape: tuple[int, int, int], dynamic_smem: ir.Value, smem_buffers: ShapeTree, + delayed_warp_init: list[Callable[[], None]], # Mutated by this function! dynamic_smem_offset: int = 0, -) -> RefTree: +) -> Callable[[], RefTree]: index = ir.IndexType.get() i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) smem = ir.Attribute.parse("#gpu.address_space") flat_ref_tys, smem_buffer_tree = jax.tree.flatten( smem_buffers, is_leaf=lambda x: isinstance(x, Union) @@ -205,13 +223,17 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: return barrier_base_ptr match ref_ty: case Union(members): - member_trees = [ - _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) + member_thunks = [ + _construct_smem_reftree( + cluster_shape, dynamic_smem, m, + delayed_warp_init, dynamic_smem_offset, + ) for m in members ] # TODO(apaszke): This is quadratic, but it shouldn't matter for now... dynamic_smem_offset += _smem_tree_size(ref_ty) - ref = Union(member_trees) + def ref(member_thunks=member_thunks): + return Union([t() for t in member_thunks]) case TMABarrier(num_barriers): ref = utils.BarrierRef.initialize( get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 @@ -229,6 +251,20 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: collective_dims, cluster_shape, ) + case TMEM(shape, dtype, layout): + addr_ref = memref.view( + ir.MemRefType.get([], i32, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + delayed_warp_init.append( + functools.partial(tcgen05.tmem_alloc, addr_ref, shape[1], exact=False) + ) + def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): + addr = memref.load(addr_ref, []) + return tcgen05.TMEMRef( + addr, layout, shape[1], utils.dtype_to_ir_type(dtype) + ) + dynamic_smem_offset += 4 # i32 takes up 4 bytes case _: mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) tile_smem = memref.view( @@ -238,7 +274,14 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: dynamic_smem_offset += _count_buffer_bytes(ref_ty) ref = tile_smem smem_refs.append(ref) - return jax.tree.unflatten(smem_buffer_tree, smem_refs) + def ref_tree_thunk(): + refs = [] + for ref in smem_refs: + if callable(ref): + ref = ref() + refs.append(ref) + return jax.tree.unflatten(smem_buffer_tree, refs) + return ref_tree_thunk def _smem_tree_size(smem_buffers: ShapeTree) -> int: @@ -258,6 +301,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: if size % utils.MBARRIER_BYTES: raise NotImplementedError("Misaligned barrier allocation") size += num_barriers * utils.MBARRIER_BYTES + case TMEM(_): + size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) return size @@ -336,15 +381,25 @@ def _launch( scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof) with ctx.named_region("Init"): - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers + delayed_warp_init = [] + smem_ref_tree_thunk = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers, delayed_warp_init ) - # TODO(apaszke): Skip the following if no barriers were initialized. + # TODO(apaszke): Skip fences if no barriers or TMEM is initialized. + # TODO(apaszke): Only initialize cluster barriers before the cluster wait. nvvm.fence_mbarrier_init() if math.prod(cluster) != 1: nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() + if delayed_warp_init: + eq = arith.CmpIPredicate.eq + is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32)) + with utils.when(is_init_warp): + for init in delayed_warp_init: + init() + tcgen05.tmem_relinquish_alloc_permit() + gpu.barrier() # Make sure the init is visible to all threads. + smem_ref_tree = smem_ref_tree_thunk() yield ctx, smem_ref_tree if prof is not None: diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 32c50095f2b3..d0a5257c2d58 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -21,7 +21,7 @@ from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import nvvm from jax.experimental.mosaic import gpu as mgpu -from jax.experimental.mosaic.gpu import c, ds, utils +from jax.experimental.mosaic.gpu import c, ds from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp import jax.random as jr @@ -65,7 +65,7 @@ def build_kernel( tma_tile_kn = 64 def kernel(ctx, a, b, d, smem): - a_smem, b_smem, d_smem, barriers, mma_done_barrier, tmem_addr = smem + a_smem, b_smem, d_smem, barriers, mma_done_barrier, acc = smem (ab_full_barriers, ab_empty_barriers) = barriers warp_idx = mgpu.warp_idx(sync=True) @@ -109,18 +109,14 @@ def _tma_body(ki, _): **common_args, ) - with mgpu.when(is_warp(MMA_WARP)): - tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3) - tcgen05.tmem_alloc(tmem_addr_addr, tile_n) - tcgen05.tmem_relinquish_alloc_permit() - tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32) + with mgpu.when(arith.andi(is_warp(MMA_WARP), warp_leader)): with mgpu.when(warp_leader): @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) def _mma_body(ki, accumulate): slot = arith.remui(ki, c(max_concurrent_steps, index)) ab_full_barriers[slot].wait() tcgen05.mma( - tmem_ref, + acc, mgpu.memref_slice(a_smem, slot), mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (0, 1, 3, 2)), a_swizzle=swizzle, @@ -142,9 +138,9 @@ def _mma_body(ki, accumulate): gpu.barrier() mma_done_barrier.wait() - tmem_ref = tcgen05.TMEMRef.from_alloc(tmem_addr, tcgen05.TMEMLayout.D, tile_n, f32) - tmem_ref[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) + acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) mgpu.commit_shared() + # TODO(apaszke): Free up TMEM? ctx.async_copy( src_ref=d_smem, dst_ref=d, @@ -161,7 +157,7 @@ def _mma_body(ki, accumulate): jax.ShapeDtypeStruct(mgpu.tile_shape((tile_m, tile_n), (tma_tile_m, tma_tile_kn)), jnp.float16), [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, mgpu.Barrier(arrival_count=1), - jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address + mgpu.TMEM((128, tile_n), jnp.float32, tcgen05.TMEMLayout.D), ) return mgpu.as_gpu_kernel( kernel, diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index f4ba56eee36a..66fd79b2c625 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -297,6 +297,12 @@ class TMEMLayout(enum.Enum): """ D = "D" + @property + def num_rows(self) -> int: + match self: + case TMEMLayout.D: + return 128 + @dataclasses.dataclass(frozen=True) class TMEMRef: @@ -327,11 +333,7 @@ def from_alloc(cls, tmem_addr_ref: ir.Value, layout: TMEMLayout, num_cols: int, @property def num_rows(self): - match self.layout: - case TMEMLayout.D: - return 128 - case _: - raise NotImplementedError(self.layout) + return self.layout.num_rows @property def shape(self): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 37bc4a3a9eb0..4fece53ed196 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -968,19 +968,17 @@ def test_mma_basic( in_jax_dtype, out_jax_dtype, ): - i32 = ir.IntegerType.get_signless(32) if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: raise self.skipTest("Only f16 input is supported for f16 output.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) - out_mlir_dtype = utils.dtype_to_ir_type(out_jax_dtype) m_tile = 128 nk_tile = swizzle // bytewidth(in_mlir_dtype) k = nk_tile * k_steps assert m % m_tile == 0 and n % nk_tile == 0 def kernel(ctx, lhs, rhs, out, scratch): - lhs_smem, rhs_smem, barriers, tmem_addr_ref = scratch + lhs_smem, rhs_smem, barriers, acc = scratch lhs_transform = (mgpu.TileTransform((m_tile, nk_tile)),) if lhs_transpose: assert nk_tile == m_tile # Make sure we didn't have to transpose tiling @@ -1004,21 +1002,16 @@ def kernel(ctx, lhs, rhs, out, scratch): ) barriers[0].wait() barriers[1].wait() - with mgpu.when(arith.cmpi(arith.CmpIPredicate.eq, mgpu.warp_idx(), c(0, i32))): - tcgen05.tmem_alloc(tmem_addr_ref, n) - tcgen05.tmem_relinquish_alloc_permit() - acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype) - with mgpu.single_thread(): - if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) - if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) - tcgen05.mma( - acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, - ) - tcgen05.commit_arrive(barriers[2]) + with mgpu.single_thread(): + if lhs_transpose: + lhs_smem = memref_transpose(lhs_smem, (0, 1, 3, 2)) + if rhs_transpose: + rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2)) + tcgen05.mma( + acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, + ) + tcgen05.commit_arrive(barriers[2]) barriers[2].wait() - acc = tcgen05.TMEMRef.from_alloc(tmem_addr_ref, tcgen05.TMEMLayout.D, n, out_mlir_dtype) acc[:].store_untiled(out) in_finfo = jnp.finfo(in_jax_dtype) @@ -1036,7 +1029,7 @@ def quantize(x): jax.ShapeDtypeStruct(tile_shape((m, k), (m_tile, nk_tile)), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape((k, n), (nk_tile, nk_tile)), in_jax_dtype), mgpu.TMABarrier(3), - jax.ShapeDtypeStruct((), jnp.int32), + mgpu.TMEM((128, n), out_jax_dtype, tcgen05.TMEMLayout.D), ] z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape