Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for: Groupwise scaling along M for FP8 gemm #2095

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link

@LucasWilkinson LucasWilkinson commented Feb 10, 2025

Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)

Improvements:

  1. Multiple threads now participating in copy A scales
  2. Predication when copying A scale loads, this means if there is partial M tile (due to the problem shape not being evenly divided by the M tile shape)
  3. More commonly used scale layouts, currently CUTLASS uses a layout like:
(M_TILES, ScaleMsPerTile, K_TILES, L), ordered: (2, 0, 1, 3)

this PR moves to a layout of (i.e. standard M-major):

(M / ScaleGranularityM, K_TILES, L), ordered: (1, 0, 2)

making it much easier to integrate into inference libraries

These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.

@LucasWilkinson LucasWilkinson changed the title [WIP][Bugfix] Bug fixes for: Groupwise scaling along M for FP8 gemm Improvements: Groupwise scaling along M for FP8 gemm Feb 10, 2025
@LucasWilkinson LucasWilkinson changed the title Improvements: Groupwise scaling along M for FP8 gemm Improvements for: Groupwise scaling along M for FP8 gemm Feb 10, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 10, 2025 21:07
@hwu36
Copy link
Collaborator

hwu36 commented Feb 21, 2025

@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved.

Our change is mainly:

Extend groupwise scaling gemm to support both M dimension and N dimension groupwise scaling in FP8 GEMM.
In examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu, two parameters ScaleGranularityM and ScaleGranularityNcontrol the scaling mode:


ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D scaling (block-wise scaling, same as 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu , 2Dx2D refers to the shape of the scaling factor)

ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling

ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling

ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/fix-fp8-blockwise branch from db87722 to 7f541db Compare February 25, 2025 06:50
Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson
Copy link
Author

@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved.

Our change is mainly:

Extend groupwise scaling gemm to support both M dimension and N dimension groupwise scaling in FP8 GEMM.
In examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu, two parameters ScaleGranularityM and ScaleGranularityNcontrol the scaling mode:


ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D scaling (block-wise scaling, same as 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu , 2Dx2D refers to the shape of the scaling factor)

ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling

ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling

ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling

apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like main) but it might actually makes sense to not do this to enable transposing A and B (since we currently have partial tiles along M this would mean partial tiles along N)

@@ -280,8 +278,11 @@ struct CollectiveMma<
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any promblems when transpose A and transpose B?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though

@hwu36
Copy link
Collaborator

hwu36 commented Feb 25, 2025

@manishucsd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants