-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support Optimized CUDA Token Bitmask Kernel #186
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @syuoni ! The enhancement of kernels is definitely useful. Could you please modify the benchmark accordingly? Once that’s done, we can proceed with merging this PR.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import time | ||
|
||
import torch | ||
|
||
from xgrammar.kernels import ( | ||
apply_token_bitmask_inplace_cuda, | ||
apply_token_bitmask_inplace_triton, | ||
) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--kernel", type=str, choices=["cuda", "triton"], default="cuda") | ||
parser.add_argument("--batch_size", type=int, default=2048) | ||
parser.add_argument("--vocab_size", type=int, default=128000) | ||
parser.add_argument("--num_warmup", type=int, default=10) | ||
parser.add_argument("--num_iters", type=int, default=50) | ||
args = parser.parse_args() | ||
|
||
vocab_size = args.vocab_size | ||
batch_size = args.batch_size | ||
bitmask_size = (vocab_size + 32 - 1) // 32 | ||
|
||
logits = torch.randn(batch_size, vocab_size, dtype=torch.float32, device="cuda") | ||
bitmask = torch.randint( | ||
torch.iinfo(torch.int32).min, | ||
torch.iinfo(torch.int32).max, | ||
(batch_size, bitmask_size), | ||
dtype=torch.int32, | ||
device="cuda", | ||
) | ||
|
||
def run(): | ||
if args.kernel == "cuda": | ||
apply_token_bitmask_inplace_cuda(logits, bitmask) | ||
elif args.kernel == "triton": | ||
apply_token_bitmask_inplace_triton(logits, bitmask) | ||
|
||
for i in range(args.num_warmup): | ||
run() | ||
torch.cuda.synchronize() | ||
|
||
start = time.perf_counter() | ||
for i in range(args.num_iters): | ||
run() | ||
torch.cuda.synchronize() | ||
exec_time = time.perf_counter() - start | ||
exec_time = (exec_time / args.num_iters) * 10**6 | ||
|
||
print(f"Kernel: {args.kernel}") | ||
print(f"Kernel execution time (us): {exec_time:.4f}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put these tests into tests/python/test_token_bitmask_operations.py
? Also, can you consider the mask to be all full, half full, or almost all empty during benchmarking? We found in testing that this has a major impact on the kernel execution time.
It should be easy because we have provided all utilities for this in test_apply_token_bitmask_inplace_large()
. You can just fit into the workflow of that function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems when masked_cnt
is small, the triton kernel skips a lot of memory access, which results higher perf. Let me check if the CUDA kernel can adopt this pattern.
In a realistic scenario, what is the distribution of the bitmask over the three patterns?
- full
- half full
- almost all empty
BTW, does "full" mean ALL logits filled with -inf or NO logits filled with -inf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be possible to take advantage of Triton kernels to optimize CUDA code.
When I say “all full,” I mean the case where all values are -1 (all-true), which corresponds to the case where mask_cnt is very small - just a terminology issue.
In real cases, I believe both the “full” and “almost all empty” cases are quite common. For example, in the case of JSON, almost all tokens are valid inside a string, which corresponds to the “full” case. On the other hand, outside a string, only tokens that conform to the syntax are valid, which corresponds to the “all-empty” case.
16d35ab
to
9daf828
Compare
| | 512 | 128k | 127k | 306.75 | 233.20 | 1.32x | | ||
| | 4096 | 128k | 1k | 955.99 | 777.94 | 1.23x | | ||
| | 4096 | 128k | 64k | 2756.63 | 2707.57 | 1.02x | | ||
| | 4096 | 128k | 127k | 2472.82 | 1782.41 | 1.39x | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Ubospica ,
Please checkout the updated benchmarking results (using triton.testing.do_bench
). The masked cnt of 1k, 64k and 127k corresponds to (almost) all full, half full and almost empty cases.
Refreshed benchmarking results:
- CUDA kernel shows similar latencies to Triton kernel on small batch sizes
- CUDA kernel offers significant perf gains on large batch sizes (on most cases).
@@ -21,7 +21,7 @@ def apply_token_bitmask_inplace_kernel( | |||
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): | |||
block_offset = (work_id % num_blocks) * BLOCK_SIZE | |||
row_id = work_id // num_blocks | |||
batch_id = tl.load(indices_ptr + row_id) | |||
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an optimization for the Triton kernel.
For contiguous logits tensor, we can pass indices_ptr=None
to the kernel, and thus the kernel can skip loading indices_ptr
. This saves global memory access. Surprisingly, this can causes up to 30% perf differences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is quite interesting. Thanks for pointing that out!
@@ -158,6 +158,34 @@ def _get_masked_tokens_from_bitmask( | |||
) | |||
|
|||
|
|||
def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this _bool_mask_to_bitmask
function from tests/python/test_token_bitmask_operations.py
. So that this function can be shared to the benchmarking script.
[torch.randperm(vocab_size)[:masked_cnt] for _ in range(batch_size)] | ||
) | ||
bool_mask.scatter_(1, masked_positions, False) | ||
assert (bool_mask.sum(dim=-1) + masked_cnt == vocab_size).all().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the approach to generate bool_mask
.
The original approach does not ensure there is masked_cnt
positions being masked, because there may be repeated values in masked_positions
.
This PR looks great to me. The efficiency of the applymask kernel is critical for end-to-end LLM serving. Thanks @syuoni ! |
Hi @syuoni, I would like to ask about these lines in this PR: @torch.library.register_fake("xgrammar::apply_token_bitmask_inplace_cuda")
def _(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[torch.Tensor] = None,
) -> None:
pass Do you think this FakeTensor op registration has any purpose? The |
Hi @Ubospica , Removing this fake registration function will break import torch
import xgrammar
@torch.compile(fullgraph=True)
def func(logits, bitmask):
xgrammar.kernels.apply_token_bitmask_inplace_cuda(logits, bitmask)
if __name__ == "__main__":
logits = torch.randn(128000, device=0)
bitmask = torch.randn(128000//32, device=0).view(torch.int32)
func(logits, bitmask) It raises error: Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2378, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: xgrammar::apply_token_bitmask_inplace_cuda: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html If users don't use Seems torch.library.impl_abstract is the name for |
#206 looks good to me. Thanks! |
This PR supports CUDA-implemented token bitmask kernel. It shows better perf than the Triton-implemented one, especially on large batch sizes.
Also, the PR provides
examples/benchmark/bench_apply_token_bitmask_inplace.py
for benchmarking.The perf results are:| GPU | Batch size | Vocab size | Triton (μs) | CUDA (μs) | Speedup ratio ||:--------------:|-----------:|-----------:|-------------:|----------:|---------------:|
| H100 80GB HBM3 | 16 | 128k | 75.4 | 6.5 | 11.65x |
| | 128 | 128k | 79.0 | 52.4 | 1.51x |
| | 2048 | 128k | 1048.5 | 714.5 | 1.47x |
| A100 SXM4 80GB | 16 | 128k | 137.1 | 9.6 | 14.30x |
| | 128 | 128k | 140.3 | 88.0 | 1.59x |
| | 2048 | 128k | 1439.9 | 1293.2 | 1.11x |
Please checkout
examples/benchmark/README.md
for the benchmarking results.Notes:
XGRAMMAR_TOKEN_BITMASK_TRITON=1
.