Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Optimized CUDA Token Bitmask Kernel #186

Merged
merged 8 commits into from
Feb 13, 2025

Conversation

syuoni
Copy link
Collaborator

@syuoni syuoni commented Feb 7, 2025

This PR supports CUDA-implemented token bitmask kernel. It shows better perf than the Triton-implemented one, especially on large batch sizes.

Also, the PR provides examples/benchmark/bench_apply_token_bitmask_inplace.py for benchmarking. The perf results are:

| GPU | Batch size | Vocab size | Triton (μs) | CUDA (μs) | Speedup ratio |
|:--------------:|-----------:|-----------:|-------------:|----------:|---------------:|
| H100 80GB HBM3 | 16 | 128k | 75.4 | 6.5 | 11.65x |
| | 128 | 128k | 79.0 | 52.4 | 1.51x |
| | 2048 | 128k | 1048.5 | 714.5 | 1.47x |
| A100 SXM4 80GB | 16 | 128k | 137.1 | 9.6 | 14.30x |
| | 128 | 128k | 140.3 | 88.0 | 1.59x |
| | 2048 | 128k | 1439.9 | 1293.2 | 1.11x |

Please checkout examples/benchmark/README.md for the benchmarking results.


Notes:

Copy link
Collaborator

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @syuoni ! The enhancement of kernels is definitely useful. Could you please modify the benchmark accordingly? Once that’s done, we can proceed with merging this PR.

Comment on lines 1 to 66
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time

import torch

