Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blockwise Scaling for FP8 #1932

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

manishucsd
Copy link
Contributor

@manishucsd manishucsd commented Nov 8, 2024

Summary

As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g., e5m2_t, e4m3_t). The typical GEMM operation uses tensorwise scaling with D = alpha * (A @ B) + beta * C, but narrower datatypes necessitate more finer-grained scaling techniques. This PR adds blockwise scaling strategy to improve accuracy while making an effort to not loose performance. Before we dive deep into blockwise scaling below is a glossary of various scaling methods:

  1. Tensorwise Scaling: Uses a single scaling factor per tensor, applied in the epilogue.
  2. Rowwise Scaling: Uses a row vector for scaling, with dimensions Mx1 for operand A and 1xN for operand B, avoiding the scaling along the reduction dimension. This can also be handled in the epilogue with EpilogueVisitorTree.
  3. Blockwise Scaling (this diff): Introduces a 2D scaling tensor, assigning one scaling value per CTA Block. Since this scaling involves the reduction dimension (M, N, K), it must be applied during the mainloop, impacting performance. This PR implements blockwise scaling for CUTLASS F8 GEMM, staging scaling tensors via shared memory, and preparing for future support of groupwise scaling.
  4. Groupwise Scaling: Uses a 2D scaling tensor with multiple scaling values per CTA Block. Scaling granularity is independent of CTA Block configuration, allowing greater flexibility for future implementations.

This enhancement focuses on improving GEMM accuracy for narrow datatypes, balancing the trade-off between performance and precision with the addition of blockwise scaling support.

Blockwise Scaling

The figure below illustrates a blockwise scaled GEMM, with operand tensors A and B shown in grey, block scaling tensors in blue, and output in green. In this implementation, we load operand tensors using UTMALDG and block scaling tensors using LDGSTS, transferring them from global memory to shared memory. Block scaling tensor loads are issued for the same stage as the operand tensor loads. To ensure proper synchronization for LDGSTS, we use cutlass::arch::cpasync_barrier_arrive with noinc modifier. We have modified the PipelineTmaAsync class to accommodate a variable number of producer thread arrival events to support this functionality effectively.

Screenshot 2024-11-07 at 1 01 24 PM

Performance

For the graph below, I used CUDA Toolkit 12.3.2. Please note that with the latest toolkit 12.6.2, I observe LDLs and STLs in the SASS and the performance of block scaling is terrible. Thus, I stick with 12.3.2 for further performance optimizations and look for source of improvements. Please find the SASS attached for the example 54_hopper_fp8_warp_specialized_gemm (F8 with Slow Accumulation, FADDs after QGMMAs inside the mainloop) and the new 64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling (F8 with Slow Accum, FFMA after QGMMAs inside the mainloop).

GEMM Performance for F8 with Tensorwise and Blockwise Scaling, GEMM Problem Shape = (M=2816, N=3072, K=16384), BlockTile = (128x128x128), CUDA Toolkit = 12 3 2

NCU Profile (Kernel compiled with CUDA Toolkit 12.3.2)

I see one major stall for both FADD version and FFMA version soon after the QGMMA, waiting for for the accumulator to be ready to apply the promotions and scaling, respectively. I don't see any other difference other that that this stall is larger for FFMA.

FADD Version (Example 54 with slow accumulation and tensorwise scaling. Modified to have same tiling and kernel schedule as 64)

Screenshot 2024-11-08 at 10 24 09 AM

FFMA Version (Example 64 with slow accumulation and blockwise scaling)

Screenshot 2024-11-08 at 10 24 36 AM

Technically, for a large GEMM shape with cooperative schedule, we would expect both version to be running the same performance. Let me know if you have more input on what we are missing here to match the performance. We will eventually need a good implementation of blockwise scaling kernel in CUTLASS for plain F8 GEMMs and also take these learnings to FlashAttention-3 F8 Scaling.

Attn: @IonThruster , @hwu36 , @thakkarV

@manishucsd manishucsd force-pushed the f8_blockwise_scaling_pr_branch branch from f802dfa to 57896c9 Compare November 9, 2024 00:11
@manishucsd manishucsd force-pushed the f8_blockwise_scaling_pr_branch branch from 57896c9 to 4d45e57 Compare November 9, 2024 05:47
@thakkarV
Copy link
Collaborator

thakkarV commented Nov 9, 2024

Super cool! Thanks for upstream :) Will do a full review soon. One comment to make to start would to please not extend the existing primary types in CUTLASS. Eg. The new collective builder should just be another specialized dispatch of the existing one. If you need to pass in extra arguments, you can include them in the builder dispatch policy itself. We want to ensure that there is always a single point of entry at each conceptual level.

@@ -100,7 +100,7 @@ using LayoutA = cutlass::layout::RowMajor; // L
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probably skip modifying this example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to match with the example 64 configuration to compare the performance of FADD version and FFMA version with the same configuration, get the PTX in two different files. Also, the absolute performance for FADD version for a large GEMM is better the current confirmation. So I kept the change. If it is ok let us change example 54 (FADD) to match the final example 64 (FFMA). If not, I can add one more instance in example 64 to have both (FFMA and FADD version in one example).

using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBlockScalingBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it implicit that Element Block Scale will always be same as ElementAccum ? if so, would be good to add some static assert in the collective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it is the case that ElementBlockScale = ElementAccumulator. I still prefer to write have an alias type in the code to increase readability and find the scale tensors quickly while reading and searching. Maybe in the future we will need different datatypes for ElementBlockScale and ElementAccumulator. Although, I don't see it happening as long as accumulation is in F32. I have added the static_asserts, in case user tries and set them differently, as we don't have NumericConvertor in the GmmaFP8Accumulation::scale_core.

warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two suggestions :

  • One can do way more than just scaling accums with this approach - so you might optionally want to consider may be generically adding an interface for allowing user to apply a point-wise operation on an accumulator per block.

  • This impl places assumptions / restrictions on types of scale and accum. So this can be enhanced to call an appropriate sub-function / impl to ensure optimal scaling - that way in the future if one decides to pass a different scale type - it can just be extended with a new overload / specialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great suggestion and we can consider it in the future. For now, we are looking for scaling the accumulators blockwise and possibly group-wise. Let us discuss more on this soon.

@manishucsd
Copy link
Contributor Author

manishucsd commented Nov 12, 2024

Pushed a few commits to address some of the comments. Please take a look. I kept the commits after initial commit separate so different comments can be reviewed easily. I will squash all commits before the merge.

@manishucsd
Copy link
Contributor Author

manishucsd commented Nov 12, 2024

Super cool! Thanks for upstream :) Will do a full review soon. One comment to make to start would to please not extend the existing primary types in CUTLASS. Eg. The new collective builder should just be another specialized dispatch of the existing one. If you need to pass in extra arguments, you can include them in the builder dispatch policy itself. We want to ensure that there is always a single point of entry at each conceptual level.

Thanks @thakkarV for the comment. Please the diff here. This diff makes the CollectiveBuilder entry point another specialization by adding a new dispatch policy. I had something similar in the TODO which also removed now. Please ignore the accidentally CMakeList change, you won't see that change in the full diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants