-
Notifications
You must be signed in to change notification settings - Fork 970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blockwise Scaling for FP8 #1932
base: main
Are you sure you want to change the base?
Conversation
f802dfa
to
57896c9
Compare
57896c9
to
4d45e57
Compare
Super cool! Thanks for upstream :) Will do a full review soon. One comment to make to start would to please not extend the existing primary types in CUTLASS. Eg. The new collective builder should just be another specialized dispatch of the existing one. If you need to pass in extra arguments, you can include them in the builder dispatch policy itself. We want to ensure that there is always a single point of entry at each conceptual level. |
@@ -100,7 +100,7 @@ using LayoutA = cutlass::layout::RowMajor; // L | |||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) | |||
|
|||
// B matrix configuration | |||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand | |||
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can probably skip modifying this example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to match with the example 64 configuration to compare the performance of FADD version and FFMA version with the same configuration, get the PTX in two different files. Also, the absolute performance for FADD version for a large GEMM is better the current confirmation. So I kept the change. If it is ok let us change example 54 (FADD) to match the final example 64 (FFMA). If not, I can add one more instance in example 64 to have both (FFMA and FADD version in one example).
...ed_gemm_with_blockwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
Outdated
Show resolved
Hide resolved
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBlockScalingBuilder< | ||
ArchTag, OperatorClass, | ||
ElementA, LayoutA, AlignmentA, | ||
ElementB, LayoutB, AlignmentB, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it implicit that Element Block Scale will always be same as ElementAccum
? if so, would be good to add some static assert in the collective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now it is the case that ElementBlockScale
= ElementAccumulator
. I still prefer to write have an alias type in the code to increase readability and find the scale tensors quickly while reading and searching. Maybe in the future we will need different datatypes for ElementBlockScale
and ElementAccumulator
. Although, I don't see it happening as long as accumulation is in F32. I have added the static_asserts, in case user tries and set them differently, as we don't have NumericConvertor
in the GmmaFP8Accumulation::scale_core
.
warpgroup_wait<0>(); | ||
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < size(accum_); ++i) { | ||
accum_(i) += accum_temp_(i) * scale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two suggestions :
-
One can do way more than just scaling accums with this approach - so you might optionally want to consider may be generically adding an interface for allowing user to apply a point-wise operation on an accumulator per block.
-
This impl places assumptions / restrictions on types of scale and accum. So this can be enhanced to call an appropriate sub-function / impl to ensure optimal scaling - that way in the future if one decides to pass a different scale type - it can just be extended with a new overload / specialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great suggestion and we can consider it in the future. For now, we are looking for scaling the accumulators blockwise and possibly group-wise. Let us discuss more on this soon.
…r Blocked Scale GEMM
f746e3c
to
2cdd89a
Compare
Pushed a few commits to address some of the comments. Please take a look. I kept the commits after initial commit separate so different comments can be reviewed easily. I will squash all commits before the merge. |
Thanks @thakkarV for the comment. Please the diff here. This diff makes the CollectiveBuilder entry point another specialization by adding a new dispatch policy. I had something similar in the TODO which also removed now. Please ignore the accidentally CMakeList change, you won't see that change in the full diff. |
Summary
As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g.,
e5m2_t
,e4m3_t
). The typical GEMM operation uses tensorwise scaling withD = alpha * (A @ B) + beta * C
, but narrower datatypes necessitate more finer-grained scaling techniques. This PR adds blockwise scaling strategy to improve accuracy while making an effort to not loose performance. Before we dive deep into blockwise scaling below is a glossary of various scaling methods:EpilogueVisitorTree
.This enhancement focuses on improving GEMM accuracy for narrow datatypes, balancing the trade-off between performance and precision with the addition of blockwise scaling support.
Blockwise Scaling
The figure below illustrates a blockwise scaled GEMM, with operand tensors A and B shown in grey, block scaling tensors in blue, and output in green. In this implementation, we load operand tensors using
UTMALDG
and block scaling tensors usingLDGSTS
, transferring them from global memory to shared memory. Block scaling tensor loads are issued for the same stage as the operand tensor loads. To ensure proper synchronization forLDGSTS
, we usecutlass::arch::cpasync_barrier_arrive
withnoinc
modifier. We have modified thePipelineTmaAsync
class to accommodate a variable number of producer thread arrival events to support this functionality effectively.Performance
For the graph below, I used CUDA Toolkit 12.3.2. Please note that with the latest toolkit 12.6.2, I observe LDLs and STLs in the SASS and the performance of block scaling is terrible. Thus, I stick with 12.3.2 for further performance optimizations and look for source of improvements. Please find the SASS attached for the example
54_hopper_fp8_warp_specialized_gemm
(F8 with Slow Accumulation, FADDs after QGMMAs inside the mainloop) and the new64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
(F8 with Slow Accum, FFMA after QGMMAs inside the mainloop).NCU Profile (Kernel compiled with CUDA Toolkit 12.3.2)
I see one major stall for both FADD version and FFMA version soon after the QGMMA, waiting for for the accumulator to be ready to apply the promotions and scaling, respectively. I don't see any other difference other that that this stall is larger for FFMA.
FADD Version (Example 54 with slow accumulation and tensorwise scaling. Modified to have same tiling and kernel schedule as 64)
FFMA Version (Example 64 with slow accumulation and blockwise scaling)
Technically, for a large GEMM shape with cooperative schedule, we would expect both version to be running the same performance. Let me know if you have more input on what we are missing here to match the performance. We will eventually need a good implementation of blockwise scaling kernel in CUTLASS for plain F8 GEMMs and also take these learnings to FlashAttention-3 F8 Scaling.
Attn: @IonThruster , @hwu36 , @thakkarV