from xgrammar.kernels import (
apply_token_bitmask_inplace_cuda,
apply_token_bitmask_inplace_triton,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--kernel", type=str, choices=["cuda", "triton"], default="cuda")
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--vocab_size", type=int, default=128000)
parser.add_argument("--num_warmup", type=int, default=10)
parser.add_argument("--num_iters", type=int, default=50)
args = parser.parse_args()

vocab_size = args.vocab_size
batch_size = args.batch_size
bitmask_size = (vocab_size + 32 - 1) // 32

logits = torch.randn(batch_size, vocab_size, dtype=torch.float32, device="cuda")
bitmask = torch.randint(
torch.iinfo(torch.int32).min,
torch.iinfo(torch.int32).max,
(batch_size, bitmask_size),
dtype=torch.int32,
device="cuda",
)

def run():
if args.kernel == "cuda":
apply_token_bitmask_inplace_cuda(logits, bitmask)
elif args.kernel == "triton":
apply_token_bitmask_inplace_triton(logits, bitmask)

for i in range(args.num_warmup):
run()
torch.cuda.synchronize()

start = time.perf_counter()
for i in range(args.num_iters):
run()
torch.cuda.synchronize()
exec_time = time.perf_counter() - start
exec_time = (exec_time / args.num_iters) * 10**6

print(f"Kernel: {args.kernel}")
print(f"Kernel execution time (us): {exec_time:.4f}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put these tests into tests/python/test_token_bitmask_operations.py? Also, can you consider the mask to be all full, half full, or almost all empty during benchmarking? We found in testing that this has a major impact on the kernel execution time.

It should be easy because we have provided all utilities for this in test_apply_token_bitmask_inplace_large(). You can just fit into the workflow of that function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems when masked_cnt is small, the triton kernel skips a lot of memory access, which results higher perf. Let me check if the CUDA kernel can adopt this pattern.

In a realistic scenario, what is the distribution of the bitmask over the three patterns?

  • full
  • half full
  • almost all empty

BTW, does "full" mean ALL logits filled with -inf or NO logits filled with -inf?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to take advantage of Triton kernels to optimize CUDA code.

When I say “all full,” I mean the case where all values are -1 (all-true), which corresponds to the case where mask_cnt is very small - just a terminology issue.

In real cases, I believe both the “full” and “almost all empty” cases are quite common. For example, in the case of JSON, almost all tokens are valid inside a string, which corresponds to the “full” case. On the other hand, outside a string, only tokens that conform to the syntax are valid, which corresponds to the “all-empty” case.

@syuoni syuoni force-pushed the enweiz/bitmask-cuda-kernel branch from 16d35ab to 9daf828 Compare February 12, 2025 14:41
| | 512 | 128k | 127k | 306.75 | 233.20 | 1.32x |
| | 4096 | 128k | 1k | 955.99 | 777.94 | 1.23x |
| | 4096 | 128k | 64k | 2756.63 | 2707.57 | 1.02x |
| | 4096 | 128k | 127k | 2472.82 | 1782.41 | 1.39x |
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Ubospica ,

Please checkout the updated benchmarking results (using triton.testing.do_bench). The masked cnt of 1k, 64k and 127k corresponds to (almost) all full, half full and almost empty cases.

Refreshed benchmarking results:

  • CUDA kernel shows similar latencies to Triton kernel on small batch sizes
  • CUDA kernel offers significant perf gains on large batch sizes (on most cases).

@@ -21,7 +21,7 @@ def apply_token_bitmask_inplace_kernel(
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
block_offset = (work_id % num_blocks) * BLOCK_SIZE
row_id = work_id // num_blocks
batch_id = tl.load(indices_ptr + row_id)
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
Copy link
Collaborator Author

@syuoni syuoni Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an optimization for the Triton kernel.

For contiguous logits tensor, we can pass indices_ptr=None to the kernel, and thus the kernel can skip loading indices_ptr. This saves global memory access. Surprisingly, this can causes up to 30% perf differences.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is quite interesting. Thanks for pointing that out!

@@ -158,6 +158,34 @@ def _get_masked_tokens_from_bitmask(
)


def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this _bool_mask_to_bitmask function from tests/python/test_token_bitmask_operations.py. So that this function can be shared to the benchmarking script.

[torch.randperm(vocab_size)[:masked_cnt] for _ in range(batch_size)]
)
bool_mask.scatter_(1, masked_positions, False)
assert (bool_mask.sum(dim=-1) + masked_cnt == vocab_size).all().item()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the approach to generate bool_mask.

The original approach does not ensure there is masked_cnt positions being masked, because there may be repeated values in masked_positions.

@Ubospica Ubospica merged commit de5e2d6 into mlc-ai:main Feb 13, 2025
1 check passed
@Ubospica
Copy link
Collaborator

This PR looks great to me. The efficiency of the applymask kernel is critical for end-to-end LLM serving. Thanks @syuoni !

@Ubospica
Copy link
Collaborator

Ubospica commented Feb 17, 2025

Hi @syuoni, I would like to ask about these lines in this PR:
https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py#L61-L67

@torch.library.register_fake("xgrammar::apply_token_bitmask_inplace_cuda")
def _(
    logits: torch.Tensor,
    bitmask: torch.Tensor,
    indices: Optional[torch.Tensor] = None,
) -> None:
    pass

Do you think this FakeTensor op registration has any purpose?

The torch.library.register_fake is supported in PyTorch 2.4, but we want to maintain compatibility with low PyTorch versions. If this function is not necessary for your application, maybe I can remove it; otherwise, I will add a check for the existence of it first.

@syuoni
Copy link
Collaborator Author

syuoni commented Feb 17, 2025

Hi @Ubospica ,

Removing this fake registration function will break torch.compile with fullgraph=True. Try removing the registration function and running the below code:

import torch
import xgrammar


@torch.compile(fullgraph=True)
def func(logits, bitmask):
    xgrammar.kernels.apply_token_bitmask_inplace_cuda(logits, bitmask)


if __name__ == "__main__":
    logits = torch.randn(128000, device=0)
    bitmask = torch.randn(128000//32, device=0).view(torch.int32)
    func(logits, bitmask)

It raises error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2378, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: xgrammar::apply_token_bitmask_inplace_cuda: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

If users don't use torch.compile with fullgraph=True (or similarly, torch.export), removing the fake registration function is OK.

Seems torch.library.impl_abstract is the name for torch.library.register_fake in lower PyTorch version. Maybe we can consider this?

@Ubospica
Copy link
Collaborator

Ubospica commented Feb 17, 2025

@syuoni Thanks for the reply. That makes sense.

torch.library.impl_abstract looks good as well but is not supported till torch 2.2. Now #206 will check if the property exists and call it. I think this should be fine as well (if you still see any issues, please let me know :) ).

@syuoni
Copy link
Collaborator Author

syuoni commented Feb 17, 2025

#206 looks good to me. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants