-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements for: Groupwise scaling along M for FP8 gemm #2095
base: main
Are you sure you want to change the base?
Improvements for: Groupwise scaling along M for FP8 gemm #2095
Conversation
@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved. Our change is mainly:
|
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
db87722
to
7f541db
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like |
@@ -280,8 +278,11 @@ struct CollectiveMma< | |||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{}); | |||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value; | |||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{}); | |||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ | |||
implementable = implementable && (args.mma_promotion_interval % 4 == 0); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any promblems when transpose A and transpose B?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T
it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though
Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)
Improvements:
this PR moves to a layout of (i.e. standard M-major):
making it much easier to integrate into inference libraries
These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.