From 9d3578dd08823922810f120faa9d52a5790c63ba Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 25 Oct 2024 17:34:47 +0800 Subject: [PATCH 1/2] [torch-frontend] use new register method to register byteir.flash_attn ops --- .../python/torch_frontend/flash_attn_op.py | 108 ++++++++++++------ 1 file changed, 73 insertions(+), 35 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py index 33d733e39..d971cad9c 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py @@ -1,34 +1,35 @@ import torch import math -from torch.library import Library - -OPERATORS = [] - - -def op(schema): - def inner(f): - # TODO: Refactor the Library API so this is less rage inducing - # TODO: Perhaps the namespace should be directly based on Python - # module - if "::" in schema: - ns = schema.split("::", 2)[0] - else: - ns = "contrib" - # TODO: Library doesn't allow FRAGMENT, need to allow it - lib = Library(ns, "FRAGMENT") - name = lib.define(schema) - if "::" in name: - name = name.split("::", 2)[1] - lib.impl(name, f, "CompositeExplicitAutograd") - OPERATORS.append(lib) - return getattr(getattr(torch.ops, ns), name) - - return inner - - -@op( - "byteir::flash_attn_fwd(Tensor q, Tensor k, Tensor v, float dropout_p, float softmax_scale, bool causal, bool return_softmax) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)" -) + +@torch.library.custom_op("byteir::flash_attn_fwd", mutates_args=()) +def byteir_flash_attn_fwd( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, return_softmax: bool +) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + sizes = q.shape + batch_size = sizes[0] + seqlen_q = sizes[1] + num_heads = sizes[2] + seqlen_k = k.shape[1] + + rng = torch.empty((2), dtype=torch.int64, device="meta") + softmax_lse = torch.empty( + (batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta" + ) + p = None + if return_softmax: + p = torch.empty( + (batch_size, num_heads, seqlen_q, seqlen_k), + dtype=torch.float, + device="meta", + ) + q_padded = q + k_padded = k + v_padded = v + out = torch.empty_like(q_padded) + out_padded = torch.empty_like(out) + return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng + +@torch.library.register_fake("byteir::flash_attn_fwd") def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_softmax): sizes = q.shape batch_size = sizes[0] @@ -55,9 +56,32 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng -@op( - "byteir::flash_attn_bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, float dropout_p, float softmax_scale, bool causal, Tensor rng) -> (Tensor, Tensor, Tensor, Tensor, Tensor)" -) +@torch.library.custom_op("byteir::flash_attn_bwd", mutates_args=()) +def byteir_flash_attn_bwd( + dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, rng_state: torch.Tensor +) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + sizes = q.shape + batch_size = sizes[0] + seqlen_q = sizes[1] + num_heads = sizes[2] + seqlen_q_rounded = ((seqlen_q + 127) // 128) * 128 + head_size = sizes[3] + head_size_rounded = ((head_size + 31) // 32) * 32 + dq_accum = torch.empty( + (batch_size, num_heads, seqlen_q_rounded, head_size_rounded), + dtype=torch.float, + device="meta", + ) + softmax_d = torch.empty( + (batch_size, num_heads, seqlen_q_rounded), dtype=torch.float, device="meta" + ) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + return dq, dk, dv, softmax_d, dq_accum + + +@torch.library.register_fake("byteir::byteir_flash_attn_bwd") def byteir_flash_attn_bwd( dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, causal, rng_state ): @@ -82,9 +106,23 @@ def byteir_flash_attn_bwd( return dq, dk, dv, softmax_d, dq_accum -@op( - "byteir::flash_attn_kvcache(Tensor q, Tensor k, Tensor v, Tensor kcache, Tensor vcache, Tensor seqlen_k, float softmax_scale, bool causal) -> (Tensor, Tensor)" -) +@torch.library.custom_op("byteir::flash_attn_kvcache", mutates_args()) +def byteir_flash_attn_kvcache( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kcache: torch.Tensor, vcache: torch.Tensor, seqlen_k: torch.Tensor, softmax_scale: float, causal: bool +) -> (torch.Tensor, torch.Tensor): + sizes = q.shape + batch_size = sizes[0] + seqlen_q = sizes[1] + num_heads = sizes[2] + + softmax_lse = torch.empty( + (batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta" + ) + out = torch.empty_like(q) + return out, softmax_lse + + +@torch.library.register_fake("byteir::flash_attn_kvcache") def byteir_flash_attn_kvcache(q, k, v, kcache, vcache, seqlen_k, softmax_scale, causal): sizes = q.shape batch_size = sizes[0] From 47e783735c01f42691a1d5c77e82228bc935f4f7 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 25 Oct 2024 17:54:05 +0800 Subject: [PATCH 2/2] update --- .../python/torch_frontend/flash_attn_op.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py index d971cad9c..cd67c9081 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/flash_attn_op.py @@ -1,10 +1,11 @@ +from typing import List import torch import math @torch.library.custom_op("byteir::flash_attn_fwd", mutates_args=()) def byteir_flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, return_softmax: bool -) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): +) -> List[torch.Tensor]: sizes = q.shape batch_size = sizes[0] seqlen_q = sizes[1] @@ -59,7 +60,7 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft @torch.library.custom_op("byteir::flash_attn_bwd", mutates_args=()) def byteir_flash_attn_bwd( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, rng_state: torch.Tensor -) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): +) -> List[torch.Tensor]: sizes = q.shape batch_size = sizes[0] seqlen_q = sizes[1] @@ -81,7 +82,7 @@ def byteir_flash_attn_bwd( return dq, dk, dv, softmax_d, dq_accum -@torch.library.register_fake("byteir::byteir_flash_attn_bwd") +@torch.library.register_fake("byteir::flash_attn_bwd") def byteir_flash_attn_bwd( dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, causal, rng_state ): @@ -106,10 +107,10 @@ def byteir_flash_attn_bwd( return dq, dk, dv, softmax_d, dq_accum -@torch.library.custom_op("byteir::flash_attn_kvcache", mutates_args()) +@torch.library.custom_op("byteir::flash_attn_kvcache", mutates_args=()) def byteir_flash_attn_kvcache( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kcache: torch.Tensor, vcache: torch.Tensor, seqlen_k: torch.Tensor, softmax_scale: float, causal: bool -) -> (torch.Tensor, torch.Tensor): +) -> List[torch.Tensor]: sizes = q.shape batch_size = sizes[0] seqlen_q = sizes[1]